Source code for gluonts.torch.distributions.binned_uniforms
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import Dict, Optional, Tuple, cast
import torch
import torch.nn.functional as F
from torch.distributions import Distribution, constraints
from gluonts.core.component import validated
from .distribution_output import DistributionOutput
[docs]class BinnedUniforms(Distribution):
r"""
Binned uniforms distribution.
Args:
bins_lower_bound (float): The lower bound of the bin edges
bins_upper_bound (float): The upper bound of the bin edges
numb_bins (int): The number of equidistance bins to allocate between
`bins_lower_bound` and `bins_upper_bound`. Default value is 100.
logits (tensor): the logits defining the probability of each bins.
These are softmaxed. The tensor is of shape (*batch_shape,)
validate_args (bool) from the pytorch Distribution class
"""
arg_constraints = {"logits": constraints.real}
support = constraints.real
has_rsample = False
def __init__(
self,
bins_lower_bound: float,
bins_upper_bound: float,
logits: torch.Tensor,
numb_bins: int = 100,
validate_args: Optional[bool] = None,
):
assert bins_lower_bound < bins_upper_bound, (
f"bins_lower_bound {bins_lower_bound} needs to less than "
f"bins_upper_bound {bins_upper_bound}"
)
assert (
logits.shape[-1] == numb_bins
), "The distribution requires one logit per bin."
self.logits = logits
setattr(self, "logits", self.logits)
device = logits.device
super(BinnedUniforms, self).__init__(
batch_shape=logits.shape[:-1],
event_shape=logits.shape[-1:],
validate_args=validate_args,
)
self.numb_bins = numb_bins
# Creation the bin locations
# Bins locations are placed uniformly between bins_lower_bound and
# bins_upper_bound, though more complex methods could be used
self.bin_min = bins_lower_bound
self.bin_max = bins_upper_bound
self.bin_edges = torch.linspace(
self.bin_min, self.bin_max, self.numb_bins + 1
)
self.bin_widths = self.bin_edges[1:] - self.bin_edges[:-1]
self.bin_centres = (self.bin_edges[1:] + self.bin_edges[:-1]) * 0.5
self.bin_edges = self.bin_edges.to(device)
self.bin_widths = self.bin_widths.to(device)
self.bin_centres = self.bin_centres.to(device)
@property
def mode(self):
"""
Returns the mode of the distribution.
mode.shape : (*batch_shape,)
"""
bins_prob = self.bins_prob
values_max, index_max = torch.max(bins_prob, dim=-1)
indicator_max = values_max.unsqueeze(dim=-1) == bins_prob
# print(indicator_max.shape)
bin_centres = self.bin_centres.unsqueeze(dim=0)
# upper_edges.shape: (1, numb_bins)
batch_shape_extended = self.bins_prob[..., 0:1].shape
bin_centres = bin_centres.repeat(batch_shape_extended)
mode = (bin_centres * indicator_max).sum(dim=-1)
return mode
@property
def median(self):
"""
Returns the median of the distribution.
median.shape : (*batch_shape,)
"""
return self.icdf(torch.tensor(0.5))
@property
def mean(self):
"""
Returns the mean of the distribution.
mean.shape : (*batch_shape,)
"""
batch_shape_extended = self.bins_prob[..., 0:1].shape
bin_centres = self.bin_centres.unsqueeze(dim=0)
# bin_centres.shape: (1, numb_bins)
bin_centres = bin_centres.repeat(batch_shape_extended)
# bin_centres.shape: (*batch_shape, numb_bins)
return torch.mean(bin_centres * self.bins_prob, dim=-1)
@property
def bins_prob(self):
"""
Returns the probability of the observed point to be in each of the bins
bins_prob.shape: (*batch_shape, event_shape).
event_shape is numb_bins
"""
bins_prob = self.log_bins_prob.exp()
return bins_prob
@property
def log_bins_prob(self):
return F.log_softmax(self.logits, dim=-1) # log_softmax along bins
[docs] def log_prob(self, x):
"""
Log probability for a tensor of datapoints `x`.
'x' is to have shape (*batch_shape)
"""
for i in range(0, len(x.shape)):
assert (
x.shape[i] == self.batch_shape[i]
), "We expect the input to be a tensor of size batch_shape"
return self.log_binned_p(x)
[docs] def log_binned_p(self, x):
"""
Log probability for a tensor of datapoints `x`.
'x' is to have shape (*batch_shape)
"""
one_hot_bin_indicator = self.get_one_hot_bin_indicator(
x, in_float=True
)
# one_hot_bin_indicator.shape: (*batch_shape, numb_bins)
logp = (one_hot_bin_indicator * self.log_bins_prob).sum(dim=-1)
# logp.shape: (*batch_shape)
return logp
[docs] def pdf(self, x):
"""
Probability for a tensor of data points `x`.
'x' is to have shape (*batch_shape)
"""
return torch.exp(self.log_prob(x))
[docs] def get_one_hot_bin_indicator(self, x, in_float=False):
"""
'x' is to have shape (*batch_shape) which can be for example () or (32,
) or (32, 168, )
"""
for i in range(0, len(x.shape)):
assert (
x.shape[i] == self.batch_shape[i]
), "We expect the input to be a tensor of size batch_shape"
numb_dim_batch_shape = len(x.shape)
x_copy = x
x = x.unsqueeze(dim=-1)
# x.shape: (*batch_shape, 1)
upper_edges = self.bin_edges[1:]
for i in range(0, numb_dim_batch_shape):
upper_edges = upper_edges.unsqueeze(dim=0)
# upper_edge.shape: [1, ... ,numb_bins]
lower_edge = self.bin_edges[:-1]
for i in range(0, numb_dim_batch_shape):
lower_edge = lower_edge.unsqueeze(dim=0)
# lower_edge.shape: [1, ... ,numb_bins]
one_hot_bin_indicator = ((lower_edge <= x) * (x < upper_edges)).long()
# one_hot_bin_indicator.shape: [*batch_shape, numb_bins]
# This handles if x falls outside of [self.bin_min, self.bin_max]
is_higher_than_last_edge = x_copy >= self.bin_edges[..., -1]
# is_higher_than_last_edge: [*batch_shape, numb_dim]
is_lower_than_first_edge = x_copy <= self.bin_edges[..., 0]
# is_lower_than_first_edge: [*batch_shape, numb_dim]
one_hot_bin_indicator[..., -1][is_higher_than_last_edge] = 1
one_hot_bin_indicator[..., 0][is_lower_than_first_edge] = 1
if not in_float:
return one_hot_bin_indicator == 1 # booleans
else:
return one_hot_bin_indicator.float() # floats
[docs] def icdf(self, quantiles):
"""
Inverse cdf of a tensor of quantile `quantiles` 'quantiles' is of shape
(*batch_shape) with values between (0.0, 1.0)
This is the function to be called from the outside.
"""
assert (quantiles >= 0.0).all(), "quantiles must be between (0.0, 1.0)"
assert (quantiles <= 1.0).all(), "quantiles must be between (0.0, 1.0)"
# If given a single value as quantile, we put it to batch size
if (
len(quantiles.shape) == 0
or len(quantiles.shape) == 1
and quantiles.shape[0] == 1
):
batch_shape = self.bins_prob[..., 0].shape
quantiles = quantiles.repeat(batch_shape)
for i in range(0, len(quantiles.shape)):
assert quantiles.shape[i] == self.batch_shape[i], (
"We expect the quantile to be either a single float or a "
"tensor of size batch_shape"
)
return self._inverse_cdf(quantiles)
def _inverse_cdf(self, quantiles):
"""
Inverse cdf of a tensor of quantile `quantiles` 'quantiles' is of shape
(*batch_shape) with values between (0.0, 1.0)
"""
return self._icdf_binned(quantiles)
def _icdf_binned(self, quantiles):
"""
Inverse cdf of a tensor of quantile `quantiles` 'quantiles' is of shape
(*batch_shape) with values between (0.0, 1.0)
"""
quantiles = quantiles.unsqueeze(dim=-1)
# quantiles.shape: (*batch_shape, 1)
batch_shape_extended = quantiles.shape
bins_prob = self.bins_prob
# For each bin we get the cdf up to the bin (lower) and the cdf
# including the bin (upper)
incomplete_cdf_upper = bins_prob.cumsum(dim=-1)
# incomplete_cdf_upper.shape: (*batch_shape, numb_bins)
incomplete_cdf_lower = torch.zeros_like(incomplete_cdf_upper)
incomplete_cdf_lower[..., 1:] = incomplete_cdf_upper[..., :-1]
# incomplete_cdf_lower.shape: (*batch_shape, numb_bins)
one_hot_bin_indicator = (incomplete_cdf_lower <= quantiles) * (
quantiles < incomplete_cdf_upper
)
# one_hot_bin_indicator.shape: (*batch_shape, numb_bins)
# Handling the quantile equal to 1.0
higher_than_last = quantiles[..., 0] >= incomplete_cdf_upper[..., -1]
one_hot_bin_indicator[..., -1][higher_than_last] = True
upper_edges = self.bin_edges[1:].unsqueeze(dim=0)
# upper_edges.shape: (1, numb_bins)
upper_edges = upper_edges.repeat(batch_shape_extended)
# upper_edges.shape: (*batch_shape, numb_bins)
lower_edges = self.bin_edges[:-1].unsqueeze(dim=0)
# lower_edges.shape: (1, numb_bins)
lower_edges = lower_edges.repeat(batch_shape_extended)
# lower_edges.shape: (*batch_shape, numb_bins)
bin_width = upper_edges[one_hot_bin_indicator].view(
batch_shape_extended
) - lower_edges[one_hot_bin_indicator].view(batch_shape_extended)
# bin_width.shape: (*batch_shape)
prob_bin = bins_prob[one_hot_bin_indicator].view(batch_shape_extended)
# prob_bin.shape: (*batch_shape)
prob_left = quantiles.view(
batch_shape_extended
) - incomplete_cdf_lower[one_hot_bin_indicator].view(
batch_shape_extended
)
# prob_left.shape: (*batch_shape)
bin_lower_edge = lower_edges[one_hot_bin_indicator].view(
batch_shape_extended
)
# bin_lower_edge.shape: (*batch_shape)
result_icdf = bin_width * prob_left / prob_bin + bin_lower_edge
return result_icdf.squeeze(dim=-1)
[docs] def cdf(self, x):
"""
Cumulative density tensor for a tensor of data points `x`.
'x' is expected to be of shape (*batch_shape)
"""
for i in range(0, len(x.shape)):
assert (
x.shape[i] == self.batch_shape[i]
), "We expect the input to be a tensor of size batch_shape"
return self._cdf_binned(x)
def _cdf_binned(self, x):
"""
Cumulative density tensor for a tensor of data points `x`.
'x' is expected to be of shape (*batch_shape). The cdf is composed of 2
parts:
- the cdf up to the bin
- the cdf within the bin that the point falls into (modeled with a uniform distribution)
"""
bins_prob = self.bins_prob
batch_shape_extended = bins_prob[..., 0:1].shape
# Get the location of points in the bins
one_hot_bin_indicator = self.get_one_hot_bin_indicator(x)
# one_hot_bin_indicator.shape: (*batch_shape, numb_bins)
# Get the cdf over the bins i.e. the probability mass up to each
# bin's upper edge
incomplete_cdf = bins_prob.cumsum(dim=-1) - bins_prob
# incomplete_cdf.shape: (*batch_shape,numb_bins)
cdf_up_to_bin = (
(incomplete_cdf * one_hot_bin_indicator)
.sum(dim=-1)
.unsqueeze(dim=-1)
)
# incomplete_cdf.shape: (*batch_shape,1)
# Prepare to select the edges of the bins that the points fall into
upper_edges = self.bin_edges[1:].unsqueeze(dim=0)
# upper_edges.shape: (1...,numb_bins)
upper_edges = upper_edges.repeat(batch_shape_extended)
# upper_edges.shape: (*batch_shape,numb_bins)
lower_edges = self.bin_edges[:-1].unsqueeze(dim=0)
lower_edges = lower_edges.repeat(batch_shape_extended)
# lower_edges.shape: (*batch_shape,numb_bins)
# With the edges and the point value we can get the cdf within the
# bin given that they are uniform
# distributions, and weight it by the probability of the bin
bin_width = upper_edges[one_hot_bin_indicator].view(
batch_shape_extended
) - lower_edges[one_hot_bin_indicator].view(batch_shape_extended)
# bin_width.shape: (*batch_shape,1)
dist_in_bin = x.unsqueeze(dim=-1) - lower_edges[
one_hot_bin_indicator
].view(batch_shape_extended)
# dist_in_bin.shape: (*batch_shape,1)
dist_in_bin = torch.max(
torch.min(dist_in_bin, bin_width), torch.zeros_like(dist_in_bin)
)
# this is for points falling outside the bins
cdf_in_bin = (
bins_prob[one_hot_bin_indicator].view(batch_shape_extended)
* dist_in_bin
/ bin_width
)
# cdf_in_bin.shape: (*batch_shape,1)
return (
(cdf_in_bin + cdf_up_to_bin)
.reshape(batch_shape_extended)
.squeeze(dim=-1)
)
[docs] def sample(self, sample_shape=torch.Size()):
"""
Returns samples from the distribution.
Returns:
samples of shape (*sample_shape, *batch_shape)
"""
if len(sample_shape) == 0:
quantiles = torch.rand(self.batch_shape)
samples = self.icdf(quantiles)
else:
samples = torch.zeros(list(sample_shape) + list(self.batch_shape))
for i in range(sample_shape[0]):
quantiles = torch.rand(self.batch_shape)
samples_i = self.icdf(quantiles)
samples[i, ...] = samples_i
return samples
[docs] def rsample(self, sample_shape=torch.Size()):
"""
We do not have an implementation for the reparameterization trick yet.
"""
raise NotImplementedError
[docs] def entropy(self):
"""
We do not have an implementation of the entropy yet.
"""
raise NotImplementedError
[docs] def enumerate_support(self, expand=True):
"""
This is a real valued distribution.
"""
raise NotImplementedError
[docs]class BinnedUniformsOutput(DistributionOutput):
distr_cls: type = BinnedUniforms
@validated()
def __init__(
self,
bins_lower_bound: float,
bins_upper_bound: float,
num_bins: int,
) -> None:
super().__init__(self)
assert (
isinstance(num_bins, int) and num_bins > 1
), "num_bins should be an integer and greater than 1"
assert bins_lower_bound < bins_upper_bound, (
f"bins_lower_bound {bins_lower_bound} needs to less than "
f"bins_upper_bound {bins_upper_bound}"
)
self.num_bins = num_bins
self.bins_lower_bound = bins_lower_bound
self.bins_upper_bound = bins_upper_bound
self.args_dim = cast(
Dict[str, int],
{"logits": num_bins},
)
[docs] @classmethod
def domain_map(cls, logits: torch.Tensor) -> torch.Tensor: # type: ignore
logits = torch.abs(logits)
return logits
[docs] def distribution(
self,
distr_args,
loc: Optional[torch.Tensor] = None,
scale: Optional[torch.Tensor] = None,
) -> BinnedUniforms:
return self.distr_cls(
self.bins_lower_bound,
self.bins_upper_bound,
distr_args,
self.num_bins,
)
@property
def event_shape(self) -> Tuple:
return ()