# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import List, Optional, Type, Union
import inspect
import torch
[docs]def resolve_device(
device: Union[str, torch.device]
) -> Union[str, torch.device]:
"""
Resolves a torch device to the most appropriate one.
The ``"auto"`` device is resolved to ``"cuda"`` if CUDA is available, and
to ``"cpu"`` otherwise. Otherwise the device is unchanged.
"""
if device == "auto":
if torch.cuda.is_available():
return "cuda"
return "cpu"
return device
[docs]def copy_parameters(
net_source: torch.nn.Module,
net_dest: torch.nn.Module,
strict: bool = True,
) -> None:
"""
Copies parameters from one network to another.
Parameters
----------
net_source
Input network.
net_dest
Output network.
strict:
whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
"""
net_dest.load_state_dict(net_source.state_dict(), strict=strict)
[docs]def weighted_average(
x: torch.Tensor, weights: Optional[torch.Tensor] = None, dim=None
) -> torch.Tensor:
"""
Computes the weighted average of a given tensor across a given dim, masking
values associated with weight zero,
meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`.
Parameters
----------
x
Input tensor, of which the average must be computed.
weights
Weights tensor, of the same shape as `x`.
dim
The dim along which to average `x`
Returns
-------
Tensor:
The tensor with values averaged along the specified `dim`.
"""
if weights is not None:
weighted_tensor = torch.where(
weights != 0, x * weights, torch.zeros_like(x)
)
sum_weights = torch.clamp(
weights.sum(dim=dim) if dim else weights.sum(), min=1.0
)
return (
weighted_tensor.sum(dim=dim) if dim else weighted_tensor.sum()
) / sum_weights
else:
return x.mean(dim=dim)
[docs]def lagged_sequence_values(
indices: List[int],
prior_sequence: torch.Tensor,
sequence: torch.Tensor,
dim: int,
) -> torch.Tensor:
"""
Constructs an array of lagged values from a given sequence.
Parameters
----------
indices
Indices of the lagged observations. For example, ``[0]`` indicates
that, at any time ``t``, the will have only the observation from
time ``t`` itself; instead, ``[0, 24]`` indicates that the output
will have observations from times ``t`` and ``t-24``.
prior_sequence
Tensor containing the input sequence prior to the time range for
which the output is required.
sequence
Tensor containing the input sequence in the time range where the
output is required.
dim
Time dimension.
Returns
-------
Tensor
A tensor of shape (*sequence.shape, len(indices)).
"""
assert max(indices) <= prior_sequence.shape[dim], (
f"lags cannot go further than prior sequence length, found lag"
f" {max(indices)} while prior sequence is only"
f" {prior_sequence.shape[dim]}-long"
)
full_sequence = torch.cat((prior_sequence, sequence), dim=dim)
lags_values = []
for lag_index in indices:
begin_index = -lag_index - sequence.shape[dim]
end_index = -lag_index if lag_index > 0 else None
lags_values.append(
slice_along_dim(
full_sequence, dim=dim, slice_=slice(begin_index, end_index)
)
)
return torch.stack(lags_values, dim=-1)
[docs]def repeat_along_dim(a: torch.Tensor, dim: int, repeats: int) -> torch.Tensor:
"""
Repeat a tensor along a given dimension, using ``torch.repeat`` internally.
Parameters
----------
a
Original tensor to repeat.
dim
Dimension to repeat data over.
repeats
How many time to repeat the input tensor.
Returns
-------
torch.Tensor
A tensor with the same size as the input one, except dimension
``dim`` which is multiplied by ``repeats``.
"""
if repeats == 1:
return a
r = [1] * len(a.shape)
r[dim] = repeats
return a.repeat(*r)
[docs]def slice_along_dim(a: torch.Tensor, dim: int, slice_: slice) -> torch.Tensor:
"""
Slice a tensor along a given dimension.
Parameters
----------
a
Original tensor to slice.
dim
Dimension to slice over.
slice_
Slice to take.
Returns
-------
torch.Tensor
A tensor with the same size as the input one, except dimension
``dim`` which has length equal to the slice length.
"""
idx = [slice(None)] * len(a.shape)
idx[dim] = slice_
return a[idx]
[docs]def take_last(a: torch.Tensor, dim: int, num: int) -> torch.Tensor:
"""
Take last elements from a given tensor along a given dimension.
Parameters
----------
a
Original tensor to slice.
dim
Dimension to slice over.
num
Number of trailing elements to retain (non-negative).
Returns
-------
torch.Tensor
A tensor with the same size as the input one, except dimension
``dim`` which has length equal to ``num``.
"""
assert num >= 0
return slice_along_dim(a, dim, slice(a.shape[dim] - num, None))
[docs]def unsqueeze_expand(a: torch.Tensor, dim: int, size: int) -> torch.Tensor:
"""
Unsqueeze a dimension and expand over it in one go.
Parameters
----------
a
Original tensor to unsqueeze.
dim
Dimension to unsqueeze.
size
Size for the new dimension.
Returns
-------
torch.Tensor
A tensor with an added dimension ``dim`` of size ``size``.
"""
a = a.unsqueeze(dim)
sizes = list(a.shape)
sizes[dim] = size
return a.expand(*sizes)