# Source code for gluonts.mx.linalg_util

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import Optional

import mxnet as mx
import numpy as np

from gluonts.core.component import DType
from gluonts.mx import Tensor

[docs]def batch_diagonal(
F,
matrix: Tensor,
num_data_points: Optional[int] = None,
float_type=np.float32,
) -> Tensor:
"""
This function extracts the diagonal of a batch matrix.

Parameters
----------
F
A module that can either refer to the Symbol API or the NDArray
API in MXNet.
matrix
matrix of shape (batch_size, num_data_points, num_data_points).
num_data_points
Number of rows in the kernel_matrix.

Returns
-------
Tensor
Diagonals of kernel_matrix of shape (batch_size, num_data_points, 1).

"""
return F.linalg.gemm2(
F.ones_like(F.slice_axis(matrix, axis=2, begin=0, end=1)),
)

[docs]def lower_triangular_ones(F, d: int, offset: int = 0) -> Tensor:
"""
Constructs a lower triangular matrix consisting of ones.

Parameters
----------
F
d
Dimension of the output tensor, whose shape will be (d, d).
offset
Indicates how many diagonals to set to zero in the lower triangular
part. By default, offset = 0, so the output matrix contains also the
main diagonal. For example, if offset = 1 then the output will be a
strictly lower triangular matrix (i.e. the main diagonal will be zero).

Returns
-------
Tensor
Tensor of shape (d, d) consisting of ones in the strictly lower
triangular part, and zeros elsewhere.

"""
for k in range(offset, d):

# noinspection PyMethodOverriding,PyPep8Naming
[docs]def jitter_cholesky_eig(
F,
matrix: Tensor,
num_data_points: Optional[int] = None,
float_type: DType = np.float64,
diag_weight: float = 1e-6,
) -> Tensor:
"""
This function applies the jitter method using the eigenvalue decomposition.
The eigenvalues are bound below by the jitter, which is proportional to the mean of the
diagonal elements

Parameters
----------
F
A module that can either refer to the Symbol API or the NDArray
API in MXNet.
matrix
Matrix of shape (batch_size, num_data_points, num_data_points).
num_data_points
Number of rows in the kernel_matrix.
float_type
Determines whether to use single or double precision.

Returns
-------
Tensor
Returns the approximate lower triangular Cholesky factor L
of shape (batch_size, num_data_points, num_data_points)
"""
diag = batch_diagonal(
F, matrix, num_data_points, float_type
)  # shape (batch_size, num_data_points, 1)
diag_mean = diag.mean(axis=1).expand_dims(
axis=2
)  # shape (batch_size, 1, 1)
U, Lambda = F.linalg.syevd(matrix)
jitter = F.broadcast_mul(diag_mean, F.ones_like(diag)) * diag_weight
# K = U^TLambdaU, where the rows of U are the eigenvectors of K.
# The eigendecomposition :math:U^TLambdaU is used instead of :math: ULambdaU^T, sine
# to utilize row-based computation (see Section 4, Seeger et al., 2018)
return F.linalg.potrf(
F.linalg.gemm2(
U,
F.linalg.gemm2(
F.eye(num_data_points, dtype=float_type),
F.maximum(jitter, Lambda.expand_dims(axis=2)),
),
U,
),
transpose_a=True,
)
)

# noinspection PyMethodOverriding,PyPep8Naming
[docs]def jitter_cholesky(
F,
matrix: Tensor,
num_data_points: Optional[int] = None,
float_type: DType = np.float64,
max_iter_jitter: int = 10,
neg_tol: float = -1e-8,
diag_weight: float = 1e-6,
increase_jitter: int = 10,
) -> Optional[Tensor]:
"""
This function applies the jitter method.  It iteratively tries to compute the Cholesky decomposition and
adds a positive tolerance to the diagonal that increases at each iteration until the matrix is positive definite
or the maximum number of iterations has been reached.

Parameters
----------
matrix
Kernel matrix of shape (batch_size, num_data_points, num_data_points).
num_data_points
Number of rows in the kernel_matrix.
float_type
Determines whether to use single or double precision.
max_iter_jitter
Maximum number of iterations for jitter to iteratively make the matrix positive definite.
neg_tol
Parameter in the jitter methods to eliminate eliminate matrices with diagonal elements smaller than this
when checking if a matrix is positive definite.
diag_weight
Multiple of mean of diagonal entries to initialize the jitter.
increase_jitter
Each iteration multiply by jitter by this amount
Returns
-------
Optional[Tensor]
The method either fails to make the matrix positive definite within the maximum number of iterations
and outputs an error or succeeds and returns the lower triangular Cholesky factor L
of shape (batch_size, num_data_points, num_data_points)
"""
num_iter = 0
diag = batch_diagonal(
F, matrix, num_data_points, float_type
)  # shape (batch_size, num_data_points, 1)
diag_mean = diag.mean(axis=1).expand_dims(
axis=2
)  # shape (batch_size, 1, 1)
jitter = F.zeros_like(diag)  # shape (batch_size, num_data_points, 1)
# Ensure that diagonal entries are numerically non-negative, as defined by neg_tol
# TODO: Add support for symbolic case: Cannot use < operator with symbolic variables
if F.sum(diag <= neg_tol) > 0:
raise mx.base.MXNetError(
" Matrix is not positive definite: negative diagonal elements"
)
while num_iter <= max_iter_jitter:
try:
L = F.linalg.potrf(
matrix,
F.eye(num_data_points, dtype=float_type),
jitter,
),
)
)
# gpu will not throw error but will store nans. If nan, L.sum() = nan and
# L.nansum() computes the sum treating nans as zeros so the error tolerance can be large.
# for axis = Null, nansum() and sum() will sum over all elements and return scalar array with shape (1,)
# TODO: Add support for symbolic case: Cannot use <= operator with symbolic variables
assert F.abs(L.nansum() - L.sum()) <= 1e-1
return L
except:
if num_iter == 0:
# Initialize the jitter: constant jitter per each batch
jitter = (