Skip to contents

Trains a pigauto model: a gated ensemble of a phylogenetic baseline and an attention-based graph neural network correction, implemented as an internal torch module (ResidualPhyloDAE; "Residual" here refers to the ResNet-style skip connections in the GNN layers, not to a statistical residual). For continuous, count, and ordinal traits the baseline is Brownian motion (phylogenetic correlation matrix); for binary and categorical traits it is phylogenetic label propagation. Supports all five trait types via a unified latent space.

Usage

fit_pigauto(
  data,
  tree,
  splits = NULL,
  graph = NULL,
  baseline = NULL,
  hidden_dim = 64L,
  k_eigen = "auto",
  n_gnn_layers = 2L,
  gate_cap = 0.8,
  use_attention = TRUE,
  use_transformer_blocks = TRUE,
  n_heads = 4L,
  ffn_mult = 4L,
  use_trait_attention = FALSE,
  n_trait_heads = 2L,
  trait_embed_dim = 32L,
  dropout = 0.1,
  lr = 0.003,
  weight_decay = 1e-04,
  epochs = 3000L,
  corruption_rate = 0.55,
  corruption_start = 0.2,
  corruption_ramp = 500L,
  refine_steps = 8L,
  lambda_shrink = 0.03,
  lambda_gate = 0.01,
  warmup_epochs = 200L,
  edge_dropout = 0.1,
  eval_every = 100L,
  patience = 10L,
  clip_norm = 1,
  conformal_method = c("split", "bootstrap"),
  conformal_bootstrap_B = 500L,
  conformal_split_val = FALSE,
  gate_method = c("cv_folds", "median_splits", "single_split"),
  gate_splits_B = 31L,
  gate_cv_folds = 5L,
  safety_floor = TRUE,
  phylo_signal_gate = TRUE,
  phylo_signal_threshold = 0.2,
  phylo_signal_method = c("lambda", "blomberg_k"),
  min_val_cells = 20L,
  verbose = TRUE,
  seed = 1L
)

Arguments

data

object of class "pigauto_data".

tree

object of class "phylo".

splits

list (output of make_missing_splits) or NULL.

graph

list (output of build_phylo_graph) or NULL.

baseline

list (output of fit_baseline) or NULL.

hidden_dim

integer. Hidden layer width (default 64).

k_eigen

integer. Number of spectral node features (default 8).

n_gnn_layers

integer. Number of graph message-passing layers (default 2). Each layer has its own learnable alpha gate, layer normalisation, and ResNet-style skip connection.

gate_cap

numeric. Upper bound for the per-column blend gate (default 0.8). Safety comes from regularisation, not the cap.

use_attention

logical. Use attention in the GNN layers (default TRUE).

use_transformer_blocks

logical. Replace the legacy attention stack with pre-norm transformer-encoder blocks (multi-head attention

  • FFN + two residual skips). Default TRUE. Set FALSE to reconstruct pre-v0.9.0 fits (single-head attention with a learnable alpha gate per layer).

n_heads

integer. Number of attention heads when use_transformer_blocks = TRUE (default 4). Each head learns its own phylogenetic bandwidth (B2 rate-aware attention).

ffn_mult

integer. Feed-forward width multiplier inside each transformer block (default 4, giving hidden_dim * 4).

use_trait_attention

logical. Opt-in within-row cross-trait self-attention (B3, v0.9.3). When TRUE (default FALSE), the model builds per-trait tokens from each row's latent values (linear projection + learnable positional embedding), applies one multi-head self-attention block over the trait sequence, mean-pools to a trait_embed_dim feature, and concatenates it alongside (x, coords, covs) at the encoder input. Intended for trait sets with strong within-row functional coupling that the joint MVN / threshold-joint baseline cannot capture (e.g. nonlinear cross-trait structure). On the BIEN n=2000 plant bench it did not improve pooled RMSE (Σ is already captured by the joint baseline); kept as an opt-in for datasets where it may help. Default FALSE preserves v0.9.2 behaviour exactly.

n_trait_heads

integer. Number of attention heads in the within-row self-attention block when use_trait_attention = TRUE. Default 2. Ignored when use_trait_attention = FALSE.

trait_embed_dim

integer. Embedding dim per trait token in the within-row self-attention block. Default 32. Ignored when use_trait_attention = FALSE.

dropout

numeric. Dropout rate (default 0.10).

