# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
import math
from numbers import Number
from typing import Dict, Optional, Tuple, Union
import torch
from torch.distributions import constraints
from torch.distributions.utils import broadcast_all
import torch.nn.functional as F
from .distribution_output import DistributionOutput
from gluonts.core.component import validated
from torch.distributions import Distribution
CONST_SQRT_2 = math.sqrt(2)
CONST_INV_SQRT_2PI = 1 / math.sqrt(2 * math.pi)
CONST_INV_SQRT_2 = 1 / math.sqrt(2)
CONST_LOG_INV_SQRT_2PI = math.log(CONST_INV_SQRT_2PI)
CONST_LOG_SQRT_2PI_E = 0.5 * math.log(2 * math.pi * math.e)
[docs]class TruncatedNormal(Distribution):
"""
Implements a Truncated Normal distribution with location scaling.
Location scaling prevents the location to be "too far" from 0, which ultimately
leads to numerically unstable samples and poor gradient computation (e.g. gradient explosion).
In practice, the location is computed according to
.. math::
loc = tanh(loc / upscale) * upscale.
This behaviour can be disabled by switching off the tanh_loc parameter (see below).
Parameters
----------
loc:
normal distribution location parameter
scale:
normal distribution sigma parameter (squared root of variance)
min:
minimum value of the distribution. Default = -1.0
max:
maximum value of the distribution. Default = 1.0
upscale:
scaling factor. Default = 5.0
tanh_loc:
if ``True``, the above formula is used for
the location scaling, otherwise the raw value is kept.
Default is ``False``
References
----------
- https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
Notes
-----
This implementation is strongly based on:
- https://github.com/pytorch/rl/blob/main/torchrl/modules/distributions/truncated_normal.py
- https://github.com/toshas/torch_truncnorm
"""
arg_constraints = {
"loc": constraints.real,
"scale": constraints.greater_than(1e-6),
}
has_rsample = True
eps = 1e-6
def __init__(
self,
loc: torch.Tensor,
scale: torch.Tensor,
min: Union[torch.Tensor, float] = -1.0,
max: Union[torch.Tensor, float] = 1.0,
upscale: Union[torch.Tensor, float] = 5.0,
tanh_loc: bool = False,
):
scale = scale.clamp_min(self.eps)
if tanh_loc:
loc = (loc / upscale).tanh() * upscale
loc = loc + (max - min) / 2 + min
self.min = min
self.max = max
self.upscale = upscale
self.loc, self.scale, a, b = broadcast_all(
loc, scale, self.min, self.max
)
self._non_std_a = a
self._non_std_b = b
self.a = (a - self.loc) / self.scale
self.b = (b - self.loc) / self.scale
if isinstance(a, Number) and isinstance(b, Number):
batch_shape = torch.Size()
else:
batch_shape = self.a.size()
super(TruncatedNormal, self).__init__(batch_shape)
if self.a.dtype != self.b.dtype:
raise ValueError("Truncation bounds types are different")
if any(
(self.a >= self.b)
.view(
-1,
)
.tolist()
):
raise ValueError("Incorrect truncation range")
eps = self.eps
self._dtype_min_gt_0 = eps
self._dtype_max_lt_1 = 1 - eps
self._little_phi_a = self._little_phi(self.a)
self._little_phi_b = self._little_phi(self.b)
self._big_phi_a = self._big_phi(self.a)
self._big_phi_b = self._big_phi(self.b)
self._Z = (self._big_phi_b - self._big_phi_a).clamp(eps, 1 - eps)
self._log_Z = self._Z.log()
little_phi_coeff_a = torch.nan_to_num(self.a, nan=math.nan)
little_phi_coeff_b = torch.nan_to_num(self.b, nan=math.nan)
self._lpbb_m_lpaa_d_Z = (
self._little_phi_b * little_phi_coeff_b
- self._little_phi_a * little_phi_coeff_a
) / self._Z
self._mean = -(self._little_phi_b - self._little_phi_a) / self._Z
self._variance = (
1
- self._lpbb_m_lpaa_d_Z
- ((self._little_phi_b - self._little_phi_a) / self._Z) ** 2
)
self._entropy = (
CONST_LOG_SQRT_2PI_E + self._log_Z - 0.5 * self._lpbb_m_lpaa_d_Z
)
self._log_scale = self.scale.log()
self._mean_non_std = self._mean * self.scale + self.loc
self._variance_non_std = self._variance * self.scale**2
self._entropy_non_std = self._entropy + self._log_scale
@constraints.dependent_property
def support(self):
return constraints.interval(self.a, self.b)
@property
def mean(self):
return self._mean_non_std
@property
def variance(self):
return self._variance_non_std
@property
def entropy(self):
return self._entropy_non_std
@staticmethod
def _little_phi(x):
return (-(x**2) * 0.5).exp() * CONST_INV_SQRT_2PI
def _big_phi(self, x):
phi = 0.5 * (1 + (x * CONST_INV_SQRT_2).erf())
return phi.clamp(self.eps, 1 - self.eps)
@staticmethod
def _inv_big_phi(x):
return CONST_SQRT_2 * (2 * x - 1).erfinv()
[docs] def cdf_truncated_standard_normal(self, value):
return ((self._big_phi(value) - self._big_phi_a) / self._Z).clamp(0, 1)
[docs] def icdf_truncated_standard_normal(self, value):
y = self._big_phi_a + value * self._Z
y = y.clamp(self.eps, 1 - self.eps)
return self._inv_big_phi(y)
[docs] def log_prob_truncated_standard_normal(self, value):
return CONST_LOG_INV_SQRT_2PI - self._log_Z - (value**2) * 0.5
def _to_std_rv(self, value):
return (value - self.loc) / self.scale
def _from_std_rv(self, value):
return value * self.scale + self.loc
[docs] def cdf(self, value):
return self.cdf_truncated_standard_normal(self._to_std_rv(value))
[docs] def icdf(self, value):
sample = self._from_std_rv(self.icdf_truncated_standard_normal(value))
# clamp data but keep gradients
sample_clip = torch.stack(
[sample.detach(), self._non_std_a.detach().expand_as(sample)], 0
).max(0)[0]
sample_clip = torch.stack(
[sample_clip, self._non_std_b.detach().expand_as(sample)], 0
).min(0)[0]
sample.data.copy_(sample_clip)
return sample
[docs] def log_prob(self, value):
a = self._non_std_a + self._dtype_min_gt_0
a = a.expand_as(value)
b = self._non_std_b - self._dtype_min_gt_0
b = b.expand_as(value)
value = torch.min(torch.stack([value, b], -1), dim=-1)[0]
value = torch.max(torch.stack([value, a], -1), dim=-1)[0]
value = self._to_std_rv(value)
return self.log_prob_truncated_standard_normal(value) - self._log_scale
[docs] def rsample(self, sample_shape=None):
if sample_shape is None:
sample_shape = torch.Size([])
shape = self._extended_shape(sample_shape)
p = torch.empty(shape, device=self.a.device).uniform_(
self._dtype_min_gt_0, self._dtype_max_lt_1
)
return self.icdf(p)
[docs]class TruncatedNormalOutput(DistributionOutput):
distr_cls: type = TruncatedNormal
@validated()
def __init__(
self,
min: float = -1.0,
max: float = 1.0,
upscale: float = 5.0,
tanh_loc: bool = False,
) -> None:
assert min < max, "max must be strictly greater than min"
super().__init__(self)
self.min = min
self.max = max
self.upscale = upscale
self.tanh_loc = tanh_loc
self.args_dim: Dict[str, int] = {
"loc": 1,
"scale": 1,
}
[docs] @classmethod
def domain_map( # type: ignore
cls,
loc: torch.Tensor,
scale: torch.Tensor,
):
scale = F.softplus(scale)
return (
loc.squeeze(-1),
scale.squeeze(-1),
)
# Overwrites the parent class method: We pass constant float and
# boolean parameters across tensors
[docs] def distribution(
self,
distr_args,
loc: Optional[torch.Tensor] = None,
scale: Optional[torch.Tensor] = None,
) -> Distribution:
(loc, scale) = distr_args
assert isinstance(loc, torch.Tensor)
assert isinstance(scale, torch.Tensor)
return TruncatedNormal(
loc=loc,
scale=scale,
upscale=self.upscale,
min=self.min,
max=self.max,
tanh_loc=self.tanh_loc,
)
@property
def event_shape(self) -> Tuple:
return ()