Distributions

PyTorch Distributions

Most distributions in Pyro are thin wrappers around PyTorch distributions. For details on the PyTorch distribution interface, see torch.distributions.distribution.Distribution. For differences between the Pyro and PyTorch interfaces, see TorchDistributionMixin.

Bernoulli

class Bernoulli(probs=None, logits=None, validate_args=None)

Wraps torch.distributions.bernoulli.Bernoulli with TorchDistributionMixin.

Beta

class Beta(concentration1, concentration0, validate_args=None)

Wraps torch.distributions.beta.Beta with TorchDistributionMixin.

Binomial

class Binomial(total_count=1, probs=None, logits=None, validate_args=None)

Wraps torch.distributions.binomial.Binomial with TorchDistributionMixin.

Categorical

class Categorical(probs=None, logits=None, validate_args=None)

Wraps torch.distributions.categorical.Categorical with TorchDistributionMixin.

Cauchy

class Cauchy(loc, scale, validate_args=None)

Wraps torch.distributions.cauchy.Cauchy with TorchDistributionMixin.

Chi2

class Chi2(df, validate_args=None)

Wraps torch.distributions.chi2.Chi2 with TorchDistributionMixin.

Dirichlet

class Dirichlet(concentration, validate_args=None)

Wraps torch.distributions.dirichlet.Dirichlet with TorchDistributionMixin.

Exponential

class Exponential(rate, validate_args=None)

Wraps torch.distributions.exponential.Exponential with TorchDistributionMixin.

ExponentialFamily

class ExponentialFamily(batch_shape=torch.Size([]), event_shape=torch.Size([]), validate_args=None)

Wraps torch.distributions.exp_family.ExponentialFamily with TorchDistributionMixin.

FisherSnedecor

class FisherSnedecor(df1, df2, validate_args=None)

Wraps torch.distributions.fishersnedecor.FisherSnedecor with TorchDistributionMixin.

Gamma

class Gamma(concentration, rate, validate_args=None)

Wraps torch.distributions.gamma.Gamma with TorchDistributionMixin.

Geometric

class Geometric(probs=None, logits=None, validate_args=None)

Wraps torch.distributions.geometric.Geometric with TorchDistributionMixin.

Gumbel

class Gumbel(loc, scale, validate_args=None)

Wraps torch.distributions.gumbel.Gumbel with TorchDistributionMixin.

HalfCauchy

class HalfCauchy(scale, validate_args=None)

Wraps torch.distributions.half_cauchy.HalfCauchy with TorchDistributionMixin.

HalfNormal

class HalfNormal(scale, validate_args=None)

Wraps torch.distributions.half_normal.HalfNormal with TorchDistributionMixin.

Independent

class Independent(base_distribution, reinterpreted_batch_ndims, validate_args=None)[source]

Wraps torch.distributions.independent.Independent with TorchDistributionMixin.

Laplace

class Laplace(loc, scale, validate_args=None)

Wraps torch.distributions.laplace.Laplace with TorchDistributionMixin.

LogNormal

class LogNormal(loc, scale, validate_args=None)

Wraps torch.distributions.log_normal.LogNormal with TorchDistributionMixin.

LogisticNormal

class LogisticNormal(loc, scale, validate_args=None)

Wraps torch.distributions.logistic_normal.LogisticNormal with TorchDistributionMixin.

LowRankMultivariateNormal

class LowRankMultivariateNormal(loc, cov_factor, cov_diag, validate_args=None)

Wraps torch.distributions.lowrank_multivariate_normal.LowRankMultivariateNormal with TorchDistributionMixin.

Multinomial

class Multinomial(total_count=1, probs=None, logits=None, validate_args=None)

Wraps torch.distributions.multinomial.Multinomial with TorchDistributionMixin.

MultivariateNormal

class MultivariateNormal(loc, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None)[source]

Wraps torch.distributions.multivariate_normal.MultivariateNormal with TorchDistributionMixin.

NegativeBinomial

class NegativeBinomial(total_count, probs=None, logits=None, validate_args=None)

Wraps torch.distributions.negative_binomial.NegativeBinomial with TorchDistributionMixin.

Normal

class Normal(loc, scale, validate_args=None)

Wraps torch.distributions.normal.Normal with TorchDistributionMixin.

OneHotCategorical

class OneHotCategorical(probs=None, logits=None, validate_args=None)

Wraps torch.distributions.one_hot_categorical.OneHotCategorical with TorchDistributionMixin.

