Source code for pyro.distributions.torch

from __future__ import absolute_import, division, print_function

import torch
from torch.distributions import constraints, kl_divergence, register_kl

from pyro.distributions.torch_distribution import IndependentConstraint, TorchDistributionMixin
from pyro.distributions.util import sum_rightmost


[docs]class MultivariateNormal(torch.distributions.MultivariateNormal, TorchDistributionMixin): support = IndependentConstraint(constraints.real, 1) # TODO move upstream
[docs]class Independent(torch.distributions.Independent, TorchDistributionMixin): @constraints.dependent_property def support(self): return IndependentConstraint(self.base_dist.support, self.reinterpreted_batch_ndims) @property def _validate_args(self): return self.base_dist._validate_args @_validate_args.setter def _validate_args(self, value): self.base_dist._validate_args = value
@register_kl(Independent, Independent) def _kl_independent_independent(p, q): if p.reinterpreted_batch_ndims != q.reinterpreted_batch_ndims: raise NotImplementedError kl = kl_divergence(p.base_dist, q.base_dist) if p.reinterpreted_batch_ndims: kl = sum_rightmost(kl, p.reinterpreted_batch_ndims) return kl # Programmatically load all distributions from PyTorch. __all__ = [] for _name, _Dist in torch.distributions.__dict__.items(): if not isinstance(_Dist, type): continue if not issubclass(_Dist, torch.distributions.Distribution): continue if _Dist is torch.distributions.Distribution: continue try: _PyroDist = locals()[_name] except KeyError: _PyroDist = type(_name, (_Dist, TorchDistributionMixin), {}) _PyroDist.__module__ = __name__ locals()[_name] = _PyroDist _PyroDist.__doc__ = ''' Wraps :class:`{}.{}` with :class:`~pyro.distributions.torch_distribution.TorchDistributionMixin`. '''.format(_Dist.__module__, _Dist.__name__) __all__.append(_name) # Create sphinx documentation. __doc__ = '\n\n'.join([ ''' {0} ---------------------------------------------------------------- .. autoclass:: pyro.distributions.{0} '''.format(_name) for _name in sorted(__all__) ])