# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
from torch.distributions import AffineTransform, TransformedDistribution
from gluonts.core.component import validated
from .distribution_output import DistributionOutput
[docs]class ISQF(torch.distributions.Distribution):
r"""
Distribution class for the Incremental (Spline) Quantile Function in the
paper ``Learning Quantile Functions without Quantile Crossing for
Distribution-free Time Series Forecasting``
by Park, Robinson, Aubet, Kan, Gasthaus, Wang
Parameters
----------
spline_knots, spline_heights
Tensor parametrizing the x-positions (y-positions) of the spline knots
Shape: (*batch_shape, (num_qk-1), num_pieces)
qk_x, qk_y
Tensor containing the increasing x-positions (y-positions) of the
quantile knots, Shape: (*batch_shape, num_qk)
beta_l, beta_r
Tensor containing the non-negative learnable parameter of the
left (right) tail, Shape: (*batch_shape,)
"""
def __init__(
self,
spline_knots: torch.Tensor,
spline_heights: torch.Tensor,
beta_l: torch.Tensor,
beta_r: torch.Tensor,
qk_y: torch.Tensor,
qk_x: torch.Tensor,
tol: float = 1e-4,
validate_args: bool = False,
) -> None:
self.num_qk, self.num_pieces = qk_y.shape[-1], spline_knots.shape[-1]
self.spline_knots, self.spline_heights = spline_knots, spline_heights
self.beta_l, self.beta_r = beta_l, beta_r
self.qk_y_all = qk_y
self.tol = tol
super().__init__(
batch_shape=self.batch_shape, validate_args=validate_args
)
# Get quantile knots (qk) parameters
(
self.qk_x,
self.qk_x_plus,
self.qk_x_l,
self.qk_x_r,
) = ISQF.parameterize_qk(qk_x)
(
self.qk_y,
self.qk_y_plus,
self.qk_y_l,
self.qk_y_r,
) = ISQF.parameterize_qk(qk_y)
# Get spline knots (sk) parameters
self.sk_y, self.delta_sk_y = ISQF.parameterize_spline(
self.spline_heights,
self.qk_y,
self.qk_y_plus,
self.tol,
)
self.sk_x, self.delta_sk_x = ISQF.parameterize_spline(
self.spline_knots,
self.qk_x,
self.qk_x_plus,
self.tol,
)
if self.num_pieces > 1:
self.sk_x_plus = torch.cat(
[self.sk_x[..., 1:], self.qk_x_plus.unsqueeze(dim=-1)], dim=-1
)
else:
self.sk_x_plus = self.qk_x_plus.unsqueeze(dim=-1)
# Get tails parameters
self.tail_al, self.tail_bl = ISQF.parameterize_tail(
self.beta_l, self.qk_x_l, self.qk_y_l
)
self.tail_ar, self.tail_br = ISQF.parameterize_tail(
-self.beta_r, 1 - self.qk_x_r, self.qk_y_r
)
[docs] @staticmethod
def parameterize_qk(
quantile_knots: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
r"""
Function to parameterize the x or y positions
of the num_qk quantile knots
Parameters
----------
quantile_knots
x or y positions of the quantile knots
shape: (*batch_shape, num_qk)
Returns
-------
qk
x or y positions of the quantile knots (qk),
with index=1, ..., num_qk-1,
shape: (*batch_shape, num_qk-1)
qk_plus
x or y positions of the quantile knots (qk),
with index=2, ..., num_qk,
shape: (*batch_shape, num_qk-1)
qk_l
x or y positions of the left-most quantile knot (qk),
shape: (*batch_shape)
qk_r
x or y positions of the right-most quantile knot (qk),
shape: (*batch_shape)
"""
qk, qk_plus = quantile_knots[..., :-1], quantile_knots[..., 1:]
qk_l, qk_r = quantile_knots[..., 0], quantile_knots[..., -1]
return qk, qk_plus, qk_l, qk_r
[docs] @staticmethod
def parameterize_spline(
spline_knots: torch.Tensor,
qk: torch.Tensor,
qk_plus: torch.Tensor,
tol: float = 1e-4,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Function to parameterize the x or y positions of the spline knots
Parameters
----------
spline_knots
variable that parameterizes the spline knot positions
qk
x or y positions of the quantile knots (qk),
with index=1, ..., num_qk-1,
shape: (*batch_shape, num_qk-1)
qk_plus
x or y positions of the quantile knots (qk),
with index=2, ..., num_qk,
shape: (*batch_shape, num_qk-1)
num_pieces
number of spline knot pieces
tol
tolerance hyperparameter for numerical stability
Returns
-------
sk
x or y positions of the spline knots (sk),
shape: (*batch_shape, num_qk-1, num_pieces)
delta_sk
difference of x or y positions of the spline knots (sk),
shape: (*batch_shape, num_qk-1, num_pieces)
"""
# The spacing between spline knots is parameterized
# by softmax function (in [0,1] and sum to 1)
# We add tol to prevent overflow in computing 1/spacing in spline CRPS
# After adding tol, it is normalized by
# (1 + num_pieces * tol) to keep the sum-to-1 property
num_pieces = spline_knots.shape[-1]
delta_x = (F.softmax(spline_knots, dim=-1) + tol) / (
1 + num_pieces * tol
)
zero_tensor = torch.zeros_like(
delta_x[..., 0:1]
) # 0:1 for keeping dimension
x = torch.cat(
[zero_tensor, torch.cumsum(delta_x, dim=-1)[..., :-1]], dim=-1
)
qk, qk_plus = qk.unsqueeze(dim=-1), qk_plus.unsqueeze(dim=-1)
sk = x * (qk_plus - qk) + qk
delta_sk = delta_x * (qk_plus - qk)
return sk, delta_sk
[docs] @staticmethod
def parameterize_tail(
beta: torch.Tensor, qk_x: torch.Tensor, qk_y: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Function to parameterize the tail parameters Note that the exponential
tails are given by q(alpha)
= a_l log(alpha) + b_l if left tail
= a_r log(1-alpha) + b_r if right tail
where
a_l=1/beta_l, b_l=-a_l*log(qk_x_l)+q(qk_x_l)
a_r=1/beta_r, b_r=a_r*log(1-qk_x_r)+q(qk_x_r)
Parameters
----------
beta
parameterizes the left or right tail, shape: (*batch_shape,)
qk_x
left- or right-most x-positions of the quantile knots,
shape: (*batch_shape,)
qk_y
left- or right-most y-positions of the quantile knots,
shape: (*batch_shape,)
Returns
-------
tail_a
a_l or a_r as described above
tail_b
b_l or b_r as described above
"""
tail_a = 1 / beta
tail_b = -tail_a * torch.log(qk_x) + qk_y
return tail_a, tail_b
[docs] def quantile(self, alpha: torch.Tensor) -> torch.Tensor:
return self.quantile_internal(alpha, dim=0)
[docs] def quantile_internal(
self, alpha: torch.Tensor, dim: Optional[int] = None
) -> torch.Tensor:
r"""
Evaluates the quantile function at the quantile levels input_alpha
Parameters
----------
alpha
Tensor of shape = (*batch_shape,) if axis=None, or containing an
additional axis on the specified position, otherwise
dim
Index of the axis containing the different quantile levels which
are to be computed.
Read the description below for detailed information
Returns
-------
Tensor
Quantiles tensor, of the same shape as alpha
"""
qk_x, qk_x_l, qk_x_plus = self.qk_x, self.qk_x_l, self.qk_x_plus
# The following describes the parameters reshaping in
# quantile_internal, quantile_spline and quantile_tail
# tail parameters: tail_al, tail_ar, tail_bl, tail_br,
# shape = (*batch_shape,)
# spline parameters: sk_x, sk_x_plus, sk_y, sk_y_plus,
# shape = (*batch_shape, num_qk-1, num_pieces)
# quantile knots parameters: qk_x, qk_x_plus, qk_y, qk_y_plus,
# shape = (*batch_shape, num_qk-1)
# dim=None - passed at inference when num_samples is None
# shape of input_alpha = (*batch_shape,), will be expanded to
# (*batch_shape, 1, 1) to perform operation
# The shapes of parameters are as described above,
# no reshaping is needed
# dim=0 - passed at inference when num_samples is not None
# shape of input_alpha = (num_samples, *batch_shape)
# it will be expanded to
# (num_samples, *batch_shape, 1, 1) to perform operation
#
# The shapes of tail parameters
# should be (num_samples, *batch_shape)
#
# The shapes of spline parameters
# should be (num_samples, *batch_shape, num_qk-1, num_pieces)
#
# The shapes of quantile knots parameters
# should be (num_samples, *batch_shape, num_qk-1)
#
# We expand at dim=0 for all of them
# dim=-2 - passed at training when we evaluate quantiles at
# spline knots in order to compute alpha_tilde
#
# This is only for the quantile_spline function
# shape of input_alpha = (*batch_shape, num_qk-1, num_pieces)
# it will be expanded to
# (*batch_shape, num_qk-1, num_pieces, 1) to perform operation
#
# The shapes of spline and quantile knots parameters should be
# (*batch_shape, num_qk-1, 1, num_pieces)
# and (*batch_shape, num_qk-1, 1), respectively
#
# We expand at dim=-2 and dim=-1 for
# spline and quantile knots parameters, respectively
if dim is not None:
qk_x_l = qk_x_l.unsqueeze(dim=dim)
qk_x = qk_x.unsqueeze(dim=dim)
qk_x_plus = qk_x_plus.unsqueeze(dim=dim)
quantile = torch.where(
alpha < qk_x_l,
self.quantile_tail(alpha, dim=dim, left_tail=True),
self.quantile_tail(alpha, dim=dim, left_tail=False),
)
spline_val = self.quantile_spline(alpha, dim=dim)
for spline_idx in range(self.num_qk - 1):
is_in_between = torch.logical_and(
qk_x[..., spline_idx] <= alpha,
alpha < qk_x_plus[..., spline_idx],
)
quantile = torch.where(
is_in_between,
spline_val[..., spline_idx],
quantile,
)
return quantile
[docs] def quantile_spline(
self,
alpha: torch.Tensor,
dim: Optional[int] = None,
) -> torch.Tensor:
# Refer to the description in quantile_internal
qk_y = self.qk_y
sk_x, delta_sk_x, delta_sk_y = (
self.sk_x,
self.delta_sk_x,
self.delta_sk_y,
)
if dim is not None:
qk_y = qk_y.unsqueeze(dim=0 if dim == 0 else -1)
sk_x = sk_x.unsqueeze(dim=dim)
delta_sk_x = delta_sk_x.unsqueeze(dim=dim)
delta_sk_y = delta_sk_y.unsqueeze(dim=dim)
if dim is None or dim == 0:
alpha = alpha.unsqueeze(dim=-1)
alpha = alpha.unsqueeze(dim=-1)
spline_val = (alpha - sk_x) / delta_sk_x
spline_val = torch.maximum(
torch.minimum(spline_val, torch.ones_like(spline_val)),
torch.zeros_like(spline_val),
)
return qk_y + torch.sum(spline_val * delta_sk_y, dim=-1)
[docs] def quantile_tail(
self,
alpha: torch.Tensor,
dim: Optional[int] = None,
left_tail: bool = True,
) -> torch.Tensor:
# Refer to the description in quantile_internal
if left_tail:
tail_a, tail_b = self.tail_al, self.tail_bl
else:
tail_a, tail_b = self.tail_ar, self.tail_br
alpha = 1 - alpha
if dim is not None:
tail_a, tail_b = tail_a.unsqueeze(dim=dim), tail_b.unsqueeze(
dim=dim
)
return tail_a * torch.log(alpha) + tail_b
[docs] def cdf_spline(self, z: torch.Tensor) -> torch.Tensor:
r"""
For observations z and splines defined in [qk_x[k], qk_x[k+1]] Computes
the quantile level alpha_tilde such that alpha_tilde.
= q^{-1}(z) if z is in-between qk_x[k] and qk_x[k+1]
= qk_x[k] if z<qk_x[k]
= qk_x[k+1] if z>qk_x[k+1]
Parameters
----------
z
Observation, shape = (*batch_shape,)
Returns
-------
alpha_tilde
Corresponding quantile level, shape = (*batch_shape, num_qk-1)
"""
qk_y, qk_y_plus = self.qk_y, self.qk_y_plus
qk_x, qk_x_plus = self.qk_x, self.qk_x_plus
sk_x, delta_sk_x, delta_sk_y = (
self.sk_x,
self.delta_sk_x,
self.delta_sk_y,
)
z_expand = z.unsqueeze(dim=-1)
if self.num_pieces > 1:
qk_y_expand = qk_y.unsqueeze(dim=-1)
z_expand_twice = z_expand.unsqueeze(dim=-1)
knots_eval = self.quantile_spline(sk_x, dim=-2)
# Compute \sum_{s=0}^{s_0-1} \Delta sk_y[s],
# where \Delta sk_y[s] = (sk_y[s+1]-sk_y[s])
mask_sum_s0 = torch.lt(knots_eval, z_expand_twice)
mask_sum_s0_minus = torch.cat(
[
mask_sum_s0[..., 1:],
torch.zeros_like(qk_y_expand, dtype=torch.bool),
],
dim=-1,
)
sum_delta_sk_y = torch.sum(mask_sum_s0_minus * delta_sk_y, dim=-1)
mask_s0_only = torch.logical_and(
mask_sum_s0, torch.logical_not(mask_sum_s0_minus)
)
# Compute (sk_x[s_0+1]-sk_x[s_0])/(sk_y[s_0+1]-sk_y[s_0])
frac_s0 = torch.sum(
(mask_s0_only * delta_sk_x) / delta_sk_y, dim=-1
)
# Compute sk_x_{s_0}
sk_x_s0 = torch.sum(mask_s0_only * sk_x, dim=-1)
# Compute alpha_tilde
alpha_tilde = (
sk_x_s0 + (z_expand - qk_y - sum_delta_sk_y) * frac_s0
)
else:
# num_pieces=1, ISQF reduces to IQF
alpha_tilde = qk_x + (z_expand - qk_y) / (qk_y_plus - qk_y) * (
qk_x_plus - qk_x
)
alpha_tilde = torch.minimum(
torch.maximum(alpha_tilde, qk_x), qk_x_plus
)
return alpha_tilde
[docs] def cdf_tail(
self, z: torch.Tensor, left_tail: bool = True
) -> torch.Tensor:
r"""
Computes the quantile level alpha_tilde such that alpha_tilde.
= q^{-1}(z) if z is in the tail region
= qk_x_l or qk_x_r if z is in the non-tail region
Parameters
----------
z
Observation, shape = (*batch_shape,)
left_tail
If True, compute alpha_tilde for the left tail
Otherwise, compute alpha_tilde for the right tail
Returns
-------
alpha_tilde
Corresponding quantile level, shape = (*batch_shape,)
"""
if left_tail:
tail_a, tail_b, qk_x = self.tail_al, self.tail_bl, self.qk_x_l
else:
tail_a, tail_b, qk_x = self.tail_ar, self.tail_br, 1 - self.qk_x_r
log_alpha_tilde = torch.minimum((z - tail_b) / tail_a, torch.log(qk_x))
alpha_tilde = torch.exp(log_alpha_tilde)
return alpha_tilde if left_tail else 1 - alpha_tilde
[docs] def crps_tail(
self, z: torch.Tensor, left_tail: bool = True
) -> torch.Tensor:
r"""
Compute CRPS in analytical form for left/right tails
Parameters
----------
z
Observation to evaluate. shape = (*batch_shape,)
left_tail
If True, compute CRPS for the left tail
Otherwise, compute CRPS for the right tail
Returns
-------
Tensor
Tensor containing the CRPS, of the same shape as z
"""
alpha_tilde = self.cdf_tail(z, left_tail=left_tail)
if left_tail:
tail_a, tail_b, qk_x, qk_y = (
self.tail_al,
self.tail_bl,
self.qk_x_l,
self.qk_y_l,
)
term1 = (z - tail_b) * (qk_x**2 - 2 * qk_x + 2 * alpha_tilde)
term2 = qk_x**2 * tail_a * (-torch.log(qk_x) + 0.5)
term2 = term2 + 2 * torch.where(
z < qk_y,
qk_x * tail_a * (torch.log(qk_x) - 1)
+ alpha_tilde * (-z + tail_b + tail_a),
torch.zeros_like(qk_x),
)
else:
tail_a, tail_b, qk_x, qk_y = (
self.tail_ar,
self.tail_br,
self.qk_x_r,
self.qk_y_r,
)
term1 = (z - tail_b) * (-1 - qk_x**2 + 2 * alpha_tilde)
term2 = tail_a * (
-0.5 * (qk_x + 1) ** 2
+ (qk_x**2 - 1) * torch.log(1 - qk_x)
+ 2 * alpha_tilde
)
term2 = term2 + 2 * torch.where(
z > qk_y,
(1 - alpha_tilde) * (z - tail_b),
tail_a * (1 - qk_x) * torch.log(1 - qk_x),
)
return term1 + term2
[docs] def crps_spline(self, z: torch.Tensor) -> torch.Tensor:
r"""
Compute CRPS in analytical form for the spline
Parameters
----------
z
Observation to evaluate. shape = (*batch_shape,)
Returns
-------
Tensor
Tensor containing the CRPS, of the same shape as z
"""
qk_x, qk_x_plus, qk_y = self.qk_x, self.qk_x_plus, self.qk_y
sk_x, sk_x_plus = self.sk_x, self.sk_x_plus
delta_sk_x, delta_sk_y = self.delta_sk_x, self.delta_sk_y
z_expand = z.unsqueeze(dim=-1)
qk_x_plus_expand = qk_x_plus.unsqueeze(dim=-1)
alpha_tilde = self.cdf_spline(z)
alpha_tilde_expand = alpha_tilde.unsqueeze(dim=-1)
r = torch.minimum(torch.maximum(alpha_tilde_expand, sk_x), sk_x_plus)
coeff1 = (
-2 / 3 * sk_x_plus**3
+ sk_x * sk_x_plus**2
+ sk_x_plus**2
- (1 / 3) * sk_x**3
- 2 * sk_x * sk_x_plus
- r**2
+ 2 * sk_x * r
)
coeff2 = (
-2 * torch.maximum(alpha_tilde_expand, sk_x_plus)
+ sk_x_plus**2
+ 2 * qk_x_plus_expand
- qk_x_plus_expand**2
)
result = (
(qk_x_plus**2 - qk_x**2) * (z_expand - qk_y)
+ 2 * (qk_x_plus - alpha_tilde) * (qk_y - z_expand)
+ torch.sum((delta_sk_y / delta_sk_x) * coeff1, dim=-1)
+ torch.sum(delta_sk_y * coeff2, dim=-1)
)
return torch.sum(result, dim=-1)
[docs] def loss(self, z: torch.Tensor) -> torch.Tensor:
return self.crps(z)
[docs] def crps(self, z: torch.Tensor) -> torch.Tensor:
r"""
Compute CRPS in analytical form
Parameters
----------
z
Observation to evaluate. Shape = (*batch_shape,)
Returns
-------
Tensor
Tensor containing the CRPS, of the same shape as z
"""
crps_lt = self.crps_tail(z, left_tail=True)
crps_rt = self.crps_tail(z, left_tail=False)
return crps_lt + crps_rt + self.crps_spline(z)
[docs] def cdf(self, z: torch.Tensor) -> torch.Tensor:
r"""
Computes the quantile level alpha_tilde such that
q(alpha_tilde) = z
Parameters
----------
z
Tensor of shape = (*batch_shape,)
Returns
-------
alpha_tilde
Tensor of shape = (*batch_shape,)
"""
qk_y, qk_y_l, qk_y_plus = self.qk_y, self.qk_y_l, self.qk_y_plus
alpha_tilde = torch.where(
z < qk_y_l,
self.cdf_tail(z, left_tail=True),
self.cdf_tail(z, left_tail=False),
)
spline_alpha_tilde = self.cdf_spline(z)
for spline_idx in range(self.num_qk - 1):
is_in_between = torch.logical_and(
qk_y[..., spline_idx] <= z, z < qk_y_plus[..., spline_idx]
)
alpha_tilde = torch.where(
is_in_between, spline_alpha_tilde[..., spline_idx], alpha_tilde
)
return alpha_tilde
[docs] def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
r"""
Function used to draw random samples
Parameters
----------
num_samples
number of samples
dtype
data type
Returns
-------
Tensor
Tensor of shape (*batch_shape,) if num_samples = None
else (num_samples, *batch_shape)
"""
# if sample_shape=()) then input_alpha should have the same shape
# as beta_l, i.e., (*batch_shape,)
# else u should be (*sample_shape, *batch_shape)
target_shape = (
self.beta_l.shape
if sample_shape == torch.Size()
else torch.Size(sample_shape) + self.beta_l.shape
)
alpha = torch.rand(
target_shape,
dtype=self.beta_l.dtype,
device=self.beta_l.device,
layout=self.beta_l.layout,
)
sample = self.quantile(alpha)
if sample_shape == torch.Size():
sample = sample.squeeze(dim=0)
return sample
@property
def batch_shape(self) -> torch.Size:
return self.beta_l.shape
[docs]class ISQFOutput(DistributionOutput):
r"""
DistributionOutput class for the Incremental (Spline) Quantile Function
Parameters
----------
num_pieces
number of spline pieces for each spline
ISQF reduces to IQF when num_pieces = 1
qk_x
List containing the x-positions of quantile knots
tol
tolerance for numerical safeguarding
"""
distr_cls: type = ISQF
@validated()
def __init__(
self, num_pieces: int, qk_x: List[float], tol: float = 1e-4
) -> None:
# ISQF reduces to IQF when num_pieces = 1
super().__init__()
assert (
isinstance(num_pieces, int) and num_pieces > 0
), "num_pieces should be an integer and greater than 0"
self.num_pieces = num_pieces
self.qk_x = sorted(qk_x)
self.num_qk = len(qk_x)
self.tol = tol
self.args_dim: Dict[str, int] = {
"spline_knots": (self.num_qk - 1) * num_pieces,
"spline_heights": (self.num_qk - 1) * num_pieces,
"beta_l": 1,
"beta_r": 1,
"quantile_knots": self.num_qk,
}
[docs] @classmethod
def domain_map( # type: ignore
cls,
spline_knots: torch.Tensor,
spline_heights: torch.Tensor,
beta_l: torch.Tensor,
beta_r: torch.Tensor,
quantile_knots: torch.Tensor,
tol: float = 1e-4,
) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
]:
"""
Domain map function The inputs of this function are specified by
self.args_dim.
spline_knots, spline_heights:
parameterizing the x-/ y-positions of the spline knots,
shape = (*batch_shape, (num_qk-1)*num_pieces)
beta_l, beta_r:
parameterizing the left/right tail, shape = (*batch_shape, 1)
quantile_knots:
parameterizing the y-positions of the quantile knots,
shape = (*batch_shape, num_qk)
"""
# Add tol to prevent the y-distance of
# two quantile knots from being too small
#
# Because in this case the spline knots could be squeezed together
# and cause overflow in spline CRPS computation
qk_y = torch.cat(
[
quantile_knots[..., 0:1],
torch.abs(quantile_knots[..., 1:]) + tol,
],
dim=-1,
)
qk_y = torch.cumsum(qk_y, dim=-1)
# Prevent overflow when we compute 1/beta
beta_l = torch.abs(beta_l.squeeze(-1)) + tol
beta_r = torch.abs(beta_r.squeeze(-1)) + tol
return spline_knots, spline_heights, beta_l, beta_r, qk_y
[docs] def distribution(
self,
distr_args,
loc: Optional[torch.Tensor] = None,
scale: Optional[torch.Tensor] = None,
) -> ISQF:
"""
function outputing the distribution class
distr_args: distribution arguments
loc: shift to the data mean
scale: scale to the data
"""
distr_args, qk_x = self.reshape_spline_args(distr_args, self.qk_x)
distr = self.distr_cls(*distr_args, qk_x, self.tol)
if scale is None:
return distr
else:
return TransformedISQF(
distr, [AffineTransform(loc=loc, scale=scale)]
)
[docs] def reshape_spline_args(self, distr_args, qk_x: List[float]):
"""
auxiliary function reshaping knots and heights to (*batch_shape,
num_qk-1, num_pieces) qk_x to (*batch_shape, num_qk)
"""
spline_knots, spline_heights = distr_args[0], distr_args[1]
batch_shape = spline_knots.shape[:-1]
num_qk, num_pieces = self.num_qk, self.num_pieces
# repeat qk_x from (num_qk,) to (*batch_shape, num_qk)
qk_x_repeat = torch.tensor(
qk_x, dtype=spline_knots.dtype, device=spline_knots.device
).repeat(*batch_shape, 1)
# knots and heights have shape (*batch_shape, (num_qk-1)*num_pieces)
# reshape them to (*batch_shape, (num_qk-1), num_pieces)
spline_knots_reshape = spline_knots.reshape(
*batch_shape, (num_qk - 1), num_pieces
)
spline_heights_reshape = spline_heights.reshape(
*batch_shape, (num_qk - 1), num_pieces
)
distr_args_reshape = (
spline_knots_reshape,
spline_heights_reshape,
*distr_args[2:],
)
return distr_args_reshape, qk_x_repeat
@property
def event_shape(self) -> Tuple:
return ()