lr

numeric. AdamW learning rate (default 0.003).

weight_decay

numeric. AdamW weight decay (default 1e-4).

epochs

integer. Maximum training epochs (default 3000).

corruption_rate

numeric. Final corruption fraction if corruption_ramp > 0; otherwise the fixed corruption rate per epoch (default 0.55).

corruption_start

numeric. Initial corruption fraction for the curriculum schedule (default 0.20). Ignored if corruption_ramp = 0.

corruption_ramp

integer. Epochs over which corruption linearly ramps from corruption_start to corruption_rate (default 500). Set to 0 for fixed corruption.

refine_steps

integer. Iterative refinement steps at inference (default 8).

lambda_shrink

numeric. Weight on the shrinkage penalty ||delta - baseline||^2 that keeps the GNN correction close to the phylogenetic baseline (default 0.03).

lambda_gate

numeric. Weight on the gate regularisation penalty that pushes learnable gates toward zero. Prevents gates from staying open when the GNN provides no useful correction (default 0.01).

warmup_epochs

integer. Linear learning-rate warmup over the first N epochs (default 200). After warmup, a cosine schedule decays the LR to 1e-5.

edge_dropout

numeric. Fraction of adjacency edges randomly zeroed each training epoch for graph regularisation (default 0.1). Set to 0 to disable.

eval_every

integer. Evaluate on val every N epochs (default 100).

patience

integer. Early-stopping patience in eval cycles (default 10).

clip_norm

numeric. Gradient clip norm (default 1.0).

conformal_method

character. How the conformal prediction score is estimated from validation residuals. "split" (default, backward-compatible) takes a single sample quantile; "bootstrap" takes conformal_bootstrap_B bootstrap resamples of the val residuals and averages the per-resample quantiles. Empirically the bootstrap variant reduces the conformal-score variance across seeds by ~30% at small n_val, but on the simulation design used to evaluate it (n=150 species, 35\ into better 95\ across 10 seeds). Ship as an opt-in experimental knob; defaults to "split".

conformal_bootstrap_B

integer. Bootstrap resamples used when conformal_method = "bootstrap"; default 500. Ignored otherwise.

conformal_split_val

