Source code for gluonts.torch.model.tft.module

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import Dict, List, Optional, Tuple

import torch
import torch.nn as nn
from gluonts.core.component import validated
from gluonts.torch.modules.quantile_output import QuantileOutput
from gluonts.torch.modules.scaler import StdScaler
from gluonts.torch.util import weighted_average

from .layers import (
    FeatureEmbedder,
    FeatureProjector,
    GatedResidualNetwork,
    TemporalFusionDecoder,
    TemporalFusionEncoder,
    VariableSelectionNetwork,
)


[docs]class TemporalFusionTransformerModel(nn.Module): """Temporal Fusion Transformer neural network. Partially based on the implementation in github.com/kashif/pytorch-transformer-ts. Inputs feat_static_real, feat_static_cat and feat_dynamic_real are mandatory. Inputs feat_dynamic_cat, past_feat_dynamic_real and past_feat_dynamic_cat are optional. """ @validated() def __init__( self, context_length: int, prediction_length: int, d_feat_static_real: Optional[List[int]] = None, # Defaults to [1] c_feat_static_cat: Optional[List[int]] = None, # Defaults to [1] d_feat_dynamic_real: Optional[List[int]] = None, # Defaults to [1] c_feat_dynamic_cat: Optional[List[int]] = None, # Defaults to [] d_past_feat_dynamic_real: Optional[List[int]] = None, # Defaults to [] c_past_feat_dynamic_cat: Optional[List[int]] = None, # Defaults to [] quantiles: Optional[List[float]] = None, num_heads: int = 4, d_hidden: int = 32, d_var: int = 32, dropout_rate: float = 0.1, ): super().__init__() self.context_length = context_length self.prediction_length = prediction_length self.num_heads = num_heads self.d_hidden = d_hidden self.d_var = d_var self.dropout_rate = dropout_rate if quantiles is None: quantiles = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] self.quantiles = quantiles self.d_feat_static_real = d_feat_static_real or [1] self.d_feat_dynamic_real = d_feat_dynamic_real or [1] self.d_past_feat_dynamic_real = d_past_feat_dynamic_real or [] self.c_feat_static_cat = c_feat_static_cat or [1] self.c_feat_dynamic_cat = c_feat_dynamic_cat or [] self.c_past_feat_dynamic_cat = c_past_feat_dynamic_cat or [] self.num_feat_static = len(self.d_feat_static_real) + len( self.c_feat_static_cat ) self.num_feat_dynamic = len(self.d_feat_dynamic_real) + len( self.c_feat_dynamic_cat ) self.num_past_feat_dynamic = len(self.d_past_feat_dynamic_real) + len( self.c_past_feat_dynamic_cat ) self.scaler = StdScaler(dim=1, keepdim=True) self.target_proj = nn.Linear(in_features=1, out_features=self.d_var) # Past-only dynamic features if self.d_past_feat_dynamic_real: self.past_feat_dynamic_proj = FeatureProjector( feature_dims=self.d_past_feat_dynamic_real, embedding_dims=[self.d_var] * len(self.d_past_feat_dynamic_real), ) else: self.past_feat_dynamic_proj = None if self.c_past_feat_dynamic_cat: self.past_feat_dynamic_embed = FeatureEmbedder( cardinalities=self.c_past_feat_dynamic_cat, embedding_dims=[self.d_var] * len(self.c_past_feat_dynamic_cat), ) else: self.past_feat_dynamic_embed = None # Known dynamic features if self.d_feat_dynamic_real: self.feat_dynamic_proj = FeatureProjector( feature_dims=self.d_feat_dynamic_real, embedding_dims=[self.d_var] * len(self.d_feat_dynamic_real), ) else: self.feat_dynamic_proj = None if self.c_feat_dynamic_cat: self.feat_dynamic_embed = FeatureEmbedder( cardinalities=self.c_feat_dynamic_cat, embedding_dims=[self.d_var] * len(self.c_feat_dynamic_cat), ) else: self.feat_dynamic_embed = None # Static features if self.d_feat_static_real: self.feat_static_proj = FeatureProjector( feature_dims=self.d_feat_static_real, embedding_dims=[self.d_var] * len(self.d_feat_static_real), ) else: self.feat_static_proj = None if self.c_feat_static_cat: self.feat_static_embed = FeatureEmbedder( cardinalities=self.c_feat_static_cat, embedding_dims=[self.d_var] * len(self.c_feat_static_cat), ) else: self.feat_static_embed = None self.static_selector = VariableSelectionNetwork( d_hidden=self.d_var, num_vars=self.num_feat_static, dropout=self.dropout_rate, ) self.ctx_selector = VariableSelectionNetwork( d_hidden=self.d_var, num_vars=self.num_past_feat_dynamic + self.num_feat_dynamic + 1, add_static=True, dropout=self.dropout_rate, ) self.tgt_selector = VariableSelectionNetwork( d_hidden=self.d_var, num_vars=self.num_feat_dynamic, add_static=True, dropout=self.dropout_rate, ) self.selection = GatedResidualNetwork( d_hidden=self.d_var, dropout=self.dropout_rate, ) self.enrichment = GatedResidualNetwork( d_hidden=self.d_var, dropout=self.dropout_rate, ) self.state_h = GatedResidualNetwork( d_hidden=self.d_var, d_output=self.d_hidden, dropout=self.dropout_rate, ) self.state_c = GatedResidualNetwork( d_hidden=self.d_var, d_output=self.d_hidden, dropout=self.dropout_rate, ) self.temporal_encoder = TemporalFusionEncoder( d_input=self.d_var, d_hidden=self.d_hidden, ) self.temporal_decoder = TemporalFusionDecoder( context_length=self.context_length, prediction_length=self.prediction_length, d_hidden=self.d_hidden, d_var=self.d_var, num_heads=self.num_heads, dropout=self.dropout_rate, ) self.output = QuantileOutput(quantiles=self.quantiles) self.output_proj = self.output.get_args_proj(in_features=self.d_hidden)
[docs] def input_shapes(self, batch_size=1) -> Dict[str, Tuple[int, ...]]: return { "past_target": (batch_size, self.context_length), "past_observed_values": (batch_size, self.context_length), "feat_static_real": (batch_size, sum(self.d_feat_static_real)), "feat_static_cat": (batch_size, len(self.c_feat_static_cat)), "feat_dynamic_real": ( batch_size, self.context_length + self.prediction_length, sum(self.d_feat_dynamic_real), ), "feat_dynamic_cat": ( batch_size, self.context_length + self.prediction_length, len(self.c_feat_dynamic_cat), ), "past_feat_dynamic_real": ( batch_size, self.context_length, sum(self.d_past_feat_dynamic_real), ), "past_feat_dynamic_cat": ( batch_size, self.context_length, len(self.c_past_feat_dynamic_cat), ), }
[docs] def input_types(self) -> Dict[str, torch.dtype]: return { "past_target": torch.float, "past_observed_values": torch.float, "feat_static_real": torch.float, "feat_static_cat": torch.long, "feat_dynamic_real": torch.float, "feat_dynamic_cat": torch.long, "past_feat_dynamic_real": torch.float, "past_feat_dynamic_cat": torch.long, }
def _preprocess( self, past_target: torch.Tensor, # [N, T] past_observed_values: torch.Tensor, # [N, T] feat_static_real: torch.Tensor, # [N, D_sr] feat_static_cat: torch.Tensor, # [N, D_sc] feat_dynamic_real: torch.Tensor, # [N, T + H, D_dr] feat_dynamic_cat: torch.Tensor, # [N, T + H, D_dc] past_feat_dynamic_real: torch.Tensor, # [N, T, D_pr] past_feat_dynamic_cat: torch.Tensor, # [N, T, D_pc] ) -> Tuple[ List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], torch.Tensor, torch.Tensor, ]: past_target, loc, scale = self.scaler( data=past_target, weights=past_observed_values ) past_covariates = [self.target_proj(past_target.unsqueeze(-1))] future_covariates = [] static_covariates = [] if self.past_feat_dynamic_proj is not None: projs = self.past_feat_dynamic_proj(past_feat_dynamic_real) past_covariates.extend(projs) if self.past_feat_dynamic_embed is not None: embs = self.past_feat_dynamic_embed(past_feat_dynamic_cat) past_covariates.extend(embs) if self.feat_dynamic_proj is not None: projs = self.feat_dynamic_proj(feat_dynamic_real) for proj in projs: ctx_proj = proj[..., : self.context_length, :] tgt_proj = proj[..., self.context_length :, :] past_covariates.append(ctx_proj) future_covariates.append(tgt_proj) if self.feat_dynamic_embed is not None: embs = self.feat_dynamic_embed(feat_dynamic_cat) for emb in embs: ctx_emb = emb[..., : self.context_length, :] tgt_emb = emb[..., self.context_length :, :] past_covariates.append(ctx_emb) future_covariates.append(tgt_emb) if self.feat_static_proj is not None: projs = self.feat_static_proj(feat_static_real) static_covariates.extend(projs) if self.feat_static_embed is not None: embs = self.feat_static_embed(feat_static_cat) static_covariates.extend(embs) return ( past_covariates, future_covariates, static_covariates, loc, scale, )
[docs] def forward( self, past_target: torch.Tensor, # [N, T] past_observed_values: torch.Tensor, # [N, T] feat_static_real: Optional[torch.Tensor], # [N, D_sr] feat_static_cat: Optional[torch.Tensor], # [N, D_sc] feat_dynamic_real: Optional[torch.Tensor], # [N, T + H, D_dr] feat_dynamic_cat: Optional[torch.Tensor] = None, # [N, T + H, D_dc] past_feat_dynamic_real: Optional[torch.Tensor] = None, # [N, T, D_pr] past_feat_dynamic_cat: Optional[torch.Tensor] = None, # [N, T, D_pc] ) -> torch.Tensor: ( past_covariates, # [[N, T, d_var], ...] future_covariates, # [[N, H, d_var], ...] static_covariates, # [[N, d_var], ...] loc, # [N, 1] scale, # [N, 1] ) = self._preprocess( past_target=past_target, past_observed_values=past_observed_values, feat_static_real=feat_static_real, feat_static_cat=feat_static_cat, feat_dynamic_real=feat_dynamic_real, feat_dynamic_cat=feat_dynamic_cat, past_feat_dynamic_real=past_feat_dynamic_real, past_feat_dynamic_cat=past_feat_dynamic_cat, ) static_var, _ = self.static_selector(static_covariates) # [N, d_var] c_selection = self.selection(static_var).unsqueeze(1) # [N, 1, d_var] c_enrichment = self.enrichment(static_var).unsqueeze(1) c_h = self.state_h(static_var) # [N, self.d_hidden] c_c = self.state_c(static_var) # [N, self.d_hidden] states = [c_h.unsqueeze(0), c_c.unsqueeze(0)] ctx_input, _ = self.ctx_selector( past_covariates, c_selection ) # [N, T, d_var] tgt_input, _ = self.tgt_selector( future_covariates, c_selection ) # [N, H, d_var] encoding = self.temporal_encoder( ctx_input, tgt_input, states ) # [N, T + H, d_hidden] decoding = self.temporal_decoder( encoding, c_enrichment, past_observed_values ) # [N, H, d_hidden] preds = self.output_proj(decoding) output = preds * scale.unsqueeze(-1) + loc.unsqueeze(-1) return output.transpose(1, 2) # [N, Q, H]
[docs] def loss( self, past_target: torch.Tensor, # [N, T] past_observed_values: torch.Tensor, # [N, T] future_target: torch.Tensor, # [N, H] future_observed_values: torch.Tensor, # [N, H] feat_static_real: torch.Tensor, # [N, D_sr] feat_static_cat: torch.Tensor, # [N, D_sc] feat_dynamic_real: torch.Tensor, # [N, T + H, D_dr] feat_dynamic_cat: Optional[torch.Tensor] = None, # [N, T + H, D_dc] past_feat_dynamic_real: Optional[torch.Tensor] = None, # [N, T, D_pr] past_feat_dynamic_cat: Optional[torch.Tensor] = None, # [N, T, D_pc] ) -> torch.Tensor: preds = self.forward( past_target=past_target, past_observed_values=past_observed_values, feat_static_real=feat_static_real, feat_static_cat=feat_static_cat, feat_dynamic_real=feat_dynamic_real, feat_dynamic_cat=feat_dynamic_cat, past_feat_dynamic_real=past_feat_dynamic_real, past_feat_dynamic_cat=past_feat_dynamic_cat, ) # [N, Q, T] loss = self.output.quantile_loss( y_true=future_target, y_pred=preds.transpose(1, 2) ) # [N, T] loss = weighted_average(loss, future_observed_values) # [N] return loss.mean()