class SVI(model, guide, optim, loss, loss_and_grads=None, num_samples=10, num_steps=0, **kwargs)[source]

Bases: pyro.infer.abstract_infer.TracePosterior

  • model – the model (callable containing Pyro primitives)
  • guide – the guide (callable containing Pyro primitives)
  • optim (pyro.optim.PyroOptim) – a wrapper a for a PyTorch optimizer
  • loss (pyro.infer.elbo.ELBO) – an instance of a subclass of ELBO. Pyro provides three built-in losses: Trace_ELBO, TraceGraph_ELBO, and TraceEnum_ELBO. See the ELBO docs to learn how to implement a custom loss.
  • num_samples – the number of samples for Monte Carlo posterior approximation
  • num_steps – the number of optimization steps to take in run()

A unified interface for stochastic variational inference in Pyro. The most commonly used loss is loss=Trace_ELBO(). See the tutorial SVI Part I for a discussion.

evaluate_loss(*args, **kwargs)[source]
Returns:estimate of the loss
Return type:float

Evaluate the loss function. Any args or kwargs are passed to the model and guide.

run(*args, **kwargs)[source]
step(*args, **kwargs)[source]
Returns:estimate of the loss
Return type:float

Take a gradient step on the loss function (and any auxiliary loss functions generated under the hood by loss_and_grads). Any args or kwargs are passed to the model and guide


class ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, retain_graph=None)[source]

Bases: object

ELBO is the top-level interface for stochastic variational inference via optimization of the evidence lower bound.

Most users will not interact with this base class ELBO directly; instead they will create instances of derived classes: Trace_ELBO, TraceGraph_ELBO, or TraceEnum_ELBO.

  • num_particles – The number of particles/samples used to form the ELBO (gradient) estimators.
  • max_plate_nesting (int) – Optional bound on max number of nested pyro.plate() contexts. This is only required when enumerating over sample sites in parallel, e.g. if a site sets infer={"enumerate": "parallel"}. If omitted, ELBO may guess a valid value by running the (model,guide) pair once, however this guess may be incorrect if model or guide structure is dynamic.
  • vectorize_particles (bool) – Whether to vectorize the ELBO computation over num_particles. Defaults to False. This requires static structure in model and guide.
  • strict_enumeration_warning (bool) – Whether to warn about possible misuse of enumeration, i.e. that pyro.infer.traceenum_elbo.TraceEnum_ELBO is used iff there are enumerated sample sites.
  • ignore_jit_warnings (bool) – Flag to ignore warnings from the JIT tracer, when . All torch.jit.TracerWarning will be ignored.
  • retain_graph (bool) – Whether to retain autograd graph during an SVI step. Defaults to None (False).


[1] Automated Variational Inference in Probabilistic Programming David Wingate, Theo Weber

[2] Black Box Variational Inference, Rajesh Ranganath, Sean Gerrish, David M. Blei

class Trace_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, retain_graph=None)[source]

Bases: pyro.infer.elbo.ELBO

A trace implementation of ELBO-based SVI. The estimator is constructed along the lines of references [1] and [2]. There are no restrictions on the dependency structure of the model or the guide. The gradient estimator includes partial Rao-Blackwellization for reducing the variance of the estimator when non-reparameterizable random variables are present. The Rao-Blackwellization is partial in that it only uses conditional independence information that is marked by plate contexts. For more fine-grained Rao-Blackwellization, see TraceGraph_ELBO.


[1] Automated Variational Inference in Probabilistic Programming,
David Wingate, Theo Weber
[2] Black Box Variational Inference,
Rajesh Ranganath, Sean Gerrish, David M. Blei
loss(model, guide, *args, **kwargs)[source]
Returns:returns an estimate of the ELBO
Return type:float

Evaluates the ELBO with an estimator that uses num_particles many samples/particles.

differentiable_loss(model, guide, *args, **kwargs)[source]

Computes the surrogate loss that can be differentiated with autograd to produce gradient estimates for the model and guide parameters

loss_and_grads(model, guide, *args, **kwargs)[source]
Returns:returns an estimate of the ELBO
Return type:float

Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator. Performs backward on the latter. Num_particle many samples are used to form the estimators.

class JitTrace_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, retain_graph=None)[source]

Bases: pyro.infer.trace_elbo.Trace_ELBO

Like Trace_ELBO but uses pyro.ops.jit.compile() to compile loss_and_grads().

This works only for a limited set of models:

  • Models must have static structure.
  • Models must not depend on any global data (except the param store).
  • All model inputs that are tensors must be passed in via *args.
  • All model inputs that are not tensors must be passed in via **kwargs, and compilation will be triggered once per unique **kwargs.
