# The Variables DAG **Module:** `leaspy.variables.dag` `document: leaspy.variables.dag.VariablesDAG` ## Why does Leaspy need a DAG? A Leaspy model is not a simple function $f(x; \theta)$. It has **many variables** of different natures (data, parameters, latent variables, derived quantities), and they form a **dependency chain**: some variables can only be computed once others are known. For example, in the logistic model: - You cannot compute the **reparametrized time** `rt` until you know the patient's time-shift `tau` and acceleration `alpha`. - You cannot compute `alpha` until you sample the latent variable `xi`. - You cannot compute the **model output** until you have `rt`, the metric, and the population parameter `g`. The **VariablesDAG** is the data structure that encodes all of these relationships. It answers two critical questions at runtime: 1. **Forward propagation**: "Given that I just assigned a value to variable X, which downstream variables need to be (re)computed?" 2. **Classification**: "Which variables are parameters? Which are individual latent variables? Which are derived?" Without it, every algorithm would need model-specific hard-coded logic. With it, the MCMC-SAEM algorithm can remain **generic**: it proposes a new value for a latent variable, and the DAG ensures consistency propagates automatically. ## What is a DAG, concretely? A **Directed Acyclic Graph** is a set of **nodes** (variables) connected by **directed edges** (dependencies), with no cycles. - **Directed**: each edge has a direction: "A is needed to compute B" (A → B). - **Acyclic**: there is no circular chain like A → B → C → A. In code, `VariablesDAG` is a frozen dataclass that stores: | Attribute | What it holds | |---|---| | `variables` | A mapping from variable name → variable specification object | | `direct_ancestors` | For each variable, the set of variables it **directly depends on** | | `direct_children` | (precomputed) For each variable, the set of variables that depend on it | | `sorted_variables_names` | All variable names in **topological order** (roots first, leaves last) | | `sorted_children` | For each variable, **all** downstream descendants (transitive closure) | | `sorted_ancestors` | For each variable, **all** upstream ancestors (transitive closure) | | `sorted_variables_by_type` | Variables grouped by their Python type (`ModelParameter`, `IndividualLatentVariable`, etc.) | The topological sort guarantees that when computing values top-to-bottom, every variable's inputs are already available. ## The Logistic Model's DAG Below is the complete dependency graph for a multivariate `LogisticModel` with sources. Each node is a variable in the model; each arrow means "this variable is needed to compute that one". The graph is organized into three conceptual sections — *Temporal Variability*, *Geometrical Model*, and *Spatial Variability* — that merge at the top into the observation model and the negative log-likelihood. ```{image} ../../../_static/images/DAG_Multivariate.drawio.png :alt: DAG of the multivariate logistic model :align: center :class: dag-zoomable :width: 100% ``` **Legend** (variable types): - **Input data** — observed values provided by the dataset - **Observational Model** — likelihood computation - **Linked / Derived** — deterministically computed from parents - **Individual latent variables** $z_i$ — sampled per patient (E-step) - **Population latent variables** $z_{pop}$ — sampled at population level (E-step) - **Model parameters** $\theta$ — estimated during optimization (M-step) - **Hyperparameters** — fixed priors, not learned ### Section-by-section breakdown Each tab below isolates one section of the diagram, shows its sub-graph, and explains every variable. `````{tabs} ````{tab} Temporal Variability **Temporal Variability** governs *when* each patient's disease trajectory is positioned on the time axis. It introduces two individual latent variables — a time-shift $\tau_i$ and an acceleration factor $\xi_i$ — that together define a patient-specific time reparametrization. ```{image} ../../../_static/images/dag_temporal.png :alt: DAG of temporal variability :align: center :width: 100% ``` ```{table} Variables :class: dag-var-table | Variable & Type | Description | |---|---| | $\overline{\xi}$ — Hyperparameter | Mean of the acceleration factor distribution. Fixed at 0.0, so the prior mode of $\alpha_i = \exp(\xi_i)$ is 1 (meaning "average speed"). Code name: `xi_mean`. *Origin/Definition: `RiemanianManifoldModel.get_variables_specs()`* | | $\sigma_\xi$ — Model parameter | Standard deviation of $\xi_i$. **Estimated** during the M-step — controls how much acceleration varies across patients. Code name: `xi_std`. *Origin/Definition: `TimeReparametrizedModel.get_variables_specs()`* | | $\xi_i$ — Individual latent | Acceleration factor (log-scale) for patient $i$. Sampled from $\mathcal{N}(\overline{\xi},\; \sigma_\xi^2)$. Code name: `xi`. *Origin/Definition: `TimeReparametrizedModel.get_variables_specs()`* | | $\alpha_i$ — Linked | Individual acceleration: $\alpha_i = \exp(\xi_i)$. Deterministic transform. Code name: `alpha`. *Origin/Definition: `TimeReparametrizedModel.get_variables_specs()`* | | $\overline{\tau}$ — Model parameter | Population mean of the time-shift. **Estimated** — represents the "average age of onset". Code name: `tau_mean`. *Origin/Definition: `TimeReparametrizedModel.get_variables_specs()`* | | $\sigma_\tau$ — Model parameter | Standard deviation of $\tau_i$. **Estimated** — controls how spread out disease onset ages are. Code name: `tau_std`. *Origin/Definition: `TimeReparametrizedModel.get_variables_specs()`* | | $\tau_i$ — Individual latent | Time-shift for patient $i$. Sampled from $\mathcal{N}(\overline{\tau},\; \sigma_\tau^2)$. Code name: `tau`. *Origin/Definition: `TimeReparametrizedModel.get_variables_specs()`* | | $t_{ij}$ — Input data | Observed timepoints (visit ages). Code name: `t`. *Origin/Definition: `McmcSaemCompatibleModel.get_variables_specs()`* | | $\psi_i(t_{ij})$ — Linked | **Reparametrized time**: $\psi_i(t_{ij}) = \alpha_i \cdot (t_{ij} - \tau_i)$. This is the patient-specific "disease clock". Code name: `rt`. *Origin: `TimeReparametrizedModel.get_variables_specs()`. Definition: `TimeReparametrizedModel.time_reparametrization`* | ``` ```` ````{tab} Geometrical Model **Geometrical Model** defines the shape of the disease trajectory on the Riemannian manifold. It introduces population-level parameters $g_k$ (position) and $v_k$ (velocity) for each feature $k$, and combines them with the reparametrized time to produce the per-feature trajectory $\gamma_k$. ```{image} ../../../_static/images/dag_geometrical.png :alt: DAG of geometrical model :align: center :width: 100% ``` ```{table} Variables :class: dag-var-table | Variable & Type | Description | |---|---| | $\overline{\log(g_k)}$ — Model parameter | Mean of the log-position prior. **Estimated** — determines where each feature's sigmoid is centered (midpoint value). Code name: `log_g_mean`. *Origin/Definition: `LogisticModel.get_variables_specs()`* | | $\sigma_{\log(g_k)}$ — Hyperparameter | Standard deviation of $\log(g_k)$. Fixed at 0.01 to keep $g_k$ close to its mean. Code name: `log_g_std`. *Origin/Definition: `LogisticModel.get_variables_specs()`* | | $\log(g_k)$ — Population latent | Log-position for feature $k$. Sampled from $\mathcal{N}(\overline{\log(g_k)},\; \sigma_{\log(g_k)}^2)$. Code name: `log_g`. *Origin/Definition: `LogisticModel.get_variables_specs()`* | | $g_k$ — Linked | **Position parameter**: $g_k = \exp(\log(g_k))$. Controls the midpoint of the sigmoid for feature $k$. Also used to compute the metric tensor. Code name: `g`. *Origin/Definition: `LogisticModel.get_variables_specs()`* | | metric — Linked | **Metric tensor**: $(g_k + 1)^2 / g_k$. Encodes the Riemannian geometry on the logistic manifold. Code name: `metric`. *Origin: `RiemanianManifoldModel.get_variables_specs()`. Definition: `LogisticModel.metric`* | | $\overline{\log(v_k)}$ — Model parameter | Mean of the log-velocity prior. **Estimated** — determines the speed of progression per feature. Code name: `log_v0_mean`. *Origin/Definition: `RiemanianManifoldModel.get_variables_specs()`* | | $\sigma_{\log(v_k)}$ — Hyperparameter | Standard deviation of $\log(v_k)$. Fixed at 0.01. Code name: `log_v0_std`. *Origin/Definition: `RiemanianManifoldModel.get_variables_specs()`* | | $\log(v_k)$ — Population latent | Log-velocity for feature $k$. Sampled from $\mathcal{N}(\overline{\log(v_k)},\; \sigma_{\log(v_k)}^2)$. Code name: `log_v0`. *Origin/Definition: `RiemanianManifoldModel.get_variables_specs()`* | | $v_k$ — Linked | **Velocity parameter**: $v_k = \exp(\log(v_k))$. Controls the rate of progression along the manifold per feature. Code name: `v0`. *Origin/Definition: `RiemanianManifoldModel.get_variables_specs()`* | | $\psi_i(t_{ij})$ — Linked | Reparametrized time from the *Temporal Variability* section (input). Code name: `rt`. *Origin: `TimeReparametrizedModel.get_variables_specs()`. Definition: `TimeReparametrizedModel.time_reparametrization`* | | $\gamma_k$ — Linked | **Geometric trajectory**: the per-feature curve. Combines `rt`, `metric`, `v0`, and `g` via parallel transport on the manifold. Code name: `model` (when no sources). *Origin: `RiemanianManifoldModel.get_variables_specs()`. Definition: `LogisticModel.model_with_sources`* | ``` ```` ````{tab} Spatial Variability **Spatial Variability** captures how individual patients deviate from the average *across features*. While temporal variability shifts trajectories in time, spatial variability shifts them across the feature space — allowing, for example, one patient to have faster decline in memory but slower decline in motor skills. ```{image} ../../../_static/images/dag_spatial.png :alt: DAG of spatial variability :align: center :width: 100% ``` ```{table} Variables :class: dag-var-table | Variable & Type | Description | |---|---| | $\overline{s}$ — Hyperparameter | Mean of the sources distribution. Fixed at 0 (centered prior). Code name: `sources_mean`. *Origin/Definition: `TimeReparametrizedModel.get_variables_specs()`* | | $\sigma_s$ — Hyperparameter | Standard deviation of the sources distribution. Fixed at 1.0 (standard normal prior). Code name: `sources_std`. *Origin/Definition: `TimeReparametrizedModel.get_variables_specs()`* | | $s_{il}$ — Individual latent | **Source component** $l$ for patient $i$. Sampled from $\mathcal{N}(\overline{s},\; \sigma_s^2)$. The individual coordinates in the low-dimensional source space. Code name: `sources`. *Origin/Definition: `TimeReparametrizedModel.get_variables_specs()`* | | $\overline{\beta_{ml}}$ — Model parameter | Mean of the mixing coefficients. **Estimated** — encodes how source dimensions map to feature differences. Shape: $(d-1) \times N_s$. Code name: `betas_mean`. *Origin/Definition: `TimeReparametrizedModel.get_variables_specs()`* | | $\sigma_\beta$ — Hyperparameter | Standard deviation of $\beta_{ml}$. Fixed at 0.01. Code name: `betas_std`. *Origin/Definition: `TimeReparametrizedModel.get_variables_specs()`* | | $\beta_{ml}$ — Population latent | **Mixing coefficients**. Sampled from $\mathcal{N}(\overline{\beta_{ml}},\; \sigma_\beta^2)$. They define the mapping from source space to feature space. Code name: `betas`. *Origin/Definition: `TimeReparametrizedModel.get_variables_specs()`* | | $v_k$ — Linked | **Velocity parameter**. Input from the *Geometrical Model* section. Used to compute the orthonormal basis $B$. Code name: `v0`. *Origin/Definition: `RiemanianManifoldModel.get_variables_specs()`* | | metric² — Linked | Square of the metric tensor. Needed for the orthonormal basis computation. Code name: `metric_sqr`. *Origin/Definition: `RiemanianManifoldModel.get_variables_specs()`* | | $B$ — Linked | **Orthonormal basis** of the tangent space (perpendicular to $v_0$), computed from `v0` and `metric_sqr` via `OrthoBasis`. Code name: `orthonormal_basis`. *Origin/Definition: `RiemanianManifoldModel.get_variables_specs()`* | | $A$ — Linked | **Mixing matrix**: $A = (B \cdot \beta)^T$. Maps from the source space ($N_s$ dimensions) to the full feature space ($d$ dimensions). Code name: `mixing_matrix`. *Origin/Definition: `TimeReparametrizedModel.get_variables_specs()`* | | $w_{ik}$ — Linked | **Space shift** for patient $i$, feature $k$: $w_i = s_i \cdot A$. The individual deviation applied to each feature's trajectory. Code name: `space_shifts`. *Origin/Definition: `TimeReparametrizedModel.get_variables_specs()`* | ``` ```` ````{tab} Negative Log-Likelihood **Negative Log-Likelihood** is where all three branches converge. The model prediction $\eta_k$ for each feature is compared against the observed data $y_{ijk}$ under a noise model parameterized by the noise standard deviation. ```{image} ../../../_static/images/dag_nll.png :alt: DAG of negative log-likelihood :align: center :width: 100% ``` ```{table} Variables :class: dag-var-table | Variable & Type | Description | |---|---| | $\gamma_k$ — Linked | Geometric trajectory from the *Geometrical Model* section. Code name: `model`. *Origin: `RiemanianManifoldModel.get_variables_specs()`. Definition: `LogisticModel.model_with_sources`* | | $w_{ik}$ — Linked | Space shift from the *Spatial Variability* section. Code name: `space_shifts`. *Origin/Definition: `TimeReparametrizedModel.get_variables_specs()`* | | $\eta_k$ — Linked | **Final model prediction** for patient $i$, feature $k$. Combines the geometric trajectory with the spatial shift: computed by `model_with_sources(rt, space_shifts, metric, v0, g)`. Code name: `model` (with sources). *Origin: `RiemanianManifoldModel.get_variables_specs()`. Definition: `LogisticModel.model_with_sources`* | | $y_{ijk}$ — Input data | **Observed measurement**. Retrieved from the `Dataset` via the observation model's getter. Code name: `y`. *Origin/Definition: `ObservationModel.get_variables_specs()`* | | noise_std — Model parameter | **Noise standard deviation**. **Estimated** — controls how much measurement noise is expected. Code name: `noise_std`. *Origin: `FullGaussianObservationModel` (factory method). Definition: `GaussianObservationModel` (standard deviation)* | | $-\log p(y \mid z, \theta)$ — Observational Model | **Negative log-likelihood**: computed in two stages — `nll_attach_ind` (per-individual NLL via `SumDim` over features/timepoints) then `nll_attach` (total NLL via `SumDim` over individuals). Under Gaussian noise: $\sum_{i,j,k} \frac{(y_{ijk} - \eta_{ijk})^2}{2 \cdot \text{noise\_std}_k^2} + \text{const}$. Code name: `nll_attach`. *Origin/Definition: `ObservationModel.get_variables_specs()`* | ``` ```` ````` ### Propagation example When the MCMC algorithm proposes a new value for $\tau_i$: - $\psi_i(t_{ij})$ must be recomputed (depends on $\tau_i$) - $\gamma_k(\psi_i(t_{ij}))$ must be recomputed (depends on $\psi_i$) - $\eta_k$ must be recomputed (depends on $\gamma_k$) - The likelihood must be recomputed (depends on $\eta_k$) - But $g_k$, $v_k$, $w_{ik}$ **remain unchanged** — they are on different branches The DAG encodes exactly this: `sorted_children["tau"]` returns the transitive closure of all downstream nodes, and the `State` invalidates their cached values. ## How `VariablesDAG` is built Each model class in the inheritance chain contributes its own variables via `get_variables_specs()`. These contributions are **accumulated using `super()`** — each class calls `super().get_variables_specs()` first, then adds its own variables on top: | Class | Contributes (mathematical notation → code name) | |---|---| | `McmcSaemCompatibleModel` | $t_{ij}$ (`t`), $y_{ijk}$ (`y`), $-\log p(...)$ (`nll_attach`) | | `TimeReparametrizedModel` | $\tau_i$, $\xi_i$, $\alpha_i$ (`alpha`), $\psi_i$ (`rt`), and if multivariate: $s_{il}$ (`sources`), $\beta_{ml}$ (`betas`), $w_{ik}$ (`space_shifts`), $A$ (`mixing_matrix`) | | `RiemanianManifoldModel` | $\log(v_k)$ (`log_v0`), $v_k$ (`v0`), metric, $B$ (`orthonormal_basis`), $\eta_k$ (`model`) | | `LogisticModel` | $\log(g_k)$ (`log_g`), $g_k$ (`g`), logistic-specific metric formula | Once all specs are collected into a single dictionary, `VariablesDAG.from_dict(specs)` analyzes the function signatures of `LinkedVariable` callables to infer edges automatically. For example, `LinkedVariable(Exp("log_g"))` tells the DAG: "I depend on `log_g`". ## Relation to `State` The `VariablesDAG` is a **static blueprint**: it describes what variables exist and how they relate. It does not hold values. The `State` object holds the **runtime values** for each variable. It keeps a reference to the DAG so it can propagate updates correctly. When you write `state["tau"] = new_value`, the `State` uses the DAG's `sorted_children["tau"]` to invalidate all downstream caches (`rt`, `model`, `nll_attach`). See [StatefulModel](../logistic/StatefulModel) for how the `State` is created and managed.