Source code for gluonts.torch.model.deepar.estimator

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import List, Optional, Iterable, Dict, Any

import torch

from gluonts.core.component import validated
from gluonts.dataset.common import Dataset
from gluonts.dataset.field_names import FieldName
from gluonts.dataset.loader import as_stacked_batches
from gluonts.itertools import Cyclic
from gluonts.dataset.stat import calculate_dataset_statistics
from gluonts.time_feature import (
    TimeFeature,
    time_features_from_frequency_str,
)
from gluonts.transform import (
    Transformation,
    Chain,
    RemoveFields,
    SetField,
    AsNumpyArray,
    AddObservedValuesIndicator,
    AddTimeFeatures,
    AddAgeFeature,
    VstackFeatures,
    InstanceSplitter,
    ValidationSplitSampler,
    TestSplitSampler,
    ExpectedNumInstanceSampler,
    MissingValueImputation,
    DummyValueImputation,
)
from gluonts.torch.model.estimator import PyTorchLightningEstimator
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
from gluonts.transform.sampler import InstanceSampler

from .lightning_module import DeepARLightningModule

PREDICTION_INPUT_NAMES = [
    "feat_static_cat",
    "feat_static_real",
    "past_time_feat",
    "past_target",
    "past_observed_values",
    "future_time_feat",
]

TRAINING_INPUT_NAMES = PREDICTION_INPUT_NAMES + [
    "future_target",
    "future_observed_values",
]


