Source code for pymc.step_methods.hmc.nuts

#   Copyright 2024 The PyMC Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.

from __future__ import annotations

from collections import namedtuple

import numpy as np

from pymc.math import logbern
from pymc.pytensorf import floatX
from pymc.stats.convergence import SamplerWarning
from pymc.step_methods.compound import Competence
from pymc.step_methods.hmc import integration
from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData
from pymc.step_methods.hmc.integration import IntegrationError, State
from pymc.vartypes import continuous_types

__all__ = ["NUTS"]


[docs] class NUTS(BaseHMC): r"""A sampler for continuous variables based on Hamiltonian mechanics. NUTS automatically tunes the step size and the number of steps per sample. A detailed description can be found at [1], "Algorithm 6: Efficient No-U-Turn Sampler with Dual Averaging". NUTS provides a number of statistics that can be accessed with `trace.get_sampler_stats`: - `mean_tree_accept`: The mean acceptance probability for the tree that generated this sample. The mean of these values across all samples but the burn-in should be approximately `target_accept` (the default for this is 0.8). - `diverging`: Whether the trajectory for this sample diverged. If there are any divergences after burnin, this indicates that the results might not be reliable. Reparametrization can often help, but you can also try to increase `target_accept` to something like 0.9 or 0.95. - `energy`: The energy at the point in phase-space where the sample was accepted. This can be used to identify posteriors with problematically long tails. See below for an example. - `energy_change`: The difference in energy between the start and the end of the trajectory. For a perfect integrator this would always be zero. - `max_energy_change`: The maximum difference in energy along the whole trajectory. - `depth`: The depth of the tree that was used to generate this sample - `tree_size`: The number of leafs of the sampling tree, when the sample was accepted. This is usually a bit less than `2 ** depth`. If the tree size is large, the sampler is using a lot of leapfrog steps to find the next sample. This can for example happen if there are strong correlations in the posterior, if the posterior has long tails, if there are regions of high curvature ("funnels"), or if the variance estimates in the mass matrix are inaccurate. Reparametrisation of the model or estimating the posterior variances from past samples might help. - `tune`: This is `True`, if step size adaptation was turned on when this sample was generated. - `step_size`: The step size used for this sample. - `step_size_bar`: The current best known step-size. After the tuning samples, the step size is set to this value. This should converge during tuning. - `model_logp`: The model log-likelihood for this sample. - `process_time_diff`: The time it took to draw the sample, as defined by the python standard library `time.process_time`. This counts all the CPU time, including worker processes in BLAS and OpenMP. - `perf_counter_diff`: The time it took to draw the sample, as defined by the python standard library `time.perf_counter` (wall time). - `perf_counter_start`: The value of `time.perf_counter` at the beginning of the computation of the draw. - `index_in_trajectory`: This is usually only interesting for debugging purposes. This indicates the position of the posterior draw in the trajectory. Eg a -4 would indicate that the draw was the result of the fourth leapfrog step in negative direction. - `largest_eigval` and `smallest_eigval`: Experimental statistics for some mass matrix adaptation algorithms. This is nan if it is not used. References ---------- .. [1] Hoffman, Matthew D., & Gelman, Andrew. (2011). The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo. """ name = "nuts" default_blocked = True stats_dtypes_shapes = { "depth": (np.int64, []), "step_size": (np.float64, []), "tune": (bool, []), "mean_tree_accept": (np.float64, []), "step_size_bar": (np.float64, []), "tree_size": (np.float64, []), "diverging": (bool, []), "energy_error": (np.float64, []), "energy": (np.float64, []), "max_energy_error": (np.float64, []), "model_logp": (np.float64, []), "process_time_diff": (np.float64, []), "perf_counter_diff": (np.float64, []), "perf_counter_start": (np.float64, []), "largest_eigval": (np.float64, []), "smallest_eigval": (np.float64, []), "index_in_trajectory": (np.int64, []), "reached_max_treedepth": (bool, []), "warning": (SamplerWarning, None), }
[docs] def __init__(self, vars=None, max_treedepth=10, early_max_treedepth=8, **kwargs): r"""Set up the No-U-Turn sampler. Parameters ---------- vars: list, default=None List of value variables. If None, all continuous RVs from the model are included. Emax: float, default 1000 Maximum energy change allowed during leapfrog steps. Larger deviations will abort the integration. target_accept: float, default .8 Adapt the step size such that the average acceptance probability across the trajectories are close to target_accept. Higher values for target_accept lead to smaller step sizes. Setting this to higher values like 0.9 or 0.99 can help with sampling from difficult posteriors. Valid values are between 0 and 1 (exclusive). step_scale: float, default 0.25 Size of steps to take, automatically scaled down by `1/n**(1/4)`. If step size adaptation is switched off, the resulting step size is used. If adaptation is enabled, it is used as initial guess. gamma: float, default .05 k: float, default .75 Parameter for dual averaging for step size adaptation. Values between 0.5 and 1 (exclusive) are admissible. Higher values correspond to slower adaptation. t0: int, default 10 Parameter for dual averaging. Higher values slow initial adaptation. adapt_step_size: bool, default=True Whether step size adaptation should be enabled. If this is disabled, `k`, `t0`, `gamma` and `target_accept` are ignored. max_treedepth: int, default=10 The maximum tree depth. Trajectories are stopped when this depth is reached. early_max_treedepth: int, default=8 The maximum tree depth during the first 200 tuning samples. scaling: array_like, ndim = {1,2} The inverse mass, or precision matrix. One dimensional arrays are interpreted as diagonal matrices. If `is_cov` is set to True, this will be interpreted as the mass or covariance matrix. is_cov: bool, default=False Treat the scaling as mass or covariance matrix. potential: Potential, optional An object that represents the Hamiltonian with methods `velocity`, `energy`, and `random` methods. It can be specified instead of the scaling matrix. model: pymc.Model The model kwargs: passed to BaseHMC Notes ----- The step size adaptation stops when `self.tune` is set to False. This is usually achieved by setting the `tune` parameter if `pm.sample` to the desired number of tuning steps. """ super().__init__(vars, **kwargs) self.max_treedepth = max_treedepth self.early_max_treedepth = early_max_treedepth self._reached_max_treedepth = 0
def _hamiltonian_step(self, start, p0, step_size): if self.tune and self.iter_count < 200: max_treedepth = self.early_max_treedepth else: max_treedepth = self.max_treedepth tree = _Tree(len(p0), self.integrator, start, step_size, self.Emax) reached_max_treedepth = False for _ in range(max_treedepth): direction = logbern(np.log(0.5)) * 2 - 1 divergence_info, turning = tree.extend(direction) if divergence_info or turning: break else: # no-break reached_max_treedepth = not self.tune stats = tree.stats() accept_stat = stats["mean_tree_accept"] stats["reached_max_treedepth"] = reached_max_treedepth return HMCStepData(tree.proposal, accept_stat, divergence_info, stats)
[docs] @staticmethod def competence(var, has_grad): """Check how appropriate this class is for sampling a random variable.""" if var.dtype in continuous_types and has_grad: return Competence.PREFERRED return Competence.INCOMPATIBLE
# A proposal for the next position Proposal = namedtuple("Proposal", "q, q_grad, energy, logp, index_in_trajectory") # A subtree of the binary tree built by nuts. Subtree = namedtuple( "Subtree", "left, right, p_sum, proposal, log_size", ) class _Tree: def __init__( self, ndim: int, integrator: integration.CpuLeapfrogIntegrator, start: State, step_size: float, Emax: float, ): """Binary tree from the NUTS algorithm. Parameters ---------- leapfrog: function A function that performs a single leapfrog step. start: integration.State The starting point of the trajectory. step_size: float The step size to use in this tree Emax: float The maximum energy change to accept before aborting the transition as diverging. """ self.ndim = ndim self.integrator = integrator self.start = start self.step_size = step_size self.Emax = Emax self.start_energy = start.energy self.left = self.right = start self.proposal = Proposal(start.q.data, start.q_grad, start.energy, start.model_logp, 0) self.depth = 0 self.log_size = 0.0 self.log_accept_sum = -np.inf self.mean_tree_accept = 0.0 self.n_proposals = 0 self.p_sum = start.p.data.copy() self.max_energy_change = 0.0 def extend(self, direction): """Double the treesize by extending the tree in the given direction. If direction is larger than 0, extend it to the right, otherwise extend it to the left. Return a tuple `(diverging, turning)` of type (DivergenceInfo, bool). `diverging` indicates, that the tree extension was aborted because the energy change exceeded `self.Emax`. `turning` indicates that the tree extension was stopped because the termination criterior was reached (the trajectory is turning back). """ if direction > 0: tree, diverging, turning = self._build_subtree( self.right, self.depth, floatX(np.asarray(self.step_size)) ) leftmost_begin, leftmost_end = self.left, self.right rightmost_begin, rightmost_end = tree.left, tree.right leftmost_p_sum = self.p_sum.copy() rightmost_p_sum = tree.p_sum self.right = tree.right else: tree, diverging, turning = self._build_subtree( self.left, self.depth, floatX(np.asarray(-self.step_size)) ) leftmost_begin, leftmost_end = tree.right, tree.left rightmost_begin, rightmost_end = self.left, self.right leftmost_p_sum = tree.p_sum rightmost_p_sum = self.p_sum.copy() self.left = tree.right self.depth += 1 if diverging or turning: return diverging, turning size1, size2 = self.log_size, tree.log_size if logbern(size2 - size1): self.proposal = tree.proposal self.log_size = np.logaddexp(self.log_size, tree.log_size) self.p_sum[:] += tree.p_sum # Additional turning check only when tree depth > 0 to avoid redundant work if self.depth > 0: left, right = self.left, self.right p_sum = self.p_sum turning = (p_sum.dot(left.v) <= 0) or (p_sum.dot(right.v) <= 0) p_sum1 = leftmost_p_sum + rightmost_begin.p.data turning1 = (p_sum1.dot(leftmost_begin.v) <= 0) or (p_sum1.dot(rightmost_begin.v) <= 0) p_sum2 = leftmost_end.p.data + rightmost_p_sum turning2 = (p_sum2.dot(leftmost_end.v) <= 0) or (p_sum2.dot(rightmost_end.v) <= 0) turning = turning | turning1 | turning2 return diverging, turning def _single_step(self, left: State, epsilon: float): """Perform a leapfrog step and handle error cases.""" right: State | None error: IntegrationError | None error_msg: str | None try: right = self.integrator.step(epsilon, left) except IntegrationError as err: error_msg = str(err) error = err right = None else: assert right is not None # since there was no IntegrationError # h - H0 energy_change = right.energy - self.start_energy if np.isnan(energy_change): energy_change = np.inf self.log_accept_sum = np.logaddexp(self.log_accept_sum, min(0, -energy_change)) if np.abs(energy_change) > np.abs(self.max_energy_change): self.max_energy_change = energy_change if energy_change < self.Emax: # Acceptance statistic # e^{H(q_0, p_0) - H(q_n, p_n)} max(1, e^{H(q_0, p_0) - H(q_n, p_n)}) # Saturated Metropolis accept probability with Boltzmann weight log_size = -energy_change proposal = Proposal( right.q.data, right.q_grad, right.energy, right.model_logp, right.index_in_trajectory, ) tree = Subtree(right, right, right.p.data, proposal, log_size) return tree, None, False else: error_msg = f"Energy change in leapfrog step is too large: {energy_change}." error = None finally: self.n_proposals += 1 tree = Subtree(None, None, None, None, -np.inf) divergence_info = DivergenceInfo(error_msg, error, left, right) return tree, divergence_info, False def _build_subtree(self, left, depth, epsilon): if depth == 0: return self._single_step(left, epsilon) tree1, diverging, turning = self._build_subtree(left, depth - 1, epsilon) if diverging or turning: return tree1, diverging, turning tree2, diverging, turning = self._build_subtree(tree1.right, depth - 1, epsilon) left, right = tree1.left, tree2.right if not (diverging or turning): p_sum = tree1.p_sum + tree2.p_sum turning = (p_sum.dot(left.v) <= 0) or (p_sum.dot(right.v) <= 0) # Additional U turn check only when depth > 1 to avoid redundant work. if depth - 1 > 0: p_sum1 = tree1.p_sum + tree2.left.p.data turning1 = (p_sum1.dot(tree1.left.v) <= 0) or (p_sum1.dot(tree2.left.v) <= 0) p_sum2 = tree1.right.p.data + tree2.p_sum turning2 = (p_sum2.dot(tree1.right.v) <= 0) or (p_sum2.dot(tree2.right.v) <= 0) turning = turning | turning1 | turning2 log_size = np.logaddexp(tree1.log_size, tree2.log_size) if logbern(tree2.log_size - log_size): proposal = tree2.proposal else: proposal = tree1.proposal else: p_sum = tree1.p_sum log_size = tree1.log_size proposal = tree1.proposal tree = Subtree(left, right, p_sum, proposal, log_size) return tree, diverging, turning def stats(self): self.mean_tree_accept = np.exp(self.log_accept_sum) / self.n_proposals return { "depth": self.depth, "mean_tree_accept": self.mean_tree_accept, "energy_error": self.proposal.energy - self.start.energy, "energy": self.proposal.energy, "tree_size": self.n_proposals, "max_energy_error": self.max_energy_change, "model_logp": self.proposal.logp, "index_in_trajectory": self.proposal.index_in_trajectory, }