from typing import Dict, List, Tuple

import mxnet as mx

from gluonts.core.component import validated
from gluonts.mx import Tensor

from .bijection import Bijection, InverseBijection
from .bijection_output import BijectionOutput
from .distribution import getF, softplus

[docs]class BoxCoxTransform(Bijection):
r"""
Implements Box-Cox transformation of a uni-variate random variable.
The Box-Cox transformation of an observation :math:z is given by

.. math::

BoxCox(z; \lambda_1, \lambda_2) = \begin{cases}
((z + \lambda_2)^{\lambda_1} - 1) / \lambda_1, \quad & \text{if }
\lambda_1 \neq 0, \\
\log (z + \lambda_2), \quad & \text{otherwise.}
\end{cases}

Here, :math:\lambda_1 and :math:\lambda_2 are learnable parameters. Note that the domain
of the transformation is not restricted.

For numerical stability, instead of checking :math:\lambda_1 is exactly zero, we use the condition

.. math::

|\lambda_1| < tol\_lambda\_1

for a pre-specified tolerance tol_lambda_1.

Inverse of the Box-Cox Transform is given by

.. math::

BoxCox^{-1}(y; \lambda_1, \lambda_2) = \begin{cases}
(y \lambda_1 + 1)^{(1/\lambda_1)} - \lambda_2, \quad & \text{if }
\lambda_1 \neq 0, \\
\exp (y) - \lambda_2, \quad & \text{otherwise.}
\end{cases}

**Notes on numerical stability:**

1.  For the forward transformation, :math:\lambda_2 must always be chosen such that

.. math::

z + \lambda_2 > 0.

To achieve this one needs to know a priori the lower bound on the observations.
This is set in BoxCoxTransformOutput, since :math:\lambda_2 is learnable.

2.  Similarly for the inverse transformation to work reliably, a sufficient condition is

.. math::

y \lambda_1 + 1 \geq 0,

where :math:y is the input to the inverse transformation.

This cannot always be guaranteed especially when :math:y is a sample from a transformed distribution.
Hence we always truncate :math:y \lambda_1 + 1 at zero.

An example showing why this could happen in our case:
consider transforming observations from the unit interval (0, 1) with parameters

.. math::

\begin{align}
\lambda_1 = &\ 1.1, \\
\lambda_2 = &\ 0.
\end{align}

Then the range of the transformation is (-0.9090, 0.0).
If Gaussian is fit to the transformed observations and a sample is drawn from it,
then it is likely that the sample is outside this range, e.g., when the mean is close to -0.9.
The subsequent inverse transformation of the sample is not a real number anymore.

>>> y = -0.91
>>> lambda_1 = 1.1
>>> lambda_2 = 0.0
>>> (y * lambda_1 + 1) ** (1 / lambda_1) + lambda_2
(-0.0017979146510711471+0.0005279153735965289j)

Parameters
----------
lambda_1
lambda_2
tol_lambda_1
For numerical stability, treat lambda_1 as zero if it is less than
tol_lambda_1
F
"""
arg_names = ["box_cox.lambda_1", "box_cox.lambda_2"]

@validated()
def __init__(
self,
lambda_1: Tensor,
lambda_2: Tensor,
tol_lambda_1: float = 1e-2,
F=None,
) -> None:
self.lambda_1 = lambda_1
self.lambda_2 = lambda_2
self.tol_lambda_1 = tol_lambda_1
self.F = F if F else getF(lambda_1)

self._power = self.F.power if self.F == mx.nd else self.F.pow

@property
def args(self) -> List:
r"""
List: current values of the parameters
"""
return [self.lambda_1, self.lambda_2]

@property
def event_dim(self) -> int:
return 0

@property
def sign(self) -> Tensor:
return 1.0

[docs]    def f(self, z: Tensor) -> Tensor:
r"""
Forward transformation of observations z

Parameters
----------
z
observations

Returns
-------
Tensor
Transformed observations
"""
F = self.F
lambda_1 = self.lambda_1
lambda_2 = self.lambda_2
tol_lambda_1 = self.tol_lambda_1
_power = self._power

