# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import Dict, List, Tuple
import numpy as np
from gluonts.core.component import validated
from gluonts.mx import Tensor
from .distribution import Distribution, _sample_multiple, getF, softplus
from .distribution_output import DistributionOutput
[docs]class LogitNormal(Distribution):
r"""
The logit-normal distribution.
Parameters
----------
mu
Tensor containing the location, of shape
`(*batch_shape, *event_shape)`.
sigma
Tensor indicating the scale, of shape `(*batch_shape, *event_shape)`.
F
"""
@validated()
def __init__(self, mu: Tensor, sigma: Tensor) -> None:
super().__init__()
self.mu = mu
self.sigma = sigma
@property
def F(self):
return getF(self.mu)
@property
def batch_shape(self) -> Tuple:
return self.mu.shape
@property
def event_shape(self) -> Tuple:
return ()
@property
def event_dim(self) -> int:
return 0
[docs] def log_prob(self, x: Tensor) -> Tensor:
F = self.F
x_clip = 1e-3
x = F.clip(x, x_clip, 1 - x_clip)
log_prob = -1.0 * (
F.log(self.sigma)
+ F.log(F.sqrt(2 * F.full(1, np.pi)))
+ F.log(x)
+ F.log(1 - x)
+ (
(F.log(x) - F.log(1 - x) - self.mu) ** 2
/ (2 * (self.sigma**2))
)
)
return log_prob
[docs] def sample(self, num_samples=None, dtype=np.float32):
def s(mu):
F = self.F
q_min = 1e-3
q_max = 1 - q_min
sample = F.sample_uniform(
F.ones_like(mu) * F.full(1, q_min),
F.ones_like(mu) * F.full(1, q_max),
)
transf_sample = self.quantile(sample)
return transf_sample
mult_samp = _sample_multiple(s, self.mu, num_samples=num_samples)
return mult_samp
[docs] def quantile(self, level: Tensor) -> Tensor:
F = self.F
exp = F.exp(
self.mu
+ (self.sigma * F.sqrt(F.full(1, 2)) * F.erfinv(2 * level - 1))
)
return exp / (1 + exp)
@property
def args(self) -> List:
return [self.mu, self.sigma]
[docs]class LogitNormalOutput(DistributionOutput):
args_dim: Dict[str, int] = {"mu": 1, "sigma": 1}
distr_cls: type = LogitNormal
[docs] @classmethod
def domain_map(cls, F, mu, sigma):
sigma = F.maximum(softplus(F, sigma), cls.eps())
return mu.squeeze(axis=-1), sigma.squeeze(axis=-1)
@property
def event_shape(self) -> Tuple:
return ()
[docs] def distribution(
self, distr_args, loc=None, scale=None, **kwargs
) -> Distribution:
return self.distr_cls(*distr_args)