Skip to contents

This article opens up pigauto’s engine room. It is intended for readers who want to know exactly what the package does between impute(traits, tree, covariates) and the returned pigauto_result. We cover the data flow, the gated ensemble formula, the phylogenetic baseline (joint MVN / threshold-joint / OVR), the GNN’s transformer blocks, the optional within-row cross-trait attention added in v0.9.3, and the calibration / uncertainty quantification machinery on top.

We assume familiarity with maximum-likelihood Brownian motion (BM) on trees and with standard transformer attention. We do not assume familiarity with the package’s internals.

1. End-to-end data flow

impute() chains six internal stages. Tensor shapes use nn for observations (= species in single-obs mode, > species when each species has multiple measurements), pp for the latent trait dimension (which expands categorical traits to KK one-hot columns), kk for spectral Laplacian features, cc for covariate dimensions, and TT for the posterior tree count in tree-uncertainty pooling.

INPUT:
  traits      data.frame   n × p_traits     (mixed types, NAs allowed)
  tree        phylo        m tips
  covariates  data.frame   n × c            (optional)

STEP 1 — preprocess_traits()
  X_scaled       n × p           per-type encoding + z-score
  trait_map      list            descriptors per trait (type, levels, mean, sd)
  obs_to_species n               only if n > m (multi-obs mode)

STEP 2 — build_phylo_graph()
  coords  m × k                  Laplacian eigenvectors
  adj     m × m                  Gaussian kernel on cophenetic distance
  D_sq    m × m                  squared cophenetic distance (B2)

STEP 3 — fit_baseline()  [FIXED, NOT TRAINED]
  mu  m × p                      phylogenetic baseline (BM / LP / joint)
  se  m × p                      conditional-MVN standard errors

STEP 4 — fit_pigauto() trains ResidualPhyloDAE on a DAE objective
  produces a torch nn_module + per-column gate

STEP 5 — calibrate_gates() picks per-trait r_cal on held-out val cells

STEP 6 — predict.pigauto_fit() applies the blend:
  pred = (1 - r_cal) · mu + r_cal · delta_GNN

The two essential ideas are:

  1. A simple, well-understood phylogenetic baseline does most of the work, and the GNN only contributes when calibrated to do so.
  2. The blend gate rcalr_\text{cal} is a safety floor: at rcal=0r_\text{cal}=0 the prediction collapses to the baseline. This bounds the worst-case regression to “as good as the baseline” regardless of how poorly the GNN trains.

2. The phylogenetic baseline

The baseline is a per-trait-type dispatcher inside fit_baseline(). For each trait you get μi\mu_i (prior posterior mean for missing cell ii) and σi\sigma_i (its posterior SD). All later inference is conditional on this.

2.1 Continuous, count, ordinal, proportion: Brownian motion

For one continuous trait yy with yOy_O observed and yMy_M missing, y𝒩(β𝟏,σ2𝐑)y \sim \mathcal{N}(\beta \mathbf{1},\, \sigma^2 \mathbf{R}) under BM on the species tree, where 𝐑=cov2cor(vcv(tree))\mathbf{R} = \mathrm{cov2cor}(\mathrm{vcv}(\text{tree})) is the phylogenetic correlation matrix. The conditional posterior is the closed-form GLS solution:

μ̂M=β𝟏+𝐑MO𝐑OO1(yOβ𝟏O),σ̂M|O2=σ2(1diag(𝐑MO𝐑OO1𝐑OM)). \hat{\mu}_M = \beta \mathbf{1} + \mathbf{R}_{MO} \mathbf{R}_{OO}^{-1} (y_O - \beta \mathbf{1}_O), \quad \widehat{\sigma}^2_{M|O} = \sigma^2\bigl(1 - \mathrm{diag}(\mathbf{R}_{MO} \mathbf{R}_{OO}^{-1} \mathbf{R}_{OM})\bigr).

R/bm_internal.R::bm_impute_col() implements this directly. Count traits are log1p-transformed, ordinal traits are coerced to integers and z-scored, proportions are logit-z-scored. The same conditional formula is used in each transformed space.

2.2 Binary and categorical: phylogenetic label propagation (LP)

