Source code for pyro.infer.mcmc.hmc

from __future__ import absolute_import, division, print_function

import math
from collections import OrderedDict

import torch
from torch.distributions import biject_to, constraints

import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.distributions.util import eye_like
from pyro.infer import config_enumerate
from pyro.infer.mcmc.adaptation import WarmupAdapter
from pyro.infer.mcmc.trace_kernel import TraceKernel
from pyro.infer.mcmc.util import TraceEinsumEvaluator
from pyro.ops.integrator import velocity_verlet
from pyro.poutine.subsample_messenger import _Subsample
from pyro.util import optional, torch_isinf, torch_isnan, ignore_jit_warnings


[docs]class HMC(TraceKernel): r""" Simple Hamiltonian Monte Carlo kernel, where ``step_size`` and ``num_steps`` need to be explicitly specified by the user. **References** [1] `MCMC Using Hamiltonian Dynamics`, Radford M. Neal :param model: Python callable containing Pyro primitives. :param float step_size: Determines the size of a single step taken by the verlet integrator while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set to 1. :param float trajectory_length: Length of a MCMC trajectory. If not specified, it will be set to ``step_size x num_steps``. In case ``num_steps`` is not specified, it will be set to :math:`2\pi`. :param int num_steps: The number of discrete steps over which to simulate Hamiltonian dynamics. The state at the end of the trajectory is returned as the proposal. This value is always equal to ``int(trajectory_length / step_size)``. :param bool adapt_step_size: A flag to decide if we want to adapt step_size during warm-up phase using Dual Averaging scheme. :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass matrix during warm-up phase using Welford scheme. :param bool full_mass: A flag to decide if mass matrix is dense or diagonal. :param dict transforms: Optional dictionary that specifies a transform for a sample site with constrained support to unconstrained space. The transform should be invertible, and implement `log_abs_det_jacobian`. If not specified and the model has sites with constrained support, automatic transformations will be applied, as specified in :mod:`torch.distributions.constraint_registry`. :param int max_plate_nesting: Optional bound on max number of nested :func:`pyro.plate` contexts. This is required if model contains discrete sample sites that can be enumerated over in parallel. :param bool jit_compile: Optional parameter denoting whether to use the PyTorch JIT to trace the log density computation, and use this optimized executable trace in the integrator. :param bool ignore_jit_warnings: Flag to ignore warnings from the JIT tracer when ``jit_compile=True``. Default is False. .. note:: Internally, the mass matrix will be ordered according to the order of the names of latent variables, not the order of their appearance in the model. Example: >>> true_coefs = torch.tensor([1., 2., 3.]) >>> data = torch.randn(2000, 3) >>> dim = 3 >>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample() >>> >>> def model(data): ... coefs_mean = torch.zeros(dim) ... coefs = pyro.sample('beta', dist.Normal(coefs_mean, torch.ones(3))) ... y = pyro.sample('y', dist.Bernoulli(logits=(coefs * data).sum(-1)), obs=labels) ... return y >>> >>> hmc_kernel = HMC(model, step_size=0.0855, num_steps=4) >>> mcmc_run = MCMC(hmc_kernel, num_samples=500, warmup_steps=100).run(data) >>> posterior = mcmc_run.marginal('beta').empirical['beta'] >>> posterior.mean # doctest: +SKIP tensor([ 0.9819, 1.9258, 2.9737]) """ def __init__(self, model, step_size=1, trajectory_length=None, num_steps=None, adapt_step_size=True, adapt_mass_matrix=True, full_mass=False, transforms=None, max_plate_nesting=None, jit_compile=False, ignore_jit_warnings=False): self.model = model self.max_plate_nesting = max_plate_nesting if trajectory_length is not None: self.trajectory_length = trajectory_length elif num_steps is not None: self.trajectory_length = step_size * num_steps else: self.trajectory_length = 2 * math.pi # from Stan self.adapt_step_size = adapt_step_size self._jit_compile = jit_compile self._ignore_jit_warnings = ignore_jit_warnings self.full_mass = full_mass self._target_accept_prob = 0.8 # from Stan # The following parameter is used in find_reasonable_step_size method. # In NUTS paper, this threshold is set to a fixed log(0.5). # After https://github.com/stan-dev/stan/pull/356, it is set to a fixed log(0.8). self._direction_threshold = math.log(0.8) # from Stan # number of tries to get a valid initial trace self._max_tries_initial_trace = 100 self.transforms = {} if transforms is None else transforms self._automatic_transform_enabled = True if transforms is None else False self._reset() self._adapter = WarmupAdapter(step_size, adapt_step_size=adapt_step_size, adapt_mass_matrix=adapt_mass_matrix, is_diag_mass=not full_mass) super(HMC, self).__init__() def _get_trace(self, z): z_trace = self.initial_trace for name, value in z.items(): z_trace.nodes[name]["value"] = value trace_poutine = poutine.trace(poutine.replay(self.model, trace=z_trace)) trace_poutine(*self._args, **self._kwargs) return trace_poutine.trace @staticmethod def _iter_latent_nodes(trace): for name, node in sorted(trace.iter_stochastic_nodes(), key=lambda x: x[0]): if not (node["fn"].has_enumerate_support or isinstance(node["fn"], _Subsample)): yield (name, node) def _compute_trace_log_prob(self, model_trace): return self._trace_prob_evaluator.log_prob(model_trace) def _kinetic_energy(self, r): # TODO: revert to `torch.dot` in pytorch==1.0 # See: https://github.com/uber/pyro/issues/1458 r_flat = torch.cat([r[site_name].reshape(-1) for site_name in sorted(r)]) if self.full_mass: return 0.5 * (r_flat * (self.inverse_mass_matrix.matmul(r_flat))).sum() else: return 0.5 * (self.inverse_mass_matrix * (r_flat ** 2)).sum() def _potential_energy(self, z): if self._jit_compile: return self._potential_energy_jit(z) # Since the model is specified in the constrained space, transform the # unconstrained R.V.s `z` to the constrained space. z_constrained = z.copy() for name, transform in self.transforms.items(): z_constrained[name] = transform.inv(z_constrained[name]) trace = self._get_trace(z_constrained) potential_energy = -self._compute_trace_log_prob(trace) # adjust by the jacobian for this transformation. for name, transform in self.transforms.items(): potential_energy += transform.log_abs_det_jacobian(z_constrained[name], z[name]).sum() return potential_energy def _potential_energy_jit(self, z): names, vals = zip(*sorted(z.items())) if self._compiled_potential_fn: return self._compiled_potential_fn(*vals) def compiled(*zi): z_constrained = list(zi) # transform to constrained space. for i, name in enumerate(names): if name in self.transforms: transform = self.transforms[name] z_constrained[i] = transform.inv(z_constrained[i]) z_constrained = dict(zip(names, z_constrained)) trace = self._get_trace(z_constrained) potential_energy = -self._compute_trace_log_prob(trace) # adjust by the jacobian for this transformation. for i, name in enumerate(names): if name in self.transforms: transform = self.transforms[name] potential_energy += transform.log_abs_det_jacobian(z_constrained[name], zi[i]).sum() return potential_energy with pyro.validation_enabled(False), optional(ignore_jit_warnings(), self._ignore_jit_warnings): self._compiled_potential_fn = torch.jit.trace(compiled, vals, check_trace=False) return self._compiled_potential_fn(*vals) def _energy(self, z, r): return self._kinetic_energy(r) + self._potential_energy(z) def _reset(self): self._t = 0 self._accept_cnt = 0 self._r_shapes = {} self._r_numels = {} self._args = None self._compiled_potential_fn = None self._kwargs = None self._initial_trace = None self._has_enumerable_sites = False self._trace_prob_evaluator = None self._potential_energy_last = None self._z_grads_last = None self._warmup_steps = None def _find_reasonable_step_size(self, z): step_size = self.step_size # We are going to find a step_size which make accept_prob (Metropolis correction) # near the target_accept_prob. If accept_prob:=exp(-delta_energy) is small, # then we have to decrease step_size; otherwise, increase step_size. r, _ = self._sample_r(name="r_presample") energy_current = self._energy(z, r) z_new, r_new, z_grads, potential_energy = velocity_verlet( z, r, self._potential_energy, self.inverse_mass_matrix, step_size) energy_new = potential_energy + self._kinetic_energy(r_new) delta_energy = energy_new - energy_current # direction=1 means keep increasing step_size, otherwise decreasing step_size. # Note that the direction is -1 if delta_energy is `NaN` which may be the # case for a diverging trajectory (e.g. in the case of evaluating log prob # of a value simulated using a large step size for a constrained sample site). direction = 1 if self._direction_threshold < -delta_energy else -1 # define scale for step_size: 2 for increasing, 1/2 for decreasing step_size_scale = 2 ** direction direction_new = direction # keep scale step_size until accept_prob crosses its target # TODO: make thresholds for too small step_size or too large step_size while direction_new == direction: step_size = step_size_scale * step_size z_new, r_new, z_grads, potential_energy = velocity_verlet( z, r, self._potential_energy, self.inverse_mass_matrix, step_size) energy_new = potential_energy + self._kinetic_energy(r_new) delta_energy = energy_new - energy_current direction_new = 1 if self._direction_threshold < -delta_energy else -1 return step_size def _guess_max_plate_nesting(self): """ Guesses max_plate_nesting by running the model once without enumeration. This optimistically assumes static model structure. """ with poutine.block(): model_trace = poutine.trace(self.model).get_trace(*self._args, **self._kwargs) sites = [site for site in model_trace.nodes.values() if site["type"] == "sample"] dims = [frame.dim for site in sites for frame in site["cond_indep_stack"] if frame.vectorized] self.max_plate_nesting = -min(dims) if dims else 0 def _configure_adaptation(self): initial_step_size = None if self.adapt_step_size: z = {name: node["value"].detach() for name, node in self._iter_latent_nodes(self.initial_trace)} for name, transform in self.transforms.items(): z[name] = transform(z[name]) with pyro.validation_enabled(False): initial_step_size = self._find_reasonable_step_size(z) self._adapter.configure(self._warmup_steps, initial_step_size) def _sample_r(self, name): r_dist = self._adapter.r_dist r_flat = pyro.sample(name, r_dist) r = {} pos = 0 for name in sorted(self._r_shapes): next_pos = pos + self._r_numels[name] r[name] = r_flat[pos:next_pos].reshape(self._r_shapes[name]) pos = next_pos assert pos == r_flat.size(0) return r, r_flat @property def inverse_mass_matrix(self): return self._adapter.inverse_mass_matrix @property def step_size(self): return self._adapter.step_size @property def num_steps(self): return max(1, int(self.trajectory_length / self.step_size)) @property def initial_trace(self): """ Find a valid trace to initiate the MCMC sampler. This is also used as a prototype trace to inter-convert between Pyro's trace object and dict object used by the integrator. """ if self._initial_trace: return self._initial_trace trace = poutine.trace(self.model).get_trace(*self._args, **self._kwargs) for i in range(self._max_tries_initial_trace): trace_log_prob_sum = self._compute_trace_log_prob(trace) if not torch_isnan(trace_log_prob_sum) and not torch_isinf(trace_log_prob_sum): self._initial_trace = trace return trace trace = poutine.trace(self.model).get_trace(self._args, self._kwargs) raise ValueError("Model specification seems incorrect - cannot find a valid trace.") @initial_trace.setter def initial_trace(self, trace): self._initial_trace = trace def _initialize_model_properties(self): if self.max_plate_nesting is None: self._guess_max_plate_nesting() # Wrap model in `poutine.enum` to enumerate over discrete latent sites. # No-op if model does not have any discrete latents. self.model = poutine.enum(config_enumerate(self.model), first_available_dim=-1 - self.max_plate_nesting) if self._automatic_transform_enabled: self.transforms = {} trace = poutine.trace(self.model).get_trace(*self._args, **self._kwargs) for name, node in trace.iter_stochastic_nodes(): if isinstance(node["fn"], _Subsample): continue if node["fn"].has_enumerate_support: self._has_enumerable_sites = True continue site_value = node["value"] if node["fn"].support is not constraints.real and self._automatic_transform_enabled: self.transforms[name] = biject_to(node["fn"].support).inv site_value = self.transforms[name](node["value"]) self._r_shapes[name] = site_value.shape self._r_numels[name] = site_value.numel() self._trace_prob_evaluator = TraceEinsumEvaluator(trace, self._has_enumerable_sites, self.max_plate_nesting) mass_matrix_size = sum(self._r_numels.values()) if self.full_mass: initial_mass_matrix = eye_like(site_value, mass_matrix_size) else: initial_mass_matrix = site_value.new_ones(mass_matrix_size) self._adapter.inverse_mass_matrix = initial_mass_matrix
[docs] def setup(self, warmup_steps, *args, **kwargs): self._warmup_steps = warmup_steps self._args = args self._kwargs = kwargs self._initialize_model_properties() self._configure_adaptation()
[docs] def cleanup(self): self._reset()
def _cache(self, potential_energy, z_grads): self._potential_energy_last = potential_energy self._z_grads_last = z_grads def _fetch_from_cache(self): return self._potential_energy_last, self._z_grads_last
[docs] def sample(self, trace): z = {name: node["value"].detach() for name, node in self._iter_latent_nodes(trace)} # automatically transform `z` to unconstrained space, if needed. for name, transform in self.transforms.items(): z[name] = transform(z[name]) r, _ = self._sample_r(name="r_t={}".format(self._t)) potential_energy, z_grads = self._fetch_from_cache() # Temporarily disable distributions args checking as # NaNs are expected during step size adaptation with optional(pyro.validation_enabled(False), self._t < self._warmup_steps): z_new, r_new, z_grads_new, potential_energy_new = velocity_verlet(z, r, self._potential_energy, self.inverse_mass_matrix, self.step_size, self.num_steps, z_grads=z_grads) # apply Metropolis correction. energy_proposal = self._kinetic_energy(r_new) + potential_energy_new energy_current = self._kinetic_energy(r) + potential_energy if potential_energy is not None \ else self._energy(z, r) delta_energy = energy_proposal - energy_current # Set accept prob to 0.0 if delta_energy is `NaN` which may be # the case for a diverging trajectory when using a large step size. if torch_isnan(delta_energy): accept_prob = delta_energy.new_tensor(0.0) else: accept_prob = (-delta_energy).exp().clamp(max=1.) rand = pyro.sample("rand_t={}".format(self._t), dist.Uniform(torch.zeros(1), torch.ones(1))) if rand < accept_prob: self._accept_cnt += 1 z = z_new if self._t < self._warmup_steps: self._adapter.step(self._t, z, accept_prob) self._t += 1 # get trace with the constrained values for `z`. for name, transform in self.transforms.items(): z[name] = transform.inv(z[name]) return self._get_trace(z)
[docs] def diagnostics(self): return OrderedDict([ ("step size", "{:.2e}".format(self.step_size)), ("acc. rate", "{:.3f}".format(self._accept_cnt / self._t)) ])