Source code for gluonts.mx.model.transformer._estimator
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from functools import partial
from typing import List, Optional
from mxnet.gluon import HybridBlock
from gluonts.core.component import validated
from gluonts.dataset.common import Dataset
from gluonts.dataset.field_names import FieldName
from gluonts.dataset.loader import (
DataLoader,
TrainDataLoader,
ValidationDataLoader,
)
from gluonts.model.predictor import Predictor
from gluonts.mx.batchify import batchify
from gluonts.mx.distribution import DistributionOutput, StudentTOutput
from gluonts.mx.model.estimator import GluonEstimator
from gluonts.mx.model.predictor import RepresentableBlockPredictor
from gluonts.mx.trainer import Trainer
from gluonts.mx.util import copy_parameters, get_hybrid_forward_input_names
from gluonts.time_feature import (
TimeFeature,
get_lags_for_frequency,
time_features_from_frequency_str,
)
from gluonts.transform import (
AddAgeFeature,
AddObservedValuesIndicator,
AddTimeFeatures,
AsNumpyArray,
Chain,
ExpectedNumInstanceSampler,
InstanceSampler,
InstanceSplitter,
RemoveFields,
SelectFields,
SetField,
TestSplitSampler,
Transformation,
ValidationSplitSampler,
VstackFeatures,
)
from ._network import TransformerPredictionNetwork, TransformerTrainingNetwork
from .trans_decoder import TransformerDecoder
from .trans_encoder import TransformerEncoder
[docs]class TransformerEstimator(GluonEstimator):
"""
Construct a Transformer estimator.
This implements a Transformer model, close to the one described in
[Vaswani2017]_.
.. [Vaswani2017] Vaswani, Ashish, et al. "Attention is all you need."
Advances in neural information processing systems. 2017.
Parameters
----------
freq
Frequency of the data to train on and predict
prediction_length
Length of the prediction horizon
context_length
Number of steps to unroll the RNN for before computing predictions
(default: None, in which case context_length = prediction_length)
trainer
Trainer object to be used (default: Trainer())
dropout_rate
Dropout regularization parameter (default: 0.1)
cardinality
Number of values of the each categorical feature (default: [1])
embedding_dimension
Dimension of the embeddings for categorical features (the same
dimension is used for all embeddings, default: 5)
distr_output
Distribution to use to evaluate observations and sample predictions
(default: StudentTOutput())
model_dim
Dimension of the transformer network, i.e., embedding dimension of the
input (default: 32)
inner_ff_dim_scale
Dimension scale of the inner hidden layer of the transformer's
feedforward network (default: 4)
pre_seq
Sequence that defined operations of the processing block before the
main transformer network. Available operations: 'd' for dropout, 'r'
for residual connections and 'n' for normalization (default: 'dn')
post_seq
Sequence that defined operations of the processing block in and after
the main transformer network. Available operations: 'd' for
dropout, 'r' for residual connections and 'n' for normalization
(default: 'drn').
act_type
Activation type of the transformer network (default: 'softrelu')
num_heads
Number of heads in the multi-head attention (default: 8)
scaling
Whether to automatically scale the target values (default: true)
lags_seq
Indices of the lagged target values to use as inputs of the RNN
(default: None, in which case these are automatically determined
based on freq)
time_features
Time features to use as inputs of the RNN (default: None, in which
case these are automatically determined based on freq)
num_parallel_samples
Number of evaluation samples per time series to increase parallelism
during inference. This is a model optimization that does not affect the
accuracy (default: 100)
train_sampler
Controls the sampling of windows during training.
validation_sampler
Controls the sampling of windows during validation.
batch_size
The size of the batches to be used training and prediction.
"""
@validated()
def __init__(
self,
freq: str,
prediction_length: int,
context_length: Optional[int] = None,
trainer: Trainer = Trainer(),
dropout_rate: float = 0.1,
cardinality: Optional[List[int]] = None,
embedding_dimension: int = 20,
distr_output: DistributionOutput = StudentTOutput(),
model_dim: int = 32,
inner_ff_dim_scale: int = 4,
pre_seq: str = "dn",
post_seq: str = "drn",
act_type: str = "softrelu",
num_heads: int = 8,
scaling: bool = True,
lags_seq: Optional[List[int]] = None,
time_features: Optional[List[TimeFeature]] = None,
use_feat_dynamic_real: bool = False,
use_feat_static_cat: bool = False,
num_parallel_samples: int = 100,
train_sampler: Optional[InstanceSampler] = None,
validation_sampler: Optional[InstanceSampler] = None,
batch_size: int = 32,
) -> None:
super().__init__(trainer=trainer, batch_size=batch_size)
assert (
prediction_length > 0
), "The value of `prediction_length` should be > 0"
assert (
context_length is None or context_length > 0
), "The value of `context_length` should be > 0"
assert dropout_rate >= 0, "The value of `dropout_rate` should be >= 0"
assert (
cardinality is not None or not use_feat_static_cat
), "You must set `cardinality` if `use_feat_static_cat=True`"
assert cardinality is None or all(
[c > 0 for c in cardinality]
), "Elements of `cardinality` should be > 0"
assert (
embedding_dimension > 0
), "The value of `embedding_dimension` should be > 0"
assert (
num_parallel_samples > 0
), "The value of `num_parallel_samples` should be > 0"
self.prediction_length = prediction_length
self.context_length = (
context_length if context_length is not None else prediction_length
)
self.distr_output = distr_output
self.dropout_rate = dropout_rate
self.use_feat_dynamic_real = use_feat_dynamic_real
self.use_feat_static_cat = use_feat_static_cat
self.cardinality = cardinality if use_feat_static_cat else [1]
self.embedding_dimension = embedding_dimension
self.num_parallel_samples = num_parallel_samples
self.lags_seq = (
lags_seq
if lags_seq is not None
else get_lags_for_frequency(freq_str=freq)
)
self.time_features = (
time_features
if time_features is not None
else time_features_from_frequency_str(freq)
)
self.history_length = self.context_length + max(self.lags_seq)
self.scaling = scaling
self.config = {
"model_dim": model_dim,
"pre_seq": pre_seq,
"post_seq": post_seq,
"dropout_rate": dropout_rate,
"inner_ff_dim_scale": inner_ff_dim_scale,
"act_type": act_type,
"num_heads": num_heads,
}
self.encoder = TransformerEncoder(
self.context_length, self.config, prefix="enc_"
)
self.decoder = TransformerDecoder(
self.prediction_length, self.config, prefix="dec_"
)
self.train_sampler = (
train_sampler
if train_sampler is not None
else ExpectedNumInstanceSampler(
num_instances=1.0, min_future=prediction_length
)
)
self.validation_sampler = (
validation_sampler
if validation_sampler is not None
else ValidationSplitSampler(min_future=prediction_length)
)
[docs] def create_transformation(self) -> Transformation:
remove_field_names = [
FieldName.FEAT_DYNAMIC_CAT,
FieldName.FEAT_STATIC_REAL,
]
if not self.use_feat_dynamic_real:
remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL)
return Chain(
[RemoveFields(field_names=remove_field_names)]
+ (
[SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0.0])]
if not self.use_feat_static_cat
else []
)
+ [
AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1),
AsNumpyArray(
field=FieldName.TARGET,
# in the following line, we add 1 for the time dimension
expected_ndim=1 + len(self.distr_output.event_shape),
),
AddObservedValuesIndicator(
target_field=FieldName.TARGET,
output_field=FieldName.OBSERVED_VALUES,
),
AddTimeFeatures(
start_field=FieldName.START,
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_TIME,
time_features=self.time_features,
pred_length=self.prediction_length,
),
AddAgeFeature(
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_AGE,
pred_length=self.prediction_length,
log_scale=True,
),
VstackFeatures(
output_field=FieldName.FEAT_TIME,
input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE]
+ (
[FieldName.FEAT_DYNAMIC_REAL]
if self.use_feat_dynamic_real
else []
),
),
]
)
def _create_instance_splitter(self, mode: str):
assert mode in ["training", "validation", "test"]
instance_sampler = {
"training": self.train_sampler,
"validation": self.validation_sampler,
"test": TestSplitSampler(),
}[mode]
return InstanceSplitter(
target_field=FieldName.TARGET,
is_pad_field=FieldName.IS_PAD,
start_field=FieldName.START,
forecast_start_field=FieldName.FORECAST_START,
instance_sampler=instance_sampler,
past_length=self.history_length,
future_length=self.prediction_length,
time_series_fields=[
FieldName.FEAT_TIME,
FieldName.OBSERVED_VALUES,
],
)
[docs] def create_training_data_loader(
self,
data: Dataset,
**kwargs,
) -> DataLoader:
input_names = get_hybrid_forward_input_names(
TransformerTrainingNetwork
)
instance_splitter = self._create_instance_splitter("training")
return TrainDataLoader(
dataset=data,
transform=instance_splitter + SelectFields(input_names),
batch_size=self.batch_size,
stack_fn=partial(batchify, ctx=self.trainer.ctx, dtype=self.dtype),
**kwargs,
)
[docs] def create_validation_data_loader(
self,
data: Dataset,
**kwargs,
) -> DataLoader:
input_names = get_hybrid_forward_input_names(
TransformerTrainingNetwork
)
instance_splitter = self._create_instance_splitter("validation")
return ValidationDataLoader(
dataset=data,
transform=instance_splitter + SelectFields(input_names),
batch_size=self.batch_size,
stack_fn=partial(batchify, ctx=self.trainer.ctx, dtype=self.dtype),
)
[docs] def create_training_network(self) -> TransformerTrainingNetwork:
return TransformerTrainingNetwork(
encoder=self.encoder,
decoder=self.decoder,
history_length=self.history_length,
context_length=self.context_length,
prediction_length=self.prediction_length,
distr_output=self.distr_output,
cardinality=self.cardinality,
embedding_dimension=self.embedding_dimension,
lags_seq=self.lags_seq,
scaling=self.scaling,
)
[docs] def create_predictor(
self, transformation: Transformation, trained_network: HybridBlock
) -> Predictor:
prediction_splitter = self._create_instance_splitter("test")
prediction_network = TransformerPredictionNetwork(
encoder=self.encoder,
decoder=self.decoder,
history_length=self.history_length,
context_length=self.context_length,
prediction_length=self.prediction_length,
distr_output=self.distr_output,
cardinality=self.cardinality,
embedding_dimension=self.embedding_dimension,
lags_seq=self.lags_seq,
scaling=self.scaling,
num_parallel_samples=self.num_parallel_samples,
)
copy_parameters(trained_network, prediction_network)
return RepresentableBlockPredictor(
input_transform=transformation + prediction_splitter,
prediction_net=prediction_network,
batch_size=self.batch_size,
prediction_length=self.prediction_length,
ctx=self.trainer.ctx,
)