Source code for pyro.distributions.relaxed_straight_through

from __future__ import absolute_import, division, print_function

import torch

from pyro.distributions.torch import RelaxedOneHotCategorical, RelaxedBernoulli
from pyro.distributions.util import copy_docs_from
from torch.distributions.utils import clamp_probs

[docs]@copy_docs_from(RelaxedOneHotCategorical) class RelaxedOneHotCategoricalStraightThrough(RelaxedOneHotCategorical): """ An implementation of :class:`~torch.distributions.relaxed_categorical.RelaxedOneHotCategorical` with a straight-through gradient estimator. This distribution has the following properties: - The samples returned by the :meth:`rsample` method are discrete/quantized. - The :meth:`log_prob` method returns the log probability of the relaxed/unquantized sample using the GumbelSoftmax distribution. - In the backward pass the gradient of the sample with respect to the parameters of the distribution uses the relaxed/unquantized sample. References: [1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables, Chris J. Maddison, Andriy Mnih, Yee Whye Teh [2] Categorical Reparameterization with Gumbel-Softmax, Eric Jang, Shixiang Gu, Ben Poole """
[docs] def rsample(self, sample_shape=torch.Size()): soft_sample = super(RelaxedOneHotCategoricalStraightThrough, self).rsample(sample_shape) soft_sample = clamp_probs(soft_sample) hard_sample = QuantizeCategorical.apply(soft_sample) return hard_sample
[docs] def log_prob(self, value): value = getattr(value, '_unquantize', value) return super(RelaxedOneHotCategoricalStraightThrough, self).log_prob(value)
class QuantizeCategorical(torch.autograd.Function): @staticmethod def forward(ctx, soft_value): argmax = soft_value.max(-1)[1] hard_value = soft_value.new_zeros(soft_value.shape) hard_value._unquantize = soft_value if argmax.dim() < hard_value.dim(): argmax = argmax.unsqueeze(-1) return hard_value.scatter_(-1, argmax, 1) @staticmethod def backward(ctx, grad): return grad
[docs]@copy_docs_from(RelaxedBernoulli) class RelaxedBernoulliStraightThrough(RelaxedBernoulli): """ An implementation of :class:`~torch.distributions.relaxed_bernoulli.RelaxedBernoulli` with a straight-through gradient estimator. This distribution has the following properties: - The samples returned by the :meth:`rsample` method are discrete/quantized. - The :meth:`log_prob` method returns the log probability of the relaxed/unquantized sample using the GumbelSoftmax distribution. - In the backward pass the gradient of the sample with respect to the parameters of the distribution uses the relaxed/unquantized sample. References: [1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables, Chris J. Maddison, Andriy Mnih, Yee Whye Teh [2] Categorical Reparameterization with Gumbel-Softmax, Eric Jang, Shixiang Gu, Ben Poole """
[docs] def rsample(self, sample_shape=torch.Size()): soft_sample = super(RelaxedBernoulliStraightThrough, self).rsample(sample_shape) soft_sample = clamp_probs(soft_sample) hard_sample = QuantizeBernoulli.apply(soft_sample) return hard_sample
[docs] def log_prob(self, value): value = getattr(value, '_unquantize', value) return super(RelaxedBernoulliStraightThrough, self).log_prob(value)
class QuantizeBernoulli(torch.autograd.Function): @staticmethod def forward(ctx, soft_value): hard_value = soft_value.round() hard_value._unquantize = soft_value return hard_value @staticmethod def backward(ctx, grad): return grad