Source code for gluonts.torch.model.deepar.estimator

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import List, Optional, Iterable, Dict, Any

import torch
from torch.utils.data import DataLoader

from gluonts.core.component import validated
from gluonts.dataset.common import Dataset
from gluonts.dataset.field_names import FieldName
from gluonts.itertools import Cyclic, PseudoShuffled, IterableSlice
from gluonts.time_feature import (
    TimeFeature,
    time_features_from_frequency_str,
)
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
from gluonts.transform import (
    Transformation,
    Chain,
    RemoveFields,
    SetField,
    AsNumpyArray,
    AddObservedValuesIndicator,
    AddTimeFeatures,
    AddAgeFeature,
    VstackFeatures,
    InstanceSplitter,
    ValidationSplitSampler,
    TestSplitSampler,
    ExpectedNumInstanceSampler,
    SelectFields,
)
from gluonts.torch.util import (
    IterableDataset,
)
from gluonts.torch.model.estimator import PyTorchLightningEstimator
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.torch.distributions import (
    DistributionOutput,
    StudentTOutput,
)
from gluonts.transform.sampler import InstanceSampler

from .module import DeepARModel
from .lightning_module import DeepARLightningModule

PREDICTION_INPUT_NAMES = [
    "feat_static_cat",
    "feat_static_real",
    "past_time_feat",
    "past_target",
    "past_observed_values",
    "future_time_feat",
]

TRAINING_INPUT_NAMES = PREDICTION_INPUT_NAMES + [
    "future_target",
    "future_observed_values",
]