Discrete traits use a softer phylogenetic prior: each species’s class probability is a kernel-weighted average over the rest of the tree, Pr(yi=k)j𝐀ij𝟙(yj=k)\Pr(y_i = k) \propto \sum_{j} \mathbf{A}_{ij} \, \mathbb{1}(y_j = k), with 𝐀\mathbf{A} the same Gaussian kernel used in the GNN’s adjacency. For categorical traits with KK classes this returns KK log-probability columns; binary returns one logit column.

2.3 Joint MVN baseline (Phase 2)

When Rphylopars is installed and ≥ 2 continuous-like latent columns exist, pigauto upgrades to a joint multivariate-BM baseline. Stack the pp BM-eligible columns into 𝐘\mathbf{Y}. Under joint BM, vec(𝐘)𝒩(vec(𝛃),𝚺𝐑)\mathrm{vec}(\mathbf{Y}) \sim \mathcal{N}(\mathrm{vec}(\boldsymbol{\beta}),\, \mathbf{\Sigma} \otimes \mathbf{R}). Rphylopars::phylopars() fits 𝚺̂\hat{\mathbf{\Sigma}} jointly across traits, and the conditional posterior is the GLS solution with the Kronecker covariance. This captures cross-trait phylogenetic correlation that the per-column path loses. The bench in script/bench_joint_baseline.R showed a 33.7% RMSE lift on simulated correlated BM data.

2.4 Threshold-joint baseline (Phase 3)

To bring binary traits inside the joint MVN, pigauto’s fit_joint_threshold_baseline() uses the Wright–Falconer liability model: each observed binary cell y{0,1}y \in \{0, 1\} is replaced by the posterior mean of an underlying continuous liability LL truncated by yy. With prior L𝒩(0,1)L \sim \mathcal{N}(0, 1) this is the standard truncated-normal mean. The resulting matrix is then handed to phylopars() exactly like the continuous case. Binary posteriors are decoded back to probabilities via p=Φ(μL/1+σL2)p = \Phi\bigl(\mu_L / \sqrt{1 + \sigma_L^2}\bigr), clipped to [0.01,0.99][0.01, 0.99]. Ordinal liability uses K-1 interval cuts (B3 ordinal).

2.5 OVR categorical (Phase 6)

The single-fit approach to categorical liability (K columns into one phylopars call) is rank-deficient and unstable. pigauto instead runs KK independent threshold-joint fits — class kk vs the rest — and renormalises the resulting KK probabilities into a row-stochastic distribution. This is BACE’s strategy and lifts AVONET Trophic.Level accuracy from ~42% to ~72%.

2.6 Multi-obs aggregation (Phase 10 + B1 soft)

When each species has multiple observations, baselines run at species level. Phase 10 aggregates obs→species with type-aware rules: mean for continuous, modal class (or argmax one-hot) for discrete. B1 (v0.9.0) adds an opt-in soft path that preserves evidence strength: a species observed as class 1 in 6/10 rows uses the convex combination pE[LL>0]+(1p)E[LL<0]p \cdot E[L \mid L > 0] + (1 - p) \cdot E[L \mid L < 0] instead of collapsing to hard class 1.

3. The GNN: ResidualPhyloDAE

The GNN’s job is to learn an additive correction on top of the phylogenetic baseline. It is a denoising autoencoder (DAE) trained with masked-cell reconstruction loss.

Name note. The torch class is ResidualPhyloDAE because its internal blocks use ResNet-style residual skip connections. The network output delta is not a statistical residual yμy - \mu — it is a full per-cell prediction, blended externally with μ\mu via the per-trait gate.

3.1 Encoder

Input tensors:

  • x n×p\in \mathbb{R}^{n \times p} — current trait latent matrix (missing cells replaced by a learnable mask token).
  • coords m×k\in \mathbb{R}^{m \times k} — species-level spectral Laplacian features.
  • covs n×c\in \mathbb{R}^{n \times c}[baseline_mu | NA-mask | user_covs] (the user covariates plus a per-cell mask indicator and the baseline prediction).

If use_trait_attention = TRUE (new in v0.9.3, see §3.5), a pooled trait-context feature of dimension trait_embed_dim is also concatenated. The encoder is a two-layer MLP:

