Source code for gluonts.mx.util

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import inspect
import tempfile
from pathlib import Path
from typing import List, Optional, Tuple, cast, Type

import mxnet as mx
from mxnet.gluon.block import _flatten

from gluonts.core.serde import dump_json, load_json
from gluonts.mx import Tensor


[docs]class HybridContext: """ A context manager that ensures that an MXNet network is operating in a hybridized / not hybridized mode. Parameters ---------- net The network whose hybrid mode has to be modified within the enclosing context. hybridize A boolean flag indicating whether the hybrid mode should be set or not. kwargs A dictionary of optional arguments to pass to the `hybridize()` call of the enclosed `HybridBlock` network. """ def __init__( self, net: mx.gluon.HybridBlock, hybridize: bool, data_batch: Optional[List[mx.nd.NDArray]] = None, **kwargs, ) -> None: self.net = net self.required_mode = hybridize self.original_mode = getattr(net, "_active", False) self.data_batch = data_batch self.kwargs = kwargs def __enter__(self): self.net.hybridize(active=self.required_mode, **self.kwargs) if self.data_batch is not None: self.net(*self.data_batch) def __exit__(self, *args): self.net.hybridize(active=self.original_mode, **self.kwargs)
[docs]def assert_shape(x: Tensor, expected_shape: Tuple[int, ...]): """ Assert expected shape if mode is mx.nd. Parameters ---------- x Input Tensor expected_shape Expected shape Returns ------- """ if isinstance(x, mx.nd.NDArray): for i, j in zip(x.shape, expected_shape): if j != -1: assert ( i == j ), f"shape mismatch got {x.shape} expected {expected_shape}"
[docs]def copy_parameters( net_source: mx.gluon.Block, net_dest: mx.gluon.Block, ignore_extra: bool = False, allow_missing: bool = False, ) -> None: """ Copies parameters from one network to another. Parameters ---------- net_source Input network. net_dest Output network. ignore_extra Whether to ignore parameters from the source that are not present in the target. allow_missing Whether to allow additional parameters in the target not present in the source. """ with tempfile.TemporaryDirectory( prefix="gluonts-estimator-temp-" ) as model_dir: model_dir_path = str(Path(model_dir) / "tmp_model") net_source.save_parameters(model_dir_path) net_dest.load_parameters( model_dir_path, ctx=mx.current_context(), allow_missing=allow_missing, ignore_extra=ignore_extra, )
[docs]def get_hybrid_forward_input_names( hybrid_block_type: Type[mx.gluon.HybridBlock], ): params = inspect.signature(hybrid_block_type.hybrid_forward).parameters param_names = [k for k, v in params.items() if not str(v).startswith("*")] assert param_names[0] == "self", ( "Expected first argument of hybrid_forward to be `self`, " f"but found `{param_names[0]}`" ) assert param_names[1] == "F", ( "Expected second argument of hybrid_forward to be `F`, " f"but found `{param_names[1]}`" ) return param_names[2:] # skip: self, F
[docs]def hybrid_block_to_symbol_block( hb: mx.gluon.HybridBlock, data_batch: List[mx.nd.NDArray] ) -> mx.gluon.SymbolBlock: """ Converts a Gluon `HybridBlock` to a `SymbolBlock`. Following the Gluon API, this is achieved by a `hybridize()` call on the passed `HybridBlock`, a single forward pass (using the provided data batch), and a combination of an `export()` and an `import()` calls of the input block. Note that MXNet has `problems with this method <https://github.com/apache/incubator-mxnet/issues/12783>`_. Parameters ---------- hb The Gluon `HybridBlock` to convert. data_batch Data to use for the forward pass after the `hybridize()` call. Returns ------- mx.gluon.SymbolBlock The resulting Gluon block backed by an MXNet symbol graph. """ with tempfile.TemporaryDirectory( prefix="gluonts-estimator-temp-" ) as model_dir: # when importing, SymbolBlock has to know about the total number # of input symbols, including nested Tensors flat_data_batch, _ = _flatten(data_batch, "input") num_inputs = len(flat_data_batch) model_dir_path = Path(model_dir) model_name = "gluonts-model" with HybridContext( net=hb, hybridize=True, data_batch=data_batch, static_alloc=True, static_shape=True, ): export_symb_block(hb, model_dir_path, model_name) sb = import_symb_block(num_inputs, model_dir_path, model_name) sb(*data_batch) return sb
[docs]def export_symb_block( hb: mx.gluon.HybridBlock, model_dir: Path, model_name: str, epoch: int = 0 ) -> None: """ Serializes a hybridized Gluon `HybridBlock`. Parameters ---------- hb The block to export. model_dir The path where the model will be saved. model_name The name identifying the model. epoch The epoch number, which together with the `model_name` identifies the model parameters. """ hb.export(path=str(model_dir / model_name), epoch=epoch) # FIXME: we persist input/output formats of hybrid blocks as mxnet does not # FIXME: https://github.com/apache/incubator-mxnet/issues/17488 with (model_dir / f"{model_name}-in_out_format.json").open("w") as fp: in_out_format = dict( in_format=hb._in_format, out_format=hb._out_format ) print(dump_json(in_out_format), file=fp)
[docs]def import_symb_block( num_inputs: int, model_dir: Path, model_name: str, epoch: int = 0 ) -> mx.gluon.SymbolBlock: """ Deserializes a hybridized Gluon `HybridBlock` as a `SymbolBlock`. Parameters ---------- num_inputs The number of inputs of the serialized block. model_dir The path where the model is saved. model_name The name identifying the model. epoch The epoch number, which together with the `model_name` identifies the model parameters. Returns ------- mx.gluon.SymbolBlock The deserialized block. """ if num_inputs == 1: input_names = ["data"] else: input_names = [f"data{i}" for i in range(num_inputs)] # FIXME: prevents mxnet from failing with empty saved parameters list # FIXME: https://github.com/apache/incubator-mxnet/issues/17488 param_file: Optional[str] = str( model_dir / f"{model_name}-{epoch:04}.params" ) if not mx.nd.load(param_file): param_file = None # FIXME: mx.gluon.SymbolBlock cannot infer float_type and uses default # np.float32 # FIXME: https://github.com/apache/incubator-mxnet/issues/11849 sb = mx.gluon.SymbolBlock.imports( symbol_file=str(model_dir / f"{model_name}-symbol.json"), input_names=input_names, param_file=param_file, ctx=mx.current_context(), ) # FIXME: try to retrieve input/output format # FIXME: https://github.com/apache/incubator-mxnet/issues/17488 format_json_path = model_dir / f"{model_name}-in_out_format.json" if format_json_path.exists(): with format_json_path.open("r") as fp: formats = load_json(fp.read()) sb._in_format = formats["in_format"] sb._out_format = formats["out_format"] return sb
[docs]def export_repr_block( rb: mx.gluon.HybridBlock, model_dir: Path, model_name: str, epoch: int = 0 ) -> None: """ Serializes a representable Gluon block. Parameters ---------- rb The block to export. model_dir The path where the model will be saved. model_name The name identifying the model. epoch The epoch number, which together with the `model_name` identifies the model parameters. """ with (model_dir / f"{model_name}-network.json").open("w") as fp: print(dump_json(rb), file=fp) rb.save_parameters(str(model_dir / f"{model_name}-{epoch:04}.params"))
[docs]def import_repr_block( model_dir: Path, model_name: str, epoch: int = 0 ) -> mx.gluon.HybridBlock: """ Deserializes a representable Gluon block. Parameters ---------- model_dir The path where the model is saved. model_name The name identifying the model. epoch The epoch number, which together with the `model_name` identifies the model parameters. Returns ------- mx.gluon.HybridBlock: The deserialized block. """ with (model_dir / f"{model_name}-network.json").open("r") as fp: rb = cast(mx.gluon.HybridBlock, load_json(fp.read())) rb.load_parameters( str(model_dir / f"{model_name}-{epoch:04}.params"), ctx=mx.current_context(), allow_missing=False, ignore_extra=False, ) return rb
[docs]def cumsum( F, x: Tensor, exclusive: bool = False, reverse: bool = False ) -> Tensor: r""" Find cumulative sum on the last axis by multiplying with lower triangular ones-matrix: .. math:: \operatorname{cumsum}(x) = \begin{cases} \operatorname{ltr\_ones} \times x & \text{for cumulative sum}\\ x \times \operatorname{ltr\_ones} & \text{for cumulative sum in the reverse order} \end{cases} Also supports `exclusive` flag to start the cumsum with zero. For example, if :math:`x = [a, b, c]`, we have .. math:: \operatorname{cumsum}(x) = \begin{cases} [a, a + b, a + b + c] & \text{if }\mathit{reverse = False, exclusive = False}\\ [0, a, a + b] & \text{if }\mathit{reverse = False, exclusive = True}\\ [a + b + c, b + c, c] & \text{if }\mathit{reverse = True, exclusive = False}\\ [b + c, c, 0] & \text{if }\mathit{reverse = True, exclusive = True}\\ \end{cases} Parameters ---------- F The function space to use. x A tensor with shape :math:`(..., n)`. exclusive If `True`, the cumulative sum starts with zero. reverse If `True`, the cumulative sum is performed in the opposite direction. Returns ------- Tensor: A modified tensor with identical shape and cumulative sums in the last axis. """ # Create a new axis (for matrix multiplication) either at last location or # last-but-one location (for reverse mode) exp_dim = -2 if reverse else -1 # (..., 1, n) if reverse is True and (..., n, 1) otherwise x = x.expand_dims(axis=exp_dim) # Ones_matrix (..., n, n) ones_matrix = F.linalg_gemm2( F.ones_like(x), F.ones_like(x), transpose_a=reverse, transpose_b=not reverse, ) cumulative_sum = F.linalg_trmm(ones_matrix, x, rightside=reverse) if exclusive: cumulative_sum = cumulative_sum - x return cumulative_sum.squeeze(axis=exp_dim)
[docs]def weighted_average( F, x: Tensor, weights: Optional[Tensor] = None, axis: Optional[int] = None, include_zeros_in_denominator=False, ) -> Tensor: """ Computes the weighted average of a given tensor across a given axis, masking values associated with weight zero, meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`. Parameters ---------- F The function space to use. x Input tensor, of which the average must be computed. weights Weights tensor, of the same shape as `x`. axis The axis along which to average `x` include_zeros_in_denominator Include zeros in the denominator. Can be useful for sparse time series because the loss can be dominated by few observed examples. Returns ------- Tensor: The tensor with values averaged along the specified `axis`. """ if weights is not None: weighted_tensor = F.where( condition=weights, x=x * weights, y=F.zeros_like(x) ) if include_zeros_in_denominator: sum_weights = F.maximum(1.0, F.ones_like(weights).sum(axis=axis)) else: sum_weights = F.maximum(1.0, weights.sum(axis=axis)) return weighted_tensor.sum(axis=axis) / sum_weights else: return x.mean(axis=axis)
[docs]def make_nd_diag(F, x: Tensor, d: int) -> Tensor: """ Make a diagonal tensor, given the diagonal. Parameters ---------- F The function space to use. x Diagonal to use, shape :math:`(..., d)`. d Last dimension of `x`. Returns ------- Tensor A tensor y of shape :math:`(..., d, d)` such that :math:`y[..., i, i] = x[..., i]`. """ return F.broadcast_mul(F.eye(d), x.expand_dims(axis=-1))
def _broadcast_param(param, axes, sizes): for axis, size in zip(axes, sizes): param = param.expand_dims(axis=axis).broadcast_axes( axis=axis, size=size ) return param
[docs]def mx_switch(F, *args, **kwargs) -> Tensor: """ A switch statement for mxnet. mx_switch((A, x), (B, y), z) corresponds to if A -> x elif B -> y else -> z Parameters ---------- F The function space to use. args Arguments. kwargs Keyword arguments Returns ------- Tensor A tensor with the respective switch entries. """ assert set(kwargs.keys()).issubset({"scope"}) assert len(args) >= 3 else_stmt = args[-1] assert not isinstance( else_stmt, (tuple, list) ), "Last element should be the else clause" rev_when_stmts = args[:-1][::-1] cur_else = else_stmt for cond, then_stmt in rev_when_stmts: cur_else = F.where(cond, then_stmt, cur_else) return cur_else