Pareto

class Pareto(scale, alpha, validate_args=None)

Wraps torch.distributions.pareto.Pareto with TorchDistributionMixin.

Poisson

class Poisson(rate, validate_args=None)

Wraps torch.distributions.poisson.Poisson with TorchDistributionMixin.

RelaxedBernoulli

class RelaxedBernoulli(temperature, probs=None, logits=None, validate_args=None)

Wraps torch.distributions.relaxed_bernoulli.RelaxedBernoulli with TorchDistributionMixin.

RelaxedOneHotCategorical

class RelaxedOneHotCategorical(temperature, probs=None, logits=None, validate_args=None)

Wraps torch.distributions.relaxed_categorical.RelaxedOneHotCategorical with TorchDistributionMixin.

StudentT

class StudentT(df, loc=0.0, scale=1.0, validate_args=None)

Wraps torch.distributions.studentT.StudentT with TorchDistributionMixin.

TransformedDistribution

class TransformedDistribution(base_distribution, transforms, validate_args=None)

Wraps torch.distributions.transformed_distribution.TransformedDistribution with TorchDistributionMixin.

Uniform

class Uniform(low, high, validate_args=None)

Wraps torch.distributions.uniform.Uniform with TorchDistributionMixin.

Weibull

class Weibull(scale, concentration, validate_args=None)

Wraps torch.distributions.weibull.Weibull with TorchDistributionMixin.

Pyro Distributions

Abstract Distribution

class Distribution[source]

Bases: object

Base class for parameterized probability distributions.

Distributions in Pyro are stochastic function objects with sample() and log_prob() methods. Distribution are stochastic functions with fixed parameters:

d = dist.Bernoulli(param)
x = d()                                # Draws a random sample.
p = d.log_prob(x)                      # Evaluates log probability of x.

Implementing New Distributions:

Derived classes must implement the methods: sample(), log_prob().

Examples:

Take a look at the examples to see how they interact with inference algorithms.

__call__(*args, **kwargs)[source]

Samples a random value (just an alias for .sample(*args, **kwargs)).

For tensor distributions, the returned tensor should have the same .shape as the parameters.

Returns:A random value.
Return type:torch.Tensor
enumerate_support(expand=True)[source]

Returns a representation of the parametrized distribution’s support, along the first dimension. This is implemented only by discrete distributions.

Note that this returns support values of all the batched RVs in lock-step, rather than the full cartesian product.

Parameters:expand (bool) – whether to expand the result to a tensor of shape (n,) + batch_shape + event_shape. If false, the return value has unexpanded shape (n,) + (1,)*len(batch_shape) + event_shape which can be broadcasted to the full shape.
Returns:An iterator over the distribution’s discrete support.
Return type:iterator
has_enumerate_support = False
has_rsample = False
log_prob(x, *args, **kwargs)[source]

Evaluates log probability densities for each of a batch of samples.

Parameters:x (torch.Tensor) – A single value or a batch of values batched along axis 0.
Returns:log probability densities as a one-dimensional Tensor with same batch size as value and params. The shape of the result should be self.batch_size.
Return type:torch.Tensor
sample(*args, **kwargs)[source]

Samples a random value.

For tensor distributions, the returned tensor should have the same .shape as the parameters, unless otherwise noted.

Parameters:sample_shape (torch.Size) – the size of the iid batch to be drawn from the distribution.
Returns:A random value or batch of random values (if parameters are batched). The shape of the result should be self.shape().
Return type:torch.Tensor
score_parts(x, *args, **kwargs)[source]

Computes ingredients for stochastic gradient estimators of ELBO.

The default implementation is correct both for non-reparameterized and for fully reparameterized distributions. Partially reparameterized distributions should override this method to compute correct .score_function and .entropy_term parts.

Parameters:x (torch.Tensor) – A single value or batch of values.
Returns:A ScoreParts object containing parts of the ELBO estimator.
Return type:ScoreParts

TorchDistributionMixin

class TorchDistributionMixin[source]

Bases: pyro.distributions.distribution.Distribution

Mixin to provide Pyro compatibility for PyTorch distributions.

You should instead use TorchDistribution for new distribution classes.

This is mainly useful for wrapping existing PyTorch distributions for use in Pyro. Derived classes must first inherit from torch.distributions.distribution.Distribution and then inherit from TorchDistributionMixin.

__call__(sample_shape=torch.Size([]))[source]

