# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import List, Optional, Iterable, Dict, Any
import torch
from torch.utils.data import DataLoader
from gluonts.core.component import validated
from gluonts.dataset.common import Dataset
from gluonts.dataset.field_names import FieldName
from gluonts.itertools import Cyclic, PseudoShuffled, IterableSlice
from gluonts.time_feature import (
TimeFeature,
time_features_from_frequency_str,
)
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
from gluonts.transform import (
Transformation,
Chain,
RemoveFields,
SetField,
AsNumpyArray,
AddObservedValuesIndicator,
AddTimeFeatures,
AddAgeFeature,
VstackFeatures,
InstanceSplitter,
ValidationSplitSampler,
TestSplitSampler,
ExpectedNumInstanceSampler,
SelectFields,
)
from gluonts.torch.util import (
IterableDataset,
)
from gluonts.torch.model.estimator import PyTorchLightningEstimator
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.torch.distributions import (
DistributionOutput,
StudentTOutput,
)
from gluonts.transform.sampler import InstanceSampler
from .module import DeepARModel
from .lightning_module import DeepARLightningModule
PREDICTION_INPUT_NAMES = [
"feat_static_cat",
"feat_static_real",
"past_time_feat",
"past_target",
"past_observed_values",
"future_time_feat",
]
TRAINING_INPUT_NAMES = PREDICTION_INPUT_NAMES + [
"future_target",
"future_observed_values",
]
[docs]class DeepAREstimator(PyTorchLightningEstimator):
"""
Estimator class to train a DeepAR model, as described in [SFG17]_.
This class is uses the model defined in ``DeepARModel``, and wraps it
into a ``DeepARLightningModule`` for training purposes: training is
performed using PyTorch Lightning's ``pl.Trainer`` class.
*Note:* the code of this model is unrelated to the implementation behind
`SageMaker's DeepAR Forecasting Algorithm
<https://docs.aws.amazon.com/sagemaker/latest/dg/deepar.html>`_.
Parameters
----------
freq
Frequency of the data to train on and predict.
prediction_length
Length of the prediction horizon.
context_length
Number of steps to unroll the RNN for before computing predictions
(default: None, in which case context_length = prediction_length).
num_layers
Number of RNN layers (default: 2).
hidden_size
Number of RNN cells for each layer (default: 40).
dropout_rate
Dropout regularization parameter (default: 0.1).
num_feat_dynamic_real
Number of dynamic real features in the data (default: 0).
num_feat_static_real
Number of static real features in the data (default: 0).
num_feat_static_cat
Number of static categorical features in the data (default: 0).
cardinality
Number of values of each categorical feature.
This must be set if ``num_feat_static_cat > 0`` (default: None).
embedding_dimension
Dimension of the embeddings for categorical features
(default: ``[min(50, (cat+1)//2) for cat in cardinality]``).
distr_output
Distribution to use to evaluate observations and sample predictions
(default: StudentTOutput()).
loss
Loss to be optimized during training
(default: ``NegativeLogLikelihood()``).
scaling
Whether to automatically scale the target values (default: true).
lags_seq
Indices of the lagged target values to use as inputs of the RNN
(default: None, in which case these are automatically determined
based on freq).
time_features
List of time features, from :py:mod:`gluonts.time_feature`, to use as
inputs of the RNN in addition to the provided data (default: None,
in which case these are automatically determined based on freq).
num_parallel_samples
Number of samples per time series to that the resulting predictor
should produce (default: 100).
batch_size
The size of the batches to be used for training (default: 32).
num_batches_per_epoch
Number of batches to be processed in each training epoch
(default: 50).
trainer_kwargs
Additional arguments to provide to ``pl.Trainer`` for construction.
train_sampler
Controls the sampling of windows during training.
validation_sampler
Controls the sampling of windows during validation.
"""
@validated()
def __init__(
self,
freq: str,
prediction_length: int,
context_length: Optional[int] = None,
num_layers: int = 2,
hidden_size: int = 40,
dropout_rate: float = 0.1,
num_feat_dynamic_real: int = 0,
num_feat_static_cat: int = 0,
num_feat_static_real: int = 0,
cardinality: Optional[List[int]] = None,
embedding_dimension: Optional[List[int]] = None,
distr_output: DistributionOutput = StudentTOutput(),
loss: DistributionLoss = NegativeLogLikelihood(),
scaling: bool = True,
lags_seq: Optional[List[int]] = None,
time_features: Optional[List[TimeFeature]] = None,
num_parallel_samples: int = 100,
batch_size: int = 32,
num_batches_per_epoch: int = 50,
trainer_kwargs: Optional[Dict[str, Any]] = None,
train_sampler: Optional[InstanceSampler] = None,
validation_sampler: Optional[InstanceSampler] = None,
) -> None:
default_trainer_kwargs = {
"max_epochs": 100,
"gradient_clip_val": 10.0,
}
if trainer_kwargs is not None:
default_trainer_kwargs.update(trainer_kwargs)
super().__init__(trainer_kwargs=default_trainer_kwargs)
self.freq = freq
self.context_length = (
context_length if context_length is not None else prediction_length
)
self.prediction_length = prediction_length
self.distr_output = distr_output
self.loss = loss
self.num_layers = num_layers
self.hidden_size = hidden_size
self.dropout_rate = dropout_rate
self.num_feat_dynamic_real = num_feat_dynamic_real
self.num_feat_static_cat = num_feat_static_cat
self.num_feat_static_real = num_feat_static_real
self.cardinality = (
cardinality if cardinality and num_feat_static_cat > 0 else [1]
)
self.embedding_dimension = embedding_dimension
self.scaling = scaling
self.lags_seq = lags_seq
self.time_features = (
time_features
if time_features is not None
else time_features_from_frequency_str(self.freq)
)
self.num_parallel_samples = num_parallel_samples
self.batch_size = batch_size
self.num_batches_per_epoch = num_batches_per_epoch
self.train_sampler = train_sampler or ExpectedNumInstanceSampler(
num_instances=1.0, min_future=prediction_length
)
self.validation_sampler = validation_sampler or ValidationSplitSampler(
min_future=prediction_length
)
def _create_instance_splitter(
self, module: DeepARLightningModule, mode: str
):
assert mode in ["training", "validation", "test"]
instance_sampler = {
"training": self.train_sampler,
"validation": self.validation_sampler,
"test": TestSplitSampler(),
}[mode]
return InstanceSplitter(
target_field=FieldName.TARGET,
is_pad_field=FieldName.IS_PAD,
start_field=FieldName.START,
forecast_start_field=FieldName.FORECAST_START,
instance_sampler=instance_sampler,
past_length=module.model._past_length,
future_length=self.prediction_length,
time_series_fields=[
FieldName.FEAT_TIME,
FieldName.OBSERVED_VALUES,
],
dummy_value=self.distr_output.value_in_support,
)
[docs] def create_training_data_loader(
self,
data: Dataset,
module: DeepARLightningModule,
shuffle_buffer_length: Optional[int] = None,
**kwargs,
) -> Iterable:
transformation = self._create_instance_splitter(
module, "training"
) + SelectFields(TRAINING_INPUT_NAMES)
training_instances = transformation.apply(
Cyclic(data)
if shuffle_buffer_length is None
else PseudoShuffled(
Cyclic(data), shuffle_buffer_length=shuffle_buffer_length
)
)
return IterableSlice(
iter(
DataLoader(
IterableDataset(training_instances),
batch_size=self.batch_size,
**kwargs,
)
),
self.num_batches_per_epoch,
)
[docs] def create_validation_data_loader(
self,
data: Dataset,
module: DeepARLightningModule,
**kwargs,
) -> Iterable:
transformation = self._create_instance_splitter(
module, "validation"
) + SelectFields(TRAINING_INPUT_NAMES)
validation_instances = transformation.apply(data)
return DataLoader(
IterableDataset(validation_instances),
batch_size=self.batch_size,
**kwargs,
)
[docs] def create_lightning_module(self) -> DeepARLightningModule:
model = DeepARModel(
freq=self.freq,
context_length=self.context_length,
prediction_length=self.prediction_length,
num_feat_dynamic_real=(
1 + self.num_feat_dynamic_real + len(self.time_features)
),
num_feat_static_real=max(1, self.num_feat_static_real),
num_feat_static_cat=max(1, self.num_feat_static_cat),
cardinality=self.cardinality,
embedding_dimension=self.embedding_dimension,
num_layers=self.num_layers,
hidden_size=self.hidden_size,
distr_output=self.distr_output,
dropout_rate=self.dropout_rate,
lags_seq=self.lags_seq,
scaling=self.scaling,
num_parallel_samples=self.num_parallel_samples,
)
return DeepARLightningModule(model=model, loss=self.loss)
[docs] def create_predictor(
self,
transformation: Transformation,
module: DeepARLightningModule,
) -> PyTorchPredictor:
prediction_splitter = self._create_instance_splitter(module, "test")
return PyTorchPredictor(
input_transform=transformation + prediction_splitter,
input_names=PREDICTION_INPUT_NAMES,
prediction_net=module.model,
batch_size=self.batch_size,
prediction_length=self.prediction_length,
device=torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
),
)