# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import NamedTuple, Optional, Type, Union
import numpy as np
from gluonts.core import fqname_for
from gluonts.core.component import (
GluonTSHyperparametersError,
from_hyperparameters,
validated,
)
from gluonts.dataset.common import Dataset
from gluonts.dataset.loader import DataLoader
from gluonts.env import env
from gluonts.itertools import Cached
from gluonts.model import Estimator, Predictor
from gluonts.mx.model.predictor import GluonPredictor
from gluonts.mx.trainer import Trainer
from gluonts.mx.util import copy_parameters
from gluonts.pydantic import ValidationError
from gluonts.transform import Transformation, TransformedDataset
from mxnet.gluon import HybridBlock
[docs]class TrainOutput(NamedTuple):
transformation: Transformation
trained_net: HybridBlock
predictor: Predictor
[docs]class GluonEstimator(Estimator):
"""
An `Estimator` type with utilities for creating Gluon-based models.
To extend this class, one needs to implement three methods:
`create_transformation`, `create_training_network`, `create_predictor`,
`create_training_data_loader`, and `create_validation_data_loader`.
"""
@validated()
def __init__(
self,
*,
trainer: Trainer,
batch_size: int = 32,
lead_time: int = 0,
dtype: Type = np.float32,
) -> None:
super().__init__(lead_time=lead_time)
assert batch_size > 0, "The value of `batch_size` should be > 0"
self.batch_size = batch_size
self.trainer = trainer
self.dtype = dtype
[docs] @classmethod
def from_hyperparameters(cls, **hyperparameters) -> "GluonEstimator":
Model = getattr(cls.__init__, "Model", None)
if not Model:
raise AttributeError(
"Cannot find attribute Model attached to the "
f"{fqname_for(cls)}. Most probably you have forgotten to mark "
"the class constructor as @validated()."
)
try:
trainer = from_hyperparameters(Trainer, **hyperparameters)
return cls(
**Model(**{**hyperparameters, "trainer": trainer}).__dict__
)
except ValidationError as e:
raise GluonTSHyperparametersError from e
[docs] def create_training_network(self) -> HybridBlock:
"""
Create and return the network used for training (i.e., computing the
loss).
Returns
-------
HybridBlock
The network that computes the loss given input data.
"""
raise NotImplementedError
[docs] def create_predictor(
self, transformation: Transformation, trained_network: HybridBlock
) -> Predictor:
"""
Create and return a predictor object.
Parameters
----------
transformation
Transformation to be applied to data before it goes into the model.
module
A trained `HybridBlock` object.
Returns
-------
Predictor
A predictor wrapping a `HybridBlock` used for inference.
"""
raise NotImplementedError
[docs] def create_training_data_loader(
self, data: Dataset, **kwargs
) -> DataLoader:
"""
Create a data loader for training purposes.
Parameters
----------
data
Dataset from which to create the data loader.
Returns
-------
DataLoader
The data loader, i.e. and iterable over batches of data.
"""
raise NotImplementedError
[docs] def create_validation_data_loader(
self, data: Dataset, **kwargs
) -> DataLoader:
"""
Create a data loader for validation purposes.
Parameters
----------
data
Dataset from which to create the data loader.
Returns
-------
DataLoader
The data loader, i.e. and iterable over batches of data.
"""
raise NotImplementedError
[docs] def train_model(
self,
training_data: Dataset,
validation_data: Optional[Dataset] = None,
from_predictor: Optional[GluonPredictor] = None,
shuffle_buffer_length: Optional[int] = None,
cache_data: bool = False,
) -> TrainOutput:
transformation = self.create_transformation()
with env._let(max_idle_transforms=max(len(training_data), 100)):
transformed_training_data: Union[TransformedDataset, Cached] = (
transformation.apply(training_data)
)
if cache_data:
transformed_training_data = Cached(transformed_training_data)
training_data_loader = self.create_training_data_loader(
transformed_training_data,
shuffle_buffer_length=shuffle_buffer_length,
)
validation_data_loader = None
if validation_data is not None:
with env._let(max_idle_transforms=max(len(validation_data), 100)):
transformed_validation_data: Union[
TransformedDataset, Cached
] = transformation.apply(validation_data)
if cache_data:
transformed_validation_data = Cached(
transformed_validation_data
)
validation_data_loader = self.create_validation_data_loader(
transformed_validation_data
)
training_network = self.create_training_network()
if from_predictor is None:
training_network.initialize(
ctx=self.trainer.ctx, init=self.trainer.init
)
else:
copy_parameters(from_predictor.network, training_network)
self.trainer(
net=training_network,
train_iter=training_data_loader,
validation_iter=validation_data_loader,
)
with self.trainer.ctx:
predictor = self.create_predictor(transformation, training_network)
return TrainOutput(
transformation=transformation,
trained_net=training_network,
predictor=predictor,
)
[docs] def train(
self,
training_data: Dataset,
validation_data: Optional[Dataset] = None,
shuffle_buffer_length: Optional[int] = None,
cache_data: bool = False,
**kwargs,
) -> Predictor:
return self.train_model(
training_data=training_data,
validation_data=validation_data,
shuffle_buffer_length=shuffle_buffer_length,
cache_data=cache_data,
).predictor
[docs] def train_from(
self,
predictor: GluonPredictor,
training_data: Dataset,
validation_data: Optional[Dataset] = None,
shuffle_buffer_length: Optional[int] = None,
cache_data: bool = False,
) -> Predictor:
return self.train_model(
training_data=training_data,
validation_data=validation_data,
shuffle_buffer_length=shuffle_buffer_length,
cache_data=cache_data,
from_predictor=predictor,
).predictor