Skip to contents

Runs a single forward pass through the fitted model and returns imputed trait values back-transformed to the original scale. Supports all trait types (continuous, binary, categorical, ordinal, count, proportion, zero-inflated count, multi-proportion) and MC dropout for multiple imputation (when n_imputations > 1). The fitted model is a gated ensemble of a phylogenetic baseline and a graph neural network correction; prediction is the per-trait blend (1 - r_cal) * baseline + r_cal * delta_GNN.

Usage

# S3 method for class 'pigauto_fit'
predict(
  object,
  newdata = NULL,
  return_se = TRUE,
  n_imputations = 1L,
  baseline_override = NULL,
  pool_method = c("median", "mean", "mode"),
  clamp_outliers = FALSE,
  clamp_factor = 5,
  match_observed = c("none", "pmm"),
  pmm_K = 5L,
  ...
)

Arguments

object

object of class "pigauto_fit".

newdata

NULL (use the training data) or a "pigauto_data" object for new species.

return_se

logical. Compute standard errors? (default TRUE).

n_imputations

integer. Number of stochastic imputation draws — BM posterior samples plus GNN dropout — (default 1L). Set to e.g. 10 or 20 for proper multiple imputation with between-imputation variance.

baseline_override

optional list(mu, se) with the same shape as object$baseline. When supplied, predictions use this baseline instead of the one saved in the fit. Used internally by multi_impute_trees() to reuse a trained GNN across posterior trees. Most users can ignore this. Default NULL (use the fit's own baseline).

pool_method

character. How to pool the M decoded draws when n_imputations > 1: "median" (default) takes the per-cell median for count / proportion / zi_count magnitude traits — robust to single dropout-noisy draws amplified by non-linear decoders (expm1 / plogis). "mean" restores the pre-v0.9.2 arithmetic-mean pooling. Continuous / ordinal / binary / categorical pooling is unchanged (linear / probability-averaged). See issue #40.

clamp_outliers

logical. Phase G (v0.9.1.9011+). When TRUE, post-back-transform predictions for log-transformed continuous, count, and zi_count magnitude traits are capped at tm$obs_max * clamp_factor (tm$obs_max is the observed maximum on the original scale, set at preprocess time). Targets tail-extrapolation modes amplified by exp() / expm1() back-transforms. Default FALSE preserves v0.9.1 behaviour exactly.

clamp_factor

numeric scalar (>= 1). Multiplicative factor on the observed maximum used by clamp_outliers. Default 5. Ignored when clamp_outliers = FALSE.

match_observed

character, one of c("none", "pmm"). Phase G' (v0.9.1.9012+). When "pmm", uses Predictive Mean Matching (Little 1988; Buuren mice) on the at-risk types (log-transformed continuous, count, zi_count magnitude, proportion). For each missing cell, finds the pmm_K observed cells whose own predictions are closest to the missing cell's prediction, samples one, and returns its observed value. Imputed values are guaranteed to lie in the observed data range – no extrapolation is possible by construction.

When to use: PMM is a niche feature in pigauto. The package already provides conformal prediction intervals (calibrated against held-out residuals) and multi_impute(draws_method = "conformal") for multi-imputation workflows – those give Rubin's-rules honest standard errors without donor-mismatch noise. PMM is only worth enabling for: (a) methodological comparison against mice / equivalent packages, or (b) workflows that specifically require imputed values to come from the observed data pool. For tail safety on single-imputation point estimates, prefer clamp_outliers = TRUE. For honest MI standard errors, prefer multi_impute(draws_method = "conformal").

The Phase G' acceptance bench (useful/MEMO_2026-05-01_phase_g_prime_results.md) confirmed PMM does not strictly improve point-estimate RMSE over the no-PMM default: it wins on extrapolating cells (e.g. AVONET Casuarius) but loses on cells where the GNN's prediction is already accurate (donor-mismatch noise).

Discrete-class types (binary / categorical / ordinal / multi_proportion) and un-log continuous: no-op. Default "none" preserves v0.9.1.9011 behaviour exactly.

pmm_K

integer scalar (>= 1). Donor pool size for PMM. Default 5L (mice convention). Ignored when match_observed = "none".

...

ignored.

Value

A list of class "pigauto_pred" with:

imputed

data.frame of imputed values in original scale with proper R types (numeric, integer, factor, ordered).

imputed_latent

Numeric matrix (n x p_latent) of predictions in latent scale.

se

Numeric matrix (n x n_original_traits) of per-cell uncertainty. Continuous/count/ordinal/proportion: SE in original scale (BM conditional SD, delta-method back-transformed). Binary: min(p, 1-p) — probability of being wrong (0 = certain, 0.5 = maximally uncertain); not a Gaussian SE. Categorical: 1 - max(p_k) — margin from certainty; not a Gaussian SE. NULL if return_se = FALSE.

probabilities

Named list. Binary traits: numeric probability vector. Categorical traits: n x K probability matrix. Other types: not present.

imputed_datasets

List of M data.frames when n_imputations > 1; NULL otherwise.

trait_map

Trait map from the fitted model.

species_names

Character vector.

trait_names

Character vector.

n_imputations

Integer, number of imputations performed.

Details

When n_imputations > 1, each imputation m draws a BM posterior sample t_BM_draw ~ N(BM_mu, BM_se) on the latent scale for originally-missing cells (BM_se = 0 for observed cells so they are never perturbed). The model runs in train mode (GNN dropout active) using t_BM_draw as input. The final blend is (1 - r_cal) * t_BM_draw + r_cal * GNN_delta(t_BM_draw): when the calibrated gate is zero the imputation is a pure BM posterior draw; when r_cal > 0 both BM draws and GNN dropout contribute variance. Point estimates are the mean (continuous, count) or mode (binary, categorical, ordinal) across passes. The M complete datasets are returned in imputed_datasets for Rubin's-rules pooling. For the user-facing multiple-imputation workflow, prefer multi_impute() which offers draws_method = "conformal" (calibrated, narrower) or "mc_dropout" (BM posterior draws + GNN dropout, wider).

Decoding per type:

continuous

reverse z-score, then exp() if log-transformed

binary

sigmoid(latent) to probability, round to 0/1

count

reverse z-score of log1p, expm1(), round, clip >= 0

ordinal

reverse z-score, round to nearest valid integer level

categorical

softmax() over K latent columns, argmax

Examples

if (FALSE) { # \dontrun{
pred <- predict(fit, return_se = TRUE)
pred$imputed        # data.frame, original scale
pred$se             # matrix, uncertainty
pred$probabilities  # list of prob vectors/matrices

# Multiple imputation (BM posterior draws + GNN dropout)
pred10 <- predict(fit, n_imputations = 10)
pred10$imputed_datasets  # 10 complete data.frames
} # }