# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
import inspect
import tempfile
from pathlib import Path
from typing import List, Optional, Tuple, cast, Type
import mxnet as mx
from mxnet.gluon.block import _flatten
from gluonts.core.serde import dump_json, load_json
from gluonts.mx import Tensor
[docs]class HybridContext:
"""
A context manager that ensures that an MXNet network is operating in a
hybridized / not hybridized mode.
Parameters
----------
net
The network whose hybrid mode has to be modified within the enclosing
context.
hybridize
A boolean flag inidicating whether the hybrid mode should be set or
not.
kwargs
A dictionary of optional arguments to pass to the `hybridize()` call
of the enclosed `HybridBlock` network.
"""
def __init__(
self,
net: mx.gluon.HybridBlock,
hybridize: bool,
data_batch: Optional[List[mx.nd.NDArray]] = None,
**kwargs,
) -> None:
self.net = net
self.required_mode = hybridize
self.original_mode = getattr(net, "_active", False)
self.data_batch = data_batch
self.kwargs = kwargs
def __enter__(self):
self.net.hybridize(active=self.required_mode, **self.kwargs)
if self.data_batch is not None:
self.net(*self.data_batch)
def __exit__(self, *args):
self.net.hybridize(active=self.original_mode, **self.kwargs)
[docs]def assert_shape(x: Tensor, expected_shape: Tuple[int, ...]):
"""
Assert expected shape if mode is mx.nd.
Parameters
----------
x
Input Tensor
expected_shape
Expected shape
Returns
-------
"""
if isinstance(x, mx.nd.NDArray):
for i, j in zip(x.shape, expected_shape):
if j != -1:
assert (
i == j
), f"shape mismatch got {x.shape} expected {expected_shape}"
[docs]def copy_parameters(
net_source: mx.gluon.Block,
net_dest: mx.gluon.Block,
ignore_extra: bool = False,
allow_missing: bool = False,
) -> None:
"""
Copies parameters from one network to another.
Parameters
----------
net_source
Input network.
net_dest
Output network.
ignore_extra
Whether to ignore parameters from the source that are not
present in the target.
allow_missing
Whether to allow additional parameters in the target not
present in the source.
"""
with tempfile.TemporaryDirectory(
prefix="gluonts-estimator-temp-"
) as model_dir:
model_dir_path = str(Path(model_dir) / "tmp_model")
net_source.save_parameters(model_dir_path)
net_dest.load_parameters(
model_dir_path,
ctx=mx.current_context(),
allow_missing=allow_missing,
ignore_extra=ignore_extra,
)
# noinspection PyProtectedMember
[docs]def hybrid_block_to_symbol_block(
hb: mx.gluon.HybridBlock, data_batch: List[mx.nd.NDArray]
) -> mx.gluon.SymbolBlock:
"""
Converts a Gluon `HybridBlock` to a `SymbolBlock`. Following the Gluon API,
this is achieved by a `hybridize()` call on the passed `HybridBlock`, a
single forward pass (using the provided data batch), and a combination of
an `export()` and an `import()` calls of the input block.
Note that MXNet has `problems with this method
<https://github.com/apache/incubator-mxnet/issues/12783>`_.
Parameters
----------
hb
The Gluon `HybridBlock` to convert.
data_batch
Data to use for the forward pass after the `hybridize()` call.
Returns
-------
mx.gluon.SymbolBlock
The resulting Gluon block backed by an MXNet symbol graph.
"""
with tempfile.TemporaryDirectory(
prefix="gluonts-estimator-temp-"
) as model_dir:
# when importing, SymbolBlock has to know about the total number
# of input symbols, including nested Tensors
flat_data_batch, _ = _flatten(data_batch, "input")
num_inputs = len(flat_data_batch)
model_dir_path = Path(model_dir)
model_name = "gluonts-model"
with HybridContext(
net=hb,
hybridize=True,
data_batch=data_batch,
static_alloc=True,
static_shape=True,
):
export_symb_block(hb, model_dir_path, model_name)
sb = import_symb_block(num_inputs, model_dir_path, model_name)
sb(*data_batch)
return sb
# noinspection PyProtectedMember
[docs]def export_symb_block(
hb: mx.gluon.HybridBlock, model_dir: Path, model_name: str, epoch: int = 0
) -> None:
"""
Serializes a hybridized Gluon `HybridBlock`.
Parameters
----------
hb
The block to export.
model_dir
The path where the model will be saved.
model_name
The name identifying the model.
epoch
The epoch number, which together with the `model_name` identifies the
model parameters.
"""
hb.export(path=str(model_dir / model_name), epoch=epoch)
# FIXME: we persist input/output formats of hybrid blocks as mxnet does not
# FIXME: https://github.com/apache/incubator-mxnet/issues/17488
with (model_dir / f"{model_name}-in_out_format.json").open("w") as fp:
in_out_format = dict(
in_format=hb._in_format, out_format=hb._out_format
)
print(dump_json(in_out_format), file=fp)
[docs]def import_symb_block(
num_inputs: int, model_dir: Path, model_name: str, epoch: int = 0
) -> mx.gluon.SymbolBlock:
"""
Deserializes a hybridized Gluon `HybridBlock` as a `SymbolBlock`.
Parameters
----------
num_inputs
The number of inputs of the serialized block.
model_dir
The path where the model is saved.
model_name
The name identifying the model.
epoch
The epoch number, which together with the `model_name` identifies the
model parameters.
Returns
-------
mx.gluon.SymbolBlock
The deserialized block.
"""
if num_inputs == 1:
input_names = ["data"]
else:
input_names = [f"data{i}" for i in range(num_inputs)]
# FIXME: prevents mxnet from failing with empty saved parameters list
# FIXME: https://github.com/apache/incubator-mxnet/issues/17488
param_file: Optional[str] = str(
model_dir / f"{model_name}-{epoch:04}.params"
)
if not mx.nd.load(param_file):
param_file = None
# FIXME: mx.gluon.SymbolBlock cannot infer float_type and uses default
# np.float32
# FIXME: https://github.com/apache/incubator-mxnet/issues/11849
sb = mx.gluon.SymbolBlock.imports(
symbol_file=str(model_dir / f"{model_name}-symbol.json"),
input_names=input_names,
param_file=param_file,
ctx=mx.current_context(),
)
# FIXME: try to retrieve input/output format
# FIXME: https://github.com/apache/incubator-mxnet/issues/17488
format_json_path = model_dir / f"{model_name}-in_out_format.json"
if format_json_path.exists():
with format_json_path.open("r") as fp:
formats = load_json(fp.read())
sb._in_format = formats["in_format"]
sb._out_format = formats["out_format"]
return sb
[docs]def export_repr_block(
rb: mx.gluon.HybridBlock, model_dir: Path, model_name: str, epoch: int = 0
) -> None:
"""
Serializes a representable Gluon block.
Parameters
----------
rb
The block to export.
model_dir
The path where the model will be saved.
model_name
The name identifying the model.
epoch
The epoch number, which together with the `model_name` identifies the
model parameters.
"""
with (model_dir / f"{model_name}-network.json").open("w") as fp:
print(dump_json(rb), file=fp)
rb.save_parameters(str(model_dir / f"{model_name}-{epoch:04}.params"))
[docs]def import_repr_block(
model_dir: Path, model_name: str, epoch: int = 0
) -> mx.gluon.HybridBlock:
"""
Deserializes a representable Gluon block.
Parameters
----------
model_dir
The path where the model is saved.
model_name
The name identifying the model.
epoch
The epoch number, which together with the `model_name` identifies the
model parameters.
Returns
-------
mx.gluon.HybridBlock:
The deserialized block.
"""
with (model_dir / f"{model_name}-network.json").open("r") as fp:
rb = cast(mx.gluon.HybridBlock, load_json(fp.read()))
rb.load_parameters(
str(model_dir / f"{model_name}-{epoch:04}.params"),
ctx=mx.current_context(),
allow_missing=False,
ignore_extra=False,
)
return rb
[docs]def cumsum(
F, x: Tensor, exclusive: bool = False, reverse: bool = False
) -> Tensor:
r"""
Find cumulative sum on the last axis by multiplying with lower triangular
ones-matrix:
.. math::
\operatorname{cumsum}(x) =
\begin{cases}
\operatorname{ltr\_ones} \times x
& \text{for cumulative sum}\\
x \times \operatorname{ltr\_ones}
& \text{for cumulative sum in the reverse order}
\end{cases}
Also supports `exclusive` flag to start the cumsum with zero.
For example, if :math:`x = [a, b, c]`, we have
.. math::
\operatorname{cumsum}(x) =
\begin{cases}
[a, a + b, a + b + c]
& \text{if }\mathit{reverse = False, exclusive = False}\\
[0, a, a + b]
& \text{if }\mathit{reverse = False, exclusive = True}\\
[a + b + c, b + c, c]
& \text{if }\mathit{reverse = True, exclusive = False}\\
[b + c, c, 0]
& \text{if }\mathit{reverse = True, exclusive = True}\\
\end{cases}
Parameters
----------
F
The function space to use.
x
A tensor with shape :math:`(..., n)`.
exclusive
If `True`, the cumulative sum starts with zero.
reverse
If `True`, the cumulative sum is performed in the opposite direction.
Returns
-------
Tensor:
A modified tensor with identical shape and cumulative sums in the last
axis.
"""
# Create a new axis (for matrix multiplication) either at last location or
# last-but-one location (for reverse mode)
exp_dim = -2 if reverse else -1
# (..., 1, n) if reverse is True and (..., n, 1) otherwise
x = x.expand_dims(axis=exp_dim)
# Ones_matrix (..., n, n)
ones_matrix = F.linalg_gemm2(
F.ones_like(x),
F.ones_like(x),
transpose_a=reverse,
transpose_b=not reverse,
)
cumulative_sum = F.linalg_trmm(ones_matrix, x, rightside=reverse)
if exclusive:
cumulative_sum = cumulative_sum - x
return cumulative_sum.squeeze(axis=exp_dim)
[docs]def weighted_average(
F,
x: Tensor,
weights: Optional[Tensor] = None,
axis: Optional[int] = None,
include_zeros_in_denominator=False,
) -> Tensor:
"""
Computes the weighted average of a given tensor across a given axis,
masking values associated with weight zero,
meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`.
Parameters
----------
F
The function space to use.
x
Input tensor, of which the average must be computed.
weights
Weights tensor, of the same shape as `x`.
axis
The axis along which to average `x`
include_zeros_in_denominator
Include zeros in the denominator. Can be useful for sparse time series
because the loss can be dominated by few observed examples.
Returns
-------
Tensor:
The tensor with values averaged along the specified `axis`.
"""
if weights is not None:
weighted_tensor = F.where(
condition=weights, x=x * weights, y=F.zeros_like(x)
)
if include_zeros_in_denominator:
sum_weights = F.maximum(1.0, F.ones_like(weights).sum(axis=axis))
else:
sum_weights = F.maximum(1.0, weights.sum(axis=axis))
return weighted_tensor.sum(axis=axis) / sum_weights
else:
return x.mean(axis=axis)
[docs]def make_nd_diag(F, x: Tensor, d: int) -> Tensor:
"""
Make a diagonal tensor, given the diagonal.
Parameters
----------
F
The function space to use.
x
Diagonal to use, shape :math:`(..., d)`.
d
Last dimension of `x`.
Returns
-------
Tensor
A tensor y of shape :math:`(..., d, d)` such that
:math:`y[..., i, i] = x[..., i]`.
"""
return F.broadcast_mul(F.eye(d), x.expand_dims(axis=-1))
def _broadcast_param(param, axes, sizes):
for axis, size in zip(axes, sizes):
param = param.expand_dims(axis=axis).broadcast_axes(
axis=axis, size=size
)
return param
[docs]def mx_switch(F, *args, **kwargs) -> Tensor:
"""
A switch statement for mxnet.
mx_switch((A, x), (B, y), z)
corresponds to
if A -> x
elif B -> y
else -> z
Parameters
----------
F
The function space to use.
args
Arguments.
kwargs
Keyword arguments
Returns
-------
Tensor
A tensor with the respective switch entries.
"""
assert set(kwargs.keys()).issubset({"scope"})
assert len(args) >= 3
else_stmt = args[-1]
assert not isinstance(
else_stmt, (tuple, list)
), "Last element should be the else clause"
rev_when_stmts = args[:-1][::-1]
cur_else = else_stmt
for cond, then_stmt in rev_when_stmts:
cur_else = F.where(cond, then_stmt, cur_else)
return cur_else