The Variables DAG#

Module: leaspy.variables.dag

document: leaspy.variables.dag.VariablesDAG

Why does Leaspy need a DAG?#

A Leaspy model is not a simple function \(f(x; \theta)\). It has many variables of different natures (data, parameters, latent variables, derived quantities), and they form a dependency chain: some variables can only be computed once others are known.

For example, in the logistic model:

  • You cannot compute the reparametrized time rt until you know the patient’s time-shift tau and acceleration alpha.

  • You cannot compute alpha until you sample the latent variable xi.

  • You cannot compute the model output until you have rt, the metric, and the population parameter g.

The VariablesDAG is the data structure that encodes all of these relationships. It answers two critical questions at runtime:

  1. Forward propagation: “Given that I just assigned a value to variable X, which downstream variables need to be (re)computed?”

  2. Classification: “Which variables are parameters? Which are individual latent variables? Which are derived?”

Without it, every algorithm would need model-specific hard-coded logic. With it, the MCMC-SAEM algorithm can remain generic: it proposes a new value for a latent variable, and the DAG ensures consistency propagates automatically.

What is a DAG, concretely?#

A Directed Acyclic Graph is a set of nodes (variables) connected by directed edges (dependencies), with no cycles.

  • Directed: each edge has a direction: “A is needed to compute B” (A → B).

  • Acyclic: there is no circular chain like A → B → C → A.

In code, VariablesDAG is a frozen dataclass that stores:

Attribute

What it holds

variables

A mapping from variable name → variable specification object

direct_ancestors

For each variable, the set of variables it directly depends on

direct_children

(precomputed) For each variable, the set of variables that depend on it

sorted_variables_names

All variable names in topological order (roots first, leaves last)

sorted_children

For each variable, all downstream descendants (transitive closure)

sorted_ancestors

For each variable, all upstream ancestors (transitive closure)

sorted_variables_by_type

Variables grouped by their Python type (ModelParameter, IndividualLatentVariable, etc.)

The topological sort guarantees that when computing values top-to-bottom, every variable’s inputs are already available.

The Logistic Model’s DAG#

Below is the complete dependency graph for a multivariate LogisticModel with sources. Each node is a variable in the model; each arrow means “this variable is needed to compute that one”. The graph is organized into three conceptual sections — Temporal Variability, Geometrical Model, and Spatial Variability — that merge at the top into the observation model and the negative log-likelihood.

DAG of the multivariate logistic model

Legend (variable types):

  • Input data — observed values provided by the dataset

  • Observational Model — likelihood computation

  • Linked / Derived — deterministically computed from parents

  • Individual latent variables \(z_i\) — sampled per patient (E-step)

  • Population latent variables \(z_{pop}\) — sampled at population level (E-step)

  • Model parameters \(\theta\) — estimated during optimization (M-step)

  • Hyperparameters — fixed priors, not learned

Section-by-section breakdown#

Each tab below isolates one section of the diagram, shows its sub-graph, and explains every variable.

Temporal Variability governs when each patient’s disease trajectory is positioned on the time axis. It introduces two individual latent variables — a time-shift \(\tau_i\) and an acceleration factor \(\xi_i\) — that together define a patient-specific time reparametrization.

DAG of temporal variability
Variables#

Variable & Type

Description

\(\overline{\xi}\) Hyperparameter

Mean of the acceleration factor distribution. Fixed at 0.0, so the prior mode of \(\alpha_i = \exp(\xi_i)\) is 1 (meaning “average speed”). Code name: xi_mean. Origin/Definition: RiemanianManifoldModel.get_variables_specs()

\(\sigma_\xi\) Model parameter

Standard deviation of \(\xi_i\). Estimated during the M-step — controls how much acceleration varies across patients. Code name: xi_std. Origin/Definition: TimeReparametrizedModel.get_variables_specs()

\(\xi_i\) Individual latent