loss_and_surrogate_loss(model, guide, *args, **kwargs)[source]
differentiable_loss(model, guide, *args, **kwargs)[source]
loss_and_grads(model, guide, *args, **kwargs)[source]
class TraceGraph_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, retain_graph=None)[source]

Bases: pyro.infer.elbo.ELBO

A TraceGraph implementation of ELBO-based SVI. The gradient estimator is constructed along the lines of reference [1] specialized to the case of the ELBO. It supports arbitrary dependency structure for the model and guide as well as baselines for non-reparameterizable random variables. Where possible, conditional dependency information as recorded in the Trace is used to reduce the variance of the gradient estimator. In particular three kinds of conditional dependency information are used to reduce variance: - the sequential order of samples (z is sampled after y => y does not depend on z) - plate generators


[1] Gradient Estimation Using Stochastic Computation Graphs,
John Schulman, Nicolas Heess, Theophane Weber, Pieter Abbeel
[2] Neural Variational Inference and Learning in Belief Networks
Andriy Mnih, Karol Gregor
loss(model, guide, *args, **kwargs)[source]
Returns:returns an estimate of the ELBO
Return type:float

Evaluates the ELBO with an estimator that uses num_particles many samples/particles.

loss_and_grads(model, guide, *args, **kwargs)[source]
Returns:returns an estimate of the ELBO
Return type:float

Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator. Performs backward on the latter. Num_particle many samples are used to form the estimators. If baselines are present, a baseline loss is also constructed and differentiated.

class JitTraceGraph_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, retain_graph=None)[source]

Bases: pyro.infer.tracegraph_elbo.TraceGraph_ELBO

Like TraceGraph_ELBO but uses torch.jit.trace() to compile loss_and_grads().

This works only for a limited set of models:

  • Models must have static structure.
  • Models must not depend on any global data (except the param store).
  • All model inputs that are tensors must be passed in via *args.
  • All model inputs that are not tensors must be passed in via **kwargs, and compilation will be triggered once per unique **kwargs.
loss_and_grads(model, guide, *args, **kwargs)[source]
class BackwardSampleMessenger(enum_trace, guide_trace)[source]

Bases: pyro.poutine.messenger.Messenger

Implements forward filtering / backward sampling for sampling from the joint posterior distribution

class TraceEnum_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, retain_graph=None)[source]

Bases: pyro.infer.elbo.ELBO

A trace implementation of ELBO-based SVI that supports - exhaustive enumeration over discrete sample sites, and - local parallel sampling over any sample site.

To enumerate over a sample site in the guide, mark the site with either infer={'enumerate': 'sequential'} or infer={'enumerate': 'parallel'}. To configure all guide sites at once, use config_enumerate(). To enumerate over a sample site in the model, mark the site infer={'enumerate': 'parallel'} and ensure the site does not appear in the guide.

This assumes restricted dependency structure on the model and guide: variables outside of an plate can never depend on variables inside that plate.

loss(model, guide, *args, **kwargs)[source]
Returns:an estimate of the ELBO
Return type:float

Estimates the ELBO using num_particles many samples (particles).

differentiable_loss(model, guide, *args, **kwargs)[source]
Returns:a differentiable estimate of the ELBO
Return type:torch.Tensor
Raises:ValueError – if the ELBO is not differentiable (e.g. is identically zero)

Estimates a differentiable ELBO using num_particles many samples (particles). The result should be infinitely differentiable (as long as underlying derivatives have been implemented).

loss_and_grads(model, guide, *args, **kwargs)[source]
Returns:an estimate of the ELBO
Return type:float

Estimates the ELBO using num_particles many samples (particles). Performs backward on the ELBO of each particle.

compute_marginals(model, guide, *args, **kwargs)[source]

Computes marginal distributions at each model-enumerated sample site.

Returns:a dict mapping site name to marginal Distribution object
Return type:OrderedDict
sample_posterior(model, guide, *args, **kwargs)[source]

Sample from the joint posterior distribution of all model-enumerated sites given all observations

class JitTraceEnum_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, retain_graph=None)[source]

Bases: pyro.infer.traceenum_elbo.TraceEnum_ELBO

Like TraceEnum_ELBO but uses pyro.ops.jit.compile() to compile loss_and_grads().

This works only for a limited set of models:

  • Models must have static structure.
  • Models must not depend on any global data (except the param store).
  • All model inputs that are tensors must be passed in via *args.
  • All model inputs that are not tensors must be passed in via **kwargs, and compilation will be triggered once per unique **kwargs.
differentiable_loss(model, guide, *args, **kwargs)[source]
loss_and_grads(model, guide, *args, **kwargs)[source]
class TraceMeanField_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, retain_graph=None)[source]

Bases: pyro.infer.trace_elbo.Trace_ELBO

A trace implementation of ELBO-based SVI. This is currently the only ELBO estimator in Pyro that uses analytic KL divergences when those are available.

