import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pymc3 as pm

print(f"Runing on PyMC3 v{pm.__version__}")
Runing on PyMC3 v3.11.0
%load_ext watermark
az.style.use("arviz-darkgrid")

Model comparison

To demonstrate the use of model comparison criteria in PyMC3, we implement the 8 schools example from Section 5.5 of Gelman et al (2003), which attempts to infer the effects of coaching on SAT scores of students from 8 schools. Below, we fit a pooled model, which assumes a single fixed effect across all schools, and a hierarchical model that allows for a random effect that partially pools the data.

The data include the observed treatment effects and associated standard deviations in the 8 schools.

J = 8
y = np.array([28, 8, -3, 7, -1, 1, 18, 12])
sigma = np.array([15, 10, 16, 11, 9, 11, 10, 18])

Pooled model

with pm.Model() as pooled:
    mu = pm.Normal("mu", 0, sigma=1e6)

    obs = pm.Normal("obs", mu, sigma=sigma, observed=y)

    trace_p = pm.sample(2000, return_inferencedata=True)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [mu]
100.00% [6000/6000 00:10<00:00 Sampling 2 chains, 0 divergences]
Sampling 2 chains for 1_000 tune and 2_000 draw iterations (2_000 + 4_000 draws total) took 31 seconds.
The acceptance probability does not match the target. It is 0.8786140471023212, but should be close to 0.8. Try to increase the number of tuning steps.
az.plot_trace(trace_p);
../_images/model_comparison_7_0.png

Hierarchical model

with pm.Model() as hierarchical:

    eta = pm.Normal("eta", 0, 1, shape=J)
    mu = pm.Normal("mu", 0, sigma=10)
    tau = pm.HalfNormal("tau", 10)

    theta = pm.Deterministic("theta", mu + tau * eta)

    obs = pm.Normal("obs", theta, sigma=sigma, observed=y)

    trace_h = pm.sample(2000, target_accept=0.9, return_inferencedata=True)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [tau, mu, eta]
100.00% [6000/6000 00:18<00:00 Sampling 2 chains, 0 divergences]
Sampling 2 chains for 1_000 tune and 2_000 draw iterations (2_000 + 4_000 draws total) took 34 seconds.
az.plot_trace(trace_h, var_names="mu");
../_images/model_comparison_10_0.png
az.plot_forest(trace_h, var_names="theta");
../_images/model_comparison_11_0.png

Leave-one-out Cross-validation (LOO)

LOO cross-validation is an estimate of the out-of-sample predictive fit. In cross-validation, the data are repeatedly partitioned into training and holdout sets, iteratively fitting the model with the former and evaluating the fit with the holdout data. Vehtari et al. (2016) introduced an efficient computation of LOO from MCMC samples (without the need for re-fitting the data). This approximation is based on importance sampling. The importance weights are stabilized using a method known as Pareto-smoothed importance sampling (PSIS).

Widely-applicable Information Criterion (WAIC)

WAIC (Watanabe 2010) is a fully Bayesian criterion for estimating out-of-sample expectation, using the computed log pointwise posterior predictive density (LPPD) and correcting for the effective number of parameters to adjust for overfitting.

By default ArviZ uses LOO, but WAIC is also available.

pooled_loo = az.loo(trace_p, pooled)

pooled_loo.loo
-30.569563398379955
hierarchical_loo = az.loo(trace_h, hierarchical)

hierarchical_loo.loo
-30.754274702021085

ArviZ includes two convenience functions to help compare LOO for different models. The first of these functions is compare, which computes LOO (or WAIC) from a set of traces and models and returns a DataFrame.

df_comp_loo = az.compare({"hierarchical": trace_h, "pooled": trace_p})
df_comp_loo
/Users/CloudChaoszero/opt/anaconda3/envs/pymc3-dev-py38/lib/python3.8/site-packages/arviz/stats/stats.py:146: UserWarning: The default method used to estimate the weights for each model,has changed from BB-pseudo-BMA to stacking
  warnings.warn(
rank loo p_loo d_loo weight se dse warning loo_scale
pooled 0 -30.569563 0.680583 0.000000 1.0 1.105191 0.0000 False log
hierarchical 1 -30.754275 1.113869 0.184711 0.0 1.045108 0.2397 False log

We have many columns, so let’s check out their meaning one by one:

  1. The index is the names of the models taken from the keys of the dictionary passed to compare(.).

  2. rank, the ranking of the models starting from 0 (best model) to the number of models.

  3. loo, the values of LOO (or WAIC). The DataFrame is always sorted from best LOO/WAIC to worst.

  4. p_loo, the value of the penalization term. We can roughly think of this value as the estimated effective number of parameters (but do not take that too seriously).

  5. d_loo, the relative difference between the value of LOO/WAIC for the top-ranked model and the value of LOO/WAIC for each model. For this reason we will always get a value of 0 for the first model.

  6. weight, the weights assigned to each model. These weights can be loosely interpreted as the probability of each model being true (among the compared models) given the data.

  7. se, the standard error for the LOO/WAIC computations. The standard error can be useful to assess the uncertainty of the LOO/WAIC estimates. By default these errors are computed using stacking.

  8. dse, the standard errors of the difference between two values of LOO/WAIC. The same way that we can compute the standard error for each value of LOO/WAIC, we can compute the standard error of the differences between two values of LOO/WAIC. Notice that both quantities are not necessarily the same, the reason is that the uncertainty about LOO/WAIC is correlated between models. This quantity is always 0 for the top-ranked model.

  9. warning, If True the computation of LOO/WAIC may not be reliable.

  10. loo_scale, the scale of the reported values. The default is the log scale as previously mentioned. Other options are deviance – this is the log-score multiplied by -2 (this reverts the order: a lower LOO/WAIC will be better) – and negative-log – this is the log-score multiplied by -1 (as with the deviance scale, a lower value is better).

The second convenience function takes the output of compare and produces a summary plot in the style of the one used in the book Statistical Rethinking by Richard McElreath (check also this port of the examples in the book to PyMC3).

az.plot_compare(df_comp_loo, insample_dev=False);
../_images/model_comparison_19_0.png

The empty circle represents the values of LOO and the black error bars associated with them are the values of the standard deviation of LOO.

The value of the highest LOO, i.e the best estimated model, is also indicated with a vertical dashed grey line to ease comparison with other LOO values.

For all models except the top-ranked one we also get a triangle indicating the value of the difference of WAIC between that model and the top model and a grey errobar indicating the standard error of the differences between the top-ranked WAIC and WAIC for each model.

Interpretation

Though we might expect the hierarchical model to outperform a complete pooling model, there is little to choose between the models in this case, given that both models gives very similar values of the information criteria. This is more clearly appreciated when we take into account the uncertainty (in terms of standard errors) of LOO and WAIC.

Reference

Gelman, A., Hwang, J., & Vehtari, A. (2014). Understanding predictive information criteria for Bayesian models. Statistics and Computing, 24(6), 997–1016.

Vehtari, A, Gelman, A, Gabry, J. (2016). Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC. Statistics and Computing

%watermark -n -u -v -iv -w
Last updated: Sun Feb 07 2021

Python implementation: CPython
Python version       : 3.8.6
IPython version      : 7.20.0

pymc3     : 3.11.0
matplotlib: None
arviz     : 0.11.0
numpy     : 1.20.0

Watermark: 2.1.0