[docs]class DeepAREstimator(PyTorchLightningEstimator): """ Estimator class to train a DeepAR model, as described in [SFG17]_. This class is uses the model defined in ``DeepARModel``, and wraps it into a ``DeepARLightningModule`` for training purposes: training is performed using PyTorch Lightning's ``pl.Trainer`` class. *Note:* the code of this model is unrelated to the implementation behind `SageMaker's DeepAR Forecasting Algorithm <https://docs.aws.amazon.com/sagemaker/latest/dg/deepar.html>`_. Parameters ---------- freq Frequency of the data to train on and predict. prediction_length Length of the prediction horizon. context_length Number of steps to unroll the RNN for before computing predictions (default: None, in which case context_length = prediction_length). num_layers Number of RNN layers (default: 2). hidden_size Number of RNN cells for each layer (default: 40). dropout_rate Dropout regularization parameter (default: 0.1). num_feat_dynamic_real Number of dynamic real features in the data (default: 0). num_feat_static_real Number of static real features in the data (default: 0). num_feat_static_cat Number of static categorical features in the data (default: 0). cardinality Number of values of each categorical feature. This must be set if ``num_feat_static_cat > 0`` (default: None). embedding_dimension Dimension of the embeddings for categorical features (default: ``[min(50, (cat+1)//2) for cat in cardinality]``). distr_output Distribution to use to evaluate observations and sample predictions (default: StudentTOutput()). loss Loss to be optimized during training (default: ``NegativeLogLikelihood()``). scaling Whether to automatically scale the target values (default: true). lags_seq Indices of the lagged target values to use as inputs of the RNN (default: None, in which case these are automatically determined based on freq). time_features List of time features, from :py:mod:`gluonts.time_feature`, to use as inputs of the RNN in addition to the provided data (default: None, in which case these are automatically determined based on freq). num_parallel_samples Number of samples per time series to that the resulting predictor should produce (default: 100). batch_size The size of the batches to be used for training (default: 32). num_batches_per_epoch Number of batches to be processed in each training epoch (default: 50). trainer_kwargs Additional arguments to provide to ``pl.Trainer`` for construction. train_sampler Controls the sampling of windows during training. validation_sampler Controls the sampling of windows during validation. """ @validated() def __init__( self, freq: str, prediction_length: int, context_length: Optional[int] = None, num_layers: int = 2, hidden_size: int = 40, dropout_rate: float = 0.1, num_feat_dynamic_real: int = 0, num_feat_static_cat: int = 0, num_feat_static_real: int = 0, cardinality: Optional[List[int]] = None, embedding_dimension: Optional[List[int]] = None, distr_output: DistributionOutput = StudentTOutput(), loss: DistributionLoss = NegativeLogLikelihood(), scaling: bool = True, lags_seq: Optional[List[int]] = None, time_features: Optional[List[TimeFeature]] = None, num_parallel_samples: int = 100, batch_size: int = 32, num_batches_per_epoch: int = 50, trainer_kwargs: Optional[Dict[str, Any]] = None, train_sampler: Optional[InstanceSampler] = None, validation_sampler: Optional[InstanceSampler] = None, ) -> None: default_trainer_kwargs = { "max_epochs": 100, "gradient_clip_val": 10.0, } if trainer_kwargs is not None: default_trainer_kwargs.update(trainer_kwargs) super().__init__(trainer_kwargs=default_trainer_kwargs) self.freq = freq self.context_length = ( context_length if context_length is not None else prediction_length ) self.prediction_length = prediction_length self.distr_output = distr_output self.loss = loss self.num_layers = num_layers self.hidden_size = hidden_size self.dropout_rate = dropout_rate self.num_feat_dynamic_real = num_feat_dynamic_real self.num_feat_static_cat = num_feat_static_cat self.num_feat_static_real = num_feat_static_real self.cardinality = ( cardinality if cardinality and num_feat_static_cat > 0 else [1] ) self.embedding_dimension = embedding_dimension self.scaling = scaling self.lags_seq = lags_seq self.time_features = ( time_features if time_features is not None else time_features_from_frequency_str(self.freq) ) self.num_parallel_samples = num_parallel_samples self.batch_size = batch_size self.num_batches_per_epoch = num_batches_per_epoch self.train_sampler = train_sampler or ExpectedNumInstanceSampler( num_instances=1.0, min_future=prediction_length ) self.validation_sampler = validation_sampler or ValidationSplitSampler( min_future=prediction_length )
[docs] def create_transformation(self) -> Transformation: remove_field_names = [] if self.num_feat_static_real == 0: remove_field_names.append(FieldName.FEAT_STATIC_REAL) if self.num_feat_dynamic_real == 0: remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL) return Chain( [RemoveFields(field_names=remove_field_names)] + ( [SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0])] if not self.num_feat_static_cat > 0 else [] ) + ( [ SetField( output_field=FieldName.FEAT_STATIC_REAL, value=[0.0] ) ] if not self.num_feat_static_real > 0 else [] ) + [ AsNumpyArray( field=FieldName.FEAT_STATIC_CAT, expected_ndim=1, dtype=int, ), AsNumpyArray( field=FieldName.FEAT_STATIC_REAL, expected_ndim=1, ), AsNumpyArray( field=FieldName.TARGET, # in the following line, we add 1 for the time dimension expected_ndim=1 + len(self.distr_output.event_shape), ), AddObservedValuesIndicator( target_field=FieldName.TARGET, output_field=FieldName.OBSERVED_VALUES, ), AddTimeFeatures( start_field=FieldName.START, target_field=FieldName.TARGET, output_field=FieldName.FEAT_TIME, time_features=self.time_features, pred_length=self.prediction_length, ), AddAgeFeature( target_field=FieldName.TARGET, output_field=FieldName.FEAT_AGE, pred_length=self.prediction_length, log_scale=True, ), VstackFeatures( output_field=FieldName.FEAT_TIME, input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE] + ( [FieldName.FEAT_DYNAMIC_REAL] if self.num_feat_dynamic_real > 0 else [] ), ), ] )
def _create_instance_splitter( self, module: DeepARLightningModule, mode: str ): assert mode in ["training", "validation", "test"] instance_sampler = { "training": self.train_sampler, "validation": self.validation_sampler, "test": TestSplitSampler(), }[mode] return InstanceSplitter( target_field=FieldName.TARGET, is_pad_field=FieldName.IS_PAD, start_field=FieldName.START, forecast_start_field=FieldName.FORECAST_START, instance_sampler=instance_sampler, past_length=module.model._past_length, future_length=self.prediction_length, time_series_fields=[ FieldName.FEAT_TIME, FieldName.OBSERVED_VALUES, ], dummy_value=self.distr_output.value_in_support, )
[docs] def create_training_data_loader( self, data: Dataset, module: DeepARLightningModule, shuffle_buffer_length: Optional[int] = None, **kwargs, ) -> Iterable: transformation = self._create_instance_splitter( module, "training" ) + SelectFields(TRAINING_INPUT_NAMES) training_instances = transformation.apply( Cyclic(data) if shuffle_buffer_length is None else PseudoShuffled( Cyclic(data), shuffle_buffer_length=shuffle_buffer_length ) ) return IterableSlice( iter( DataLoader( IterableDataset(training_instances), batch_size=self.batch_size, **kwargs, ) ), self.num_batches_per_epoch, )
[docs] def create_validation_data_loader( self, data: Dataset, module: DeepARLightningModule, **kwargs, ) -> Iterable: transformation = self._create_instance_splitter( module, "validation" ) + SelectFields(TRAINING_INPUT_NAMES) validation_instances = transformation.apply(data) return DataLoader( IterableDataset(validation_instances), batch_size=self.batch_size, **kwargs, )
[docs] def create_lightning_module(self) -> DeepARLightningModule: model = DeepARModel( freq=self.freq, context_length=self.context_length, prediction_length=self.prediction_length, num_feat_dynamic_real=( 1 + self.num_feat_dynamic_real + len(self.time_features) ), num_feat_static_real=max(1, self.num_feat_static_real), num_feat_static_cat=max(1, self.num_feat_static_cat), cardinality=self.cardinality, embedding_dimension=self.embedding_dimension, num_layers=self.num_layers, hidden_size=self.hidden_size, distr_output=self.distr_output, dropout_rate=self.dropout_rate, lags_seq=self.lags_seq, scaling=self.scaling, num_parallel_samples=self.num_parallel_samples, ) return DeepARLightningModule(model=model, loss=self.loss)
[docs] def create_predictor( self, transformation: Transformation, module: DeepARLightningModule, ) -> PyTorchPredictor: prediction_splitter = self._create_instance_splitter(module, "test") return PyTorchPredictor( input_transform=transformation + prediction_splitter, input_names=PREDICTION_INPUT_NAMES, prediction_net=module.model, batch_size=self.batch_size, prediction_length=self.prediction_length, device=torch.device( "cuda" if torch.cuda.is_available() else "cpu" ), )