Skip to contents

Overview

Real comparative datasets contain many kinds of traits: body mass (continuous), clutch size (count), migratory status (yes/no), diet type (carnivore/herbivore/omnivore), threat status (LC < VU < EN). pigauto handles all of them in a single model — you do not need separate imputation runs for different column types.

Type R class Example Notes
Continuous numeric Body mass, wing length Auto-detected
Count integer Clutch size, litter size Auto-detected
Binary factor (2 levels) Migratory yes/no Auto-detected
Categorical factor (>2 levels) Diet type, lifestyle Auto-detected
Ordinal ordered Threat status (LC < VU < EN) Auto-detected
Proportion numeric (0–1) Habitat cover, diet fraction Requires trait_types = "proportion" override
ZI count integer (zero-inflated) Parasite load, rare behaviour Requires trait_types = "zi_count" override; experimental — accuracy more variable than other types
Multi-proportion K numeric columns summing to 1 Diet composition, plumage-colour fractions, microbiome relative abundances Requires multi_proportion_groups = list(<name> = c("col1", ..., "colK")); encoded via centred log-ratio (CLR) + per-component z-score

The first five rows in this table are auto-detected from R column class; proportion, zi_count, and multi_proportion must be declared explicitly (trait_types or multi_proportion_groups). All eight share the same latent space — the phylogenetic baseline and GNN correction both operate in this space, and type-specific logic appears only at encoding, loss computation, and decoding.

Synthetic example

library(pigauto)
library(ape)

set.seed(42)
n <- 60
tree <- rtree(n)

traits <- data.frame(
  row.names = tree$tip.label,
  mass      = exp(rnorm(n, 3, 0.5)),
  clutch    = as.integer(rpois(n, 3) + 1L),
  migr      = factor(sample(c("no", "yes"), n, replace = TRUE)),
  diet      = factor(sample(c("herb", "carn", "omni"), n, replace = TRUE)),
  threat    = ordered(sample(c("LC", "VU", "EN"), n, replace = TRUE),
                      levels = c("LC", "VU", "EN"))
)

Preprocessing

preprocess_traits() auto-detects column types from R classes:

pd <- preprocess_traits(traits, tree)
print(pd)
#> pigauto_data
#>   Species: 60 
#>   Traits:  5 
#>   Types:   binary=1, categorical=1, continuous=1, count=1, ordinal=1 
#>   Latent columns: 7

The trait_map records each trait’s type, levels, latent column indices, and normalisation parameters:

str(pd$trait_map, max.level = 1)
#> List of 5
#>  $ mass  :List of 9
#>  $ clutch:List of 9
#>  $ migr  :List of 8
#>  $ diet  :List of 8
#>  $ threat:List of 8
pd$trait_map$diet
#> $name
#> [1] "diet"
#> 
#> $type
#> [1] "categorical"
#> 
#> $n_latent
#> [1] 3
#> 
#> $latent_cols
#> [1] 4 5 6
#> 
#> $levels
#> [1] "carn" "herb" "omni"
#> 
#> $log_transform
#> [1] FALSE
#> 
#> $mean
#> [1] NA
#> 
#> $sd
#> [1] NA

Creating splits

When trait_map is supplied, make_missing_splits() operates at the original-trait level. For categorical traits, all K one-hot columns are held out together:

spl <- make_missing_splits(pd$X_scaled, missing_frac = 0.20,
                           seed = 1, trait_map = pd$trait_map)
cat("Val cells (latent):", length(spl$val_idx), "\n")
#> Val cells (latent): 21
cat("Test cells (latent):", length(spl$test_idx), "\n")
#> Test cells (latent): 57

Baseline fitting

The baseline uses phylogenetic conditional MVN for continuous-family latent columns, and label-propagation or threshold/liability candidates for discrete-family columns:

bl <- fit_baseline(pd, tree, splits = spl)
dim(bl$mu)
#> [1] 60  7

Training

fit_pigauto() uses type-specific losses and trait-level corruption masking automatically when a trait_map is present:

fit <- fit_pigauto(
  pd, tree,
  splits = spl,
  epochs = 200L,
  eval_every = 50L,
  patience = 5L,
  verbose = FALSE,
  seed = 1
)
#> Warning: Small validation set for 5 trait(s): mass (n=2), clutch (n=4), migr
#> (n=3), diet (n=3), threat (n=3). Calibrated gate and conformal scores will be
#> noisy; coverage may deviate from the 95%% target. See `?fit_pigauto` under
#> 'Calibration at small n' for smoothing options.
print(fit)
#> pigauto_fit
#>   Species : 60 
#>   Traits  : 5 -- mass, clutch, migr, diet, threat 
#>   Types   : binary=1, categorical=1, continuous=1, count=1, ordinal=1 
#>   Architecture: hidden_dim = 64 | k_eigen = 4 
#>   Best val loss : 1.3449 
#>   Test loss     : 1.2059 
#>   Gate calibration: yes
#>   Conformal scores: 3 traits
#> 
#> Phylogenetic signal (lambda, threshold 0.20):
#>   gated (BM skipped, using grand mean): mass (lambda=0.00), clutch (lambda=0.00), migr (lambda=0.00), diet (lambda=0.00), threat (lambda=0.00)

Prediction and decoding

predict() decodes latent predictions back to original types:

