# Copyright 2018, Inc. or its affiliates. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
# or in the "license" file accompanying this file. This file is distributed
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

# Third-party imports
import mxnet as mx

# First-party imports
from gluonts.core.component import validated
from import Tensor
from import _broadcast_param
from import SNMLPBlock
from import getF
from import (

[docs]def log_abs_det(A: Tensor) -> Tensor: """ Logarithm of the absolute value of matrix `A` Parameters ---------- A Tensor matrix from which to compute the log absolute value of its determinant Returns ------- Tensor """ F = getF(A) A_squared = F.linalg.gemm2(A, A, transpose_a=True) L = F.linalg.potrf(A_squared) return F.diag(L, axis1=-2, axis2=-1).abs().log().sum(-1)
[docs]class InvertibleResnetHybridBlock(BijectionHybridBlock): """ Based on [BJC19]_, apart from f and f_inv that are swapped. """ @validated() def __init__( self, event_shape, hidden_units: int = 16, num_hidden_layers: int = 1, num_inv_iters: int = 10, ignore_logdet: bool = False, activation: str = "lipswish", num_power_iter: int = 1, flatten: bool = False, coeff: float = 0.9, use_caching: bool = True, *args, **kwargs ): super().__init__(*args, **kwargs) assert len(event_shape) == 1 self._event_shape = event_shape self._hidden_units = hidden_units self._num_hidden_layers = num_hidden_layers self._num_inv_iters = num_inv_iters self._ignore_logdet = ignore_logdet self._activation = activation self._coeff = coeff self._num_power_iter = num_power_iter self._flatten = flatten self._use_caching = use_caching self._cached_x = None self._cached_y = None with self.name_scope(): self._block = SNMLPBlock( self._event_shape[0], self._hidden_units, self._event_shape[0], num_hidden_layers=self._num_hidden_layers, activation=self._activation, flatten=self._flatten, ) @property def event_shape(self): return self._event_shape @property def event_dim(self) -> int: return len(self.event_shape)
[docs] def f(self, x: Tensor) -> Tensor: """ Forward transformation of iResnet Parameters ---------- x observations Returns ------- Tensor transformed tensor `\text{iResnet}(x)` """ if x is self._cached_x: return self._cached_y y = x for _ in range(self._num_inv_iters): y = x - self._block(y) if self._use_caching: self._cached_x = x self._cached_y = y return y
[docs] def f_inv(self, y: Tensor) -> Tensor: """ Inverse transformation of `iResnet` Parameters ---------- y input tensor Returns ------- Tensor transformed tensor `\text{iResnet}^{-1}(y)` """ if y is self._cached_y: return self._cached_x x = y + self._block(y) if self._use_caching: self._cached_x = x self._cached_y = y return x
[docs] def log_abs_det_jac(self, x: Tensor, y: Tensor) -> Tensor: """ Logarithm of the absolute value of the Jacobian determinant corresponding to the iResnet Transform Parameters ---------- x input of the forward transformation or output of the inverse transform y output of the forward transform or input of the inverse transform Returns ------- Tensor Jacobian evaluated for x as input or y as output """ if self._ignore_logdet: assert x is not None ladj = mx.nd.zeros(x.shape[0]) else: # we take the negative value since we use the forward pass of mlp to # compute the inverse, and we want ladj of forward, which is # opposite of ladj of reverse jac_block = self._block.jacobian(y) batch_shape, (output_dim, input_dim) = ( jac_block.shape[:-2], jac_block.shape[-2:], ) identity = _broadcast_param( mx.nd.eye(output_dim, input_dim), axes=range(len(batch_shape)), sizes=batch_shape, ) ladj = -log_abs_det(identity + jac_block) return ladj
[docs]def iresnet(num_blocks: int, **block_kwargs) -> ComposedBijectionHybridBlock: """ Parameters ---------- num_blocks number of iResnet blocks block_kwargs keyword arguments given to initialize each block object Returns ------- """ return ComposedBijectionHybridBlock( [ InvertibleResnetHybridBlock(**block_kwargs) for _ in range(num_blocks) ] )