from __future__ import absolute_import, division, print_function
from collections import namedtuple
import torch
import pyro
import pyro.distributions as dist
from pyro.distributions.util import logsumexp
from pyro.infer.mcmc.hmc import HMC
from pyro.ops.integrator import velocity_verlet
from pyro.util import optional, torch_isnan
# sum_accept_probs and num_proposals are used to calculate
# the statistic accept_prob for Dual Averaging scheme;
# z_left_grads and z_right_grads are kept to avoid recalculating
# grads at left and right leaves;
# r_sum is used to check turning condition;
# z_proposal_pe and z_proposal_grads are used to cache the
# potential energy and potential energy gradient values for
# the proposal trace.
# weight is the number of valid points in case we use slice sampling
# and is the log sum of (unnormalized) probabilites of valid points
# when we use multinomial sampling
_TreeInfo = namedtuple("TreeInfo", ["z_left", "r_left", "z_left_grads",
"z_right", "r_right", "z_right_grads",
"z_proposal", "z_proposal_pe", "z_proposal_grads",
"r_sum", "weight", "turning", "diverging",
"sum_accept_probs", "num_proposals"])
[docs]class NUTS(HMC):
"""
No-U-Turn Sampler kernel, which provides an efficient and convenient way
to run Hamiltonian Monte Carlo. The number of steps taken by the
integrator is dynamically adjusted on each call to ``sample`` to ensure
an optimal length for the Hamiltonian trajectory [1]. As such, the samples
generated will typically have lower autocorrelation than those generated
by the :class:`~pyro.infer.mcmc.HMC` kernel. Optionally, the NUTS kernel
also provides the ability to adapt step size during the warmup phase.
Refer to the `baseball example <https://github.com/uber/pyro/blob/dev/examples/baseball.py>`_
to see how to do Bayesian inference in Pyro using NUTS.
**References**
[1] `The No-U-turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo`,
Matthew D. Hoffman, and Andrew Gelman.
[2] `A Conceptual Introduction to Hamiltonian Monte Carlo`, Michael Betancourt
[3] `Slice Sampling`, Radford M. Neal
:param model: Python callable containing Pyro primitives.
:param float step_size: Determines the size of a single step taken by the
verlet integrator while computing the trajectory using Hamiltonian
dynamics. If not specified, it will be set to 1.
:param bool adapt_step_size: A flag to decide if we want to adapt step_size
during warm-up phase using Dual Averaging scheme.
:param bool adapt_mass_matrix: A flag to decide if we want to adapt mass
matrix during warm-up phase using Welford scheme.
:param bool full_mass: A flag to decide if mass matrix is dense or diagonal.
:param bool use_multinomial_sampling: A flag to decide if we want to sample
candidates along its trajectory using "multinomial sampling" or using
"slice sampling". Slice sampling is used in the original NUTS paper [1],
while multinomial sampling is suggested in [2]. By default, this flag is
set to True. If it is set to `False`, NUTS uses slice sampling.
:param dict transforms: Optional dictionary that specifies a transform
for a sample site with constrained support to unconstrained space. The
transform should be invertible, and implement `log_abs_det_jacobian`.
If not specified and the model has sites with constrained support,
automatic transformations will be applied, as specified in
:mod:`torch.distributions.constraint_registry`.
:param int max_plate_nesting: Optional bound on max number of nested
:func:`pyro.plate` contexts. This is required if model contains
discrete sample sites that can be enumerated over in parallel.
:param bool jit_compile: Optional parameter denoting whether to use
the PyTorch JIT to trace the log density computation, and use this
optimized executable trace in the integrator.
Example:
>>> true_coefs = torch.tensor([1., 2., 3.])
>>> data = torch.randn(2000, 3)
>>> dim = 3
>>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample()
>>>
>>> def model(data):
... coefs_mean = torch.zeros(dim)
... coefs = pyro.sample('beta', dist.Normal(coefs_mean, torch.ones(3)))
... y = pyro.sample('y', dist.Bernoulli(logits=(coefs * data).sum(-1)), obs=labels)
... return y
>>>
>>> nuts_kernel = NUTS(model, adapt_step_size=True)
>>> mcmc_run = MCMC(nuts_kernel, num_samples=500, warmup_steps=300).run(data)
>>> posterior = mcmc_run.marginal('beta').empirical['beta']
>>> posterior.mean # doctest: +SKIP
tensor([ 0.9221, 1.9464, 2.9228])
"""
def __init__(self,
model,
step_size=1,
adapt_step_size=True,
adapt_mass_matrix=True,
full_mass=False,
use_multinomial_sampling=True,
transforms=None,
max_plate_nesting=None,
jit_compile=False,
ignore_jit_warnings=False):
super(NUTS, self).__init__(model,
step_size,
adapt_step_size=adapt_step_size,
adapt_mass_matrix=adapt_mass_matrix,
full_mass=full_mass,
transforms=transforms,
max_plate_nesting=max_plate_nesting,
jit_compile=jit_compile,
ignore_jit_warnings=ignore_jit_warnings)
self.use_multinomial_sampling = use_multinomial_sampling
self._max_tree_depth = 10 # from Stan
# There are three conditions to stop doubling process:
# + Tree is becoming too big.
# + The trajectory is making a U-turn.
# + The probability of the states becoming negligible: p(z, r) << u,
# here u is the "slice" variable introduced at the `self.sample(...)` method.
# Denote E_p = -log p(z, r), E_u = -log u, the third condition is equivalent to
# sliced_energy := E_p - E_u > some constant =: max_sliced_energy.
# This also suggests the notion "diverging" in the implemenation:
# when the energy E_p diverges from E_u too much, we stop doubling.
# Here, as suggested in [1], we set dE_max = 1000.
self._max_sliced_energy = 1000
def _is_turning(self, r_left, r_right, r_sum):
# We follow the strategy in Section A.4.2 of [2] for this implementation.
r_left_flat = torch.cat([r_left[site_name].reshape(-1) for site_name in sorted(r_left)])
r_right_flat = torch.cat([r_right[site_name].reshape(-1) for site_name in sorted(r_right)])
# TODO: change to torch.dot for pytorch 1.0
if self.full_mass:
if (((r_sum - r_left_flat) * (self.inverse_mass_matrix.matmul(r_left_flat)))
.sum() > 0 and
((r_sum - r_right_flat) * (self.inverse_mass_matrix.matmul(r_right_flat)))
.sum() > 0):
return False
else:
if ((self.inverse_mass_matrix * (r_sum - r_left_flat) * r_left_flat).sum() > 0 and
(self.inverse_mass_matrix * (r_sum - r_right_flat) * r_right_flat).sum() > 0):
return False
return True
def _build_basetree(self, z, r, z_grads, log_slice, direction, energy_current):
step_size = self.step_size if direction == 1 else -self.step_size
z_new, r_new, z_grads, potential_energy = velocity_verlet(
z, r, self._potential_energy, self.inverse_mass_matrix, step_size, z_grads=z_grads)
r_new_flat = torch.cat([r_new[site_name].reshape(-1) for site_name in sorted(r_new)])
energy_new = potential_energy + self._kinetic_energy(r_new)
# handle the NaN case
energy_new = energy_new.new_tensor(float("inf")) if torch_isnan(energy_new) else energy_new
sliced_energy = energy_new + log_slice
diverging = (sliced_energy > self._max_sliced_energy)
delta_energy = energy_new - energy_current
accept_prob = (-delta_energy).exp().clamp(max=1.0)
if self.use_multinomial_sampling:
tree_weight = -sliced_energy
else:
# As a part of the slice sampling process (see below), along the trajectory
# we eliminate states which p(z, r) < u, or dE > 0.
# Due to this elimination (and stop doubling conditions),
# the weight of binary tree might not equal to 2^tree_depth.
tree_weight = (sliced_energy.new_ones(()) if sliced_energy <= 0
else sliced_energy.new_zeros(()))
return _TreeInfo(z_new, r_new, z_grads, z_new, r_new, z_grads, z_new, potential_energy,
z_grads, r_new_flat, tree_weight, False, diverging, accept_prob, 1)
def _build_tree(self, z, r, z_grads, log_slice, direction, tree_depth, energy_current):
if tree_depth == 0:
return self._build_basetree(z, r, z_grads, log_slice, direction, energy_current)
# build the first half of tree
half_tree = self._build_tree(z, r, z_grads, log_slice,
direction, tree_depth-1, energy_current)
z_proposal = half_tree.z_proposal
z_proposal_pe = half_tree.z_proposal_pe
z_proposal_grads = half_tree.z_proposal_grads
# Check conditions to stop doubling. If we meet that condition,
# there is no need to build the other tree.
if half_tree.turning or half_tree.diverging:
return half_tree
# Else, build remaining half of tree.
# If we are going to the right, start from the right leaf of the first half.
if direction == 1:
z = half_tree.z_right
r = half_tree.r_right
z_grads = half_tree.z_right_grads
else: # otherwise, start from the left leaf of the first half
z = half_tree.z_left
r = half_tree.r_left
z_grads = half_tree.z_left_grads
other_half_tree = self._build_tree(z, r, z_grads, log_slice,
direction, tree_depth-1, energy_current)
if self.use_multinomial_sampling:
tree_weight = logsumexp(torch.stack([half_tree.weight, other_half_tree.weight]), dim=0)
else:
tree_weight = half_tree.weight + other_half_tree.weight
sum_accept_probs = half_tree.sum_accept_probs + other_half_tree.sum_accept_probs
num_proposals = half_tree.num_proposals + other_half_tree.num_proposals
r_sum = half_tree.r_sum + other_half_tree.r_sum
# The probability of that proposal belongs to which half of tree
# is computed based on the weights of each half.
if self.use_multinomial_sampling:
other_half_tree_prob = (other_half_tree.weight - tree_weight).exp()
else:
# For the special case that the weights of each half are both 0,
# we choose the proposal from the first half
# (any is fine, because the probability of picking it at the end is 0!).
other_half_tree_prob = (other_half_tree.weight / tree_weight if tree_weight > 0
else tree_weight.new_zeros(()))
is_other_half_tree = pyro.sample("is_other_half_tree",
dist.Bernoulli(probs=other_half_tree_prob))
if is_other_half_tree == 1:
z_proposal = other_half_tree.z_proposal
z_proposal_pe = other_half_tree.z_proposal_pe
z_proposal_grads = other_half_tree.z_proposal_grads
# leaves of the full tree are determined by the direction
if direction == 1:
z_left = half_tree.z_left
r_left = half_tree.r_left
z_left_grads = half_tree.z_left_grads
z_right = other_half_tree.z_right
r_right = other_half_tree.r_right
z_right_grads = other_half_tree.z_right_grads
else:
z_left = other_half_tree.z_left
r_left = other_half_tree.r_left
z_left_grads = other_half_tree.z_left_grads
z_right = half_tree.z_right
r_right = half_tree.r_right
z_right_grads = half_tree.z_right_grads
# We already check if first half tree is turning. Now, we check
# if the other half tree or full tree are turning.
turning = other_half_tree.turning or self._is_turning(r_left, r_right, r_sum)
# The divergence is checked by the second half tree (the first half is already checked).
diverging = other_half_tree.diverging
return _TreeInfo(z_left, r_left, z_left_grads, z_right, r_right, z_right_grads, z_proposal,
z_proposal_pe, z_proposal_grads, r_sum, tree_weight, turning, diverging,
sum_accept_probs, num_proposals)
[docs] def sample(self, trace):
z = {name: node["value"].detach() for name, node in self._iter_latent_nodes(trace)}
potential_energy, z_grads = self._fetch_from_cache()
# automatically transform `z` to unconstrained space, if needed.
for name, transform in self.transforms.items():
z[name] = transform(z[name])
r, r_flat = self._sample_r(name="r_t={}".format(self._t))
energy_current = self._kinetic_energy(r) + potential_energy if potential_energy is not None \
else self._energy(z, r)
# Ideally, following a symplectic integrator trajectory, the energy is constant.
# In that case, we can sample the proposal uniformly, and there is no need to use "slice".
# However, it is not the case for real situation: there are errors during the computation.
# To deal with that problem, as in [1], we introduce an auxiliary "slice" variable (denoted
# by u).
# The sampling process goes as follows:
# first sampling u from initial state (z_0, r_0) according to
# u ~ Uniform(0, p(z_0, r_0)),
# then sampling state (z, r) from the integrator trajectory according to
# (z, r) ~ Uniform({(z', r') in trajectory | p(z', r') >= u}).
#
# For more information about slice sampling method, see [3].
# For another version of NUTS which uses multinomial sampling instead of slice sampling,
# see [2].
if self.use_multinomial_sampling:
log_slice = -energy_current
else:
# Rather than sampling the slice variable from `Uniform(0, exp(-energy))`, we can
# sample log_slice directly using `energy`, so as to avoid potential underflow or
# overflow issues ([2]).
slice_exp_term = pyro.sample("slicevar_exp_t={}".format(self._t),
dist.Exponential(energy_current.new_tensor(1.)))
log_slice = -energy_current - slice_exp_term
z_left = z_right = z
r_left = r_right = r
z_left_grads = z_right_grads = z_grads
accepted = False
r_sum = r_flat
if self.use_multinomial_sampling:
tree_weight = energy_current.new_zeros(())
else:
tree_weight = energy_current.new_ones(())
# Temporarily disable distributions args checking as
# NaNs are expected during step size adaptation.
with optional(pyro.validation_enabled(False), self._t < self._warmup_steps):
# doubling process, stop when turning or diverging
for tree_depth in range(self._max_tree_depth + 1):
direction = pyro.sample("direction_t={}_treedepth={}".format(self._t, tree_depth),
dist.Bernoulli(probs=torch.ones(1) * 0.5))
direction = int(direction.item())
if direction == 1: # go to the right, start from the right leaf of current tree
new_tree = self._build_tree(z_right, r_right, z_right_grads, log_slice,
direction, tree_depth, energy_current)
# update leaf for the next doubling process
z_right = new_tree.z_right
r_right = new_tree.r_right
z_right_grads = new_tree.z_right_grads
else: # go the the left, start from the left leaf of current tree
new_tree = self._build_tree(z_left, r_left, z_left_grads, log_slice,
direction, tree_depth, energy_current)
z_left = new_tree.z_left
r_left = new_tree.r_left
z_left_grads = new_tree.z_left_grads
if new_tree.turning or new_tree.diverging: # stop doubling
break
if self.use_multinomial_sampling:
new_tree_prob = (new_tree.weight - tree_weight).exp()
else:
new_tree_prob = new_tree.weight / tree_weight
rand = pyro.sample("rand_t={}_treedepth={}".format(self._t, tree_depth),
dist.Uniform(new_tree_prob.new_tensor(0.),
new_tree_prob.new_tensor(1.)))
if rand < new_tree_prob:
accepted = True
z = new_tree.z_proposal
self._cache(new_tree.z_proposal_pe, new_tree.z_proposal_grads)
r_sum = r_sum + new_tree.r_sum
if self._is_turning(r_left, r_right, r_sum): # stop doubling
break
else: # update tree_weight
if self.use_multinomial_sampling:
tree_weight = logsumexp(torch.stack([tree_weight, new_tree.weight]), dim=0)
else:
tree_weight = tree_weight + new_tree.weight
if self._t < self._warmup_steps:
accept_prob = new_tree.sum_accept_probs / new_tree.num_proposals
self._adapter.step(self._t, z, accept_prob)
if accepted:
self._accept_cnt += 1
self._t += 1
# get trace with the constrained values for `z`.
for name, transform in self.transforms.items():
z[name] = transform.inv(z[name])
return self._get_trace(z)