Trains a pigauto model: a gated ensemble of a phylogenetic baseline
and an attention-based graph neural network correction, implemented
as an internal torch module (ResidualPhyloDAE; "Residual" here
refers to the ResNet-style skip connections in the GNN layers, not
to a statistical residual). For continuous, count, and ordinal traits
the baseline is Brownian motion (phylogenetic correlation matrix); for binary
and categorical traits it is phylogenetic label propagation. Supports
all five trait types via a unified latent space.
Usage
fit_pigauto(
data,
tree,
splits = NULL,
graph = NULL,
baseline = NULL,
hidden_dim = 64L,
k_eigen = "auto",
n_gnn_layers = 2L,
gate_cap = 0.8,
use_attention = TRUE,
use_transformer_blocks = TRUE,
n_heads = 4L,
ffn_mult = 4L,
use_trait_attention = FALSE,
n_trait_heads = 2L,
trait_embed_dim = 32L,
dropout = 0.1,
lr = 0.003,
weight_decay = 1e-04,
epochs = 3000L,
corruption_rate = 0.55,
corruption_start = 0.2,
corruption_ramp = 500L,
refine_steps = 8L,
lambda_shrink = 0.03,
lambda_gate = 0.01,
warmup_epochs = 200L,
edge_dropout = 0.1,
eval_every = 100L,
patience = 10L,
clip_norm = 1,
conformal_method = c("split", "bootstrap"),
conformal_bootstrap_B = 500L,
conformal_split_val = FALSE,
gate_method = c("cv_folds", "median_splits", "single_split"),
gate_splits_B = 31L,
gate_cv_folds = 5L,
safety_floor = TRUE,
phylo_signal_gate = TRUE,
phylo_signal_threshold = 0.2,
phylo_signal_method = c("lambda", "blomberg_k"),
min_val_cells = 20L,
verbose = TRUE,
seed = 1L
)Arguments
- data
object of class
"pigauto_data".- tree
object of class
"phylo".- splits
list (output of
make_missing_splits) orNULL.- graph
list (output of
build_phylo_graph) orNULL.- baseline
list (output of
fit_baseline) orNULL.integer. Hidden layer width (default
64).- k_eigen
integer. Number of spectral node features (default
8).- n_gnn_layers
integer. Number of graph message-passing layers (default
2). Each layer has its own learnable alpha gate, layer normalisation, and ResNet-style skip connection.- gate_cap
numeric. Upper bound for the per-column blend gate (default
0.8). Safety comes from regularisation, not the cap.- use_attention
logical. Use attention in the GNN layers (default
TRUE).- use_transformer_blocks
logical. Replace the legacy attention stack with pre-norm transformer-encoder blocks (multi-head attention
FFN + two residual skips). Default
TRUE. SetFALSEto reconstruct pre-v0.9.0 fits (single-head attention with a learnable alpha gate per layer).
- n_heads
integer. Number of attention heads when
use_transformer_blocks = TRUE(default4). Each head learns its own phylogenetic bandwidth (B2 rate-aware attention).- ffn_mult
integer. Feed-forward width multiplier inside each transformer block (default
4, givinghidden_dim * 4).- use_trait_attention
logical. Opt-in within-row cross-trait self-attention (B3, v0.9.3). When
TRUE(defaultFALSE), the model builds per-trait tokens from each row's latent values (linear projection + learnable positional embedding), applies one multi-head self-attention block over the trait sequence, mean-pools to atrait_embed_dimfeature, and concatenates it alongside(x, coords, covs)at the encoder input. Intended for trait sets with strong within-row functional coupling that the joint MVN / threshold-joint baseline cannot capture (e.g. nonlinear cross-trait structure). On the BIEN n=2000 plant bench it did not improve pooled RMSE (Σ is already captured by the joint baseline); kept as an opt-in for datasets where it may help. DefaultFALSEpreserves v0.9.2 behaviour exactly.- n_trait_heads
integer. Number of attention heads in the within-row self-attention block when
use_trait_attention = TRUE. Default2. Ignored whenuse_trait_attention = FALSE.- trait_embed_dim
integer. Embedding dim per trait token in the within-row self-attention block. Default
32. Ignored whenuse_trait_attention = FALSE.- dropout
numeric. Dropout rate (default
0.10).- lr
numeric. AdamW learning rate (default
0.003).- weight_decay
numeric. AdamW weight decay (default
1e-4).- epochs
integer. Maximum training epochs (default
3000).- corruption_rate
numeric. Final corruption fraction if
corruption_ramp > 0; otherwise the fixed corruption rate per epoch (default0.55).- corruption_start
numeric. Initial corruption fraction for the curriculum schedule (default
0.20). Ignored ifcorruption_ramp = 0.- corruption_ramp
integer. Epochs over which corruption linearly ramps from
corruption_starttocorruption_rate(default500). Set to0for fixed corruption.- refine_steps
integer. Iterative refinement steps at inference (default
8).- lambda_shrink
numeric. Weight on the shrinkage penalty
||delta - baseline||^2that keeps the GNN correction close to the phylogenetic baseline (default0.03).- lambda_gate
numeric. Weight on the gate regularisation penalty that pushes learnable gates toward zero. Prevents gates from staying open when the GNN provides no useful correction (default
0.01).- warmup_epochs
integer. Linear learning-rate warmup over the first N epochs (default
200). After warmup, a cosine schedule decays the LR to1e-5.- edge_dropout
numeric. Fraction of adjacency edges randomly zeroed each training epoch for graph regularisation (default
0.1). Set to0to disable.- eval_every
integer. Evaluate on val every N epochs (default
100).- patience
integer. Early-stopping patience in eval cycles (default
10).- clip_norm
numeric. Gradient clip norm (default
1.0).- conformal_method
character. How the conformal prediction score is estimated from validation residuals.
"split"(default, backward-compatible) takes a single sample quantile;"bootstrap"takesconformal_bootstrap_Bbootstrap resamples of the val residuals and averages the per-resample quantiles. Empirically the bootstrap variant reduces the conformal-score variance across seeds by ~30% at smalln_val, but on the simulation design used to evaluate it (n=150 species, 35\ into better 95\ across 10 seeds). Ship as an opt-in experimental knob; defaults to"split".- conformal_bootstrap_B
integer. Bootstrap resamples used when
conformal_method = "bootstrap"; default500. Ignored otherwise.- conformal_split_val
logical. Default
FALSE(pre-2026-04-28 single-set behaviour, retained because forcing the split regresses the AVONET300 / OVR-categorical / BIEN safety-floor smoke benches by 2-26%). WhenTRUE, the validation set is split per latent column into a calibration half (used to pick the calibrated blend gate) and a conformal half (used to compute the conformal residual quantile). This restores split-conformal exchangeability — without the split, the gate is selected to minimise residual MSE on the very cells whose residuals drive the conformal quantile, producing systematic undercoverage (most visible at smalln_val; see thecoverage_investigationmemo). Use this when accurate 95% coverage matters more than the bench-grade RMSE; the split only activates per-column when that column has at least2 * min_val_cellsval cells (smaller columns silently fall back to the single-set path to keepcalibrate_gates()'s half-A / half-B cross-check stable).- gate_method
character. How the per-trait calibrated gate is chosen.
"single_split"(default, backward-compatible) runs the grid search on a single random half-A / half-B split of the val rows;"median_splits"repeats the whole procedure forgate_splits_Brandom splits and takes the medianbest_g."cv_folds"(2026-04-30) partitions val cells intogate_cv_folds(default 5) deterministic non-overlapping folds and runs the grid + half-B-verify procedure once per fold (training set = K-1 folds, held-out = remaining fold), taking the componentwise median of K winning weight vectors."cv_folds"uses larger training sets per split (K-1/K vs 1/2 inmedian_splits) and has a standard cross-validation interpretation, motivated by the open val→test drift observed on 4/32 binary cells in the discrete-bench memo."median_splits"slightly reduces gate bimodality at smalln_val(SD 0.406 → 0.360 across 10 seeds on the evaluation sim) with a small coverage-SD improvement (0.094 → 0.086). Negligible runtime cost (B × cheap grid searches).- gate_splits_B
integer. Random splits used when
gate_method = "median_splits"; default31(odd so the median is well-defined).- gate_cv_folds
integer. Number of CV folds when
gate_method = "cv_folds"; default5, must be>= 2. Capped atn_valper trait so each fold has at least 1 cell. When effective K< 2(e.g.n_val = 1), the code falls back to a single split.- safety_floor
logical. When
TRUE(default), post-training calibration searches a 3-way simplex of BM, GNN, and grand-mean candidates. Because the grand-mean corner is always in the grid, the selected candidate cannot be worse than that corner on the validation cells under the calibration metric. WhenFALSE, the v0.9.1 1-D calibration is used exactly (r_MEAN = 0).- phylo_signal_gate
logical. When
TRUE(default since v0.9.1.9003), compute per-trait Pagel's \(\lambda\) on training-observed cells before fitting; for traits withlambda < phylo_signal_threshold, force(r_cal_bm = 0, r_cal_gnn = 0, r_cal_mean = 1)directly and skip BM + GNN training on those traits. Requires thephytoolspackage. Falls back to safety-floor-only behaviour (phylo_signal_gate = FALSEeffective) whenphytoolsis absent.- phylo_signal_threshold
numeric, default
0.2. Traits with Pagel's \(\lambda\) below this value are routed to the grand-mean corner of the safety-floor simplex.- phylo_signal_method
character, currently only
"lambda"is fully implemented. Reserved"blomberg_k"path returns Blomberg's K viaphytools::phylosig()but uses the same threshold — which is NOT dimensionally comparable; users selecting K must supply a K-appropriate threshold.- min_val_cells
integer. Warn at fit time if any trait has fewer than
min_val_cellsvalidation cells available for gate calibration and conformal-score estimation. Default10: the floor of pathological territory, where the conformal quantile collapses tomax(val_residuals)and gate calibration becomes essentially a coin flip between0andgate_cap. Recommended operational target isn_val >= 20-30per trait; achieve this by increasingmissing_fracor collecting more species. See Calibration at small n below.- verbose
logical. Print training progress (default
TRUE).- seed
integer. Random seed (default
1).
Details
Blend formulation: The prediction is \(\hat{x} = (1-r)\mu + r\delta\), where \(\mu\) is the BM baseline, \(\delta\) is the model's direct prediction, and \(r = \sigma(\rho) \times \mathrm{cap}\) is a per-column learnable gate bounded in \((0, \mathrm{gate\_cap})\). When \(r = 0\), the prediction collapses to the baseline. The gate is regularised toward zero via the shrinkage penalty on \(\delta - \mu\), so the model defaults to the baseline unless the GNN's correction demonstrably helps on the validation set.
Training objective (per epoch):
A random subset of observed cells is corrupted with a learnable mask token.
The model predicts \(\delta\) from graph context.
Loss = type-specific reconstruction on corrupted cells +
lambda_shrink* MSE(\(\delta - \mu\)) +lambda_gate* MSE(\(r\)).
The gate penalty on \(r\) is necessary because when \(\delta = \mu\) (the BM-optimal solution for observed cells), the reconstruction and shrinkage losses both equal zero regardless of \(r\), leaving no gradient to close the gate. The explicit penalty ensures gates default toward zero when the GNN correction provides no benefit.
Type-specific losses:
- continuous/count/ordinal
MSE
- binary
BCE with logits
- categorical
cross-entropy over K latent columns
Calibration at small n
pigauto's 95\
interval half-width for each trait is the empirical (1 - alpha)
quantile of \(|y - \hat y|\) on held-out validation cells. When the
number of validation cells per trait (n_val) is large
(\(\gtrsim 30\)), split-conformal gives near-exact marginal coverage
under mild exchangeability assumptions.
At small n_val (\(<\) 20, and especially \(<\) 10) two
things degrade at once:
The conformal quantile clamps to
max(residuals)because the required quantile level exceeds 1. The score is the single largest val residual and has substantial sampling variance across fits.The gate calibration's half-A / half-B split ends up with only a few cells per half, so the grid-search winner is essentially random between
0andgate_cap.
Empirically (pigauto simulation harness, n=150 species, 35\
trait_MAR, 10 random seeds), default behaviour produces 92\
coverage with per-fit coverage ranging [0.73, 1.00]. The mean
is close to the 95\
gate_method = "median_splits" and conformal_method =
"bootstrap" reduce their respective estimator variances but
empirically do not meaningfully narrow the coverage distribution.
Treat the 95\
regime; increase missing_frac (more held-out data for
calibration) or collect more species if tight per-fit coverage is
required.
The min_val_cells warning fires at fit time whenever any trait
falls below this threshold, so you know when to interpret the
intervals cautiously.
