Source code for gluonts.mx.distribution.transformed_distribution

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import Any, List, Optional, Tuple

import mxnet as mx
import numpy as np
from mxnet import autograd

from gluonts.core.component import validated
from gluonts.mx import Tensor

from . import bijection as bij
from .distribution import Distribution, _index_tensor, getF
from .bijection import AffineTransformation


[docs]class TransformedDistribution(Distribution): r""" A distribution obtained by applying a sequence of transformations on top of a base distribution. """ @validated() def __init__( self, base_distribution: Distribution, transforms: List[bij.Bijection] ) -> None: self.base_distribution = base_distribution self.transforms = transforms self.is_reparameterizable = self.base_distribution.is_reparameterizable # use these to cache shapes and avoid recomputing all steps # the reason we cannot do the computations here directly # is that this constructor would fail in mx.symbol mode self._event_dim: Optional[int] = None self._event_shape: Optional[Tuple] = None self._batch_shape: Optional[Tuple] = None @property def F(self): return self.base_distribution.F @property def support_min_max(self) -> Tuple[Tensor, Tensor]: F = self.F lb, ub = self.base_distribution.support_min_max for t in self.transforms: _lb = t.f(lb) _ub = t.f(ub) lb = F.minimum(_lb, _ub) ub = F.maximum(_lb, _ub) return lb, ub def _slice_bijection( self, trans: bij.Bijection, item: Any ) -> bij.Bijection: from .box_cox_transform import BoxCoxTransform if isinstance(trans, bij.AffineTransformation): loc = ( _index_tensor(trans.loc, item) if trans.loc is not None else None ) scale = ( _index_tensor(trans.scale, item) if trans.scale is not None else None ) return bij.AffineTransformation(loc=loc, scale=scale) elif isinstance(trans, BoxCoxTransform): return BoxCoxTransform( _index_tensor(trans.lambda_1, item), _index_tensor(trans.lambda_2, item), ) elif isinstance(trans, bij.InverseBijection): return bij.InverseBijection( self._slice_bijection(trans._bijection, item) ) else: return trans def __getitem__(self, item): bd_slice = self.base_distribution[item] trans_slice = [self._slice_bijection(t, item) for t in self.transforms] return TransformedDistribution(bd_slice, trans_slice) @property def event_dim(self): if self._event_dim is None: self._event_dim = max( [self.base_distribution.event_dim] + [t.event_dim for t in self.transforms] ) assert isinstance(self._event_dim, int) return self._event_dim @property def batch_shape(self) -> Tuple: if self._batch_shape is None: shape = ( self.base_distribution.batch_shape + self.base_distribution.event_shape ) self._batch_shape = shape[: len(shape) - self.event_dim] assert isinstance(self._batch_shape, tuple) return self._batch_shape @property def event_shape(self) -> Tuple: if self._event_shape is None: shape = ( self.base_distribution.batch_shape + self.base_distribution.event_shape ) self._event_shape = shape[len(shape) - self.event_dim :] assert isinstance(self._event_shape, tuple) return self._event_shape
[docs] def sample( self, num_samples: Optional[int] = None, dtype=np.float32 ) -> Tensor: with autograd.pause(): s = self.base_distribution.sample( num_samples=num_samples, dtype=dtype ) for t in self.transforms: s = t.f(s) return s
[docs] def sample_rep( self, num_samples: Optional[int] = None, dtype=float ) -> Tensor: s = self.base_distribution.sample_rep( num_samples=num_samples, dtype=dtype ) for t in self.transforms: s = t.f(s) return s
[docs] def log_prob(self, y: Tensor) -> Tensor: F = getF(y) lp = 0.0 x = y for t in self.transforms[::-1]: x = t.f_inv(y) ladj = t.log_abs_det_jac(x, y) lp = lp - sum_trailing_axes(F, ladj, self.event_dim - t.event_dim) y = x return self.base_distribution.log_prob(x) + lp
[docs] def cdf(self, y: Tensor) -> Tensor: x = y sign = 1.0 for t in self.transforms[::-1]: x = t.f_inv(x) sign = sign * t.sign f = self.base_distribution.cdf(x) return sign * (f - 0.5) + 0.5
[docs] def quantile(self, level: Tensor) -> Tensor: F = getF(level) sign = 1.0 for t in self.transforms: sign = sign * t.sign if not isinstance(sign, (mx.nd.NDArray, mx.sym.Symbol)): level = level if sign > 0 else (1.0 - level) q = self.base_distribution.quantile(level) else: # level.shape = (#levels,) # q_pos.shape = (#levels, batch_size, ...) # sign.shape = (batch_size, ...) q_pos = self.base_distribution.quantile(level) q_neg = self.base_distribution.quantile(1.0 - level) cond = F.broadcast_greater(sign, sign.zeros_like()) cond = F.broadcast_add(cond, q_pos.zeros_like()) q = F.where(cond, q_pos, q_neg) for t in self.transforms: q = t.f(q) return q
[docs]class AffineTransformedDistribution(TransformedDistribution): @validated() def __init__( self, base_distribution: Distribution, loc: Optional[Tensor] = None, scale: Optional[Tensor] = None, ) -> None: super().__init__(base_distribution, [AffineTransformation(loc, scale)]) self.loc = loc if loc is not None else 0 self.scale = scale if scale is not None else 1 @property def mean(self) -> Tensor: return self.base_distribution.mean * self.scale + self.loc @property def stddev(self) -> Tensor: return self.base_distribution.stddev * self.scale @property def variance(self) -> Tensor: # TODO: cover the multivariate case here too return self.base_distribution.variance * self.scale**2
# TODO: crps
[docs]def sum_trailing_axes(F, x: Tensor, k: int) -> Tensor: for _ in range(k): x = F.sum(x, axis=-1) return x