logical. Default FALSE (pre-2026-04-28 single-set behaviour, retained because forcing the split regresses the AVONET300 / OVR-categorical / BIEN safety-floor smoke benches by 2-26%). When TRUE, the validation set is split per latent column into a calibration half (used to pick the calibrated blend gate) and a conformal half (used to compute the conformal residual quantile). This restores split-conformal exchangeability — without the split, the gate is selected to minimise residual MSE on the very cells whose residuals drive the conformal quantile, producing systematic undercoverage (most visible at small n_val; see the coverage_investigation memo). Use this when accurate 95% coverage matters more than the bench-grade RMSE; the split only activates per-column when that column has at least 2 * min_val_cells val cells (smaller columns silently fall back to the single-set path to keep calibrate_gates()'s half-A / half-B cross-check stable).

gate_method

character. How the per-trait calibrated gate is chosen. "single_split" (default, backward-compatible) runs the grid search on a single random half-A / half-B split of the val rows; "median_splits" repeats the whole procedure for gate_splits_B random splits and takes the median best_g. "cv_folds" (2026-04-30) partitions val cells into gate_cv_folds (default 5) deterministic non-overlapping folds and runs the grid + half-B-verify procedure once per fold (training set = K-1 folds, held-out = remaining fold), taking the componentwise median of K winning weight vectors. "cv_folds" uses larger training sets per split (K-1/K vs 1/2 in median_splits) and has a standard cross-validation interpretation, motivated by the open val→test drift observed on 4/32 binary cells in the discrete-bench memo. "median_splits" slightly reduces gate bimodality at small n_val (SD 0.406 → 0.360 across 10 seeds on the evaluation sim) with a small coverage-SD improvement (0.094 → 0.086). Negligible runtime cost (B × cheap grid searches).

gate_splits_B

integer. Random splits used when gate_method = "median_splits"; default 31 (odd so the median is well-defined).

gate_cv_folds

integer. Number of CV folds when gate_method = "cv_folds"; default 5, must be >= 2. Capped at n_val per trait so each fold has at least 1 cell. When effective K < 2 (e.g. n_val = 1), the code falls back to a single split.

safety_floor

logical. When TRUE (default), post-training calibration searches a 3-way simplex of BM, GNN, and grand-mean candidates. Because the grand-mean corner is always in the grid, the selected candidate cannot be worse than that corner on the validation cells under the calibration metric. When FALSE, the v0.9.1 1-D calibration is used exactly (r_MEAN = 0).

phylo_signal_gate

logical. When TRUE (default since v0.9.1.9003), compute per-trait Pagel's \(\lambda\) on training-observed cells before fitting; for traits with lambda < phylo_signal_threshold, force (r_cal_bm = 0, r_cal_gnn = 0, r_cal_mean = 1) directly and skip BM + GNN training on those traits. Requires the phytools package. Falls back to safety-floor-only behaviour (phylo_signal_gate = FALSE effective) when phytools is absent.

phylo_signal_threshold

numeric, default 0.2. Traits with Pagel's \(\lambda\) below this value are routed to the grand-mean corner of the safety-floor simplex.

phylo_signal_method

character, currently only "lambda" is fully implemented. Reserved "blomberg_k" path returns Blomberg's K via phytools::phylosig() but uses the same threshold — which is NOT dimensionally comparable; users selecting K must supply a K-appropriate threshold.

min_val_cells

integer. Warn at fit time if any trait has fewer than min_val_cells validation cells available for gate calibration and conformal-score estimation. Default 10: the floor of pathological territory, where the conformal quantile collapses to max(val_residuals) and gate calibration becomes essentially a coin flip between 0 and gate_cap. Recommended operational target is n_val >= 20-30 per trait; achieve this by increasing missing_frac or collecting more species. See Calibration at small n below.

verbose

logical. Print training progress (default TRUE).

seed

integer. Random seed (default 1).

Value

An object of class "pigauto_fit".

Details

Blend formulation: The prediction is \(\hat{x} = (1-r)\mu + r\delta\), where \(\mu\) is the BM baseline, \(\delta\) is the model's direct prediction, and \(r = \sigma(\rho) \times \mathrm{cap}\) is a per-column learnable gate bounded in \((0, \mathrm{gate\_cap})\). When \(r = 0\), the prediction collapses to the baseline. The gate is regularised toward zero via the shrinkage penalty on \(\delta - \mu\), so the model defaults to the baseline unless the GNN's correction demonstrably helps on the validation set.

Training objective (per epoch):

  1. A random subset of observed cells is corrupted with a learnable mask token.

  2. The model predicts \(\delta\) from graph context.

  3. Loss = type-specific reconstruction on corrupted cells + lambda_shrink * MSE(\(\delta - \mu\)) + lambda_gate * MSE(\(r\)).

The gate penalty on \(r\) is necessary because when \(\delta = \mu\) (the BM-optimal solution for observed cells), the reconstruction and shrinkage losses both equal zero regardless of \(r\), leaving no gradient to close the gate. The explicit penalty ensures gates default toward zero when the GNN correction provides no benefit.

Type-specific losses:

continuous/count/ordinal

MSE

binary

BCE with logits

categorical

cross-entropy over K latent columns

Calibration at small n

pigauto's 95\ interval half-width for each trait is the empirical (1 - alpha) quantile of \(|y - \hat y|\) on held-out validation cells. When the number of validation cells per trait (n_val) is large (\(\gtrsim 30\)), split-conformal gives near-exact marginal coverage under mild exchangeability assumptions.

At small n_val (\(<\) 20, and especially \(<\) 10) two things degrade at once:

  • The conformal quantile clamps to max(residuals) because the required quantile level exceeds 1. The score is the single largest val residual and has substantial sampling variance across fits.

  • The gate calibration's half-A / half-B split ends up with only a few cells per half, so the grid-search winner is essentially random between 0 and gate_cap.

Empirically (pigauto simulation harness, n=150 species, 35\ trait_MAR, 10 random seeds), default behaviour produces 92\ coverage with per-fit coverage ranging [0.73, 1.00]. The mean is close to the 95\ gate_method = "median_splits" and conformal_method = "bootstrap" reduce their respective estimator variances but empirically do not meaningfully narrow the coverage distribution. Treat the 95\ regime; increase missing_frac (more held-out data for calibration) or collect more species if tight per-fit coverage is required.

The min_val_cells warning fires at fit time whenever any trait falls below this threshold, so you know when to interpret the intervals cautiously.