Samples a random value.

This is reparameterized whenever possible, calling rsample() for reparameterized distributions and sample() for non-reparameterized distributions.

Parameters:sample_shape (torch.Size) – the size of the iid batch to be drawn from the distribution.
Returns:A random value or batch of random values (if parameters are batched). The shape of the result should be self.shape().
Return type:torch.Tensor
event_dim
Returns:Number of dimensions of individual events.
Return type:int
shape(sample_shape=torch.Size([]))[source]

The tensor shape of samples from this distribution.

Samples are of shape:

d.shape(sample_shape) == sample_shape + d.batch_shape + d.event_shape
Parameters:sample_shape (torch.Size) – the size of the iid batch to be drawn from the distribution.
Returns:Tensor shape of samples.
Return type:torch.Size
expand_by(sample_shape)[source]

Expands a distribution by adding sample_shape to the left side of its batch_shape.

To expand internal dims of self.batch_shape from 1 to something larger, use expand() instead.

Parameters:sample_shape (torch.Size) – The size of the iid batch to be drawn from the distribution.
Returns:An expanded version of this distribution.
Return type:ReshapedDistribution
reshape(sample_shape=None, extra_event_dims=None)[source]
to_event(reinterpreted_batch_ndims=None)[source]

Reinterprets the n rightmost dimensions of this distributions batch_shape as event dims, adding them to the left side of event_shape.

Example:

>>> [d1.batch_shape, d1.event_shape]
[torch.Size([2, 3]), torch.Size([4, 5])]
>>> d2 = d1.to_event(1)
>>> [d2.batch_shape, d2.event_shape]
[torch.Size([2]), torch.Size([3, 4, 5])]
>>> d3 = d1.to_event(2)
>>> [d3.batch_shape, d3.event_shape]
[torch.Size([]), torch.Size([2, 3, 4, 5])]
Parameters:reinterpreted_batch_ndims (int) – The number of batch dimensions to reinterpret as event dimensions.
Returns:A reshaped version of this distribution.
Return type:pyro.distributions.torch.Independent
independent(reinterpreted_batch_ndims=None)[source]
mask(mask)[source]

Masks a distribution by a zero-one tensor that is broadcastable to the distributions batch_shape.

Parameters:mask (torch.Tensor) – A zero-one valued float tensor.
Returns:A masked copy of this distribution.
Return type:MaskedDistribution

TorchDistribution

class TorchDistribution(batch_shape=torch.Size([]), event_shape=torch.Size([]), validate_args=None)[source]

Bases: torch.distributions.distribution.Distribution, pyro.distributions.torch_distribution.TorchDistributionMixin

Base class for PyTorch-compatible distributions with Pyro support.

This should be the base class for almost all new Pyro distributions.

Note

Parameters and data should be of type Tensor and all methods return type Tensor unless otherwise noted.

Tensor Shapes:

TorchDistributions provide a method .shape() for the tensor shape of samples:

x = d.sample(sample_shape)
assert x.shape == d.shape(sample_shape)

Pyro follows the same distribution shape semantics as PyTorch. It distinguishes between three different roles for tensor shapes of samples:

  • sample shape corresponds to the shape of the iid samples drawn from the distribution. This is taken as an argument by the distribution’s sample method.
  • batch shape corresponds to non-identical (independent) parameterizations of the distribution, inferred from the distribution’s parameter shapes. This is fixed for a distribution instance.
  • event shape corresponds to the event dimensions of the distribution, which is fixed for a distribution class. These are collapsed when we try to score a sample from the distribution via d.log_prob(x).

These shapes are related by the equation:

assert d.shape(sample_shape) == sample_shape + d.batch_shape + d.event_shape

Distributions provide a vectorized log_prob() method that evaluates the log probability density of each event in a batch independently, returning a tensor of shape sample_shape + d.batch_shape:

x = d.sample(sample_shape)
assert x.shape == d.shape(sample_shape)
log_p = d.log_prob(x)
assert log_p.shape == sample_shape + d.batch_shape

Implementing New Distributions:

Derived classes must implement the methods sample() (or rsample() if .has_rsample == True) and log_prob(), and must implement the properties batch_shape, and event_shape. Discrete classes may also implement the enumerate_support() method to improve gradient estimates and set .has_enumerate_support = True.

AVFMultivariateNormal

class AVFMultivariateNormal(loc, scale_tril, control_var)[source]