h=ReLU(𝐖2Dropout(ReLU(𝐖1concat(x,coords,covs,)))), h = \mathrm{ReLU}\bigl(\mathbf{W}_2 \, \mathrm{Dropout}(\mathrm{ReLU}(\mathbf{W}_1 \mathrm{concat}(x, \text{coords}, \text{covs}, \ldots)))\bigr),

producing hn×hdh \in \mathbb{R}^{n \times h_d} with hidden dim hdh_d (default 64).

In multi-obs mode, hh is averaged across observations of the same species (scatter_mean) to produce a species-level hidden state hspeciesm×hdh_\text{species} \in \mathbb{R}^{m \times h_d} for the graph message passing, then broadcast back to observation level afterwards.

3.2 Graph Transformer Block (Phase 9 + B2)

The default path stacks LL pre-norm transformer encoder blocks (n_gnn_layers, default 2). Each block has:

  • Multi-head attention (n_heads default 4) over the mm species, with a per-head learnable phylogenetic bias added to the attention scores. With 𝐃2\mathbf{D}^2 the squared cophenetic-distance matrix and βh=softplus(log_bwh)\beta_h = \mathrm{softplus}(\mathrm{log\_bw}_h) a learned bandwidth per head, the bias is 𝐁h=𝐃2/(2βh2)\mathbf{B}_h = -\mathbf{D}^2 / (2\beta_h^2). One head can attend tightly (fast-evolving traits), another broadly (conserved traits).
  • A position-wise FFN with width hd𝚏𝚏𝚗_𝚖𝚞𝚕𝚝h_d \cdot \mathtt{ffn\_mult} (default 4), output linear initialised to zero so the block ≈ identity at training step 0 — preserves the gate-closed-at-init safety.
  • Layer norm before each sub-block, residual skips after each.
  • Optional per-layer covariate injection (when user covariates are present): cov_h = cov_encoder(user_covs) is added inside the block via a learnable projection, so covariate features are visible at every depth, not just the encoder.

Legacy single-head attention is retained behind use_transformer_blocks = FALSE for reconstructing pre-v0.9.0 fits.

3.3 Decoder and the gate

A symmetric two-layer MLP maps the species-broadcast hidden state back to the latent space: delta = dec2(ReLU(dec1(h))) of shape n×pn \times p.

The per-column blend gate is

r=σ(ρ)𝚐𝚊𝚝𝚎_𝚌𝚊𝚙,r(0,𝚐𝚊𝚝𝚎_𝚌𝚊𝚙], r = \sigma(\rho) \cdot \mathtt{gate\_cap}, \qquad r \in (0, \mathtt{gate\_cap}],

with ρp\rho \in \mathbb{R}^p a learnable per-column parameter. The model output is

x̂=(1r)μ+rdelta+cov_linear(u), \hat{x} = (1 - r) \cdot \mu + r \cdot \mathrm{delta} + \mathrm{cov\_linear}(u),

where cov_linear is a small direct linear regression on user covariates (added outside the blend; gives the GNN a “linear shortcut” on covariates so it doesn’t have to learn β\beta through nonlinear layers).

The gate is initialised so the GNN contribution starts negligible: ρinit=1\rho_\text{init} = -1 for continuous columns (effective gate ≈ 0.135 × gate_cap), ρinit0\rho_\text{init} \approx 0 for discrete (fully closed).

3.4 Loss and three safety regularisations

For each training batch:

=type(x̂,ytrue)type-aware reconstruction+λshrinkMSE(deltaμ)+λgateMSE(r). \mathcal{L} = \underbrace{\mathcal{L}_\text{type}(\hat{x}, y_\text{true})}_\text{type-aware reconstruction} + \lambda_\text{shrink} \cdot \mathrm{MSE}(\mathrm{delta} - \mu) + \lambda_\text{gate} \cdot \mathrm{MSE}(r).

  • type\mathcal{L}_\text{type} dispatches by trait type: MSE for continuous / count / ordinal / proportion, BCE for binary / zi-gate, cross-entropy for categorical, MSE on CLR for multi-proportion.
  • λshrink\lambda_\text{shrink} (default 0.03) penalises delta drifting away from the baseline.
  • λgate\lambda_\text{gate} (default 0.01) actively pushes the gate toward zero. Without this term, ρ\rho has no gradient when delta ≈ μ (both shrinkage and reconstruction losses are zero), so the gate would stay at its init. The explicit penalty guarantees the gate defaults toward baseline-only.