In contrast to, e.g., TraceGraph_ELBO and Trace_ELBO this estimator places restrictions on the dependency structure of the model and guide. In particular it assumes that the guide has a mean-field structure, i.e. that it factorizes across the different latent variables present in the guide. It also assumes that all of the latent variables in the guide are reparameterized. This latter condition is satisfied for, e.g., the Normal distribution but is not satisfied for, e.g., the Categorical distribution.


This estimator may give incorrect results if the mean-field condition is not satisfied.

Note for advanced users:

The mean field condition is a sufficient but not necessary condition for this estimator to be correct. The precise condition is that for every latent variable z in the guide, its parents in the model must not include any latent variables that are descendants of z in the guide. Here ‘parents in the model’ and ‘descendants in the guide’ is with respect to the corresponding (statistical) dependency structure. For example, this condition is always satisfied if the model and guide have identical dependency structures.

loss(model, guide, *args, **kwargs)[source]
Returns:returns an estimate of the ELBO
Return type:float

Evaluates the ELBO with an estimator that uses num_particles many samples/particles.

class JitTraceMeanField_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, retain_graph=None)[source]

Bases: pyro.infer.trace_mean_field_elbo.TraceMeanField_ELBO

Like TraceMeanField_ELBO but uses pyro.ops.jit.trace() to compile loss_and_grads().

This works only for a limited set of models:

  • Models must have static structure.
  • Models must not depend on any global data (except the param store).
  • All model inputs that are tensors must be passed in via *args.
  • All model inputs that are not tensors must be passed in via **kwargs, and compilation will be triggered once per unique **kwargs.
differentiable_loss(model, guide, *args, **kwargs)[source]
loss_and_grads(model, guide, *args, **kwargs)[source]
class RenyiELBO(alpha=0, num_particles=2, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True)[source]

Bases: pyro.infer.elbo.ELBO

An implementation of Renyi’s \(\alpha\)-divergence variational inference following reference [1].

In order for the objective to be a strict lower bound, we require \(\alpha \ge 0\). Note, however, that according to reference [1], depending on the dataset \(\alpha < 0\) might give better results. In the special case \(\alpha = 0\), the objective function is that of the important weighted autoencoder derived in reference [2].


Setting \(\alpha < 1\) gives a better bound than the usual ELBO. For \(\alpha = 1\), it is better to use Trace_ELBO class because it helps reduce variances of gradient estimations.


Mini-batch training is not supported yet.

  • alpha (float) – The order of \(\alpha\)-divergence. Here \(\alpha \neq 1\). Default is 0.
  • num_particles – The number of particles/samples used to form the objective (gradient) estimator. Default is 2.
  • max_plate_nesting (int) – Bound on max number of nested pyro.plate() contexts. Default is infinity.
  • strict_enumeration_warning (bool) – Whether to warn about possible misuse of enumeration, i.e. that TraceEnum_ELBO is used iff there are enumerated sample sites.


[1] Renyi Divergence Variational Inference,
Yingzhen Li, Richard E. Turner
[2] Importance Weighted Autoencoders,
Yuri Burda, Roger Grosse, Ruslan Salakhutdinov
loss(model, guide, *args, **kwargs)[source]
Returns:returns an estimate of the ELBO
Return type:float

Evaluates the ELBO with an estimator that uses num_particles many samples/particles.

loss_and_grads(model, guide, *args, **kwargs)[source]
Returns:returns an estimate of the ELBO
Return type:float

Computes the ELBO as well as the surrogate ELBO that is used to form the gradient estimator. Performs backward on the latter. Num_particle many samples are used to form the estimators.

logsumexp(input, dim, keepdim=False, out=None)

Returns the log of summed exponentials of each row of the input tensor in the given dimension dim. The computation is numerically stabilized.

For summation index \(j\) given by dim and other indices \(i\), the result is

\[\text{logsumexp}(x)_{i} = \log \sum_j \exp(x_{ij})\]

If keepdim is True, the output tensor is of the same size as input except in the dimension(s) dim where it is of size 1. Otherwise, dim is squeezed (see torch.squeeze()), resulting in the output tensor having 1 (or len(dim)) fewer dimension(s).

input (Tensor): the input tensor dim (int or tuple of ints): the dimension or dimensions to reduce keepdim (bool): whether the output tensor has dim retained or not out (Tensor, optional): the output tensor
>>> a = torch.randn(3, 3)
>>> torch.logsumexp(a, 1)
tensor([ 0.8442,  1.4322,  0.8711])


class Importance(model, guide=None, num_samples=None)[source]

Bases: pyro.infer.abstract_infer.TracePosterior

  • model – probabilistic model defined as a function
  • guide – guide used for sampling defined as a function
  • num_samples – number of samples to draw from the guide (default 10)