Bases: pyro.distributions.torch.MultivariateNormal

Multivariate normal (Gaussian) distribution with transport equation inspired control variates (adaptive velocity fields).

A distribution over vectors in which all the elements have a joint Gaussian density.

Parameters:
  • loc (torch.Tensor) – D-dimensional mean vector.
  • scale_tril (torch.Tensor) – Cholesky of Covariance matrix; D x D matrix.
  • control_var (torch.Tensor) – 2 x L x D tensor that parameterizes the control variate; L is an arbitrary positive integer. This parameter needs to be learned (i.e. adapted) to achieve lower variance gradients. In a typical use case this parameter will be adapted concurrently with the loc and scale_tril that define the distribution.

Example usage:

control_var = torch.tensor(0.1 * torch.ones(2, 1, D), requires_grad=True)
opt_cv = torch.optim.Adam([control_var], lr=0.1, betas=(0.5, 0.999))

for _ in range(1000):
    d = AVFMultivariateNormal(loc, scale_tril, control_var)
    z = d.rsample()
    cost = torch.pow(z, 2.0).sum()
    cost.backward()
    opt_cv.step()
    opt_cv.zero_grad()
arg_constraints = {'control_var': Real(), 'loc': Real(), 'scale_tril': LowerTriangular()}
rsample(sample_shape=torch.Size([]))[source]

Delta

class Delta(v, log_density=0.0, event_dim=0, validate_args=None)[source]

Bases: pyro.distributions.torch_distribution.TorchDistribution

Degenerate discrete distribution (a single point).

Discrete distribution that assigns probability one to the single element in its support. Delta distribution parameterized by a random choice should not be used with MCMC based inference, as doing so produces incorrect results.

Parameters:
  • v (torch.Tensor) – The single support element.
  • log_density (torch.Tensor) – An optional density for this Delta. This is useful to keep the class of Delta distributions closed under differentiable transformation.
  • event_dim (int) – Optional event dimension, defaults to zero.
arg_constraints = {'log_density': Real(), 'v': Real()}
expand(batch_shape, _instance=None)[source]
has_rsample = True
log_prob(x)[source]
mean
rsample(sample_shape=torch.Size([]))[source]
support = Real()
variance

EmpiricalDistribution

class Empirical(samples, log_weights, validate_args=None)[source]

Bases: pyro.distributions.torch_distribution.TorchDistribution

Empirical distribution associated with the sampled data.

Parameters:
  • samples (torch.Tensor) – samples from the empirical distribution.
  • log_weights (torch.Tensor) – log weights (optional) corresponding to the samples. The leftmost shape of log_weights must match that of samples
arg_constraints = {}
enumerate_support(expand=True)[source]

See pyro.distributions.torch_distribution.TorchDistribution.enumerate_support()

event_shape

See pyro.distributions.torch_distribution.TorchDistribution.event_shape()

has_enumerate_support = True
log_prob(value)[source]

Returns the log of the probability mass function evaluated at value. Note that this currently only supports scoring values with empty sample_shape.

Parameters:value (torch.Tensor) – scalar or tensor value to be scored.
log_weights
mean

See pyro.distributions.torch_distribution.TorchDistribution.mean()

sample(sample_shape=torch.Size([]))[source]

See pyro.distributions.torch_distribution.TorchDistribution.sample()

sample_size

Number of samples that constitute the empirical distribution.

Return int:number of samples collected.
support = Real()
variance

See pyro.distributions.torch_distribution.TorchDistribution.variance()

GaussianScaleMixture

class GaussianScaleMixture(coord_scale, component_logits, component_scale)[source]

Bases: pyro.distributions.torch_distribution.TorchDistribution

Mixture of Normal distributions with zero mean and diagonal covariance matrices.

That is, this distribution is a mixture with K components, where each component distribution is a D-dimensional Normal distribution with zero mean and a D-dimensional diagonal covariance matrix. The K different covariance matrices are controlled by the parameters coord_scale and component_scale. That is, the covariance matrix of the k’th component is given by

Sigma_ii = (component_scale_k * coord_scale_i) ** 2 (i = 1, …, D)

where component_scale`_k is a positive scale factor and `coord_scale`_i are positive scale parameters shared between all K components. The mixture weights are controlled by a K-dimensional vector of softmax logits, `component_logits. This distribution implements pathwise derivatives for samples from the distribution. This distribution does not currently support batched parameters.

