Automatic Guide Generation¶
The pyro.contrib.autoguide
module provides algorithms to automatically
generate guides from simple models, for use in SVI
.
For example to generate a mean field Gaussian guide:
def model():
...
guide = AutoDiagonalNormal(model) # a mean field guide
svi = SVI(model, guide, Adam({'lr': 1e-3}), Trace_ELBO())
Automatic guides can also be combined using pyro.poutine.block()
and
AutoGuideList
.
AutoGuide¶
-
class
AutoGuide
(model, prefix='auto')[source]¶ Bases:
object
Base class for automatic guides.
Derived classes must implement the
__call__()
method.Auto guides can be used individually or combined in an
AutoGuideList
object.Parameters: - model (callable) – a pyro model
- prefix (str) – a prefix that will be prefixed to all param internal sites
-
__call__
(*args, **kwargs)[source]¶ A guide with the same
*args, **kwargs
as the basemodel
.Returns: A dict mapping sample site name to sampled value. Return type: dict
AutoGuideList¶
-
class
AutoGuideList
(model, prefix='auto')[source]¶ Bases:
pyro.contrib.autoguide.AutoGuide
Container class to combine multiple automatic guides.
Example usage:
guide = AutoGuideList(my_model) guide.add(AutoDiagonalNormal(poutine.block(model, hide=["assignment"]))) guide.add(AutoDiscreteParallel(poutine.block(model, expose=["assignment"]))) svi = SVI(model, guide, optim, Trace_ELBO())
Parameters: - model (callable) – a Pyro model
- prefix (str) – a prefix that will be prefixed to all param internal sites
-
__call__
(*args, **kwargs)[source]¶ A composite guide with the same
*args, **kwargs
as the basemodel
.Returns: A dict mapping sample site name to sampled value. Return type: dict
AutoCallable¶
-
class
AutoCallable
(model, guide, median=<function AutoCallable.<lambda>>)[source]¶ Bases:
pyro.contrib.autoguide.AutoGuide
AutoGuide
wrapper for simple callable guides.This is used internally for composing autoguides with custom user-defined guides that are simple callables, e.g.:
def my_local_guide(*args, **kwargs): ... guide = AutoGuideList(model) guide.add(AutoDelta(poutine.block(model, expose=['my_global_param'])) guide.add(my_local_guide) # automatically wrapped in an AutoCallable
To specify a median callable, you can instead:
def my_local_median(*args, **kwargs) ... guide.add(AutoCallable(model, my_local_guide, my_local_median))
For more complex guides that need e.g. access to plates, users should instead subclass
AutoGuide
.Parameters: - model (callable) – a Pyro model
- guide (callable) – a Pyro guide (typically over only part of the model)
- median (callable) – an optional callable returning a dict mapping sample site name to computed median tensor.
AutoDelta¶
-
class
AutoDelta
(model, prefix='auto')[source]¶ Bases:
pyro.contrib.autoguide.AutoGuide
This implementation of
AutoGuide
uses Delta distributions to construct a MAP guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
...note:: This class does MAP inference in constrained space.
Usage:
guide = AutoDelta(model) svi = SVI(model, guide, ...)
By default latent variables are randomly initialized by the model. To change this default behavior the user should call
pyro.param()
before beginning inference, with"auto_"
prefixed to the targetd sample site names e.g. for sample sites named “level” and “concentration”, initialize via:pyro.param("auto_level", torch.tensor([-1., 0., 1.])) pyro.param("auto_concentration", torch.ones(k), constraint=constraints.positive)
AutoContinuous¶
-
class
AutoContinuous
(model, prefix='auto')[source]¶ Bases:
pyro.contrib.autoguide.AutoGuide
Base class for implementations of continuous-valued Automatic Differentiation Variational Inference [1].
Each derived class implements its own
get_posterior()
method.Assumes model structure and latent dimension are fixed, and all latent variables are continuous.
Parameters: model (callable) – a Pyro model Reference:
- [1] ‘Automatic Differentiation Variational Inference’,
- Alp Kucukelbir, Dustin Tran, Rajesh Ranganath, Andrew Gelman, David M. Blei
-
__call__
(*args, **kwargs)[source]¶ An automatic guide with the same
*args, **kwargs
as the basemodel
.Returns: A dict mapping sample site name to sampled value. Return type: dict
-
median
(*args, **kwargs)[source]¶ Returns the posterior median value of each latent variable.
Returns: A dict mapping sample site name to median tensor. Return type: dict
-
quantiles
(quantiles, *args, **kwargs)[source]¶ Returns posterior quantiles each latent variable. Example:
print(guide.quantiles([0.05, 0.5, 0.95]))
Parameters: quantiles (torch.Tensor or list) – A list of requested quantiles between 0 and 1. Returns: A dict mapping sample site name to a list of quantile values. Return type: dict
AutoMultivariateNormal¶
-
class
AutoMultivariateNormal
(model, prefix='auto')[source]¶ Bases:
pyro.contrib.autoguide.AutoContinuous
This implementation of
AutoContinuous
uses a Cholesky factorization of a Multivariate Normal distribution to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.Usage:
guide = AutoMultivariateNormal(model) svi = SVI(model, guide, ...)
By default the mean vector is initialized to zero and the Cholesky factor is initialized to the identity. To change this default behavior the user should call
pyro.param()
before beginning inference, e.g.:latent_dim = 10 pyro.param("auto_loc", torch.randn(latent_dim)) pyro.param("auto_scale_tril", torch.tril(torch.rand(latent_dim)), constraint=constraints.lower_cholesky)
AutoDiagonalNormal¶
-
class
AutoDiagonalNormal
(model, prefix='auto')[source]¶ Bases:
pyro.contrib.autoguide.AutoContinuous
This implementation of
AutoContinuous
uses a Normal distribution with a diagonal covariance matrix to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.Usage:
guide = AutoDiagonalNormal(model) svi = SVI(model, guide, ...)
By default the mean vector is initialized to zero and the scale is initialized to the identity. To change this default behavior the user should call
pyro.param()
before beginning inference, e.g.:latent_dim = 10 pyro.param("auto_loc", torch.randn(latent_dim)) pyro.param("auto_scale", torch.ones(latent_dim), constraint=constraints.positive)
AutoLowRankMultivariateNormal¶
-
class
AutoLowRankMultivariateNormal
(model, prefix='auto', rank=1)[source]¶ Bases:
pyro.contrib.autoguide.AutoContinuous
This implementation of
AutoContinuous
uses a low rank plus diagonal Multivariate Normal distribution to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.Usage:
guide = AutoLowRankMultivariateNormal(model, rank=10) svi = SVI(model, guide, ...)
By default the
cov_diag
is initialized to 1/2 and thecov_factor
is intialized randomly such thatcov_factor.matmul(cov_factor.t())
is half the identity matrix. To change this default behavior the user should callpyro.param()
before beginning inference, e.g.:latent_dim = 10 pyro.param("auto_loc", torch.randn(latent_dim)) pyro.param("auto_cov_factor", torch.randn(latent_dim, rank))) pyro.param("auto_cov_diag", torch.randn(latent_dim).exp()), constraint=constraints.positive)
Parameters:
AutoIAFNormal¶
-
class
AutoIAFNormal
(model, hidden_dim=None, prefix='auto')[source]¶ Bases:
pyro.contrib.autoguide.AutoContinuous
This implementation of
AutoContinuous
uses a Diagonal Normal distribution transformed via aInverseAutoregressiveFlow
to construct a guide over the entire latent space. The guide does not depend on the model’s*args, **kwargs
.Usage:
guide = AutoIAFNormal(model, hidden_dim=latent_dim) svi = SVI(model, guide, ...)
Parameters:
AutoLaplaceApproximation¶
-
class
AutoLaplaceApproximation
(model, prefix='auto')[source]¶ Bases:
pyro.contrib.autoguide.AutoContinuous
Laplace approximation (quadratic approximation) approximates the posterior math:log p(z | x) by a multivariate normal distribution in the unconstrained space. Under the hood, it uses Delta distributions to construct a MAP guide over the entire (unconstrained) latent space. Its covariance is given by the inverse of the hessian of \(-\log p(x, z)\) at the MAP point of z.
Usage:
delta_guide = AutoLaplaceApproximation(model) svi = SVI(model, delta_guide, ...) # ...then train the delta_guide... guide = delta_guide.laplace_approximation()
By default the mean vector is initialized to zero. To change this default behavior the user should call
pyro.param()
before beginning inference, e.g.:latent_dim = 10 pyro.param("auto_loc", torch.randn(latent_dim))
-
laplace_approximation
(*args, **kwargs)[source]¶ Returns a
AutoMultivariateNormal
instance whose posterior’s loc and scale_tril are given by Laplace approximation.
-
AutoDiscreteParallel¶
-
class
AutoDiscreteParallel
(model, prefix='auto')[source]¶ Bases:
pyro.contrib.autoguide.AutoGuide
A discrete mean-field guide that learns a latent discrete distribution for each discrete site in the model.