MMM.compute_mean_contributions_over_time#
- MMM.compute_mean_contributions_over_time()[source]#
Posterior-mean counterfactual contributions as a DataFrame.
Convenience wrapper around
compute_counterfactual_contributions_dataset()that averages over(chain, draw)and returns a flatpd.DataFrame.Each column answers a counterfactual question: “how much would the predicted \(\hat y(t)\) decrease if we removed this component?”
Formally, for component \(j\) with value \(v_j(t)\) in the linear predictor:
\[\text{contribution}_j(t) = \mathbb{E}\bigl[\text{inv}(\mu) \cdot s\bigr] - \mathbb{E}\bigl[\text{inv}(\mu - v_j) \cdot s\bigr]\]where \(\text{inv}\) is the inverse link function and \(s\) is
target_scale.For identity-link (additive) models this reduces to \(\mathbb{E}[v_j] \cdot s\), and the columns sum exactly to \(\hat y(t)\).
For log-link (multiplicative) models this computes a genuine per-component counterfactual. Because interaction effects are counted by every component that participates in them, the columns sum to more than \(\hat y(t)\). This is an expected property of per-component counterfactuals in a multiplicative model, not a defect.
This method does not require
add_original_scale_contribution_variable()to have been called.- Returns:
pd.DataFrameWide-format DataFrame with one row per observation (date x extra dims). Columns include:
date– date coordinateExtra dimension columns (e.g.
geo) when the model is multidimensionalOne column per channel (named after channel coordinate labels)
One column per control variable (if present)
yearly_seasonality(if yearly seasonality is enabled)intercept
- Raises:
ValueErrorIf the model has not been fitted (no
idata).
See also
compute_counterfactual_contributions_datasetFull posterior as an
xr.Dataset(retains chain/draw dims).add_original_scale_contribution_variablePre-compute original-scale deterministics inside the model graph.
MMMIDataWrapper.get_contributionsFull posterior contributions as an
xr.Dataset.
Examples
mmm.fit(X, y) contributions_df = mmm.compute_mean_contributions_over_time()