Source code for pyro.distributions.empirical

from __future__ import absolute_import, division, print_function

import torch
from torch.distributions import constraints

from pyro.distributions.torch import Categorical
from pyro.distributions.torch_distribution import TorchDistribution
from pyro.distributions.util import copy_docs_from


[docs]@copy_docs_from(TorchDistribution) class Empirical(TorchDistribution): r""" Empirical distribution associated with the sampled data. :param torch.Tensor samples: samples from the empirical distribution. :param torch.Tensor log_weights: log weights (optional) corresponding to the samples. The leftmost shape of ``log_weights`` must match that of samples """ arg_constraints = {} support = constraints.real has_enumerate_support = True def __init__(self, samples, log_weights, validate_args=None): self._samples = samples self._log_weights = log_weights sample_shape, weight_shape = samples.size(), log_weights.size() if weight_shape > sample_shape or weight_shape != sample_shape[:len(weight_shape)]: raise ValueError("The shape of ``log_weights`` ({}) must match " "the leftmost shape of ``samples`` ({})".format(weight_shape, sample_shape)) self._aggregation_dim = log_weights.dim() - 1 event_shape = sample_shape[len(weight_shape):] self._categorical = Categorical(logits=self._log_weights) super(TorchDistribution, self).__init__(batch_shape=weight_shape[:-1], event_shape=event_shape, validate_args=validate_args) @property def sample_size(self): """ Number of samples that constitute the empirical distribution. :return int: number of samples collected. """ return self._log_weights.numel()
[docs] def sample(self, sample_shape=torch.Size()): sample_idx = self._categorical.sample(sample_shape) return self._samples[sample_idx]
[docs] def log_prob(self, value): """ Returns the log of the probability mass function evaluated at ``value``. Note that this currently only supports scoring values with empty ``sample_shape``. :param torch.Tensor value: scalar or tensor value to be scored. """ if self._validate_args: if value.shape != self.batch_shape + self.event_shape: raise ValueError("``value.shape`` must be {}".format(self.batch_shape + self.event_shape)) if self.batch_shape: value = value.unsqueeze(self._aggregation_dim) selection_mask = self._samples.eq(value) # Get a mask for all entries in the ``weights`` tensor # that correspond to ``value``. for _ in range(len(self.event_shape)): selection_mask = selection_mask.min(dim=-1)[0] selection_mask = selection_mask.type(self._categorical.probs.type()) return (self._categorical.probs * selection_mask).sum(dim=-1).log()
def _weighted_mean(self, value, keepdim=False): weights = self._log_weights.reshape(self._log_weights.size() + torch.Size([1] * (value.dim() - self._log_weights.dim()))) dim = self._aggregation_dim max_weight = weights.max(dim=dim, keepdim=True)[0] relative_probs = (weights - max_weight).exp() return (value * relative_probs).sum(dim=dim, keepdim=keepdim) / relative_probs.sum(dim=dim, keepdim=keepdim) @property def event_shape(self): return self._event_shape @property def mean(self): if self._samples.dtype in (torch.int32, torch.int64): raise ValueError("Mean for discrete empirical distribution undefined. " + "Consider converting samples to ``torch.float32`` " + "or ``torch.float64``. If these are samples from a " + "`Categorical` distribution, consider converting to a " + "`OneHotCategorical` distribution.") return self._weighted_mean(self._samples) @property def variance(self): if self._samples.dtype in (torch.int32, torch.int64): raise ValueError("Variance for discrete empirical distribution undefined. " + "Consider converting samples to ``torch.float32`` " + "or ``torch.float64``. If these are samples from a " + "`Categorical` distribution, consider converting to a " + "`OneHotCategorical` distribution.") mean = self.mean.unsqueeze(self._aggregation_dim) deviation_squared = torch.pow(self._samples - mean, 2) return self._weighted_mean(deviation_squared) @property def log_weights(self): return self._log_weights
[docs] def enumerate_support(self, expand=True): # Empirical does not support batching, so expanding is a no-op. return self._samples