
GNN architecture and the math behind pigauto
Source:vignettes/gnn-architecture.Rmd
gnn-architecture.RmdThis article opens up pigauto’s engine room. It is
intended for readers who want to know exactly what the package
does between impute(traits, tree, covariates) and the
returned pigauto_result. We cover the data flow, the gated
ensemble formula, the phylogenetic baseline (joint MVN / threshold-joint
/ OVR), the GNN’s transformer blocks, the optional within-row
cross-trait attention added in v0.9.3, and the calibration / uncertainty
quantification machinery on top.
We assume familiarity with maximum-likelihood Brownian motion (BM) on trees and with standard transformer attention. We do not assume familiarity with the package’s internals.
1. End-to-end data flow
impute() chains six internal stages. Tensor shapes use
for observations (= species in single-obs mode, > species when each
species has multiple measurements),
for the latent trait dimension (which expands categorical
traits to
one-hot columns),
for spectral Laplacian features,
for covariate dimensions, and
for the posterior tree count in tree-uncertainty pooling.
INPUT:
traits data.frame n × p_traits (mixed types, NAs allowed)
tree phylo m tips
covariates data.frame n × c (optional)
STEP 1 — preprocess_traits()
X_scaled n × p per-type encoding + z-score
trait_map list descriptors per trait (type, levels, mean, sd)
obs_to_species n only if n > m (multi-obs mode)
STEP 2 — build_phylo_graph()
coords m × k Laplacian eigenvectors
adj m × m Gaussian kernel on cophenetic distance
D_sq m × m squared cophenetic distance (B2)
STEP 3 — fit_baseline() [FIXED, NOT TRAINED]
mu m × p phylogenetic baseline (BM / LP / joint)
se m × p conditional-MVN standard errors
STEP 4 — fit_pigauto() trains ResidualPhyloDAE on a DAE objective
produces a torch nn_module + per-column gate
STEP 5 — calibrate_gates() picks per-trait r_cal on held-out val cells
STEP 6 — predict.pigauto_fit() applies the blend:
pred = (1 - r_cal) · mu + r_cal · delta_GNN
The two essential ideas are:
- A simple, well-understood phylogenetic baseline does most of the work, and the GNN only contributes when calibrated to do so.
- The blend gate is a safety floor: at the prediction collapses to the baseline. This bounds the worst-case regression to “as good as the baseline” regardless of how poorly the GNN trains.
2. The phylogenetic baseline
The baseline is a per-trait-type dispatcher inside
fit_baseline(). For each trait you get
(prior posterior mean for missing cell
)
and
(its posterior SD). All later inference is conditional on this.
2.1 Continuous, count, ordinal, proportion: Brownian motion
For one continuous trait with observed and missing, under BM on the species tree, where is the phylogenetic correlation matrix. The conditional posterior is the closed-form GLS solution:
R/bm_internal.R::bm_impute_col() implements this
directly. Count traits are log1p-transformed, ordinal traits are coerced
to integers and z-scored, proportions are logit-z-scored. The same
conditional formula is used in each transformed space.
2.2 Binary and categorical: phylogenetic label propagation (LP)
Discrete traits use a softer phylogenetic prior: each species’s class probability is a kernel-weighted average over the rest of the tree, , with the same Gaussian kernel used in the GNN’s adjacency. For categorical traits with classes this returns log-probability columns; binary returns one logit column.
2.3 Joint MVN baseline (Phase 2)
When Rphylopars is installed and ≥ 2 continuous-like
latent columns exist, pigauto upgrades to a joint multivariate-BM
baseline. Stack the
BM-eligible columns into
.
Under joint BM,
.
Rphylopars::phylopars() fits
jointly across traits, and the conditional posterior is the GLS solution
with the Kronecker covariance. This captures cross-trait phylogenetic
correlation that the per-column path loses. The bench in
script/bench_joint_baseline.R showed a 33.7% RMSE lift on
simulated correlated BM data.
2.4 Threshold-joint baseline (Phase 3)
To bring binary traits inside the joint MVN, pigauto’s
fit_joint_threshold_baseline() uses the Wright–Falconer
liability model: each observed binary cell
is replaced by the posterior mean of an underlying continuous
liability
truncated by
.
With prior
this is the standard truncated-normal mean. The resulting matrix is then
handed to phylopars() exactly like the continuous case.
Binary posteriors are decoded back to probabilities via
,
clipped to
.
Ordinal liability uses K-1 interval cuts (B3 ordinal).
2.5 OVR categorical (Phase 6)
The single-fit approach to categorical liability (K columns into one phylopars call) is rank-deficient and unstable. pigauto instead runs independent threshold-joint fits — class vs the rest — and renormalises the resulting probabilities into a row-stochastic distribution. This is BACE’s strategy and lifts AVONET Trophic.Level accuracy from ~42% to ~72%.
2.6 Multi-obs aggregation (Phase 10 + B1 soft)
When each species has multiple observations, baselines run at species level. Phase 10 aggregates obs→species with type-aware rules: mean for continuous, modal class (or argmax one-hot) for discrete. B1 (v0.9.0) adds an opt-in soft path that preserves evidence strength: a species observed as class 1 in 6/10 rows uses the convex combination instead of collapsing to hard class 1.
3. The GNN: ResidualPhyloDAE
The GNN’s job is to learn an additive correction on top of the phylogenetic baseline. It is a denoising autoencoder (DAE) trained with masked-cell reconstruction loss.
Name note. The torch class is
ResidualPhyloDAEbecause its internal blocks use ResNet-style residual skip connections. The network outputdeltais not a statistical residual — it is a full per-cell prediction, blended externally with via the per-trait gate.
3.1 Encoder
Input tensors:
-
x— current trait latent matrix (missing cells replaced by a learnable mask token). -
coords— species-level spectral Laplacian features. -
covs—[baseline_mu | NA-mask | user_covs](the user covariates plus a per-cell mask indicator and the baseline prediction).
If use_trait_attention = TRUE (new in v0.9.3, see §3.5),
a pooled trait-context feature of dimension trait_embed_dim
is also concatenated. The encoder is a two-layer MLP:
producing with hidden dim (default 64).
In multi-obs mode,
is averaged across observations of the same species
(scatter_mean) to produce a species-level hidden state
for the graph message passing, then broadcast back to observation level
afterwards.
3.2 Graph Transformer Block (Phase 9 + B2)
The default path stacks
pre-norm transformer encoder blocks (n_gnn_layers, default
2). Each block has:
- Multi-head attention (
n_headsdefault 4) over the species, with a per-head learnable phylogenetic bias added to the attention scores. With the squared cophenetic-distance matrix and a learned bandwidth per head, the bias is . One head can attend tightly (fast-evolving traits), another broadly (conserved traits). - A position-wise FFN with width (default 4), output linear initialised to zero so the block ≈ identity at training step 0 — preserves the gate-closed-at-init safety.
- Layer norm before each sub-block, residual skips after each.
- Optional per-layer covariate injection (when user covariates are
present):
cov_h = cov_encoder(user_covs)is added inside the block via a learnable projection, so covariate features are visible at every depth, not just the encoder.
Legacy single-head attention is retained behind
use_transformer_blocks = FALSE for reconstructing
pre-v0.9.0 fits.
3.3 Decoder and the gate
A symmetric two-layer MLP maps the species-broadcast hidden state
back to the latent space: delta = dec2(ReLU(dec1(h))) of
shape
.
The per-column blend gate is
with a learnable per-column parameter. The model output is
where cov_linear is a small direct linear regression on
user covariates (added outside the blend; gives the GNN a “linear
shortcut” on covariates so it doesn’t have to learn
through nonlinear layers).
The gate is initialised so the GNN contribution starts negligible: for continuous columns (effective gate ≈ 0.135 × gate_cap), for discrete (fully closed).
3.4 Loss and three safety regularisations
For each training batch:
- dispatches by trait type: MSE for continuous / count / ordinal / proportion, BCE for binary / zi-gate, cross-entropy for categorical, MSE on CLR for multi-proportion.
-
(default 0.03) penalises
deltadrifting away from the baseline. - (default 0.01) actively pushes the gate toward zero. Without this term, has no gradient when delta ≈ μ (both shrinkage and reconstruction losses are zero), so the gate would stay at its init. The explicit penalty guarantees the gate defaults toward baseline-only.
Together with the architectural cap (gate_cap ≤ 0.8 by
default), these three give a strong inductive bias toward the
phylogenetic baseline — useful when the GNN has nothing extra to
learn on a given dataset.
3.5 Within-row cross-trait attention (B3, v0.9.3, opt-in)
The encoder above mixes all trait columns into a single hidden vector
via a linear projection, which loses per-trait identity. When
use_trait_attention = TRUE, the model additionally builds a
per-trait token sequence:
with
a learnable positional embedding per trait column (dim
trait_embed_dim, default 32). One multi-head self-attention
block (n_trait_heads, default 2) mixes these
tokens within each row
,
followed by mean-pool to a single
-dim
feature. That feature is concatenated alongside
at the encoder input.
The mechanism is intended for trait sets with strong within-row functional coupling that the joint MVN baseline cannot already capture (non-Gaussian, nonlinear, or non-monotone cross-trait structure). On the BIEN plant bench it did not improve pooled RMSE — the joint MVN/threshold-joint path already encodes for the dominant trait types, so the second cross-trait mechanism is redundant. The flag ships as opt-in because it may help on datasets where the linear assumption is too weak (e.g. functional traits with phase-transition or interaction structure).
Backward-compatible: saved fits without the field default to
use_trait_attention = FALSE and reconstruct
identically.
4. Calibration: making the gate a real safety floor
After training, pigauto does not ship the gate that the optimiser converged to. Instead it overrides with a per-trait calibrated gate chosen on held-out validation cells.
calibrate_gates() runs a per-trait grid search of
minimising val-set reconstruction loss (MSE for continuous; 0/1 loss for
discrete with an absolute cell floor and a split-validation cross-check
that prevents the GNN from harming baseline accuracy). The output is a
single scalar
per latent column, stored on the fit and used at prediction time.
This is the second layer of safety: even if training somehow pushes the learnable gate wide open, calibration on held-out cells can close it back down. In practice, on datasets with strong phylogenetic signal where the GNN cannot improve on BM, and the prediction is exactly the baseline.
5. Uncertainty quantification
pigauto exposes three distinct uncertainty mechanisms; they answer different questions and must not be conflated.
| Mechanism | Source | Validity |
|---|---|---|
pred$se (cont./count/ordinal/prop.) |
Conditional-MVN SD from the BM baseline, delta-method back-transformed. | Exact under BM, model-dependent. |
pred$se (binary/categorical) |
/ — uncertainty score, not a Gaussian SE. | Use for ranking/reporting; do not plug into Rubin’s rules. |
pred$conformal_lower,
pred$conformal_upper
|
Split conformal residual quantile on the val set: . | Distribution-free 95% marginal coverage regardless of model assumptions. |
For multiple-imputation workflows, multi_impute()
exposes two draw methods. "conformal" (default) samples
missing cells from
on the transformed scale — calibrated against actual residuals.
"mc_dropout" runs
stochastic GNN passes in training mode (dropout active) on top of
BM-draw inputs, which is wider than conformal but reflects model
uncertainty when the gate is fully open.
For pooled inference, pool_mi() applies Rubin’s (1987)
rules to
downstream fits and returns a tidy data.frame with
estimate, std.error, df,
fmi, riv per term.
6. Tree uncertainty: a two-step workflow
When the species tree itself is uncertain, the Nakagawa & de
Villemereuil (2019, Syst. Biol.) algorithm pools across
posterior trees.
multi_impute_trees(traits, trees, m_per_tree = 5L) performs
step 1 — a full pigauto fit per tree, producing
completed datasets, each tagged with the tree index that produced it.
Step 2 is the user’s responsibility: refit the
downstream comparative model using the same tree that
produced each dataset, then pool all
fits via pool_mi():
fits <- Map(function(dat, t_idx) {
dat$species <- rownames(dat)
nlme::gls(y ~ x,
correlation = ape::corBrownian(phy = trees[[t_idx]],
form = ~species),
data = dat, method = "ML")
}, mi$datasets, mi$tree_index)
pool_mi(fits)Compute is linear in . The 2019 paper’s relative-efficiency index typically converges before .
7. Putting it together: a worked predictive equation
A clean end-to-end summary of what predict.pigauto_fit()
returns for one continuous trait on one missing cell
in a single-tree, single- imputation call:
with from the joint MVN (or per-column BM fallback), from the GNN (transformer blocks + optional within- row cross-trait attention), and from validation calibration. The conformal interval is
where is the split-conformal residual quantile for trait , giving marginal coverage on the original scale.
8. Where to look in the code
| Concept | File |
|---|---|
| BM kernel + conditional MVN | R/bm_internal.R |
| Joint MVN baseline | R/joint_mvn_baseline.R |
| Threshold-joint (binary + ordinal) |
R/joint_threshold_baseline.R,
R/liability.R
|
| OVR categorical | R/ovr_categorical.R |
| Graph Transformer Block (B2) | R/graph_transformer_block.R |
| ResidualPhyloDAE + B3 trait attention | R/model_residual_dae.R |
| Training loop, gate calibration |
R/fit_pigauto.R, R/fit_helpers.R
|
| Prediction, conformal intervals | R/predict_pigauto.R |
| Multi-imputation pooling |
R/multi_impute.R, R/pool_mi.R
|
| Tree uncertainty workflow | R/multi_impute_trees.R |
References
- Felsenstein, J. (1985). Phylogenies and the comparative method. AmNat.
- Pagel, M. (1999). Inferring the historical patterns of biological evolution. Nature.
- Bruggeman, J., Heringa, J., & Brandt, B. W. (2009). PhyloPars: estimation of missing parameter values using phylogeny. NAR.
- Goolsby, E. W., Bruggeman, J., & Ané, C. (2017). Rphylopars: fast multivariate phylogenetic comparative methods for missing data and within-species variation. MEE.
- Wright, S. (1934). An analysis of variability in number of digits in an inbred strain of guinea pigs. Genetics. (Liability model)
- Vaswani, A. et al. (2017). Attention is all you need. NeurIPS.
- Ying, C. et al. (2021). Do Transformers Really Perform Bad for Graph Representation? NeurIPS. (Graphormer / multi-scale phylogenetic attention bias.)
- Rubin, D. B. (1987). Multiple Imputation for Nonresponse in Surveys.
- Nakagawa, S., & Freckleton, R. P. (2008, 2011). Missing inaction: the dangers of ignoring missing data. TREE / Model averaging, missing data and multiple imputation. BES.
- Nakagawa, S., & de Villemereuil, P. (2019). A general method for simultaneously accounting for phylogenetic and species sampling uncertainty via Rubin’s rules in comparative analysis. Syst. Biol. 68(4): 632–641.
- Vovk, V., Gammerman, A., & Shafer, G. (2005). Algorithmic Learning in a Random World. (Conformal prediction.)