# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import Dict, List, Optional, Tuple
import numpy as np
from gluonts.core.component import validated
from gluonts.mx import Tensor
from gluonts.mx.util import cumsum
from .bijection import AffineTransformation, Bijection
from .distribution import Distribution, getF
from .distribution_output import DistributionOutput
from .transformed_distribution import TransformedDistribution
[docs]class ISQF(Distribution):
r"""
Distribution class for the Incremental (Spline) Quantile Function in the
paper ``Learning Quantile Functions without Quantile Crossing for
Distribution-free Time Series Forecasting`` by Park, Robinson, Aubet, Kan,
Gasthaus, Wang.
Parameters
----------
spline_knots, spline_heights
Tensor parametrizing the x-positions (y-positions) of the spline knots
Shape: (*batch_shape, (num_qk-1), num_pieces)
qk_x, qk_y
Tensor containing the increasing x-positions (y-positions) of the
quantile knots, Shape: (*batch_shape, num_qk)
beta_l, beta_r
Tensor containing the non-negative learnable parameter of the
left (right) tail, Shape: (*batch_shape,)
"""
is_reparameterizable = False
@validated()
def __init__(
self,
spline_knots: Tensor,
spline_heights: Tensor,
beta_l: Tensor,
beta_r: Tensor,
qk_y: Tensor,
qk_x: Tensor,
num_qk: int,
num_pieces: int,
tol: float = 1e-4,
) -> None:
self.num_qk, self.num_pieces = num_qk, num_pieces
self.spline_knots, self.spline_heights = spline_knots, spline_heights
self.beta_l, self.beta_r = beta_l, beta_r
self.qk_y_all = qk_y
self.tol = tol
F = self.F
# Get quantile knots (qk) parameters
(
self.qk_x,
self.qk_x_plus,
self.qk_x_l,
self.qk_x_r,
) = ISQF.parametrize_qk(F, qk_x)
(
self.qk_y,
self.qk_y_plus,
self.qk_y_l,
self.qk_y_r,
) = ISQF.parametrize_qk(F, qk_y)
# Get spline knots (sk) parameters
self.sk_y, self.delta_sk_y = ISQF.parametrize_spline(
F,
self.spline_heights,
self.qk_y,
self.qk_y_plus,
self.num_pieces,
self.tol,
)
self.sk_x, self.delta_sk_x = ISQF.parametrize_spline(
F,
self.spline_knots,
self.qk_x,
self.qk_x_plus,
self.num_pieces,
self.tol,
)
if self.num_pieces > 1:
self.sk_x_plus = F.concat(
F.slice_axis(self.sk_x, axis=-1, begin=1, end=None),
F.expand_dims(self.qk_x_plus, axis=-1),
dim=-1,
)
else:
self.sk_x_plus = F.expand_dims(self.qk_x_plus, axis=-1)
# Get tails parameters
self.tail_al, self.tail_bl = ISQF.parametrize_tail(
F, self.beta_l, self.qk_x_l, self.qk_y_l
)
self.tail_ar, self.tail_br = ISQF.parametrize_tail(
F, -self.beta_r, 1 - self.qk_x_r, self.qk_y_r
)
@property
def F(self):
return getF(self.beta_l)
@property
def args(self) -> List:
return [
self.spline_knots,
self.spline_heights,
self.beta_l,
self.beta_r,
self.qk_y_all,
]
[docs] @staticmethod
def parametrize_qk(
F, quantile_knots: Tensor
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
r"""
Function to parametrize the x or y positions of the num_qk quantile
knots.
Parameters
----------
quantile_knots
x or y positions of the quantile knots
shape: (*batch_shape, num_qk)
Returns
-------
qk
x or y positions of the quantile knots (qk),
with index=1, ..., num_qk-1,
shape: (*batch_shape, num_qk-1)
qk_plus
x or y positions of the quantile knots (qk),
with index=2, ..., num_qk,
shape: (*batch_shape, num_qk-1)
qk_l
x or y positions of the left-most quantile knot (qk),
shape: (*batch_shape)
qk_r
x or y positions of the right-most quantile knot (qk),
shape: (*batch_shape)
"""
qk = F.slice_axis(quantile_knots, axis=-1, begin=0, end=-1)
qk_plus = F.slice_axis(quantile_knots, axis=-1, begin=1, end=None)
qk_l = F.slice_axis(quantile_knots, axis=-1, begin=0, end=1).squeeze(
axis=-1
)
qk_r = F.slice_axis(
quantile_knots, axis=-1, begin=-1, end=None
).squeeze(axis=-1)
return qk, qk_plus, qk_l, qk_r
[docs] @staticmethod
def parametrize_spline(
F,
spline_knots: Tensor,
qk: Tensor,
qk_plus: Tensor,
num_pieces: int,
tol: float = 1e-4,
) -> Tuple[Tensor, Tensor]:
r"""
Function to parametrize the x or y positions of the spline knots
Parameters
----------
spline_knots
variable that parameterizes the spline knot positions
qk
x or y positions of the quantile knots (qk),
with index=1, ..., num_qk-1,
shape: (*batch_shape, num_qk-1)
qk_plus
x or y positions of the quantile knots (qk),
with index=2, ..., num_qk,
shape: (*batch_shape, num_qk-1)
num_pieces
number of spline knot pieces
tol
tolerance hyperparameter for numerical stability
Returns
-------
sk
x or y positions of the spline knots (sk),
shape: (*batch_shape, num_qk-1, num_pieces)
delta_sk
difference of x or y positions of the spline knots (sk),
shape: (*batch_shape, num_qk-1, num_pieces)
"""
# The spacing between spline knots is parametrized
# by softmax function (in [0,1] and sum to 1)
# We add tol to prevent overflow in computing 1/spacing in spline CRPS
# After adding tol, it is normalized by
# (1 + num_pieces * tol) to keep the sum-to-1 property
delta_x = (F.softmax(spline_knots) + tol) / (1 + num_pieces * tol)
# TODO: update to mxnet cumsum when it supports axis=-1
x = cumsum(F, delta_x, exclusive=True)
qk = F.expand_dims(qk, axis=-1)
qk_plus = F.expand_dims(qk_plus, axis=-1)
sk = F.broadcast_add(F.broadcast_mul(x, (qk_plus - qk)), qk)
delta_sk = F.broadcast_mul(delta_x, (qk_plus - qk))
return sk, delta_sk
[docs] @staticmethod
def parametrize_tail(
F, beta: Tensor, qk_x: Tensor, qk_y: Tensor
) -> Tuple[Tensor, Tensor]:
r"""
Function to parametrize the tail parameters.
Note that the exponential tails are given by
q(alpha)
= a_l log(alpha) + b_l if left tail
= a_r log(1-alpha) + b_r if right tail
where
a_l=1/beta_l, b_l=-a_l*log(qk_x_l)+q(qk_x_l)
a_r=1/beta_r, b_r=a_r*log(1-qk_x_r)+q(qk_x_r)
Parameters
----------
beta
parameterizes the left or right tail, shape: (*batch_shape,)
qk_x
left- or right-most x-positions of the quantile knots,
shape: (*batch_shape,)
qk_y
left- or right-most y-positions of the quantile knots,
shape: (*batch_shape,)
Returns
-------
tail_a
a_l or a_r as described above
tail_b
b_l or b_r as described above
"""
tail_a = 1 / beta
tail_b = -tail_a * F.log(qk_x) + qk_y
return tail_a, tail_b
[docs] def quantile(self, input_alpha: Tensor) -> Tensor:
return self.quantile_internal(input_alpha, axis=0)
[docs] def quantile_internal(
self, alpha: Tensor, axis: Optional[int] = None
) -> Tensor:
r"""
Evaluates the quantile function at the quantile levels input_alpha
Parameters
----------
alpha
Tensor of shape = (*batch_shape,) if axis=None, or containing an
additional axis on the specified position, otherwise
axis
Index of the axis containing the different quantile levels which
are to be computed.
Read the description below for detailed information
Returns
-------
Tensor
Quantiles tensor, of the same shape as alpha
"""
F = self.F
qk_x, qk_x_l, qk_x_plus = self.qk_x, self.qk_x_l, self.qk_x_plus
# The following describes the parameters reshaping in
# quantile_internal, quantile_spline and quantile_tail
# tail parameters: tail_al, tail_ar, tail_bl, tail_br,
# shape = (*batch_shape,)
# spline parameters: sk_x, sk_x_plus, sk_y, sk_y_plus,
# shape = (*batch_shape, num_qk-1, num_pieces)
# quantile knots parameters: qk_x, qk_x_plus, qk_y, qk_y_plus,
# shape = (*batch_shape, num_qk-1)
# axis=None - passed at inference when num_samples is None
# shape of input_alpha = (*batch_shape,), will be expanded to
# (*batch_shape, 1, 1) to perform operation
# The shapes of parameters are as described above,
# no reshaping is needed
# axis=0 - passed at inference when num_samples is not None
# shape of input_alpha = (num_samples, *batch_shape)
# it will be expanded to
# (num_samples, *batch_shape, 1, 1) to perform operation
#
# The shapes of tail parameters
# should be (num_samples, *batch_shape)
#
# The shapes of spline parameters
# should be (num_samples, *batch_shape, num_qk-1, num_pieces)
#
# The shapes of quantile knots parameters
# should be (num_samples, *batch_shape, num_qk-1)
#
# We expand axis=0 for all of them
# axis=-2 - passed at training when we evaluate quantiles at
# spline knots in order to compute alpha_tilde
#
# This is only for the quantile_spline function
# shape of input_alpha = (*batch_shape, num_qk-1, num_pieces)
# it will be expanded to
# (*batch_shape, num_qk-1, num_pieces, 1) to perform operation
#
# The shapes of spline and quantile knots parameters should be
# (*batch_shape, num_qk-1, 1, num_pieces)
# and (*batch_shape, num_qk-1, 1), respectively
#
# We expand axis=-2 and axis=-1 for
# spline and quantile knots parameters, respectively
if axis is not None:
qk_x_l = F.expand_dims(qk_x_l, axis=axis)
qk_x = F.expand_dims(qk_x, axis=axis)
qk_x_plus = F.expand_dims(qk_x_plus, axis=axis)
quantile = F.where(
F.broadcast_lesser(alpha, qk_x_l),
self.quantile_tail(alpha, axis=axis, left_tail=True),
self.quantile_tail(alpha, axis=axis, left_tail=False),
)
spline_val = self.quantile_spline(alpha, axis=axis)
for spline_idx in range(self.num_qk - 1):
is_in_between = F.broadcast_logical_and(
F.broadcast_lesser_equal(
F.slice_axis(
qk_x, axis=-1, begin=spline_idx, end=spline_idx + 1
).squeeze(-1),
alpha,
),
F.broadcast_lesser(
alpha,
F.slice_axis(
qk_x_plus,
axis=-1,
begin=spline_idx,
end=spline_idx + 1,
).squeeze(-1),
),
)
quantile = F.where(
is_in_between,
F.slice_axis(
spline_val, axis=-1, begin=spline_idx, end=spline_idx + 1
).squeeze(-1),
quantile,
)
return quantile
[docs] def quantile_spline(
self,
alpha: Tensor,
axis: Optional[int] = None,
) -> Tensor:
r"""
Evaluates the spline functions at the quantile levels contained in
alpha.
Parameters
----------
alpha
Input quantile levels
axis
Axis along which to expand
For details of input_alpha shape and axis,
refer to the description in quantile_internal
Returns
-------
Tensor
Quantiles tensor
with shape
= (*batch_shape, num_qk-1) if axis = None
= (1, *batch_shape, num_qk-1) if axis = 0
= (*batch_shape, num_qk-1, num_pieces) if axis = -2
"""
F = self.F
qk_y = self.qk_y
sk_x, delta_sk_x, delta_sk_y = (
self.sk_x,
self.delta_sk_x,
self.delta_sk_y,
)
if axis is not None:
qk_y = F.expand_dims(qk_y, axis=0 if axis == 0 else -1)
sk_x = F.expand_dims(sk_x, axis=axis)
delta_sk_x = F.expand_dims(delta_sk_x, axis=axis)
delta_sk_y = F.expand_dims(delta_sk_y, axis=axis)
if axis is None or axis == 0:
alpha = F.expand_dims(alpha, axis=-1)
alpha = F.expand_dims(alpha, axis=-1)
spline_val = F.broadcast_div(F.broadcast_sub(alpha, sk_x), delta_sk_x)
spline_val = F.maximum(
F.minimum(spline_val, F.ones_like(spline_val)),
F.zeros_like(spline_val),
)
return F.broadcast_add(
qk_y,
F.sum(
F.broadcast_mul(spline_val, delta_sk_y),
axis=-1,
keepdims=False,
),
)
[docs] def quantile_tail(
self,
alpha: Tensor,
axis: Optional[int] = None,
left_tail: bool = True,
) -> Tensor:
r"""
Evaluates the tail functions at the quantile levels contained in alpha
Parameters
----------
alpha
Input quantile levels
axis
Axis along which to expand
For details of input_alpha shape and axis,
refer to the description in quantile_internal
left_tail
If True, compute the quantile for the left tail
Otherwise, compute the quantile for the right tail
Returns
-------
Tensor
Quantiles tensor, of the same shape as alpha
"""
F = self.F
if left_tail:
tail_a, tail_b = self.tail_al, self.tail_bl
else:
tail_a, tail_b = self.tail_ar, self.tail_br
alpha = 1 - alpha
if axis is not None:
tail_a, tail_b = (
F.expand_dims(tail_a, axis=axis),
F.expand_dims(tail_b, axis=axis),
)
return F.broadcast_add(F.broadcast_mul(tail_a, F.log(alpha)), tail_b)
[docs] def cdf_spline(self, z: Tensor) -> Tensor:
r"""
For observations z and splines defined in [qk_x[k], qk_x[k+1]] Computes
the quantile level alpha_tilde such that alpha_tilde.
= q^{-1}(z) if z is in-between qk_x[k] and qk_x[k+1]
= qk_x[k] if z<qk_x[k]
= qk_x[k+1] if z>qk_x[k+1]
Parameters
----------
z
Observation, shape = (*batch_shape,)
Returns
-------
alpha_tilde
Corresponding quantile level, shape = (*batch_shape, num_qk-1)
"""
F = self.F
qk_y, qk_y_plus = self.qk_y, self.qk_y_plus
qk_x, qk_x_plus = self.qk_x, self.qk_x_plus
sk_x, delta_sk_x, delta_sk_y = (
self.sk_x,
self.delta_sk_x,
self.delta_sk_y,
)
z_expand = F.expand_dims(z, axis=-1)
if self.num_pieces > 1:
qk_y_expand = F.expand_dims(qk_y, axis=-1)
z_expand_twice = F.expand_dims(z_expand, axis=-1)
knots_eval = self.quantile_spline(sk_x, axis=-2)
# Compute \sum_{s=0}^{s_0-1} \Delta sk_y[s],
# where \Delta sk_y[s] = (sk_y[s+1]-sk_y[s])
mask_sum_s0 = F.broadcast_lesser(knots_eval, z_expand_twice)
mask_sum_s0_minus = F.concat(
F.slice_axis(mask_sum_s0, axis=-1, begin=1, end=None),
F.zeros_like(qk_y_expand),
dim=-1,
)
sum_delta_sk_y = F.sum(
F.broadcast_mul(mask_sum_s0_minus, delta_sk_y),
axis=-1,
keepdims=False,
)
mask_s0_only = mask_sum_s0 - mask_sum_s0_minus
# Compute (sk_x[s_0+1]-sk_x[s_0])/(sk_y[s_0+1]-sk_y[s_0])
frac_s0 = F.sum(
(mask_s0_only * delta_sk_x) / delta_sk_y,
axis=-1,
keepdims=False,
)
# Compute sk_x_{s_0}
sk_x_s0 = F.sum(mask_s0_only * sk_x, axis=-1, keepdims=False)
# Compute alpha_tilde
alpha_tilde = (
sk_x_s0
+ (F.broadcast_sub(z_expand, qk_y) - sum_delta_sk_y) * frac_s0
)
else:
# num_pieces=1, ISQF reduces to IQF
alpha_tilde = qk_x + F.broadcast_sub(z_expand, qk_y) / (
qk_y_plus - qk_y
) * (qk_x_plus - qk_x)
alpha_tilde = F.broadcast_minimum(
F.broadcast_maximum(alpha_tilde, qk_x), qk_x_plus
)
return alpha_tilde
[docs] def cdf_tail(self, z: Tensor, left_tail: bool = True) -> Tensor:
r"""
Computes the quantile level alpha_tilde such that alpha_tilde.
= q^{-1}(z) if z is in the tail region
= qk_x_l or qk_x_r if z is in the non-tail region
Parameters
----------
z
Observation, shape = (*batch_shape,)
left_tail
If True, compute alpha_tilde for the left tail
Otherwise, compute alpha_tilde for the right tail
Returns
-------
alpha_tilde
Corresponding quantile level, shape = (*batch_shape,)
"""
F = self.F
if left_tail:
tail_a, tail_b, qk_x = self.tail_al, self.tail_bl, self.qk_x_l
else:
tail_a, tail_b, qk_x = self.tail_ar, self.tail_br, 1 - self.qk_x_r
log_alpha_tilde = F.minimum((z - tail_b) / tail_a, F.log(qk_x))
alpha_tilde = F.exp(log_alpha_tilde)
return alpha_tilde if left_tail else 1 - alpha_tilde
[docs] def crps_tail(self, z: Tensor, left_tail: bool = True) -> Tensor:
r"""
Compute CRPS in analytical form for left/right tails.
Parameters
----------
z
Observation to evaluate. shape = (*batch_shape,)
left_tail
If True, compute CRPS for the left tail
Otherwise, compute CRPS for the right tail
Returns
-------
Tensor
Tensor containing the CRPS, of the same shape as z
"""
F = self.F
alpha_tilde = self.cdf_tail(z, left_tail=left_tail)
if left_tail:
tail_a, tail_b, qk_x, qk_y = (
self.tail_al,
self.tail_bl,
self.qk_x_l,
self.qk_y_l,
)
term1 = (z - tail_b) * (qk_x**2 - 2 * qk_x + 2 * alpha_tilde)
term2 = qk_x**2 * tail_a * (-F.log(qk_x) + 0.5)
term2 = term2 + 2 * F.where(
z < qk_y,
qk_x * tail_a * (F.log(qk_x) - 1)
+ alpha_tilde * (-z + tail_b + tail_a),
F.zeros_like(qk_x),
)
else:
tail_a, tail_b, qk_x, qk_y = (
self.tail_ar,
self.tail_br,
self.qk_x_r,
self.qk_y_r,
)
term1 = (z - tail_b) * (-1 - qk_x**2 + 2 * alpha_tilde)
term2 = tail_a * (
-0.5 * (qk_x + 1) ** 2
+ (qk_x**2 - 1) * F.log(1 - qk_x)
+ 2 * alpha_tilde
)
term2 = term2 + 2 * F.where(
z > qk_y,
(1 - alpha_tilde) * (z - tail_b),
tail_a * (1 - qk_x) * F.log(1 - qk_x),
)
return term1 + term2
[docs] def crps_spline(self, z: Tensor) -> Tensor:
r"""
Compute CRPS in analytical form for the spline
Parameters
----------
z
Observation to evaluate. shape = (*batch_shape,)
Returns
-------
Tensor
Tensor containing the CRPS, of the same shape as z
"""
F = self.F
qk_x, qk_x_plus, qk_y = self.qk_x, self.qk_x_plus, self.qk_y
sk_x, sk_x_plus = self.sk_x, self.sk_x_plus
delta_sk_x, delta_sk_y = self.delta_sk_x, self.delta_sk_y
z_expand, qk_x_plus_expand = (
F.expand_dims(z, axis=-1),
F.expand_dims(qk_x_plus, axis=-1),
)
alpha_tilde = self.cdf_spline(z)
alpha_tilde_expand = F.expand_dims(alpha_tilde, axis=-1)
r = F.broadcast_minimum(
F.broadcast_maximum(alpha_tilde_expand, sk_x), sk_x_plus
)
coeff1 = (
-2 / 3 * sk_x_plus**3
+ sk_x * sk_x_plus**2
+ sk_x_plus**2
- (1 / 3) * sk_x**3
- 2 * sk_x * sk_x_plus
- r**2
+ 2 * sk_x * r
)
coeff2 = F.broadcast_add(
-2 * F.broadcast_maximum(alpha_tilde_expand, sk_x_plus)
+ sk_x_plus**2,
2 * qk_x_plus_expand - qk_x_plus_expand**2,
)
result = (
(qk_x_plus**2 - qk_x**2) * F.broadcast_sub(z_expand, qk_y)
+ 2
* F.broadcast_sub(qk_x_plus, alpha_tilde)
* F.broadcast_sub(qk_y, z_expand)
+ F.sum(
(delta_sk_y / delta_sk_x) * coeff1, axis=-1, keepdims=False
)
+ F.sum(delta_sk_y * coeff2, axis=-1, keepdims=False)
)
return F.sum(result, axis=-1, keepdims=False)
[docs] def loss(self, z: Tensor) -> Tensor:
return self.crps(z)
[docs] def crps(self, z: Tensor) -> Tensor:
r"""
Compute CRPS in analytical form
Parameters
----------
z
Observation to evaluate. Shape = (*batch_shape,)
Returns
-------
Tensor
Tensor containing the CRPS, of the same shape as z
"""
crps_lt = self.crps_tail(z, left_tail=True)
crps_rt = self.crps_tail(z, left_tail=False)
return crps_lt + crps_rt + self.crps_spline(z)
[docs] def cdf(self, z: Tensor) -> Tensor:
r"""
Computes the quantile level alpha_tilde such that
q(alpha_tilde) = z
Parameters
----------
z
Tensor of shape = (*batch_shape,)
Returns
-------
alpha_tilde
Tensor of shape = (*batch_shape,)
"""
F = self.F
qk_y, qk_y_l, qk_y_plus = self.qk_y, self.qk_y_l, self.qk_y_plus
alpha_tilde = F.where(
z < qk_y_l,
self.cdf_tail(z, left_tail=True),
self.cdf_tail(z, left_tail=False),
)
spline_alpha_tilde = self.cdf_spline(z)
for i in range(self.num_qk - 1):
is_in_between = F.broadcast_logical_and(
F.slice_axis(qk_y, axis=-1, begin=i, end=i + 1).squeeze(-1)
<= z,
z
< F.slice_axis(qk_y_plus, axis=-1, begin=i, end=i + 1).squeeze(
-1
),
)
alpha_tilde = F.where(
is_in_between,
F.slice_axis(
spline_alpha_tilde, axis=-1, begin=i, end=i + 1
).squeeze(-1),
alpha_tilde,
)
return alpha_tilde
[docs] def sample(
self, num_samples: Optional[int] = None, dtype=np.float32
) -> Tensor:
r"""
Function used to draw random samples
Parameters
----------
num_samples
number of samples
dtype
data type
Returns
-------
Tensor
Tensor of shape (*batch_shape,) if num_samples = None
else (num_samples, *batch_shape)
"""
F = self.F
# if num_samples=None then input_alpha should have the same shape
# as beta_l, i.e., (*batch_shape,)
# else u should be (num_samples, *batch_shape)
alpha = F.random.uniform_like(
data=(
self.beta_l
if num_samples is None
else self.beta_l.expand_dims(axis=0).repeat(
axis=0, repeats=num_samples
)
)
)
sample = self.quantile(alpha)
if num_samples is None:
sample = F.squeeze(sample, axis=0)
return sample
@property
def batch_shape(self) -> Tuple:
return self.beta_l.shape
@property
def event_shape(self) -> Tuple:
return ()
@property
def event_dim(self) -> int:
return 0
[docs]class ISQFOutput(DistributionOutput):
r"""
DistributionOutput class for the Incremental (Spline) Quantile Function.
Parameters
----------
num_pieces
number of spline pieces for each spline
ISQF reduces to IQF when num_pieces = 1
alpha
Tensor containing the x-positions of quantile knots
tol
tolerance for numerical safeguarding
"""
distr_cls: type = ISQF
@validated()
def __init__(
self, num_pieces: int, qk_x: List[float], tol: float = 1e-4
) -> None:
# ISQF reduces to IQF when num_pieces = 1
super().__init__(self)
assert (
isinstance(num_pieces, int) and num_pieces > 0
), "num_pieces should be an integer and greater than 0"
self.num_pieces = num_pieces
self.qk_x = sorted(qk_x)
self.num_qk = len(qk_x)
self.tol = tol
self.args_dim: Dict[str, int] = {
"spline_knots": (self.num_qk - 1) * num_pieces,
"spline_heights": (self.num_qk - 1) * num_pieces,
"beta_l": 1,
"beta_r": 1,
"quantile_knots": self.num_qk,
}
[docs] @classmethod
def domain_map(
cls,
F,
*args: Tensor,
tol: float = 1e-4,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
"""
Domain map function The inputs of this function are specified by
self.args_dim knots, heights:
parameterizing the x-/ y-positions of the spline knots,
shape = (*batch_shape, (num_qk-1)*num_pieces)
q:
parameterizing the y-positions of the quantile knots,
shape = (*batch_shape, num_qk)
beta_l, beta_r:
parameterizing the left/right tail, shape = (*batch_shape, 1)
"""
try:
spline_knots, spline_heights, beta_l, beta_r, quantile_knots = args
except ValueError:
raise ValueError(
"Failed to unpack args of domain_map. Double check your input."
)
# Add tol to prevent the y-distance of
# two quantile knots from being too small
#
# Because in this case the spline knots could be squeezed together
# and cause overflow in spline CRPS computation
qk_y = F.concat(
F.slice_axis(quantile_knots, axis=-1, begin=0, end=1),
F.abs(F.slice_axis(quantile_knots, axis=-1, begin=1, end=None))
+ tol,
dim=-1,
)
# TODO: update to mxnet cumsum when it supports axis=-1
qk_y = cumsum(F, qk_y)
# Prevent overflow when we compute 1/beta
beta_l, beta_r = (
F.abs(beta_l.squeeze(axis=-1)) + tol,
F.abs(beta_r.squeeze(axis=-1)) + tol,
)
return spline_knots, spline_heights, beta_l, beta_r, qk_y
[docs] def distribution(
self,
distr_args,
loc: Optional[Tensor] = None,
scale: Optional[Tensor] = None,
) -> ISQF:
"""
function outputing the distribution class
distr_args: distribution arguments
loc: shift to the data mean
scale: scale to the data
"""
distr_args, qk_x = self.reshape_spline_args(distr_args, self.qk_x)
if scale is None:
return self.distr_cls(
*distr_args, qk_x, self.num_qk, self.num_pieces, self.tol
)
else:
distr = self.distr_cls(
*distr_args, qk_x, self.num_qk, self.num_pieces, self.tol
)
return TransformedISQF(
distr, [AffineTransformation(loc=loc, scale=scale)]
)
[docs] def reshape_spline_args(self, distr_args, qk_x):
"""
auxiliary function reshaping knots and heights to (*batch_shape,
num_qk-1, num_pieces) alpha to (*batch_shape, num_qk)
"""
spline_knots, spline_heights = distr_args[0], distr_args[1]
beta_l = distr_args[2]
qk_y = distr_args[4]
F = getF(beta_l)
# FIXME number 1
# Convert alpha from list of len=num_qk to
# Tensor of shape (*batch_shape, num_qk)
#
# For example, if alpha = [0.1, 0.5, 0.9],
# then alpha_reshape will be a Tensor of shape (*batch_shape, 3)
# with the last dimension being [0.1, 0.5, 0.9]
#
# In PyTorch, it would be torch.tensor(alpha).repeat(*batch_shape,1)
qk_x_reshape = F.concat(
*[
F.expand_dims(F.ones_like(beta_l), axis=-1) * qk_x[i]
for i in range(self.num_qk)
],
dim=-1,
)
# FIXME number 2
# knots and heights have shape (*batch_shape, (num_qk-1)*num_pieces)
# I want to convert the shape to (*batch_shape, (num_qk-1), num_pieces)
# Here I make a shape_holder with target_shape, and use reshape_like
# create a shape holder of shape (*batch_shape, num_qk-1, num_pieces)
shape_holder = F.repeat(
F.expand_dims(
F.slice_axis(qk_y, axis=-1, begin=0, end=-1), axis=-1
),
repeats=self.num_pieces,
axis=-1,
)
spline_knots_reshape = F.reshape_like(spline_knots, shape_holder)
spline_heights_reshape = F.reshape_like(spline_heights, shape_holder)
distr_args_reshape = (
spline_knots_reshape,
spline_heights_reshape,
) + distr_args[2:]
return distr_args_reshape, qk_x_reshape
@property
def event_shape(self) -> Tuple:
return ()