See reference [1] for details on the implementations of the pathwise derivative. Please consider citing this reference if you use the pathwise derivative in your research.

[1] Pathwise Derivatives for Multivariate Distributions, Martin Jankowiak & Theofanis Karaletsos. arXiv:1806.01856

Note that this distribution supports both even and odd dimensions, but the former should be more a bit higher precision, since it doesn’t use any erfs in the backward call. Also note that this distribution does not support D = 1.

Parameters:
  • coord_scale (torch.tensor) – D-dimensional vector of scales
  • component_logits (torch.tensor) – K-dimensional vector of logits
  • component_scale (torch.tensor) – K-dimensional vector of scale multipliers
arg_constraints = {'component_logits': Real(), 'component_scale': GreaterThan(lower_bound=0.0), 'coord_scale': GreaterThan(lower_bound=0.0)}
has_rsample = True
log_prob(value)[source]
rsample(sample_shape=torch.Size([]))[source]

MaskedMixture

class MaskedMixture(mask, component0, component1, validate_args=None)[source]

Bases: pyro.distributions.torch_distribution.TorchDistribution

A masked deterministic mixture of two distributions.

This is useful when the mask is sampled from another distribution, possibly correlated across the batch. Often the mask can be marginalized out via enumeration.

Example:

change_point = pyro.sample("change_point",
                           dist.Categorical(torch.ones(len(data) + 1)),
                           infer={'enumerate': 'parallel'})
mask = torch.arange(len(data), dtype=torch.long) >= changepoint
with pyro.plate("data", len(data)):
    pyro.sample("obs", MaskedMixture(mask, dist1, dist2), obs=data)
Parameters:
arg_constraints = {}
expand(batch_shape)[source]
has_rsample
log_prob(value)[source]
mean[source]
rsample(sample_shape=torch.Size([]))[source]
sample(sample_shape=torch.Size([]))[source]
support
variance[source]

MixtureOfDiagNormalsSharedCovariance

class MixtureOfDiagNormalsSharedCovariance(locs, coord_scale, component_logits)[source]

Bases: pyro.distributions.torch_distribution.TorchDistribution

Mixture of Normal distributions with diagonal covariance matrices.

That is, this distribution is a mixture with K components, where each component distribution is a D-dimensional Normal distribution with a D-dimensional mean parameter loc and a D-dimensional diagonal covariance matrix specified by a scale parameter coord_scale. The K different component means are gathered into the parameter locs and the scale parameter is shared between all K components. The mixture weights are controlled by a K-dimensional vector of softmax logits, component_logits. This distribution implements pathwise derivatives for samples from the distribution.

See reference [1] for details on the implementations of the pathwise derivative. Please consider citing this reference if you use the pathwise derivative in your research. Note that this distribution does not support dimension D = 1.

[1] Pathwise Derivatives for Multivariate Distributions, Martin Jankowiak & Theofanis Karaletsos. arXiv:1806.01856

Parameters:
  • locs (torch.Tensor) – K x D mean matrix
  • coord_scale (torch.Tensor) – shared D-dimensional scale vector
  • component_logits (torch.Tensor) – K-dimensional vector of softmax logits
arg_constraints = {'component_logits': Real(), 'coord_scale': GreaterThan(lower_bound=0.0), 'locs': Real()}
expand(batch_shape, _instance=None)[source]
has_rsample = True
log_prob(value)[source]
rsample(sample_shape=torch.Size([]))[source]

OMTMultivariateNormal

class OMTMultivariateNormal(loc, scale_tril)[source]

Bases: pyro.distributions.torch.MultivariateNormal

Multivariate normal (Gaussian) distribution with OMT gradients w.r.t. both parameters. Note the gradient computation w.r.t. the Cholesky factor has cost O(D^3), although the resulting gradient variance is generally expected to be lower.

A distribution over vectors in which all the elements have a joint Gaussian density.

Parameters:
arg_constraints = {'loc': Real(), 'scale_tril': LowerTriangular()}
rsample(sample_shape=torch.Size([]))[source]

RelaxedBernoulliStraightThrough

class RelaxedBernoulliStraightThrough(temperature, probs=None, logits=None, validate_args=None)[source]

Bases: pyro.distributions.torch.RelaxedBernoulli

An implementation of RelaxedBernoulli with a straight-through gradient estimator.