Together with the architectural cap (gate_cap ≤ 0.8 by default), these three give a strong inductive bias toward the phylogenetic baseline — useful when the GNN has nothing extra to learn on a given dataset.

3.5 Within-row cross-trait attention (B3, v0.9.3, opt-in)

The encoder above mixes all trait columns into a single hidden vector via a linear projection, which loses per-trait identity. When use_trait_attention = TRUE, the model additionally builds a per-trait token sequence:

tokensi,j=𝐖vxi,j+𝐞j,j=1,,p, \text{tokens}_{i,j} = \mathbf{W}_v \, x_{i,j} + \mathbf{e}_j, \qquad j = 1, \ldots, p,

with 𝐞jed\mathbf{e}_j \in \mathbb{R}^{e_d} a learnable positional embedding per trait column (dim trait_embed_dim, default 32). One multi-head self-attention block (n_trait_heads, default 2) mixes these pp tokens within each row ii, followed by mean-pool to a single ede_d-dim feature. That feature is concatenated alongside (x,coords,covs)(x, \text{coords}, \text{covs}) at the encoder input.

The mechanism is intended for trait sets with strong within-row functional coupling that the joint MVN baseline cannot already capture (non-Gaussian, nonlinear, or non-monotone cross-trait structure). On the BIEN plant bench it did not improve pooled RMSE — the joint MVN/threshold-joint path already encodes 𝚺\mathbf{\Sigma} for the dominant trait types, so the second cross-trait mechanism is redundant. The flag ships as opt-in because it may help on datasets where the linear 𝚺\mathbf{\Sigma} assumption is too weak (e.g. functional traits with phase-transition or interaction structure).

Backward-compatible: saved fits without the field default to use_trait_attention = FALSE and reconstruct identically.

4. Calibration: making the gate a real safety floor

After training, pigauto does not ship the gate σ(ρ)\sigma(\rho) that the optimiser converged to. Instead it overrides ρ\rho with a per-trait calibrated gate rcalr_\text{cal} chosen on held-out validation cells.

calibrate_gates() runs a per-trait grid search of r[0,𝚐𝚊𝚝𝚎_𝚌𝚊𝚙]r \in [0, \mathtt{gate\_cap}] minimising val-set reconstruction loss (MSE for continuous; 0/1 loss for discrete with an absolute cell floor and a split-validation cross-check that prevents the GNN from harming baseline accuracy). The output is a single scalar rcalr_\text{cal} per latent column, stored on the fit and used at prediction time.

This is the second layer of safety: even if training somehow pushes the learnable gate wide open, calibration on held-out cells can close it back down. In practice, on datasets with strong phylogenetic signal where the GNN cannot improve on BM, rcal0r_\text{cal} \approx 0 and the prediction is exactly the baseline.

5. Uncertainty quantification

pigauto exposes three distinct uncertainty mechanisms; they answer different questions and must not be conflated.

Mechanism Source Validity
pred$se (cont./count/ordinal/prop.) Conditional-MVN SD from the BM baseline, delta-method back-transformed. Exact under BM, model-dependent.
pred$se (binary/categorical) min(p,1p)\min(p, 1{-}p) / 1maxkpk1 - \max_k p_k — uncertainty score, not a Gaussian SE. Use for ranking/reporting; do not plug into Rubin’s rules.
pred$conformal_lower, pred$conformal_upper Split conformal residual quantile on the val set: s=q(1α)(|yŷ|val)s = q_{(1-\alpha)}( | y - \hat{y} |_{\text{val}} ). Distribution-free 95% marginal coverage regardless of model assumptions.

For multiple-imputation workflows, multi_impute() exposes two draw methods. "conformal" (default) samples missing cells from 𝒩(μ,s/1.96)\mathcal{N}(\mu, s/1.96) on the transformed scale — calibrated against actual residuals. "mc_dropout" runs MM stochastic GNN passes in training mode (dropout active) on top of BM-draw inputs, which is wider than conformal but reflects model uncertainty when the gate is fully open.

For pooled inference, pool_mi() applies Rubin’s (1987) rules to MM downstream fits and returns a tidy data.frame with estimate, std.error, df, fmi, riv per term.

6. Tree uncertainty: a two-step workflow

