Automatic Guide Generation

The pyro.contrib.autoguide module provides algorithms to automatically generate guides from simple models, for use in SVI. For example to generate a mean field Gaussian guide:

def model():
    ...

guide = AutoDiagonalNormal(model)  # a mean field guide
svi = SVI(model, guide, Adam({'lr': 1e-3}), Trace_ELBO())

Automatic guides can also be combined using pyro.poutine.block() and AutoGuideList.

AutoGuide

class AutoGuide(model, prefix='auto')[source]

Bases: object

Base class for automatic guides.

Derived classes must implement the __call__() method.

Auto guides can be used individually or combined in an AutoGuideList object.

Parameters:
  • model (callable) – a pyro model
  • prefix (str) – a prefix that will be prefixed to all param internal sites
__call__(*args, **kwargs)[source]

A guide with the same *args, **kwargs as the base model.

Returns:A dict mapping sample site name to sampled value.
Return type:dict
median(*args, **kwargs)[source]

Returns the posterior median value of each latent variable.

Returns:A dict mapping sample site name to median tensor.
Return type:dict
sample_latent(**kwargs)[source]

Samples an encoded latent given the same *args, **kwargs as the base model.

AutoGuideList

class AutoGuideList(model, prefix='auto')[source]

Bases: pyro.contrib.autoguide.AutoGuide

Container class to combine multiple automatic guides.

Example usage:

guide = AutoGuideList(my_model)
guide.add(AutoDiagonalNormal(poutine.block(model, hide=["assignment"])))
guide.add(AutoDiscreteParallel(poutine.block(model, expose=["assignment"])))
svi = SVI(model, guide, optim, Trace_ELBO())
Parameters:
  • model (callable) – a Pyro model
  • prefix (str) – a prefix that will be prefixed to all param internal sites
__call__(*args, **kwargs)[source]

A composite guide with the same *args, **kwargs as the base model.

Returns:A dict mapping sample site name to sampled value.
Return type:dict
add(part)[source]

Add an automatic guide for part of the model. The guide should have been created by blocking the model to restrict to a subset of sample sites. No two parts should operate on any one sample site.

Parameters:part (AutoGuide or callable) – a partial guide to add
median(*args, **kwargs)[source]

Returns the posterior median value of each latent variable.

Returns:A dict mapping sample site name to median tensor.
Return type:dict

AutoCallable

class AutoCallable(model, guide, median=<function AutoCallable.<lambda>>)[source]

Bases: pyro.contrib.autoguide.AutoGuide

AutoGuide wrapper for simple callable guides.

This is used internally for composing autoguides with custom user-defined guides that are simple callables, e.g.:

def my_local_guide(*args, **kwargs):
    ...