This distribution has the following properties:

  • The samples returned by the rsample() method are discrete/quantized.
  • The log_prob() method returns the log probability of the relaxed/unquantized sample using the GumbelSoftmax distribution.
  • In the backward pass the gradient of the sample with respect to the parameters of the distribution uses the relaxed/unquantized sample.

References:

[1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables,
Chris J. Maddison, Andriy Mnih, Yee Whye Teh
[2] Categorical Reparameterization with Gumbel-Softmax,
Eric Jang, Shixiang Gu, Ben Poole
log_prob(value)[source]

See pyro.distributions.torch.RelaxedBernoulli.log_prob()

rsample(sample_shape=torch.Size([]))[source]

See pyro.distributions.torch.RelaxedBernoulli.rsample()

RelaxedOneHotCategoricalStraightThrough

class RelaxedOneHotCategoricalStraightThrough(temperature, probs=None, logits=None, validate_args=None)[source]

Bases: pyro.distributions.torch.RelaxedOneHotCategorical

An implementation of RelaxedOneHotCategorical with a straight-through gradient estimator.

This distribution has the following properties:

  • The samples returned by the rsample() method are discrete/quantized.
  • The log_prob() method returns the log probability of the relaxed/unquantized sample using the GumbelSoftmax distribution.
  • In the backward pass the gradient of the sample with respect to the parameters of the distribution uses the relaxed/unquantized sample.

References:

[1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables,
Chris J. Maddison, Andriy Mnih, Yee Whye Teh
[2] Categorical Reparameterization with Gumbel-Softmax,
Eric Jang, Shixiang Gu, Ben Poole
log_prob(value)[source]

See pyro.distributions.torch.RelaxedOneHotCategorical.log_prob()

rsample(sample_shape=torch.Size([]))[source]

See pyro.distributions.torch.RelaxedOneHotCategorical.rsample()

Rejector

class Rejector(propose, log_prob_accept, log_scale)[source]

Bases: pyro.distributions.torch_distribution.TorchDistribution

Rejection sampled distribution given an acceptance rate function.

Parameters:
  • propose (Distribution) – A proposal distribution that samples batched proposals via propose(). rsample() supports a sample_shape arg only if propose() supports a sample_shape arg.
  • log_prob_accept (callable) – A callable that inputs a batch of proposals and returns a batch of log acceptance probabilities.
  • log_scale – Total log probability of acceptance.
has_rsample = True
log_prob(x)[source]
rsample(sample_shape=torch.Size([]))[source]
score_parts(x)[source]

VonMises

class VonMises(loc, concentration, validate_args=None)[source]

Bases: pyro.distributions.torch_distribution.TorchDistribution

A circular von Mises distribution.

This implementation uses polar coordinates. The loc and value args can be any real number (to facilitate unconstrained optimization), but are interpreted as angles modulo 2 pi.

See VonMises3D for a 3D cartesian coordinate cousin of this distribution.

Currently only log_prob() is implemented.

Parameters:
arg_constraints = {'concentration': GreaterThan(lower_bound=0.0), 'loc': Real()}
expand(batch_shape)[source]
log_prob(value)[source]
support = Real()

VonMises3D

class VonMises3D(concentration, validate_args=None)[source]

Bases: pyro.distributions.torch_distribution.TorchDistribution

Spherical von Mises distribution.

This implementation combines the direction parameter and concentration parameter into a single combined parameter that contains both direction and magnitude. The value arg is represented in cartesian coordinates: it must be a normalized 3-vector that lies on the 2-sphere.

See VonMises for a 2D polar coordinate cousin of this distribution.

Currently only log_prob() is implemented.

Parameters:concentration (torch.Tensor) – A combined location-and-concentration vector. The direction of this vector is the location, and its magnitude is the concentration.
arg_constraints = {'concentration': Real()}
expand(batch_shape)[source]
log_prob(value)[source]
support = Real()

Transformed Distributions

InverseAutoRegressiveFlow

class InverseAutoregressiveFlow(autoregressive_nn, log_scale_min_clip=-5.0, log_scale_max_clip=3.0)[source]

Bases: pyro.distributions.torch_transform.TransformModule

An implementation of Inverse Autoregressive Flow, using Eq (10) from Kingma Et Al., 2016,

\(\mathbf{y} = \mu_t + \sigma_t\odot\mathbf{x}\)

where \(\mathbf{x}\) are the inputs, \(\mathbf{y}\) are the outputs, \(\mu_t,\sigma_t\) are calculated from an autoregressive network on \(\mathbf{x}\), and \(\sigma_t>0\).

Together with TransformedDistribution this provides a way to create richer variational approximations.

Example usage:

>>> from pyro.nn import AutoRegressiveNN
>>> base_dist = dist.Normal(torch.zeros(10), torch.ones(10))
>>> iaf = InverseAutoregressiveFlow(AutoRegressiveNN(10, [40]))
>>> iaf_module = pyro.module("my_iaf", iaf)
>>> iaf_dist = dist.TransformedDistribution(base_dist, [iaf])
>>> iaf_dist.sample()  # doctest: +SKIP
    tensor([-0.4071, -0.5030,  0.7924, -0.2366, -0.2387, -0.1417,  0.0868,
            0.1389, -0.4629,  0.0986])

The inverse of the Bijector is required when, e.g., scoring the log density of a sample with TransformedDistribution. This implementation caches the inverse of the Bijector when its forward operation is called, e.g., when sampling from TransformedDistribution. However, if the cached value isn’t available, either because it was already popped from the cache, or an arbitary value is being scored, it will calculate it manually. Note that this is an operation that scales as O(D) where D is the input dimension, and so should be avoided for large dimensional uses. So in general, it is cheap to sample from IAF and score a value that was sampled by IAF, but expensive to score an arbitrary value.

Parameters:
  • autoregressive_nn (nn.Module) – an autoregressive neural network whose forward call returns a real-valued mean and logit-scale as a tuple
  • log_scale_min_clip (float) – The minimum value for clipping the log(scale) from the autoregressive NN
  • log_scale_max_clip (float) – The maximum value for clipping the log(scale) from the autoregressive NN

References:

1. Improving Variational Inference with Inverse Autoregressive Flow [arXiv:1606.04934] Diederik P. Kingma, Tim Salimans, Rafal Jozefowicz, Xi Chen, Ilya Sutskever, Max Welling

2. Variational Inference with Normalizing Flows [arXiv:1505.05770] Danilo Jimenez Rezende, Shakir Mohamed

3. MADE: Masked Autoencoder for Distribution Estimation [arXiv:1502.03509] Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle

codomain = Real()
log_abs_det_jacobian(x, y)[source]

Calculates the elementwise determinant of the log jacobian

InverseAutoRegressiveFlowStable

class InverseAutoregressiveFlowStable(autoregressive_nn, sigmoid_bias=2.0)[source]

Bases: pyro.distributions.torch_transform.TransformModule

An implementation of an Inverse Autoregressive Flow, using Eqs (13)/(14) from Kingma Et Al., 2016,

\(\mathbf{y} = \sigma_t\odot\mathbf{x} + (1-\sigma_t)\odot\mu_t\)

where \(\mathbf{x}\) are the inputs, \(\mathbf{y}\) are the outputs, \(\mu_t,\sigma_t\) are calculated from an autoregressive network on \(\mathbf{x}\), and \(\sigma_t\) is restricted to \([0,1]\).

This variant of IAF is claimed by the authors to be more numerically stable than one using Eq (10), although in practice it leads to a restriction on the distributions that can be represented, presumably since the input is restricted to rescaling by a number on \([0,1]\).

Example usage:

>>> from pyro.nn import AutoRegressiveNN
>>> base_dist = dist.Normal(torch.zeros(10), torch.ones(10))
>>> iaf = InverseAutoregressiveFlowStable(AutoRegressiveNN(10, [40]))
>>> iaf_module = pyro.module("my_iaf", iaf)
>>> iaf_dist = dist.TransformedDistribution(base_dist, [iaf])
>>> iaf_dist.sample()  # doctest: +SKIP
    tensor([-0.4071, -0.5030,  0.7924, -0.2366, -0.2387, -0.1417,  0.0868,
            0.1389, -0.4629,  0.0986])

See InverseAutoregressiveFlow docs for a discussion of the running cost.

Parameters:
  • autoregressive_nn (nn.Module) – an autoregressive neural network whose forward call returns a real-valued mean and logit-scale as a tuple
  • sigmoid_bias (float) – bias on the hidden units fed into the sigmoid; default=`2.0`

References:

1. Improving Variational Inference with Inverse Autoregressive Flow [arXiv:1606.04934] Diederik P. Kingma, Tim Salimans, Rafal Jozefowicz, Xi Chen, Ilya Sutskever, Max Welling

2. Variational Inference with Normalizing Flows [arXiv:1505.05770] Danilo Jimenez Rezende, Shakir Mohamed

3. MADE: Masked Autoencoder for Distribution Estimation [arXiv:1502.03509] Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle

codomain = Real()
log_abs_det_jacobian(x, y)[source]

Calculates the elementwise determinant of the log jacobian

PermuteTransform

class PermuteTransform(permutation)[source]

Bases: torch.distributions.transforms.Transform

A bijection that reorders the input dimensions, that is, multiplies the input by a permutation matrix. This is useful in between InverseAutoregressiveFlow transforms to increase the flexibility of the resulting distribution and stabilize learning. Whilst not being an autoregressive transform, the log absolute determinate of the Jacobian is easily calculable as 0. Note that reordering the input dimension between two layers of InverseAutoregressiveFlow is not equivalent to reordering the dimension inside the MADE networks that those IAFs use; using a PermuteTransform results in a distribution with more flexibility.

Example usage:

>>> from pyro.nn import AutoRegressiveNN
>>> from pyro.distributions import InverseAutoregressiveFlow, PermuteTransform
>>> base_dist = dist.Normal(torch.zeros(10), torch.ones(10))
>>> iaf1 = InverseAutoregressiveFlow(AutoRegressiveNN(10, [40]))
>>> ff = PermuteTransform(torch.randperm(10, dtype=torch.long))
>>> iaf2 = InverseAutoregressiveFlow(AutoRegressiveNN(10, [40]))
>>> iaf_dist = dist.TransformedDistribution(base_dist, [iaf1, ff, iaf2])
>>> iaf_dist.sample()  # doctest: +SKIP
    tensor([-0.4071, -0.5030,  0.7924, -0.2366, -0.2387, -0.1417,  0.0868,
            0.1389, -0.4629,  0.0986])
Parameters:permutation (torch.LongTensor) – a permutation ordering that is applied to the inputs.
bijective = True
codomain = Real()
inv_permutation[source]
log_abs_det_jacobian(x, y)[source]

Calculates the elementwise determinant of the log Jacobian, i.e. log(abs([dy_0/dx_0, …, dy_{N-1}/dx_{N-1}])). Note that this type of transform is not autoregressive, so the log Jacobian is not the sum of the previous expression. However, it turns out it’s always 0 (since the determinant is -1 or +1), and so returning a vector of zeros works.

PlanarFlow

class PlanarFlow(input_dim)[source]

Bases: pyro.distributions.torch_transform.TransformModule

A ‘planar’ normalizing flow that uses the transformation

\(\mathbf{y} = \mathbf{x} + \mathbf{u}\tanh(\mathbf{w}^T\mathbf{z}+b)\)

where \(\mathbf{x}\) are the inputs, \(\mathbf{y}\) are the outputs, and the learnable parameters are \(b\in\mathbb{R}\), \(\mathbf{u}\in\mathbb{R}^D\), \(\mathbf{w}\in\mathbb{R}^D\) for input dimension \(D\). For this to be an invertible transformation, the condition \(\mathbf{w}^T\mathbf{u}>-1\) is enforced.

Together with TransformedDistribution this provides a way to create richer variational approximations.

Example usage:

>>> base_dist = dist.Normal(torch.zeros(10), torch.ones(10))
>>> plf = PlanarFlow(10)
>>> plf_module = pyro.module("my_plf", plf)
>>> plf_dist = dist.TransformedDistribution(base_dist, [plf])
>>> plf_dist.sample()  # doctest: +SKIP
    tensor([-0.4071, -0.5030,  0.7924, -0.2366, -0.2387, -0.1417,  0.0868,
            0.1389, -0.4629,  0.0986])

The inverse of this transform does not possess an analytical solution and is left unimplemented. However, the inverse is cached when the forward operation is called during sampling, and so samples drawn using planar flow can be scored.

Parameters:input_dim – the dimension of the input (and output) variable.

References:

Variational Inference with Normalizing Flows [arXiv:1505.05770] Danilo Jimenez Rezende, Shakir Mohamed

codomain = Real()
log_abs_det_jacobian(x, y)[source]

Calculates the elementwise determinant of the log jacobian

reset_parameters()[source]
u_hat()[source]

TransformModule

class TransformModule[source]

Bases: torch.distributions.transforms.Transform, torch.nn.modules.module.Module

Transforms with learnable parameters such as normalizing flows should inherit from this class rather than Transform so they are also a subclass of nn.Module and inherit all the useful methods of that class.