Censored Data Models

[1]:
from copy import copy

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pymc3 as pm
import seaborn as sns

from numpy.random import default_rng
[2]:
%config InlineBackend.figure_format = 'retina'
rng = default_rng(1234)
az.style.use("arviz-darkgrid")

This example notebook on Bayesian survival analysis touches on the point of censored data. Censoring is a form of missing-data problem, in which observations greater than a certain threshold are clipped down to that threshold, or observations less than a certain threshold are clipped up to that threshold, or both. These are called right, left and interval censoring, respectively. In this example notebook we consider interval censoring.

Censored data arises in many modelling problems. Two common examples are:

  1. Survival analysis: when studying the effect of a certain medical treatment on survival times, it is impossible to prolong the study until all subjects have died. At the end of the study, the only data collected for many patients is that they were still alive for a time period \(T\) after the treatment was administered: in reality, their true survival times are greater than \(T\).

  2. Sensor saturation: a sensor might have a limited range and the upper and lower limits would simply be the highest and lowest values a sensor can report. For instance, many mercury thermometers only report a very narrow range of temperatures.

This example notebook presents two different ways of dealing with censored data in PyMC3:

  1. An imputed censored model, which represents censored data as parameters and makes up plausible values for all censored values. As a result of this imputation, this model is capable of generating plausible sets of made-up values that would have been censored. Each censored element introduces a random variable.

  2. An unimputed censored model, where the censored data are integrated out and accounted for only through the log-likelihood. This method deals more adequately with large amounts of censored data and converges more quickly.

To establish a baseline we compare to an uncensored model of the uncensored data.

[3]:
# Produce normally distributed samples
size = 500
true_mu = 13.0
true_sigma = 5.0
samples = rng.normal(true_mu, true_sigma, size)

# Set censoring limits
low = 3.0
high = 16.0


def censor(x, low, high):
    x = copy(x)
    x[x <= low] = low
    x[x >= high] = high
    return x


# Censor samples
censored = censor(samples, low, high)
[4]:
# Visualize uncensored and censored data
_, ax = plt.subplots(figsize=(10, 3))
edges = np.linspace(-5, 35, 30)
ax.hist(samples, bins=edges, density=True, histtype="stepfilled", alpha=0.2, label="Uncensored")
ax.hist(censored, bins=edges, density=True, histtype="stepfilled", alpha=0.2, label="Censored")
[ax.axvline(x=x, c="k", ls="--") for x in [low, high]]
ax.legend();
../../../_images/pymc-examples_examples_survival_analysis_censored_data_5_0.png

Uncensored Model

[5]:
def uncensored_model(data):
    with pm.Model() as model:
        mu = pm.Normal("mu", mu=((high - low) / 2) + low, sigma=(high - low))
        sigma = pm.HalfNormal("sigma", sigma=(high - low) / 2.0)
        observed = pm.Normal("observed", mu=mu, sigma=sigma, observed=data)
    return model

We should predict that running the uncensored model on uncensored data, we will get reasonable estimates of the mean and variance.

[6]:
uncensored_model_1 = uncensored_model(samples)
with uncensored_model_1:
    trace = pm.sample(tune=1000, return_inferencedata=True)
    az.plot_posterior(trace, ref_val=[true_mu, true_sigma], round_to=3);
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [sigma, mu]
100.00% [8000/8000 00:02<00:00 Sampling 4 chains, 0 divergences]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 12 seconds.
../../../_images/pymc-examples_examples_survival_analysis_censored_data_9_3.png

And that is exactly what we find.

The problem however, is that in censored data contexts, we do not have access to the true values. If we were to use the same uncensored model on the censored data, we would anticipate that our parameter estimates will be biased. If we calculate point estimates for the mean and std, then we can see that we are likely to underestimate the mean and std for this particular dataset and censor bounds.

[7]:
np.mean(censored), np.std(censored)
[7]:
(12.320820690475099, 3.7617563701385888)
[8]:
uncensored_model_2 = uncensored_model(censored)
with uncensored_model_2:
    trace = pm.sample(tune=1000, return_inferencedata=True)
    az.plot_posterior(trace, ref_val=[true_mu, true_sigma], round_to=3);
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [sigma, mu]
100.00% [8000/8000 00:02<00:00 Sampling 4 chains, 0 divergences]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 11 seconds.
../../../_images/pymc-examples_examples_survival_analysis_censored_data_12_3.png

The figure above confirms this.

Censored data models

The models below show 2 approaches to dealing with censored data. First, we need to do a bit of data pre-processing to count the number of observations that are left or right censored. We also also need to extract just the non-censored data that we observe.

