# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from functools import partial
from typing import Optional, Callable, List, Union
import torch
from torch import nn
from torch.distributions import (
Categorical,
MixtureSameFamily,
Normal,
)
from gluonts.core.component import validated
from gluonts.torch.distributions import DiscreteDistribution
from .scaling import (
min_max_scaling,
standard_normal_scaling,
)
INPUT_SCALING_MAP = {
"min_max_scaling": partial(min_max_scaling, dim=1, keepdim=True),
"standard_normal_scaling": partial(
standard_normal_scaling, dim=1, keepdim=True
),
}
def init_weights(module: nn.Module, scale: float = 1.0):
if type(module) == nn.Linear:
nn.init.uniform_(module.weight, -scale, scale)
nn.init.zeros_(module.bias)
class FeatureEmbedder(nn.Module):
"""
Creates a feature embedding for the static categorical features.
"""
@validated()
def __init__(
self,
cardinalities: List[int],
embedding_dimensions: List[int],
):
super().__init__()
assert (
len(cardinalities) > 0
), "Length of `cardinalities` list must be greater than zero"
assert len(cardinalities) == len(
embedding_dimensions
), "Length of `embedding_dims` and `embedding_dims` should match"
assert all(
[c > 0 for c in cardinalities]
), "Elements of `cardinalities` should be > 0"
assert all(
[d > 0 for d in embedding_dimensions]
), "Elements of `embedding_dims` should be > 0"
self.embedders = [
torch.nn.Embedding(num_embeddings=card, embedding_dim=dim)
for card, dim in zip(cardinalities, embedding_dimensions)
]
for embedder in self.embedders:
embedder.apply(init_weights)
def forward(self, features: torch.Tensor):
"""
Parameters
----------
features
Input features to the model, shape: (-1, num_features).
Returns
-------
torch.Tensor
Embedding, shape: (-1, sum(self.embedding_dimensions)).
"""
embedded_features = torch.cat(
[
embedder(features[:, i].long())
for i, embedder in enumerate(self.embedders)
],
dim=-1,
)
return embedded_features
[docs]class DeepNPTSNetwork(nn.Module):
"""
Base class implementing a simple feed-forward neural network that takes in
static and dynamic features and produces `num_hidden_nodes` independent
outputs. These outputs are then used by derived classes to construct the
forecast distribution for a single time step.
Note that the dynamic features are just treated as independent features
without considering their temporal nature.
"""
@validated()
def __init__(
self,
context_length: int,
num_hidden_nodes: List[int],
cardinality: List[int],
embedding_dimension: List[int],
num_time_features: int,
batch_norm: bool = False,
input_scaling: Optional[Union[Callable, str]] = None,
dropout_rate: float = 0.0,
):
super().__init__()
self.context_length = context_length
self.num_hidden_nodes = num_hidden_nodes
self.batch_norm = batch_norm
self.input_scaling = (
INPUT_SCALING_MAP[input_scaling]
if isinstance(input_scaling, str)
else input_scaling
)
self.dropout_rate = dropout_rate
# Embedding for categorical features
self.embedder = FeatureEmbedder(
cardinalities=cardinality, embedding_dimensions=embedding_dimension
)
total_embedding_dim = sum(embedding_dimension)
# We have two target related features: past_target and observed value
# indicator each of length `context_length`.
# Also, +1 for the static real feature.
dimensions = [
context_length * (num_time_features + 2) + total_embedding_dim + 1
] + num_hidden_nodes
modules: List[nn.Module] = []
for in_features, out_features in zip(dimensions[:-1], dimensions[1:]):
modules += [nn.Linear(in_features, out_features), nn.ReLU()]
if self.batch_norm:
modules.append(nn.BatchNorm1d(out_features))
if self.dropout_rate > 0:
modules.append(nn.Dropout(self.dropout_rate))
self.model = nn.Sequential(*modules)
self.model.apply(partial(init_weights, scale=0.07))
# TODO: Handle missing values using the observed value indicator.
[docs] def forward(
self,
feat_static_cat: torch.Tensor,
feat_static_real: torch.Tensor,
past_target: torch.Tensor,
past_observed_values: torch.Tensor,
past_time_feat: torch.Tensor,
):
"""
Parameters
----------
feat_static_cat
Shape (-1, num_features).
feat_static_real
Shape (-1, num_features).
past_target
Shape (-1, context_length).
past_observed_values
Shape (-1, context_length).
past_time_feat
Shape (-1, context_length, self.num_time_features).
"""
x = past_target
if self.input_scaling:
loc, scale = self.input_scaling(x)
x_scaled = (x - loc) / scale
else:
x_scaled = x
embedded_cat = self.embedder(feat_static_cat)
static_feat = torch.cat(
(embedded_cat, torch.tensor(feat_static_real)),
dim=1,
)
time_features = torch.cat(
[
x_scaled.unsqueeze(dim=-1),
past_observed_values.unsqueeze(dim=-1),
past_time_feat,
],
dim=-1,
)
features = torch.cat(
[
time_features.reshape(time_features.shape[0], -1),
static_feat,
],
dim=-1,
)
return self.model(features)
[docs]class DeepNPTSNetworkDiscrete(DeepNPTSNetwork):
"""
Extends `DeepNTPSNetwork` by implementing the output layer which converts
the outputs from the base network into probabilities of length
`context_length`. These probabilities together with the past values in the
context window constitute the one-step-ahead forecast distribution.
Specifically, the forecast is always one of the values observed in the
context window with the corresponding predicted probability.
Parameters
----------
*args
Arguments to ``DeepNPTSNetwork``.
use_softmax
Flag indicating whether to use softmax or normalization for
converting the outputs of the base network to probabilities.
kwargs
Keyword arguments to ``DeepNPTSNetwork``.
"""
@validated()
def __init__(self, *args, use_softmax: bool = False, **kwargs):
super().__init__(*args, **kwargs)
self.use_softmax = use_softmax
modules: List[nn.Module] = (
[] if self.dropout_rate > 0 else [nn.Dropout(self.dropout_rate)]
)
modules.append(
nn.Linear(self.num_hidden_nodes[-1], self.context_length)
)
self.output_layer = nn.Sequential(*modules)
self.output_layer.apply(init_weights)
[docs] def forward(
self,
feat_static_cat: torch.Tensor,
feat_static_real: torch.Tensor,
past_target: torch.Tensor,
past_observed_values: torch.Tensor,
past_time_feat: torch.Tensor,
) -> DiscreteDistribution:
h = super().forward(
feat_static_cat=feat_static_cat,
feat_static_real=feat_static_real,
past_target=past_target,
past_observed_values=past_observed_values,
past_time_feat=past_time_feat,
)
outputs = self.output_layer(h)
probs = (
nn.functional.softmax(outputs, dim=1)
if self.use_softmax
else nn.functional.normalize(
nn.functional.softplus(outputs), p=1, dim=1
)
)
return DiscreteDistribution(values=past_target, probs=probs)
[docs]class DeepNPTSNetworkSmooth(DeepNPTSNetwork):
"""
Extends `DeepNTPSNetwork` by implementing the output layer which converts
the outputs from the base network into a smoothed mixture distribution. The
components of the mixture are Gaussians centered around the observations in
the context window. The mixing probabilities as well as the width of the
Gaussians are predicted by the network.
This mixture distribution represents the one-step-ahead forecast
distribution. Note that the forecast can contain values not observed in the
context window.
"""
@validated()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
modules = (
[] if self.dropout_rate > 0 else [nn.Dropout(self.dropout_rate)]
)
modules += [
nn.Linear(self.num_hidden_nodes[-1], self.context_length + 1),
nn.Softplus(),
]
self.output_layer = nn.Sequential(*modules)
self.output_layer.apply(init_weights)
[docs] def forward(
self,
feat_static_cat: torch.Tensor,
feat_static_real: torch.Tensor,
past_target: torch.Tensor,
past_observed_values: torch.Tensor,
past_time_feat: torch.Tensor,
) -> MixtureSameFamily:
h = super().forward(
feat_static_cat=feat_static_cat,
feat_static_real=feat_static_real,
past_target=past_target,
past_observed_values=past_observed_values,
past_time_feat=past_time_feat,
)
outputs = self.output_layer(h)
probs = outputs[:, :-1]
kernel_width = outputs[:, -1:]
mix = Categorical(probs)
components = Normal(loc=past_target, scale=kernel_width)
return MixtureSameFamily(
mixture_distribution=mix, component_distribution=components
)
[docs]class DeepNPTSMultiStepNetwork(nn.Module):
"""
Implements multi-step prediction given a trained `DeepNPTSNetwork` model
that outputs one-step-ahead forecast distribution.
"""
@validated()
def __init__(
self,
net: DeepNPTSNetwork,
prediction_length: int,
num_parallel_samples: int = 100,
):
super().__init__()
self.net = net
self.prediction_length = prediction_length
self.num_parallel_samples = num_parallel_samples
[docs] def forward(
self,
feat_static_cat: torch.Tensor,
feat_static_real: torch.Tensor,
past_target: torch.Tensor,
past_observed_values: torch.Tensor,
past_time_feat: torch.Tensor,
future_time_feat: torch.Tensor,
):
"""
Generates samples from the forecast distribution.
Parameters
----------
feat_static_cat
Shape (-1, num_features).
feat_static_real
Shape (-1, num_features).
past_target
Shape (-1, context_length).
past_observed_values
Shape (-1, context_length).
past_time_feat
Shape (-1, context_length, self.num_time_features).
future_time_feat
Shape (-1, prediction_length, self.num_time_features).
Returns
-------
torch.Tensor
Tensor containing samples from the predicted distribution.
Shape is (-1, self.num_parallel_samples, self.prediction_length).
"""
# Blow up the initial `x` by the number of parallel samples required.
# (batch_size * num_parallel_samples, context_length)
past_target = past_target.repeat_interleave(
self.num_parallel_samples, dim=0
)
# Note that gluonts returns empty future_observed_values.
future_observed_values = torch.ones(
(past_observed_values.shape[0], self.prediction_length)
)
observed_values = torch.cat(
[past_observed_values, future_observed_values], dim=1
)
observed_values = observed_values.repeat_interleave(
self.num_parallel_samples, dim=0
)
time_feat = torch.cat([past_time_feat, future_time_feat], dim=1)
time_feat = time_feat.repeat_interleave(
self.num_parallel_samples, dim=0
)
feat_static_cat = feat_static_cat.repeat_interleave(
self.num_parallel_samples, dim=0
)
feat_static_real = feat_static_real.repeat_interleave(
self.num_parallel_samples, dim=0
)
future_samples = []
for t in range(self.prediction_length):
distr = self.net(
feat_static_cat=feat_static_cat,
feat_static_real=feat_static_real,
past_target=past_target,
past_observed_values=observed_values[
:, t : -self.prediction_length + t
],
past_time_feat=time_feat[
:, t : -self.prediction_length + t, :
],
)
samples = distr.sample()
if past_target.dim() != samples.dim():
samples = samples.unsqueeze(dim=-1)
future_samples.append(samples)
past_target = torch.cat([past_target[:, 1:], samples], dim=1)
# (batch_size * num_parallel_samples, prediction_length)
samples_out = torch.stack(future_samples, dim=1)
# (batch_size, num_parallel_samples, prediction_length)
return samples_out.reshape(
-1, self.num_parallel_samples, self.prediction_length
)