Source code for gluonts.mx.distribution.binned

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import List, Optional, Tuple

import mxnet as mx
import numpy as np
from mxnet import gluon

from gluonts.core.component import validated
from gluonts.mx import Tensor

from .distribution import MAX_SUPPORT_VAL, Distribution, _sample_multiple, getF
from .distribution_output import DistributionOutput


[docs]class Binned(Distribution): r""" A binned distribution defined by a set of bins via bin centers and bin probabilities. Parameters ---------- bin_log_probs Tensor containing log probabilities of the bins, of shape `(*batch_shape, num_bins)`. bin_centers Tensor containing the bin centers, of shape `(*batch_shape, num_bins)`. F label_smoothing The label smoothing weight, real number in `[0, 1)`. Default `None`. If not `None`, then the loss of the distribution will be "label smoothed" cross-entropy. For example, instead of computing cross-entropy loss between the estimated bin probabilities and a hard-label (one-hot encoding) `[1, 0, 0]`, a soft label of `[0.9, 0.05, 0.05]` is taken as the ground truth (when `label_smoothing=0.15`). See(Muller et al., 2019) [MKH19]_, for further reference. """ is_reparameterizable = False @validated() def __init__( self, bin_log_probs: Tensor, bin_centers: Tensor, label_smoothing: Optional[float] = None, ) -> None: self.bin_centers = bin_centers self.bin_log_probs = bin_log_probs self._bin_probs = None self.bin_edges = Binned._compute_edges(self.F, bin_centers) self.label_smoothing = label_smoothing @property def F(self): return getF(self.bin_log_probs) @property def support_min_max(self) -> Tuple[Tensor, Tensor]: F = self.F return ( F.broadcast_minimum( F.zeros(self.batch_shape), F.sign(F.min(self.bin_centers, axis=-1)), ) * MAX_SUPPORT_VAL, F.broadcast_maximum( F.zeros(self.batch_shape), F.sign(F.max(self.bin_centers, axis=-1)), ) * MAX_SUPPORT_VAL, ) @staticmethod def _compute_edges(F, bin_centers: Tensor) -> Tensor: r""" Computes the edges of the bins based on the centers. The first and last edge are set to :math:`10^{-10}` and :math:`10^{10}`, repsectively. Parameters ---------- F bin_centers Tensor of shape `(*batch_shape, num_bins)`. Returns ------- Tensor Tensor of shape (*batch.shape, num_bins+1) """ low = ( F.zeros_like(bin_centers.slice_axis(axis=-1, begin=0, end=1)) - 1.0e10 ) high = ( F.zeros_like(bin_centers.slice_axis(axis=-1, begin=0, end=1)) + 1.0e10 ) means = ( F.broadcast_add( bin_centers.slice_axis(axis=-1, begin=1, end=None), bin_centers.slice_axis(axis=-1, begin=0, end=-1), ) / 2.0 ) return F.concat(low, means, high, dim=-1) @property def bin_probs(self): if self._bin_probs is None: self._bin_probs = self.bin_log_probs.exp() return self._bin_probs @property def batch_shape(self) -> Tuple: return self.bin_log_probs.shape[:-1] @property def event_shape(self) -> Tuple: return () @property def event_dim(self) -> int: return 0 @property def mean(self): F = self.F return F.broadcast_mul(self.bin_probs, self.bin_centers).sum(axis=-1) @property def stddev(self): F = self.F ex2 = F.broadcast_mul(self.bin_probs, self.bin_centers.square()).sum( axis=-1 ) return F.broadcast_minus(ex2, self.mean.square()).sqrt() def _get_mask(self, x): F = self.F # TODO: when mxnet has searchsorted replace this left_edges = self.bin_edges.slice_axis(axis=-1, begin=0, end=-1) right_edges = self.bin_edges.slice_axis(axis=-1, begin=1, end=None) mask = F.broadcast_mul( F.broadcast_lesser_equal(left_edges, x), F.broadcast_lesser(x, right_edges), ) return mask @staticmethod def _smooth_mask(F, mask, alpha): return F.broadcast_add( F.broadcast_mul(mask, F.broadcast_sub(F.ones_like(alpha), alpha)), F.broadcast_mul(F.softmax(F.ones_like(mask)), alpha), )
[docs] def smooth_ce_loss(self, x): """ Cross-entropy loss with a "smooth" label. """ assert self.label_smoothing is not None F = self.F x = x.expand_dims(axis=-1) mask = self._get_mask(x) alpha = F.full(shape=(1,), val=self.label_smoothing) smooth_mask = self._smooth_mask(F, mask, alpha) return -F.broadcast_mul(self.bin_log_probs, smooth_mask).sum(axis=-1)
[docs] def log_prob(self, x): F = self.F x = x.expand_dims(axis=-1) mask = self._get_mask(x) return F.broadcast_mul(self.bin_log_probs, mask).sum(axis=-1)
[docs] def cdf(self, x: Tensor) -> Tensor: F = self.F x = x.expand_dims(axis=-1) # left_edges = self.bin_edges.slice_axis(axis=-1, begin=0, end=-1) mask = F.broadcast_lesser_equal(self.bin_centers, x) return F.broadcast_mul(self.bin_probs, mask).sum(axis=-1)
[docs] def loss(self, x: Tensor) -> Tensor: return ( self.smooth_ce_loss(x) if self.label_smoothing else -self.log_prob(x) )
[docs] def quantile(self, level: Tensor) -> Tensor: F = self.F # self.bin_probs.shape = (batch_shape, num_bins) probs = self.bin_probs.transpose() # (num_bins, batch_shape.T) # (batch_shape) zeros_batch_size = F.zeros_like( F.slice_axis(self.bin_probs, axis=-1, begin=0, end=1).squeeze( axis=-1 ) ) level = level.expand_dims(axis=0) # cdf shape (batch_size.T, levels) zeros_cdf = F.broadcast_add( zeros_batch_size.transpose().expand_dims(axis=-1), level.zeros_like(), ) start_state = (zeros_cdf, zeros_cdf.astype("int32")) def step(p, state): cdf, idx = state cdf = F.broadcast_add(cdf, p.expand_dims(axis=-1)) idx = F.where(F.broadcast_greater(cdf, level), idx, idx + 1) return zeros_batch_size, (cdf, idx) _, states = F.contrib.foreach(step, probs, start_state) _, idx = states # idx.shape = (batch.T, levels) # centers.shape = (batch, num_bins) # # expand centers to shape -> (levels, batch, num_bins) # so we can use pick with idx.T.shape = (levels, batch) # # zeros_cdf.shape (batch.T, levels) centers_expanded = F.broadcast_add( self.bin_centers.transpose().expand_dims(axis=-1), zeros_cdf.expand_dims(axis=0), ).transpose() # centers_expanded.shape = (levels, batch, num_bins) # idx.shape (batch.T, levels) a = centers_expanded.pick(idx.transpose(), axis=-1) return a
[docs] def sample(self, num_samples=None, dtype=np.float32): def s(bin_probs): F = self.F indices = F.sample_multinomial(bin_probs) if num_samples is None: return self.bin_centers.pick(indices, -1).reshape_like( F.zeros_like(indices.astype("float32")) ) else: return F.repeat( F.expand_dims(self.bin_centers, axis=0), repeats=num_samples, axis=0, ).pick(indices, -1) return _sample_multiple(s, self.bin_probs, num_samples=num_samples)
@property def args(self) -> List: return [self.bin_log_probs, self.bin_centers]
[docs]class BinnedArgs(gluon.HybridBlock): def __init__( self, num_bins: int, bin_centers: mx.nd.NDArray, **kwargs ) -> None: super().__init__(**kwargs) self.num_bins = num_bins with self.name_scope(): self.bin_centers = self.params.get_constant( "bin_centers", bin_centers ) # needs to be named self.proj for consistency with the # ArgProj class and the inference tests self.proj = gluon.nn.HybridSequential() self.proj.add( gluon.nn.Dense( self.num_bins, prefix="binproj", flatten=False, weight_initializer=mx.init.Xavier(), ) ) self.proj.add(gluon.nn.HybridLambda("log_softmax"))
[docs] def hybrid_forward( self, F, x: Tensor, bin_centers: Tensor ) -> Tuple[Tensor, Tensor]: ps = self.proj(x) reshaped_probs = ps.reshape(shape=(-2, -1, self.num_bins), reverse=1) bin_centers = F.broadcast_add(bin_centers, ps.zeros_like()) return reshaped_probs, bin_centers
[docs]class BinnedOutput(DistributionOutput): distr_cls: type = Binned @validated() def __init__( self, bin_centers: mx.nd.NDArray, label_smoothing: Optional[float] = None, ) -> None: assert label_smoothing is None or (0 <= label_smoothing < 1), ( "Smoothing factor should be less than 1 and greater than or equal" " to 0." ) super().__init__(self) self.bin_centers = bin_centers self.num_bins = self.bin_centers.shape[0] self.label_smoothing = label_smoothing assert len(self.bin_centers.shape) == 1
[docs] def get_args_proj(self, *args, **kwargs) -> gluon.nn.HybridBlock: return BinnedArgs(self.num_bins, self.bin_centers)
@staticmethod def _scale_bin_centers(F, bin_centers, loc=None, scale=None): if scale is not None: bin_centers = F.broadcast_mul( bin_centers, scale.expand_dims(axis=-1) ) if loc is not None: bin_centers = F.broadcast_add( bin_centers, loc.expand_dims(axis=-1) ) return bin_centers
[docs] def distribution(self, args, loc=None, scale=None) -> Binned: probs = args[0] bin_centers = args[1] F = getF(probs) bin_centers = F.broadcast_mul(bin_centers, F.ones_like(probs)) bin_centers = self._scale_bin_centers( F, bin_centers, loc=loc, scale=scale ) return Binned(probs, bin_centers, label_smoothing=self.label_smoothing)
@property def event_shape(self) -> Tuple: return ()