Source code for gluonts.torch.model.tft.estimator

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import Any, Dict, Iterable, List, Optional

import numpy as np
import torch

from gluonts.core.component import validated
from gluonts.dataset.common import Dataset
from gluonts.dataset.field_names import FieldName
from gluonts.dataset.loader import as_stacked_batches
from gluonts.itertools import Cyclic
from gluonts.time_feature import TimeFeature, time_features_from_frequency_str
from gluonts.torch.distributions import Output, QuantileOutput
from gluonts.torch.model.estimator import PyTorchLightningEstimator
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.transform import (
    AddConstFeature,
    AddObservedValuesIndicator,
    AddTimeFeatures,
    AsNumpyArray,
    Chain,
    ExpectedNumInstanceSampler,
    RemoveFields,
    SetField,
    TestSplitSampler,
    Transformation,
    ValidationSplitSampler,
    VstackFeatures,
)
from gluonts.transform.sampler import InstanceSampler
from gluonts.transform.split import TFTInstanceSplitter

from .lightning_module import TemporalFusionTransformerLightningModule

PREDICTION_INPUT_NAMES = [
    "past_target",
    "past_observed_values",
    "feat_static_real",
    "feat_static_cat",
    "feat_dynamic_real",
    "feat_dynamic_cat",
    "past_feat_dynamic_real",
    "past_feat_dynamic_cat",
]

TRAINING_INPUT_NAMES = PREDICTION_INPUT_NAMES + [
    "future_target",
    "future_observed_values",
]


