from __future__ import absolute_import, division, print_function
from collections import OrderedDict
import torch.nn as nn
from torch.distributions import biject_to, constraints, transform_to
from torch.nn import Parameter
import pyro
import pyro.distributions as dist
from pyro.contrib import autoname
from pyro.distributions.util import eye_like
def _get_independent_support(dist_instance):
# XXX Should we treat the case dist_instance is Independent(Independent(Normal))?
if isinstance(dist_instance, dist.Independent):
return dist_instance.base_dist.support
else:
return dist_instance.support
[docs]class Parameterized(nn.Module):
"""
A wrapper of :class:`torch.nn.Module` whose parameters can be set
constraints, set priors.
Under the hood, we move parameters to a buffer store and create "root"
parameters which are used to generate that parameter's value. For example,
if we set a contraint to a parameter, an "unconstrained" parameter will be
created, and the constrained value will be transformed from that
"unconstrained" parameter.
By default, when we set a prior to a parameter, an auto Delta guide will be
created. We can use the method :meth:`autoguide` to setup other auto guides.
To fix a parameter to a specific value, it is enough to turn off its "root"
parameters' ``requires_grad`` flags.
Example::
>>> class Linear(Parameterized):
... def __init__(self, a, b):
... super(Linear, self).__init__()
... self.a = Parameter(a)
... self.b = Parameter(b)
...
... def forward(self, x):
... return self.a * x + self.b
...
>>> linear = Linear(torch.tensor(1.), torch.tensor(0.))
>>> linear.set_constraint("a", constraints.positive)
>>> linear.set_prior("b", dist.Normal(0, 1))
>>> linear.autoguide("b", dist.Normal)
>>> assert "a_unconstrained" in dict(linear.named_parameters())
>>> assert "b_loc" in dict(linear.named_parameters())
>>> assert "b_scale_unconstrained" in dict(linear.named_parameters())
>>> assert "a" in dict(linear.named_buffers())
>>> assert "b" in dict(linear.named_buffers())
>>> assert "b_scale" in dict(linear.named_buffers())
Note that by default, data of a parameter is a float :class:`torch.Tensor`
(unless we use :func:`torch.set_default_tensor_type` to change default
tensor type). To cast these parameters to a correct data type or GPU device,
we can call methods such as :meth:`~torch.nn.Module.double` or
:meth:`~torch.nn.Module.cuda`. See :class:`torch.nn.Module` for more
information.
"""
def __init__(self):
super(Parameterized, self).__init__()
self._constraints = OrderedDict()
self._priors = OrderedDict()
self._guides = OrderedDict()
self._mode = None
[docs] def set_constraint(self, name, constraint):
"""
Sets the constraint of an existing parameter.
:param str name: Name of the parameter.
:param ~constraints.Constraint constraint: A PyTorch constraint. See
:mod:`torch.distributions.constraints` for a list of constraints.
"""
if constraint in [constraints.real, constraints.real_vector]:
if name in self._constraints: # delete previous constraints
self._constraints.pop(name, None)
self._parameters.pop("{}_unconstrained".format(name))
if name not in self._priors:
# no prior -> no guide
# so we can move param back from buffer
p = Parameter(self._buffers.pop(name).detach())
self.register_parameter(name, p)
return
if name in self._priors:
raise ValueError("Parameter {} already has a prior. Can not set a constraint for it."
.format(name))
if name in self._parameters:
p = self._parameters.pop(name)
elif name in self._buffers:
p = self._buffers[name]
else:
raise ValueError("There is no parameter with name: {}".format(name))
p_unconstrained = Parameter(transform_to(constraint).inv(p).detach())
self.register_parameter("{}_unconstrained".format(name), p_unconstrained)
# due to precision issue, we might get f(f^-1(x)) != x
# so it is necessary to transform back
p = transform_to(constraint)(p_unconstrained)
self.register_buffer(name, p.detach())
self._constraints[name] = constraint
[docs] def set_prior(self, name, prior):
"""
Sets the constraint of an existing parameter.
:param str name: Name of the parameter.
:param ~pyro.distributions.distribution.Distribution prior: A Pyro prior
distribution.
"""
if name in self._parameters:
# move param to _buffers
p = self._parameters.pop(name)
self.register_buffer(name, p)
elif name not in self._buffers:
raise ValueError("There is no parameter with name: {}".format(name))
self._priors[name] = prior
# remove the constraint and its unconstrained parameter
self.set_constraint(name, constraints.real)
self.autoguide(name, dist.Delta)
[docs] def autoguide(self, name, dist_constructor):
"""
Sets an autoguide for an existing parameter with name ``name`` (mimic
the behavior of module :mod:`pyro.contrib.autoguide`).
..note:: `dist_constructor` should be one of
:class:`~pyro.distributions.Delta`,
:class:`~pyro.distributions.Normal`, and
:class:`~pyro.distributions.MultivariateNormal`. More distribution
constructor will be supported in the future if needed.
:param str name: Name of the parameter.
:param dist_constructor: A
:class:`~pyro.distributions.distribution.Distribution` constructor.
"""
if name not in self._priors:
raise ValueError("There is no prior for parameter: {}".format(name))
if dist_constructor not in [dist.Delta, dist.Normal, dist.MultivariateNormal]:
raise NotImplementedError("Unsupported distribution type: {}"
.format(dist_constructor))
if name in self._guides:
# delete previous guide's parameters
dist_args = self._guides[name][1]
for arg in dist_args:
arg_name = "{}_{}".format(name, arg)
if arg_name in self._constraints:
# delete its unconstrained parameter
self.set_constraint(arg_name, constraints.real)
delattr(self, arg_name)
# TODO: create a new argument `autoguide_args` to store other args for other
# constructors. For example, in LowRankMVN, we need argument `rank`.
p = self._buffers[name]
if dist_constructor is dist.Delta:
p_map = Parameter(p.detach())
self.register_parameter("{}_map".format(name), p_map)
self.set_constraint("{}_map".format(name), _get_independent_support(self._priors[name]))
dist_args = {"map"}
elif dist_constructor is dist.Normal:
loc = Parameter(biject_to(self._priors[name].support).inv(p).detach())
scale = Parameter(loc.new_ones(loc.shape))
self.register_parameter("{}_loc".format(name), loc)
self.register_parameter("{}_scale".format(name), scale)
dist_args = {"loc", "scale"}
elif dist_constructor is dist.MultivariateNormal:
loc = Parameter(biject_to(self._priors[name].support).inv(p).detach())
identity = eye_like(loc, loc.size(-1))
scale_tril = Parameter(identity.repeat(loc.shape[:-1] + (1, 1)))
self.register_parameter("{}_loc".format(name), loc)
self.register_parameter("{}_scale_tril".format(name), scale_tril)
dist_args = {"loc", "scale_tril"}
else:
raise NotImplementedError
if dist_constructor is not dist.Delta:
# each arg has a constraint, so we set constraints for them
for arg in dist_args:
self.set_constraint("{}_{}".format(name, arg),
dist_constructor.arg_constraints[arg])
self._guides[name] = (dist_constructor, dist_args)
[docs] def set_mode(self, mode):
"""
Sets ``mode`` of this object to be able to use its parameters in
stochastic functions. If ``mode="model"``, a parameter will get its
value from its prior. If ``mode="guide"``, the value will be drawn from
its guide.
..note:: This method automatically sets ``mode`` for submodules which
belong to :class:`Parameterized` class.
:param str mode: Either "model" or "guide".
"""
with autoname.name_count():
for module in self.modules():
if isinstance(module, Parameterized):
module.mode = mode
@property
def mode(self):
return self._mode
@mode.setter
def mode(self, mode):
self._mode = mode
# We should get buffer values for constrained params first
# otherwise, autoguide will use the old buffer for `scale` or `scale_tril`
for name in self._constraints:
if name not in self._priors:
self._register_param(name)
for name in self._priors:
self._register_param(name)
def _sample_from_guide(self, name):
dist_constructor, dist_args = self._guides[name]
if dist_constructor is dist.Delta:
p_map = getattr(self, "{}_map".format(name))
return pyro.sample(name, dist.Delta(p_map, event_dim=p_map.dim()))
# create guide
dist_args = {arg: getattr(self, "{}_{}".format(name, arg)) for arg in dist_args}
guide = dist_constructor(**dist_args)
# no need to do transforms when support is real (for mean field ELBO)
if _get_independent_support(self._priors[name]) is constraints.real:
return pyro.sample(name, guide.to_event())
# otherwise, we do inference in unconstrained space and transform the value
# back to original space
# TODO: move this logic to contrib.autoguide or somewhere else
unconstrained_value = pyro.sample("{}_latent".format(name), guide.to_event(),
infer={"is_auxiliary": True})
transform = biject_to(self._priors[name].support)
value = transform(unconstrained_value)
log_density = transform.inv.log_abs_det_jacobian(value, unconstrained_value)
return pyro.sample(name, dist.Delta(value, log_density.sum(), event_dim=value.dim()))
def _register_param(self, name):
"""
In "model" mode, lifts the parameter with name ``name`` to a random
sample using a predefined prior (from :meth:`set_prior` method). In
"guide" mode, we use the guide generated from :meth:`autoguide`.
:param str name: Name of the parameter.
"""
if name in self._priors:
with autoname.scope(prefix=self._get_name()):
if self.mode == "model":
p = pyro.sample(name, self._priors[name])
else:
p = self._sample_from_guide(name)
elif name in self._constraints:
p_unconstrained = self._parameters["{}_unconstrained".format(name)]
p = transform_to(self._constraints[name])(p_unconstrained)
self.register_buffer(name, p)