Source code for gluonts.mx.model.estimator

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import NamedTuple, Optional, Type, Union

import numpy as np

from gluonts.core import fqname_for
from gluonts.core.component import (
    GluonTSHyperparametersError,
    from_hyperparameters,
    validated,
)
from gluonts.dataset.common import Dataset
from gluonts.dataset.loader import DataLoader
from gluonts.env import env
from gluonts.itertools import Cached
from gluonts.model import Estimator, Predictor
from gluonts.mx.model.predictor import GluonPredictor
from gluonts.mx.trainer import Trainer
from gluonts.mx.util import copy_parameters
from gluonts.pydantic import ValidationError
from gluonts.transform import Transformation, TransformedDataset
from mxnet.gluon import HybridBlock


[docs]class TrainOutput(NamedTuple): transformation: Transformation trained_net: HybridBlock predictor: Predictor
[docs]class GluonEstimator(Estimator): """ An `Estimator` type with utilities for creating Gluon-based models. To extend this class, one needs to implement three methods: `create_transformation`, `create_training_network`, `create_predictor`, `create_training_data_loader`, and `create_validation_data_loader`. """ @validated() def __init__( self, *, trainer: Trainer, batch_size: int = 32, lead_time: int = 0, dtype: Type = np.float32, ) -> None: super().__init__(lead_time=lead_time) assert batch_size > 0, "The value of `batch_size` should be > 0" self.batch_size = batch_size self.trainer = trainer self.dtype = dtype
[docs] @classmethod def from_hyperparameters(cls, **hyperparameters) -> "GluonEstimator": Model = getattr(cls.__init__, "Model", None) if not Model: raise AttributeError( "Cannot find attribute Model attached to the " f"{fqname_for(cls)}. Most probably you have forgotten to mark " "the class constructor as @validated()." ) try: trainer = from_hyperparameters(Trainer, **hyperparameters) return cls( **Model(**{**hyperparameters, "trainer": trainer}).__dict__ ) except ValidationError as e: raise GluonTSHyperparametersError from e
[docs] def create_transformation(self) -> Transformation: """ Create and return the transformation needed for training and inference. Returns ------- Transformation The transformation that will be applied entry-wise to datasets, at training and inference time. """ raise NotImplementedError
[docs] def create_training_network(self) -> HybridBlock: """ Create and return the network used for training (i.e., computing the loss). Returns ------- HybridBlock The network that computes the loss given input data. """ raise NotImplementedError
[docs] def create_predictor( self, transformation: Transformation, trained_network: HybridBlock ) -> Predictor: """ Create and return a predictor object. Parameters ---------- transformation Transformation to be applied to data before it goes into the model. module A trained `HybridBlock` object. Returns ------- Predictor A predictor wrapping a `HybridBlock` used for inference. """ raise NotImplementedError
[docs] def create_training_data_loader( self, data: Dataset, **kwargs ) -> DataLoader: """ Create a data loader for training purposes. Parameters ---------- data Dataset from which to create the data loader. Returns ------- DataLoader The data loader, i.e. and iterable over batches of data. """ raise NotImplementedError
[docs] def create_validation_data_loader( self, data: Dataset, **kwargs ) -> DataLoader: """ Create a data loader for validation purposes. Parameters ---------- data Dataset from which to create the data loader. Returns ------- DataLoader The data loader, i.e. and iterable over batches of data. """ raise NotImplementedError
[docs] def train_model( self, training_data: Dataset, validation_data: Optional[Dataset] = None, from_predictor: Optional[GluonPredictor] = None, shuffle_buffer_length: Optional[int] = None, cache_data: bool = False, ) -> TrainOutput: transformation = self.create_transformation() with env._let(max_idle_transforms=max(len(training_data), 100)): transformed_training_data: Union[TransformedDataset, Cached] = ( transformation.apply(training_data) ) if cache_data: transformed_training_data = Cached(transformed_training_data) training_data_loader = self.create_training_data_loader( transformed_training_data, shuffle_buffer_length=shuffle_buffer_length, ) validation_data_loader = None if validation_data is not None: with env._let(max_idle_transforms=max(len(validation_data), 100)): transformed_validation_data: Union[ TransformedDataset, Cached ] = transformation.apply(validation_data) if cache_data: transformed_validation_data = Cached( transformed_validation_data ) validation_data_loader = self.create_validation_data_loader( transformed_validation_data ) training_network = self.create_training_network() if from_predictor is None: training_network.initialize( ctx=self.trainer.ctx, init=self.trainer.init ) else: copy_parameters(from_predictor.network, training_network) self.trainer( net=training_network, train_iter=training_data_loader, validation_iter=validation_data_loader, ) with self.trainer.ctx: predictor = self.create_predictor(transformation, training_network) return TrainOutput( transformation=transformation, trained_net=training_network, predictor=predictor, )
[docs] def train( self, training_data: Dataset, validation_data: Optional[Dataset] = None, shuffle_buffer_length: Optional[int] = None, cache_data: bool = False, **kwargs, ) -> Predictor: return self.train_model( training_data=training_data, validation_data=validation_data, shuffle_buffer_length=shuffle_buffer_length, cache_data=cache_data, ).predictor
[docs] def train_from( self, predictor: GluonPredictor, training_data: Dataset, validation_data: Optional[Dataset] = None, shuffle_buffer_length: Optional[int] = None, cache_data: bool = False, ) -> Predictor: return self.train_model( training_data=training_data, validation_data=validation_data, shuffle_buffer_length=shuffle_buffer_length, cache_data=cache_data, from_predictor=predictor, ).predictor