[9]:
n_right_censored = sum(censored >= high)
n_left_censored = sum(censored <= low)
n_observed = len(censored) - n_right_censored - n_left_censored
uncensored = censored[(censored > low) & (censored < high)]
assert len(uncensored) == n_observed

Model 1 - Imputed Censored Model of Censored Data

In this model, we impute the censored values from the same distribution as the uncensored data. Sampling from the posterior generates possible uncensored data sets.

This model makes use of PyMC3’s bounded variables.

[10]:
with pm.Model() as imputed_censored_model:
    mu = pm.Normal("mu", mu=((high - low) / 2) + low, sigma=(high - low))
    sigma = pm.HalfNormal("sigma", sigma=(high - low) / 2.0)
    right_censored = pm.Bound(pm.Normal, lower=high)(
        "right_censored", mu=mu, sigma=sigma, shape=n_right_censored
    )
    left_censored = pm.Bound(pm.Normal, upper=low)(
        "left_censored", mu=mu, sigma=sigma, shape=n_left_censored
    )
    observed = pm.Normal("observed", mu=mu, sigma=sigma, observed=uncensored, shape=n_observed)
[11]:
with imputed_censored_model:
    trace = pm.sample(return_inferencedata=True)
    az.plot_posterior(trace, var_names=["mu", "sigma"], ref_val=[true_mu, true_sigma], round_to=3);
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [left_censored, right_censored, sigma, mu]
100.00% [8000/8000 00:06<00:00 Sampling 4 chains, 0 divergences]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 15 seconds.
../../../_images/pymc-examples_examples_survival_analysis_censored_data_17_3.png

We can see that the bias in the estimates of the mean and variance (present in the uncensored model) have been largely removed.

Model 2 - Unimputed Censored Model of Censored Data

In this model, we do not impute censored data, but instead integrate them out through the likelihood.

The implementations of the likelihoods are non-trivial. See the Stan manual (section 11.3 on censored data) and the original PyMC3 issue on GitHub for more information.

This model makes use of PyMC3’s ``Potential` <https://docs.pymc.io/api/model.html#pymc3.model.Potential>`__.

[12]:
# Import the log cdf and log complementary cdf of the normal Distribution from PyMC3
from pymc3.distributions.dist_math import normal_lccdf, normal_lcdf


# Helper functions for unimputed censored model
def left_censored_likelihood(mu, sigma, n_left_censored, lower_bound):
    """ Likelihood of left-censored data. """
    return n_left_censored * normal_lcdf(mu, sigma, lower_bound)


def right_censored_likelihood(mu, sigma, n_right_censored, upper_bound):
    """ Likelihood of right-censored data. """
    return n_right_censored * normal_lccdf(mu, sigma, upper_bound)
[13]:
with pm.Model() as unimputed_censored_model:
    mu = pm.Normal("mu", mu=0.0, sigma=(high - low) / 2.0)
    sigma = pm.HalfNormal("sigma", sigma=(high - low) / 2.0)
    observed = pm.Normal(
        "observed",
        mu=mu,
        sigma=sigma,
        observed=uncensored,
    )
    left_censored = pm.Potential(
        "left_censored", left_censored_likelihood(mu, sigma, n_left_censored, low)
    )
    right_censored = pm.Potential(
        "right_censored", right_censored_likelihood(mu, sigma, n_right_censored, high)
    )

Sampling

[14]:
with unimputed_censored_model:
    trace = pm.sample(tune=1000, return_inferencedata=True)
    az.plot_posterior(trace, var_names=["mu", "sigma"], ref_val=[true_mu, true_sigma], round_to=3);
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
WARNING (theano.tensor.opt): Optimization Warning: The Op erfcx does not provide a C implementation. As well as being potentially slow, this also disables loop fusion.
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [sigma, mu]
100.00% [8000/8000 00:02<00:00 Sampling 4 chains, 0 divergences]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 11 seconds.
../../../_images/pymc-examples_examples_survival_analysis_censored_data_23_3.png

Again, the bias in the estimates of the mean and variance (present in the uncensored model) have been largely removed.

Discussion

As we can see, both censored models appear to capture the mean and variance of the underlying distribution as well as the uncensored model! In addition, the imputed censored model is capable of generating data sets of censored values (sample from the posteriors of left_censored and right_censored to generate them), while the unimputed censored model scales much better with more censored data, and converges faster.

Authors

[15]:
%load_ext watermark
%watermark -n -u -v -iv -w -p theano,xarray
Last updated: Sat May 22 2021

Python implementation: CPython
Python version       : 3.8.5
IPython version      : 7.20.0

theano: 1.1.2
xarray: 0.16.2

arviz     : 0.11.0
matplotlib: 3.3.2
pymc3     : 3.11.1
numpy     : 1.19.2
seaborn   : 0.11.1

Watermark: 2.1.0