This method performs posterior inference by importance sampling using the guide as the proposal distribution. If no guide is provided, it defaults to proposing from the model’s prior.


Compute (Importance Sampling) Effective Sample Size (ESS).


Estimator of the normalizing constant of the target distribution. (mean of the unnormalized weights)


Compute the normalized importance weights.

logsumexp(input, dim, keepdim=False, out=None)

Returns the log of summed exponentials of each row of the input tensor in the given dimension dim. The computation is numerically stabilized.

For summation index \(j\) given by dim and other indices \(i\), the result is

\[\text{logsumexp}(x)_{i} = \log \sum_j \exp(x_{ij})\]

If keepdim is True, the output tensor is of the same size as input except in the dimension(s) dim where it is of size 1. Otherwise, dim is squeezed (see torch.squeeze()), resulting in the output tensor having 1 (or len(dim)) fewer dimension(s).

input (Tensor): the input tensor dim (int or tuple of ints): the dimension or dimensions to reduce keepdim (bool): whether the output tensor has dim retained or not out (Tensor, optional): the output tensor
>>> a = torch.randn(3, 3)
>>> torch.logsumexp(a, 1)
tensor([ 0.8442,  1.4322,  0.8711])

Discrete Inference

infer_discrete(fn=None, first_available_dim=None, temperature=1)[source]

A poutine that samples discrete sites marked with site["infer"]["enumerate"] = "parallel" from the posterior, conditioned on observations.


@infer_discrete(first_available_dim=-1, temperature=0)
def viterbi_decoder(data, hidden_dim=10):
    transition = 0.3 / hidden_dim + 0.7 * torch.eye(hidden_dim)
    means = torch.arange(float(hidden_dim))
    states = [0]
    for t in pyro.markov(range(len(data))):
                    dist.Normal(means[states[-1]], 1.),
    return states  # returns maximum likelihood states
  • fn – a stochastic function (callable containing Pyro primitive calls)
  • first_available_dim (int) – The first tensor dimension (counting from the right) that is available for parallel enumeration. This dimension and all dimensions left may be used internally by Pyro. This should be a negative integer.
  • temperature (int) – Either 1 (sample via forward-filter backward-sample) or 0 (optimize via Viterbi-like MAP inference). Defaults to 1 (sample).

Inference Utilities

class EmpiricalMarginal(trace_posterior, sites=None, validate_args=None)[source]

Bases: pyro.distributions.empirical.Empirical

Marginal distribution over a single site (or multiple, provided they have the same shape) from the TracePosterior’s model.

..note:: If multiple sites are specified, they must have the same tensor shape.
Samples from each site will be stacked and stored within a single tensor. See Empirical. To hold the marginal distribution of sites having different shapes, use Marginals instead.
  • trace_posterior (TracePosterior) – a TracePosterior instance representing a Monte Carlo posterior.
  • sites (list) – optional list of sites for which we need to generate the marginal distribution.
class Marginals(trace_posterior, sites=None, validate_args=None)[source]

Bases: object

Holds the marginal distribution over one or more sites from the TracePosterior’s model. This is a convenience container class, which can be extended by TracePosterior subclasses. e.g. for implementing diagnostics.

  • trace_posterior (TracePosterior) – a TracePosterior instance representing a Monte Carlo posterior.
  • sites (list) – optional list of sites for which we need to generate the marginal distribution.
class TracePosterior(num_chains=1)[source]

Bases: object

Abstract TracePosterior object from which posterior inference algorithms inherit. When run, collects a bag of execution traces from the approximate posterior. This is designed to be used by other utility classes like EmpiricalMarginal, that need access to the collected execution traces.


Computes information criterion of the model. Currently, returns only “Widely Applicable/Watanabe-Akaike Information Criterion” (WAIC) and the corresponding effective number of parameters.


[1] Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC, Aki Vehtari, Andrew Gelman, and Jonah Gabry

Parameters:pointwise (bool) – a flag to decide if we want to get a vectorized WAIC or not. When pointwise=False, returns the sum.
Returns OrderedDict:
 a dictionary containing values of WAIC and its effective number of parameters.
run(*args, **kwargs)[source]

Calls self._traces to populate execution traces from a stochastic Pyro model.

  • args – optional args taken by self._traces.
  • kwargs – optional keywords args taken by self._traces.
class TracePredictive(model, posterior, num_samples)[source]

Bases: pyro.infer.abstract_infer.TracePosterior

Generates and holds traces from the posterior predictive distribution, given model execution traces from the approximate posterior. This is achieved by constraining latent sites to randomly sampled parameter values from the model execution traces and running the model forward to generate traces with new response (“_RETURN”) sites.

  • model – arbitrary Python callable containing Pyro primitives.
  • posterior (TracePosterior) – trace posterior instance holding samples from the model’s approximate posterior.
  • num_samples (int) – number of samples to generate.