Acceleration factor (log-scale) for patient \(i\). Sampled from \(\mathcal{N}(\overline{\xi},\; \sigma_\xi^2)\). Code name: xi. Origin/Definition: TimeReparametrizedModel.get_variables_specs()

\(\alpha_i\) Linked

Individual acceleration: \(\alpha_i = \exp(\xi_i)\). Deterministic transform. Code name: alpha. Origin/Definition: TimeReparametrizedModel.get_variables_specs()

\(\overline{\tau}\) Model parameter

Population mean of the time-shift. Estimated — represents the “average age of onset”. Code name: tau_mean. Origin/Definition: TimeReparametrizedModel.get_variables_specs()

\(\sigma_\tau\) Model parameter

Standard deviation of \(\tau_i\). Estimated — controls how spread out disease onset ages are. Code name: tau_std. Origin/Definition: TimeReparametrizedModel.get_variables_specs()

\(\tau_i\) Individual latent

Time-shift for patient \(i\). Sampled from \(\mathcal{N}(\overline{\tau},\; \sigma_\tau^2)\). Code name: tau. Origin/Definition: TimeReparametrizedModel.get_variables_specs()

\(t_{ij}\) Input data

Observed timepoints (visit ages). Code name: t. Origin/Definition: McmcSaemCompatibleModel.get_variables_specs()

\(\psi_i(t_{ij})\) Linked

Reparametrized time: \(\psi_i(t_{ij}) = \alpha_i \cdot (t_{ij} - \tau_i)\). This is the patient-specific “disease clock”. Code name: rt. Origin: TimeReparametrizedModel.get_variables_specs(). Definition: TimeReparametrizedModel.time_reparametrization

Propagation example#

When the MCMC algorithm proposes a new value for \(\tau_i\):

  • \(\psi_i(t_{ij})\) must be recomputed (depends on \(\tau_i\))

  • \(\gamma_k(\psi_i(t_{ij}))\) must be recomputed (depends on \(\psi_i\))

  • \(\eta_k\) must be recomputed (depends on \(\gamma_k\))

  • The likelihood must be recomputed (depends on \(\eta_k\))

  • But \(g_k\), \(v_k\), \(w_{ik}\) remain unchanged — they are on different branches

The DAG encodes exactly this: sorted_children["tau"] returns the transitive closure of all downstream nodes, and the State invalidates their cached values.

How VariablesDAG is built#

Each model class in the inheritance chain contributes its own variables via get_variables_specs(). These contributions are accumulated using super() — each class calls super().get_variables_specs() first, then adds its own variables on top:

Class

Contributes (mathematical notation → code name)

McmcSaemCompatibleModel

\(t_{ij}\) (t), \(y_{ijk}\) (y), \(-\log p(...)\) (nll_attach)

TimeReparametrizedModel

\(\tau_i\), \(\xi_i\), \(\alpha_i\) (alpha), \(\psi_i\) (rt), and if multivariate: \(s_{il}\) (sources), \(\beta_{ml}\) (betas), \(w_{ik}\) (space_shifts), \(A\) (mixing_matrix)

RiemanianManifoldModel

\(\log(v_k)\) (log_v0), \(v_k\) (v0), metric, \(B\) (orthonormal_basis), \(\eta_k\) (model)

LogisticModel

\(\log(g_k)\) (log_g), \(g_k\) (g), logistic-specific metric formula

Once all specs are collected into a single dictionary, VariablesDAG.from_dict(specs) analyzes the function signatures of LinkedVariable callables to infer edges automatically. For example, LinkedVariable(Exp("log_g")) tells the DAG: “I depend on log_g”.

Relation to State#

The VariablesDAG is a static blueprint: it describes what variables exist and how they relate. It does not hold values.

The State object holds the runtime values for each variable. It keeps a reference to the DAG so it can propagate updates correctly. When you write state["tau"] = new_value, the State uses the DAG’s sorted_children["tau"] to invalidate all downstream caches (rt, model, nll_attach).

See StatefulModel for how the State is created and managed.