return F.where(
x=(_power(z + lambda_2, lambda_1) - 1.0) / lambda_1,
y=F.log(z + lambda_2),
name="Box_Cox_trans",
)

[docs]    def f_inv(self, y: Tensor) -> Tensor:
r"""Inverse of the Box-Cox Transform

Parameters
----------
y
Transformed observations

Returns
-------
Tensor
Observations

"""
F = self.F
lambda_1 = self.lambda_1
lambda_2 = self.lambda_2
tol_lambda_1 = self.tol_lambda_1
_power = self._power

# For numerical stability we truncate :math:y * \lambda_1 + 1.0 at zero.
base = F.relu(y * lambda_1 + 1.0)

return F.where(
x=_power(base, 1.0 / lambda_1) - lambda_2,
y=F.exp(y) - lambda_2,
name="Box_Cox_inverse_trans",
)

[docs]    def log_abs_det_jac(self, z: Tensor, y: Tensor = None) -> Tensor:
r"""
Logarithm of the absolute value of the Jacobian determinant corresponding to the Box-Cox Transform
is given by

.. math::
\log \frac{d}{dz} BoxCox(z; \lambda_1, \lambda_2) = \begin{cases}
\log (z + \lambda_2) (\lambda_1 - 1), \quad & \text{if } \lambda_1 \neq 0, \\
-\log (z + \lambda_2), \quad & \text{otherwise.}
\end{cases}

Note that the derivative of the transformation is always non-negative.

Parameters
----------
z
observations
y
not used

Returns
-------
Tensor

"""
F = self.F
lambda_1 = self.lambda_1
lambda_2 = self.lambda_2
tol_lambda_1 = self.tol_lambda_1

return F.where(
condition=F.abs(lambda_1).__ge__(tol_lambda_1),
x=F.log(z + lambda_2) * (lambda_1 - 1.0),
y=-F.log(z + lambda_2),
name="Box_Cox_trans_log_det_jac",
)

[docs]class BoxCoxTransformOutput(BijectionOutput):
bij_cls: type = BoxCoxTransform
args_dim: Dict[str, int] = dict(zip(BoxCoxTransform.arg_names, [1, 1]))

@validated()
def __init__(self, lb_obs: float = 0.0, fix_lambda_2: bool = True) -> None:
super().__init__()
self.lb_obs = lb_obs
self.fix_lambda_2 = fix_lambda_2

[docs]    def domain_map(self, F, *args: Tensor) -> Tuple[Tensor, ...]:
lambda_1, lambda_2 = args
if self.fix_lambda_2:
lambda_2 = -self.lb_obs * F.ones_like(lambda_2)
else:
# This makes sure that :math:z +  \lambda_2 > 0, where :math:z > lb_obs
lambda_2 = softplus(F, lambda_2) - self.lb_obs * F.ones_like(
lambda_2
)

# we squeeze the output since event_shape is ()
return lambda_1.squeeze(axis=-1), lambda_2.squeeze(axis=-1)

@property
def event_shape(self) -> Tuple:
return ()

[docs]class InverseBoxCoxTransform(InverseBijection):
"""
Implements the inverse of Box-Cox transformation as a bijection.
"""

arg_names = ["box_cox.lambda_1", "box_cox.lambda_2"]

@validated()
def __init__(
self,
lambda_1: Tensor,
lambda_2: Tensor,
tol_lambda_1: float = 1e-2,
F=None,
) -> None:
super().__init__(BoxCoxTransform(lambda_1, lambda_2, tol_lambda_1, F))

@property
def event_dim(self) -> int:
return 0

[docs]class InverseBoxCoxTransformOutput(BoxCoxTransformOutput):
bij_cls: type = InverseBoxCoxTransform

args_dim: Dict[str, int] = dict(
zip(InverseBoxCoxTransform.arg_names, [1, 1])
)

@property
def event_shape(self) -> Tuple:
return ()