pred <- predict(fit, return_se = TRUE)
head(pred$imputed)
#>       mass clutch migr diet threat
#> t58 17.909      4  yes carn     VU
#> t8  17.909      4  yes carn     VU
#> t36 17.909      4  yes carn     VU
#> t4  17.909      4  yes carn     VU
#> t22 17.909      4  yes carn     VU
#> t18 17.909      4  yes carn     VU

For binary and categorical traits, class probabilities are available:

# Binary: probability of "yes"
head(pred$probabilities$migr)
#> [1] 0.5319149 0.5319149 0.5319149 0.5319149 0.5319149 0.5319149

# Categorical: probability of each diet class
head(pred$probabilities$diet)
#>          carn      herb      omni
#> t58 0.4558713 0.2464169 0.2977118
#> t8  0.4558713 0.2464169 0.2977118
#> t36 0.4558713 0.2464169 0.2977118
#> t4  0.4558713 0.2464169 0.2977118
#> t22 0.4558713 0.2464169 0.2977118
#> t18 0.4558713 0.2464169 0.2977118

The SE matrix provides type-appropriate uncertainty:

head(pred$se)
#>     mass clutch      migr      diet    threat
#> t58    0      0 0.4680851 0.5441287 0.4998422
#> t8     0      0 0.4680851 0.5441287 0.4735485
#> t36    0      0 0.4680851 0.5441287 0.4902633
#> t4     0      0 0.4680851 0.5441287 0.4746017
#> t22    0      0 0.4680851 0.5441287 0.4892949
#> t18    0      0 0.4680851 0.5441287 0.4497299

Evaluation

evaluate_imputation() dispatches type-specific metrics:

ev <- evaluate_imputation(pred, pd$X_scaled, spl)
#> Warning in stats::cor(t_j[ok], p_j[ok]): the standard deviation is zero
#> Warning in stats::cor(t_j[ok], p_j[ok]): the standard deviation is zero
#> Warning in stats::cor(t_j[ok], p_j[ok], method = "spearman"): the standard
#> deviation is zero
#> Warning in stats::cor(t_j[ok], p_j[ok]): the standard deviation is zero
#> Warning in stats::cor(t_j[ok], p_j[ok]): the standard deviation is zero
#> Warning in stats::cor(t_j[ok], p_j[ok], method = "spearman"): the standard
#> deviation is zero
print(ev)
#>    split  trait        type  n      rmse pearson_r coverage_95       mae
#> 1    val   mass  continuous  2 0.6996585        NA           1        NA
#> 2    val clutch       count  4 1.0938335        NA          NA 0.9784612
#> 3    val   migr      binary  3        NA        NA          NA        NA
#> 4    val   diet categorical  3        NA        NA          NA        NA
#> 5    val threat     ordinal  3 0.6770739        NA          NA        NA
#> 6   test   mass  continuous 13 0.9332707        NA           1        NA
#> 7   test clutch       count 10 0.7248662        NA          NA 0.6104951
#> 8   test   migr      binary 10        NA        NA          NA        NA
#> 9   test   diet categorical  6        NA        NA          NA        NA
#> 10  test threat     ordinal  6 1.1046670        NA          NA        NA
#>    spearman_rho  accuracy     brier zero_accuracy aitchison rmse_clr
#> 1            NA        NA        NA            NA        NA       NA
#> 2            NA        NA        NA            NA        NA       NA
#> 3            NA 0.3333333 0.2616569            NA        NA       NA
#> 4            NA 0.3333333        NA            NA        NA       NA
#> 5            NA        NA        NA            NA        NA       NA
#> 6            NA        NA        NA            NA        NA       NA
#> 7            NA        NA        NA            NA        NA       NA
#> 8            NA 0.5000000 0.2510186            NA        NA       NA
#> 9            NA 0.1666667        NA            NA        NA       NA
#> 10           NA        NA        NA            NA        NA       NA
#>    simplex_mae
#> 1           NA
#> 2           NA
#> 3           NA
#> 4           NA
#> 5           NA
#> 6           NA
#> 7           NA
#> 8           NA
#> 9           NA
#> 10          NA

Multiple imputation for downstream inference

The recommended workflow uses multi_impute() to generate M complete datasets from the model’s calibrated uncertainty distribution, then pools downstream coefficients with Rubin’s rules via with_imputations() and pool_mi().

draws_method = "conformal" is the default: missing cells are drawn from Normal distributions centred on the point estimate with width set by the split-conformal calibration score. The alternative draws_method = "mc_dropout" runs M stochastic GNN forward passes with dropout active and BM posterior draws as the blend baseline; it is available for comparison.

# Generate M = 20 stochastic complete datasets
mi <- multi_impute(fit$data$X_original, tree, m = 20L)

# Fit a downstream model to each
fits <- with_imputations(mi, function(d) {
  glm(Migratory ~ log(Mass) + Diet, data = d, family = binomial)
})

# Pool with Rubin's rules
pool_mi(fits)

pool_mi() returns a tidy data.frame with estimate, std.error, p.value, df (Barnard-Rubin degrees of freedom), fmi (fraction of missing information), and riv (relative increase in variance) per coefficient.

For reference, the lower-level predict(fit, n_imputations = 5L) interface returns M complete datasets directly (without downstream pooling):