guide = AutoGuideList(model)
guide.add(AutoDelta(poutine.block(model, expose=['my_global_param']))
guide.add(my_local_guide)  # automatically wrapped in an AutoCallable

To specify a median callable, you can instead:

def my_local_median(*args, **kwargs)
    ...

guide.add(AutoCallable(model, my_local_guide, my_local_median))

For more complex guides that need e.g. access to plates, users should instead subclass AutoGuide.

Parameters:
  • model (callable) – a Pyro model
  • guide (callable) – a Pyro guide (typically over only part of the model)
  • median (callable) – an optional callable returning a dict mapping sample site name to computed median tensor.
__call__(*args, **kwargs)[source]

AutoDelta

class AutoDelta(model, prefix='auto')[source]

Bases: pyro.contrib.autoguide.AutoGuide

This implementation of AutoGuide uses Delta distributions to construct a MAP guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.

..note:: This class does MAP inference in constrained space.

Usage:

guide = AutoDelta(model)
svi = SVI(model, guide, ...)

By default latent variables are randomly initialized by the model. To change this default behavior the user should call pyro.param() before beginning inference, with "auto_" prefixed to the targetd sample site names e.g. for sample sites named “level” and “concentration”, initialize via:

pyro.param("auto_level", torch.tensor([-1., 0., 1.]))
pyro.param("auto_concentration", torch.ones(k),
           constraint=constraints.positive)
__call__(*args, **kwargs)[source]

An automatic guide with the same *args, **kwargs as the base model.

Returns:A dict mapping sample site name to sampled value.
Return type:dict
median(*args, **kwargs)[source]

Returns the posterior median value of each latent variable.

Returns:A dict mapping sample site name to median tensor.
Return type:dict

AutoContinuous

class AutoContinuous(model, prefix='auto')[source]

Bases: pyro.contrib.autoguide.AutoGuide

Base class for implementations of continuous-valued Automatic Differentiation Variational Inference [1].

Each derived class implements its own get_posterior() method.

Assumes model structure and latent dimension are fixed, and all latent variables are continuous.

Parameters:model (callable) – a Pyro model

Reference:

[1] ‘Automatic Differentiation Variational Inference’,
Alp Kucukelbir, Dustin Tran, Rajesh Ranganath, Andrew Gelman, David M. Blei
__call__(*args, **kwargs)[source]

An automatic guide with the same *args, **kwargs as the base model.

Returns:A dict mapping sample site name to sampled value.
Return type:dict
get_posterior(*args, **kwargs)[source]

Returns the posterior distribution.

median(*args, **kwargs)[source]

Returns the posterior median value of each latent variable.

Returns:A dict mapping sample site name to median tensor.
Return type:dict
quantiles(quantiles, *args, **kwargs)[source]

Returns posterior quantiles each latent variable. Example:

print(guide.quantiles([0.05, 0.5, 0.95]))
Parameters:quantiles (torch.Tensor or list) – A list of requested quantiles between 0 and 1.
Returns:A dict mapping sample site name to a list of quantile values.
Return type:dict
sample_latent(*args, **kwargs)[source]

Samples an encoded latent given the same *args, **kwargs as the base model.

AutoMultivariateNormal

class AutoMultivariateNormal(model, prefix='auto')[source]

Bases: pyro.contrib.autoguide.AutoContinuous

This implementation of AutoContinuous uses a Cholesky factorization of a Multivariate Normal distribution to construct a guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.

Usage:

guide = AutoMultivariateNormal(model)
svi = SVI(model, guide, ...)

By default the mean vector is initialized to zero and the Cholesky factor is initialized to the identity. To change this default behavior the user should call pyro.param() before beginning inference, e.g.:

latent_dim = 10
pyro.param("auto_loc", torch.randn(latent_dim))
pyro.param("auto_scale_tril", torch.tril(torch.rand(latent_dim)),
           constraint=constraints.lower_cholesky)
get_posterior(*args, **kwargs)[source]

Returns a MultivariateNormal posterior distribution.

AutoDiagonalNormal

class AutoDiagonalNormal(model, prefix='auto')[source]

Bases: pyro.contrib.autoguide.AutoContinuous

This implementation of AutoContinuous uses a Normal distribution with a diagonal covariance matrix to construct a guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.

Usage:

guide = AutoDiagonalNormal(model)
svi = SVI(model, guide, ...)

By default the mean vector is initialized to zero and the scale is initialized to the identity. To change this default behavior the user should call pyro.param() before beginning inference, e.g.:

latent_dim = 10
pyro.param("auto_loc", torch.randn(latent_dim))
pyro.param("auto_scale", torch.ones(latent_dim),
           constraint=constraints.positive)
get_posterior(*args, **kwargs)[source]

Returns a diagonal Normal posterior distribution.

AutoLowRankMultivariateNormal

class AutoLowRankMultivariateNormal(model, prefix='auto', rank=1)[source]

Bases: pyro.contrib.autoguide.AutoContinuous

This implementation of AutoContinuous uses a low rank plus diagonal Multivariate Normal distribution to construct a guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.

Usage:

guide = AutoLowRankMultivariateNormal(model, rank=10)
svi = SVI(model, guide, ...)

By default the cov_diag is initialized to 1/2 and the cov_factor is intialized randomly such that cov_factor.matmul(cov_factor.t()) is half the identity matrix. To change this default behavior the user should call pyro.param() before beginning inference, e.g.:

latent_dim = 10
pyro.param("auto_loc", torch.randn(latent_dim))
pyro.param("auto_cov_factor", torch.randn(latent_dim, rank)))
pyro.param("auto_cov_diag", torch.randn(latent_dim).exp()),
           constraint=constraints.positive)
Parameters:
  • model (callable) – a generative model
  • rank (int) – the rank of the low-rank part of the covariance matrix
  • prefix (str) – a prefix that will be prefixed to all param internal sites
get_posterior(*args, **kwargs)[source]

Returns a LowRankMultivariateNormal posterior distribution.

AutoIAFNormal

class AutoIAFNormal(model, hidden_dim=None, prefix='auto')[source]

Bases: pyro.contrib.autoguide.AutoContinuous

This implementation of AutoContinuous uses a Diagonal Normal distribution transformed via a InverseAutoregressiveFlow to construct a guide over the entire latent space. The guide does not depend on the model’s *args, **kwargs.

Usage:

guide = AutoIAFNormal(model, hidden_dim=latent_dim)
svi = SVI(model, guide, ...)
Parameters:
  • model (callable) – a generative model
  • hidden_dim (int) – number of hidden dimensions in the IAF
  • prefix (str) – a prefix that will be prefixed to all param internal sites
get_posterior(*args, **kwargs)[source]

Returns a diagonal Normal posterior distribution transformed by InverseAutoregressiveFlow.

AutoLaplaceApproximation

class AutoLaplaceApproximation(model, prefix='auto')[source]

Bases: pyro.contrib.autoguide.AutoContinuous

Laplace approximation (quadratic approximation) approximates the posterior math:log p(z | x) by a multivariate normal distribution in the unconstrained space. Under the hood, it uses Delta distributions to construct a MAP guide over the entire (unconstrained) latent space. Its covariance is given by the inverse of the hessian of \(-\log p(x, z)\) at the MAP point of z.

Usage:

delta_guide = AutoLaplaceApproximation(model)
svi = SVI(model, delta_guide, ...)
# ...then train the delta_guide...
guide = delta_guide.laplace_approximation()

By default the mean vector is initialized to zero. To change this default behavior the user should call pyro.param() before beginning inference, e.g.:

latent_dim = 10
pyro.param("auto_loc", torch.randn(latent_dim))
get_posterior(*args, **kwargs)[source]

Returns a Delta posterior distribution for MAP inference.

laplace_approximation(*args, **kwargs)[source]

Returns a AutoMultivariateNormal instance whose posterior’s loc and scale_tril are given by Laplace approximation.

AutoDiscreteParallel

class AutoDiscreteParallel(model, prefix='auto')[source]

Bases: pyro.contrib.autoguide.AutoGuide

A discrete mean-field guide that learns a latent discrete distribution for each discrete site in the model.

__call__(*args, **kwargs)[source]

An automatic guide with the same *args, **kwargs as the base model.

Returns:A dict mapping sample site name to sampled value.
Return type:dict