[docs]class DeepAREstimator(PyTorchLightningEstimator): """ Estimator class to train a DeepAR model, as described in [SFG17]_. This class is uses the model defined in ``DeepARModel``, and wraps it into a ``DeepARLightningModule`` for training purposes: training is performed using PyTorch Lightning's ``pl.Trainer`` class. *Note:* the code of this model is unrelated to the implementation behind `SageMaker's DeepAR Forecasting Algorithm <https://docs.aws.amazon.com/sagemaker/latest/dg/deepar.html>`_. Parameters ---------- freq Frequency of the data to train on and predict. prediction_length Length of the prediction horizon. context_length Number of steps to unroll the RNN for before computing predictions (default: None, in which case context_length = prediction_length). num_layers Number of RNN layers (default: 2). hidden_size Number of RNN cells for each layer (default: 40). lr Learning rate (default: ``1e-3``). weight_decay Weight decay regularization parameter (default: ``1e-8``). dropout_rate Dropout regularization parameter (default: 0.1). patience Patience parameter for learning rate scheduler. num_feat_dynamic_real Number of dynamic real features in the data (default: 0). num_feat_static_real Number of static real features in the data (default: 0). num_feat_static_cat Number of static categorical features in the data (default: 0). cardinality Number of values of each categorical feature. This must be set if ``num_feat_static_cat > 0`` (default: None). embedding_dimension Dimension of the embeddings for categorical features (default: ``[min(50, (cat+1)//2) for cat in cardinality]``). distr_output Distribution to use to evaluate observations and sample predictions (default: StudentTOutput()). scaling Whether to automatically scale the target values (default: true). default_scale Default scale that is applied if the context length window is completely unobserved. If not set, the scale in this case will be the mean scale in the batch. lags_seq Indices of the lagged target values to use as inputs of the RNN (default: None, in which case these are automatically determined based on freq). time_features List of time features, from :py:mod:`gluonts.time_feature`, to use as inputs of the RNN in addition to the provided data (default: None, in which case these are automatically determined based on freq). num_parallel_samples Number of samples per time series to that the resulting predictor should produce (default: 100). batch_size The size of the batches to be used for training (default: 32). num_batches_per_epoch Number of batches to be processed in each training epoch (default: 50). trainer_kwargs Additional arguments to provide to ``pl.Trainer`` for construction. train_sampler Controls the sampling of windows during training. validation_sampler Controls the sampling of windows during validation. nonnegative_pred_samples Should final prediction samples be non-negative? If yes, an activation function is applied to ensure non-negative. Observe that this is applied only to the final samples and this is not applied during training. """ @validated() def __init__( self, freq: str, prediction_length: int, context_length: Optional[int] = None, num_layers: int = 2, hidden_size: int = 40, lr: float = 1e-3, weight_decay: float = 1e-8, dropout_rate: float = 0.1, patience: int = 10, num_feat_dynamic_real: int = 0, num_feat_static_cat: int = 0, num_feat_static_real: int = 0, cardinality: Optional[List[int]] = None, embedding_dimension: Optional[List[int]] = None, distr_output: DistributionOutput = StudentTOutput(), scaling: bool = True, default_scale: Optional[float] = None, lags_seq: Optional[List[int]] = None, time_features: Optional[List[TimeFeature]] = None, num_parallel_samples: int = 100, batch_size: int = 32, num_batches_per_epoch: int = 50, imputation_method: Optional[MissingValueImputation] = None, trainer_kwargs: Optional[Dict[str, Any]] = None, train_sampler: Optional[InstanceSampler] = None, validation_sampler: Optional[InstanceSampler] = None, nonnegative_pred_samples: bool = False, ) -> None: default_trainer_kwargs = { "max_epochs": 100, "gradient_clip_val": 10.0, } if trainer_kwargs is not None: default_trainer_kwargs.update(trainer_kwargs) super().__init__(trainer_kwargs=default_trainer_kwargs) self.freq = freq self.context_length = ( context_length if context_length is not None else prediction_length ) self.prediction_length = prediction_length self.patience = patience self.distr_output = distr_output self.num_layers = num_layers self.hidden_size = hidden_size self.lr = lr self.weight_decay = weight_decay self.dropout_rate = dropout_rate self.num_feat_dynamic_real = num_feat_dynamic_real self.num_feat_static_cat = num_feat_static_cat self.num_feat_static_real = num_feat_static_real self.cardinality = ( cardinality if cardinality and num_feat_static_cat > 0 else [1] ) self.embedding_dimension = embedding_dimension self.scaling = scaling self.default_scale = default_scale self.lags_seq = lags_seq self.time_features = ( time_features if time_features is not None else time_features_from_frequency_str(self.freq) ) self.num_parallel_samples = num_parallel_samples self.batch_size = batch_size self.num_batches_per_epoch = num_batches_per_epoch self.imputation_method = ( imputation_method if imputation_method is not None else DummyValueImputation(self.distr_output.value_in_support) ) self.train_sampler = train_sampler or ExpectedNumInstanceSampler( num_instances=1.0, min_future=prediction_length ) self.validation_sampler = validation_sampler or ValidationSplitSampler( min_future=prediction_length ) self.nonnegative_pred_samples = nonnegative_pred_samples
[docs] @classmethod def derive_auto_fields(cls, train_iter): stats = calculate_dataset_statistics(train_iter) return { "num_feat_dynamic_real": stats.num_feat_dynamic_real, "num_feat_static_cat": len(stats.feat_static_cat), "cardinality": [len(cats) for cats in stats.feat_static_cat], }
[docs] def create_transformation(self) -> Transformation: remove_field_names = [] if self.num_feat_static_real == 0: remove_field_names.append(FieldName.FEAT_STATIC_REAL) if self.num_feat_dynamic_real == 0: remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL) return Chain( [RemoveFields(field_names=remove_field_names)] + ( [SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0])] if not self.num_feat_static_cat > 0 else [] ) + ( [ SetField( output_field=FieldName.FEAT_STATIC_REAL, value=[0.0] ) ] if not self.num_feat_static_real > 0 else [] ) + [ AsNumpyArray( field=FieldName.FEAT_STATIC_CAT, expected_ndim=1, dtype=int, ), AsNumpyArray( field=FieldName.FEAT_STATIC_REAL, expected_ndim=1, ), AsNumpyArray( field=FieldName.TARGET, # in the following line, we add 1 for the time dimension expected_ndim=1 + len(self.distr_output.event_shape), ), AddObservedValuesIndicator( target_field=FieldName.TARGET, output_field=FieldName.OBSERVED_VALUES, imputation_method=self.imputation_method, ), AddTimeFeatures( start_field=FieldName.START, target_field=FieldName.TARGET, output_field=FieldName.FEAT_TIME, time_features=self.time_features, pred_length=self.prediction_length, ), AddAgeFeature( target_field=FieldName.TARGET, output_field=FieldName.FEAT_AGE, pred_length=self.prediction_length, log_scale=True, ), VstackFeatures( output_field=FieldName.FEAT_TIME, input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE] + ( [FieldName.FEAT_DYNAMIC_REAL] if self.num_feat_dynamic_real > 0 else [] ), ), AsNumpyArray(FieldName.FEAT_TIME, expected_ndim=2), ] )
def _create_instance_splitter( self, module: DeepARLightningModule, mode: str ): assert mode in ["training", "validation", "test"] instance_sampler = { "training": self.train_sampler, "validation": self.validation_sampler, "test": TestSplitSampler(), }[mode] return InstanceSplitter( target_field=FieldName.TARGET, is_pad_field=FieldName.IS_PAD, start_field=FieldName.START, forecast_start_field=FieldName.FORECAST_START, instance_sampler=instance_sampler, past_length=module.model._past_length, future_length=self.prediction_length, time_series_fields=[ FieldName.FEAT_TIME, FieldName.OBSERVED_VALUES, ], dummy_value=self.distr_output.value_in_support, )
[docs] def create_training_data_loader( self, data: Dataset, module: DeepARLightningModule, shuffle_buffer_length: Optional[int] = None, **kwargs, ) -> Iterable: data = Cyclic(data).stream() instances = self._create_instance_splitter(module, "training").apply( data, is_train=True ) return as_stacked_batches( instances, batch_size=self.batch_size, shuffle_buffer_length=shuffle_buffer_length, field_names=TRAINING_INPUT_NAMES, output_type=torch.tensor, num_batches_per_epoch=self.num_batches_per_epoch, )
[docs] def create_validation_data_loader( self, data: Dataset, module: DeepARLightningModule, **kwargs, ) -> Iterable: instances = self._create_instance_splitter(module, "validation").apply( data, is_train=True ) return as_stacked_batches( instances, batch_size=self.batch_size, field_names=TRAINING_INPUT_NAMES, output_type=torch.tensor, )
[docs] def create_lightning_module(self) -> DeepARLightningModule: return DeepARLightningModule( lr=self.lr, weight_decay=self.weight_decay, patience=self.patience, model_kwargs={ "freq": self.freq, "context_length": self.context_length, "prediction_length": self.prediction_length, "num_feat_dynamic_real": ( 1 + self.num_feat_dynamic_real + len(self.time_features) ), "num_feat_static_real": max(1, self.num_feat_static_real), "num_feat_static_cat": max(1, self.num_feat_static_cat), "cardinality": self.cardinality, "embedding_dimension": self.embedding_dimension, "num_layers": self.num_layers, "hidden_size": self.hidden_size, "distr_output": self.distr_output, "dropout_rate": self.dropout_rate, "lags_seq": self.lags_seq, "scaling": self.scaling, "default_scale": self.default_scale, "num_parallel_samples": self.num_parallel_samples, "nonnegative_pred_samples": self.nonnegative_pred_samples, }, )
[docs] def create_predictor( self, transformation: Transformation, module: DeepARLightningModule, ) -> PyTorchPredictor: prediction_splitter = self._create_instance_splitter(module, "test") return PyTorchPredictor( input_transform=transformation + prediction_splitter, input_names=PREDICTION_INPUT_NAMES, prediction_net=module, batch_size=self.batch_size, prediction_length=self.prediction_length, device="auto", )