[docs]class TemporalFusionTransformerEstimator(PyTorchLightningEstimator): """ Estimator class to train a Temporal Fusion Transformer (TFT) model, as described in [LAL+21]_. TFT internally performs feature selection when making forecasts. For this reason, the dimensions of real-valued features can be grouped together if they correspond to the same variable (e.g., treat weather features as a one feature and holiday indicators as another feature). For example, if the dataset contains key "feat_static_real" with shape [batch_size, 3], we can, e.g., - set ``static_dims = [3]`` to treat all three dimensions as a single feature - set ``static_dims = [1, 1, 1]`` to treat each dimension as a separate feature - set ``static_dims = [2, 1]`` to treat the first two dims as a single feature See ``gluonts.torch.model.tft.TemporalFusionTransformerModel.input_shapes`` for more details on how the model configuration corresponds to the expected input shapes. Parameters ---------- freq Frequency of the data to train on and predict. prediction_length Length of the prediction horizon. context_length Number of previous time series values provided as input to the encoder. (default: None, in which case context_length = prediction_length). quantiles List of quantiles that the model will learn to predict. Defaults to [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] distr_output Distribution output to use (default: ``QuantileOutput``). num_heads Number of attention heads in self-attention layer in the decoder. hidden_dim Size of the LSTM & transformer hidden states. variable_dim Size of the feature embeddings. static_dims Sizes of the real-valued static features. dynamic_dims Sizes of the real-valued dynamic features that are known in the future. past_dynamic_dims Sizes of the real-valued dynamic features that are only known in the past. static_cardinalities Cardinalities of the categorical static features. dynamic_cardinalities Cardinalities of the categorical dynamic features that are known in the future. past_dynamic_cardinalities Cardinalities of the categorical dynamic features that are ony known in the past. time_features List of time features, from :py:mod:`gluonts.time_feature`, to use as dynamic real features in addition to the provided data (default: None, in which case these are automatically determined based on freq). lr Learning rate (default: ``1e-3``). weight_decay Weight decay (default: ``1e-8``). dropout_rate Dropout regularization parameter (default: 0.1). patience Patience parameter for learning rate scheduler. batch_size The size of the batches to be used for training (default: 32). num_batches_per_epoch: int = 50, Number of batches to be processed in each training epoch (default: 50). trainer_kwargs Additional arguments to provide to ``pl.Trainer`` for construction. train_sampler Controls the sampling of windows during training. validation_sampler Controls the sampling of windows during validation. """ @validated() def __init__( self, freq: str, prediction_length: int, context_length: Optional[int] = None, quantiles: Optional[List[float]] = None, distr_output: Optional[Output] = None, num_heads: int = 4, hidden_dim: int = 32, variable_dim: int = 32, static_dims: Optional[List[int]] = None, dynamic_dims: Optional[List[int]] = None, past_dynamic_dims: Optional[List[int]] = None, static_cardinalities: Optional[List[int]] = None, dynamic_cardinalities: Optional[List[int]] = None, past_dynamic_cardinalities: Optional[List[int]] = None, time_features: Optional[List[TimeFeature]] = None, lr: float = 1e-3, weight_decay: float = 1e-8, dropout_rate: float = 0.1, patience: int = 10, batch_size: int = 32, num_batches_per_epoch: int = 50, trainer_kwargs: Optional[Dict[str, Any]] = None, train_sampler: Optional[InstanceSampler] = None, validation_sampler: Optional[InstanceSampler] = None, ) -> None: default_trainer_kwargs = { "max_epochs": 100, "gradient_clip_val": 10.0, } if trainer_kwargs is not None: default_trainer_kwargs.update(trainer_kwargs) super().__init__(trainer_kwargs=default_trainer_kwargs) self.freq = freq self.prediction_length = prediction_length self.context_length = ( context_length if context_length is not None else prediction_length ) # Model architecture if distr_output is not None and quantiles is not None: raise ValueError( "Only one of `distr_output` and `quantiles` must be specified" ) elif distr_output is not None: self.distr_output = distr_output else: if quantiles is None: quantiles = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] self.distr_output = QuantileOutput(quantiles=quantiles) self.num_heads = num_heads self.hidden_dim = hidden_dim self.variable_dim = variable_dim self.static_dims = static_dims or [] self.dynamic_dims = dynamic_dims or [] self.past_dynamic_dims = past_dynamic_dims or [] self.static_cardinalities = static_cardinalities or [] self.dynamic_cardinalities = dynamic_cardinalities or [] self.past_dynamic_cardinalities = past_dynamic_cardinalities or [] if time_features is None: time_features = time_features_from_frequency_str(self.freq) self.time_features = time_features # Training procedure self.lr = lr self.weight_decay = weight_decay self.dropout_rate = dropout_rate self.patience = patience self.batch_size = batch_size self.num_batches_per_epoch = num_batches_per_epoch self.train_sampler = train_sampler or ExpectedNumInstanceSampler( num_instances=1.0, min_future=prediction_length ) self.validation_sampler = validation_sampler or ValidationSplitSampler( min_future=prediction_length )
[docs] def create_transformation(self) -> Transformation: remove_field_names = [] if not self.static_dims: remove_field_names.append(FieldName.FEAT_STATIC_REAL) if not self.dynamic_dims: remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL) if not self.past_dynamic_dims: remove_field_names.append(FieldName.PAST_FEAT_DYNAMIC_REAL) if not self.static_cardinalities: remove_field_names.append(FieldName.FEAT_STATIC_CAT) if not self.dynamic_cardinalities: remove_field_names.append(FieldName.FEAT_DYNAMIC_CAT) if not self.past_dynamic_cardinalities: remove_field_names.append(FieldName.PAST_FEAT_DYNAMIC_CAT) transforms = [ RemoveFields(field_names=remove_field_names), AsNumpyArray(field=FieldName.TARGET, expected_ndim=1), AddObservedValuesIndicator( target_field=FieldName.TARGET, output_field=FieldName.OBSERVED_VALUES, ), ] if len(self.time_features) > 0: transforms.append( AddTimeFeatures( start_field=FieldName.START, target_field=FieldName.TARGET, output_field=FieldName.FEAT_TIME, time_features=self.time_features, pred_length=self.prediction_length, ) ) else: # Add dummy dynamic feature if no time features are available transforms.append( AddConstFeature( output_field=FieldName.FEAT_TIME, target_field=FieldName.TARGET, pred_length=self.prediction_length, const=0.0, ) ) # Provide dummy values if static features are missing if not self.static_dims: transforms.append( SetField(output_field=FieldName.FEAT_STATIC_REAL, value=[0.0]) ) transforms.append( AsNumpyArray(field=FieldName.FEAT_STATIC_REAL, expected_ndim=1) ) if not self.static_cardinalities: transforms.append( SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0]) ) transforms.append( AsNumpyArray( field=FieldName.FEAT_STATIC_CAT, expected_ndim=1, dtype=np.int64, ) ) # Concat time features with known dynamic features input_fields = [FieldName.FEAT_TIME] if self.dynamic_dims: input_fields += [FieldName.FEAT_DYNAMIC_REAL] transforms.append( VstackFeatures( input_fields=input_fields, output_field=FieldName.FEAT_DYNAMIC_REAL, ) ) return Chain(transforms)
def _create_instance_splitter(self, mode: str): assert mode in ["training", "validation", "test"] instance_sampler = { "training": self.train_sampler, "validation": self.validation_sampler, "test": TestSplitSampler(), }[mode] ts_fields = [FieldName.FEAT_DYNAMIC_REAL] if self.dynamic_cardinalities: ts_fields.append(FieldName.FEAT_DYNAMIC_CAT) past_ts_fields = [] if self.past_dynamic_cardinalities: past_ts_fields.append(FieldName.PAST_FEAT_DYNAMIC_CAT) if self.past_dynamic_dims: past_ts_fields.append(FieldName.PAST_FEAT_DYNAMIC_REAL) return TFTInstanceSplitter( instance_sampler=instance_sampler, past_length=self.context_length, future_length=self.prediction_length, time_series_fields=ts_fields, past_time_series_fields=past_ts_fields, )
[docs] def input_names(self): input_names = list(TRAINING_INPUT_NAMES) if not self.dynamic_cardinalities: input_names.remove("feat_dynamic_cat") if not self.past_dynamic_cardinalities: input_names.remove("past_feat_dynamic_cat") if not self.past_dynamic_dims: input_names.remove("past_feat_dynamic_real") return input_names
[docs] def create_training_data_loader( self, data: Dataset, module: TemporalFusionTransformerLightningModule, shuffle_buffer_length: Optional[int] = None, **kwargs, ) -> Iterable: data = Cyclic(data).stream() instances = self._create_instance_splitter("training").apply( data, is_train=True ) return as_stacked_batches( instances, batch_size=self.batch_size, shuffle_buffer_length=shuffle_buffer_length, field_names=self.input_names(), output_type=torch.tensor, num_batches_per_epoch=self.num_batches_per_epoch, )
[docs] def create_validation_data_loader( self, data: Dataset, module: TemporalFusionTransformerLightningModule, **kwargs, ) -> Iterable: instances = self._create_instance_splitter("validation").apply( data, is_train=True ) return as_stacked_batches( instances, batch_size=self.batch_size, field_names=self.input_names(), output_type=torch.tensor, )
[docs] def create_lightning_module( self, ) -> TemporalFusionTransformerLightningModule: return TemporalFusionTransformerLightningModule( lr=self.lr, patience=self.patience, weight_decay=self.weight_decay, model_kwargs={ "context_length": self.context_length, "prediction_length": self.prediction_length, "d_var": self.variable_dim, "d_hidden": self.hidden_dim, "num_heads": self.num_heads, "distr_output": self.distr_output, "d_past_feat_dynamic_real": self.past_dynamic_dims, "c_past_feat_dynamic_cat": self.past_dynamic_cardinalities, "d_feat_dynamic_real": [1] * max(len(self.time_features), 1) + self.dynamic_dims, "c_feat_dynamic_cat": self.dynamic_cardinalities, "d_feat_static_real": self.static_dims or [1], "c_feat_static_cat": self.static_cardinalities or [1], "dropout_rate": self.dropout_rate, }, )
[docs] def create_predictor( self, transformation: Transformation, module: TemporalFusionTransformerLightningModule, ) -> PyTorchPredictor: # TODO prediction_splitter = self._create_instance_splitter("test") return PyTorchPredictor( input_transform=transformation + prediction_splitter, input_names=PREDICTION_INPUT_NAMES, prediction_net=module, batch_size=self.batch_size, prediction_length=self.prediction_length, device="auto", forecast_generator=self.distr_output.forecast_generator, )