When the species tree itself is uncertain, the Nakagawa & de Villemereuil (2019, Syst. Biol.) algorithm pools across TT posterior trees. multi_impute_trees(traits, trees, m_per_tree = 5L) performs step 1 — a full pigauto fit per tree, producing T×MT \times M completed datasets, each tagged with the tree index that produced it. Step 2 is the user’s responsibility: refit the downstream comparative model using the same tree that produced each dataset, then pool all T×MT \times M fits via pool_mi():

fits <- Map(function(dat, t_idx) {
  dat$species <- rownames(dat)
  nlme::gls(y ~ x,
            correlation = ape::corBrownian(phy = trees[[t_idx]],
                                            form = ~species),
            data = dat, method = "ML")
}, mi$datasets, mi$tree_index)
pool_mi(fits)

Compute is linear in TT. The 2019 paper’s relative-efficiency index typically converges before T=50T = 50.

7. Putting it together: a worked predictive equation

A clean end-to-end summary of what predict.pigauto_fit() returns for one continuous trait on one missing cell ii in a single-tree, single- imputation call:

ŷi=(1rcal)μiBM/joint-MVNbaseline contribution+rcaldeltaiGNNGNN contribution+cov_linear(ui)linear cov shortcut, \hat{y}_i \;=\; \underbrace{(1 - r_\text{cal}) \cdot \mu_i^{\text{BM/joint-MVN}}}_\text{baseline contribution} + \underbrace{r_\text{cal} \cdot \mathrm{delta}_i^{\text{GNN}}}_\text{GNN contribution} + \underbrace{\mathrm{cov\_linear}(u_i)}_\text{linear cov shortcut},

with μi\mu_i from the joint MVN (or per-column BM fallback), deltai\mathrm{delta}_i from the GNN (transformer blocks + optional within- row cross-trait attention), and rcalr_\text{cal} from validation calibration. The conformal interval is

ŷi±st, \hat{y}_i \pm s_t,

where sts_t is the split-conformal residual quantile for trait tt, giving 95%\ge 95\% marginal coverage on the original scale.

8. Where to look in the code

Concept File
BM kernel + conditional MVN R/bm_internal.R
Joint MVN baseline R/joint_mvn_baseline.R
Threshold-joint (binary + ordinal) R/joint_threshold_baseline.R, R/liability.R
OVR categorical R/ovr_categorical.R
Graph Transformer Block (B2) R/graph_transformer_block.R
ResidualPhyloDAE + B3 trait attention R/model_residual_dae.R
Training loop, gate calibration R/fit_pigauto.R, R/fit_helpers.R
Prediction, conformal intervals R/predict_pigauto.R
Multi-imputation pooling R/multi_impute.R, R/pool_mi.R
Tree uncertainty workflow R/multi_impute_trees.R

References

  • Felsenstein, J. (1985). Phylogenies and the comparative method. AmNat.
  • Pagel, M. (1999). Inferring the historical patterns of biological evolution. Nature.
  • Bruggeman, J., Heringa, J., & Brandt, B. W. (2009). PhyloPars: estimation of missing parameter values using phylogeny. NAR.
  • Goolsby, E. W., Bruggeman, J., & Ané, C. (2017). Rphylopars: fast multivariate phylogenetic comparative methods for missing data and within-species variation. MEE.
  • Wright, S. (1934). An analysis of variability in number of digits in an inbred strain of guinea pigs. Genetics. (Liability model)
  • Vaswani, A. et al. (2017). Attention is all you need. NeurIPS.
  • Ying, C. et al. (2021). Do Transformers Really Perform Bad for Graph Representation? NeurIPS. (Graphormer / multi-scale phylogenetic attention bias.)
  • Rubin, D. B. (1987). Multiple Imputation for Nonresponse in Surveys.
  • Nakagawa, S., & Freckleton, R. P. (2008, 2011). Missing inaction: the dangers of ignoring missing data. TREE / Model averaging, missing data and multiple imputation. BES.
  • Nakagawa, S., & de Villemereuil, P. (2019). A general method for simultaneously accounting for phylogenetic and species sampling uncertainty via Rubin’s rules in comparative analysis. Syst. Biol. 68(4): 632–641.
  • Vovk, V., Gammerman, A., & Shafer, G. (2005). Algorithmic Learning in a Random World. (Conformal prediction.)