Source code for gluonts.mx.trainer._base

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import itertools
import logging
import os
import tempfile
import time
import uuid
import warnings
from typing import cast, List, Optional, Union

import mxnet as mx
import mxnet.autograd as autograd
import mxnet.gluon.nn as nn
import numpy as np
from mxnet.metric import ndarray

from gluonts.core.component import validated
from gluonts.dataset.loader import DataLoader
from gluonts.exceptions import GluonTSDataError
from gluonts.gluonts_tqdm import tqdm
from gluonts.mx.context import get_mxnet_context
from gluonts.mx.trainer.callback import Callback, CallbackList
from gluonts.mx.util import HybridContext

from .learning_rate_scheduler import LearningRateReduction
from .model_averaging import SelectNBestMean, save_epoch_info, ModelAveraging

logger = logging.getLogger("gluonts").getChild("trainer")


MODEL_ARTIFACT_FILE_NAME = "model"
STATE_ARTIFACT_FILE_NAME = "state"

# make the IDE happy: mx.py does not explicitly import autograd
mx.autograd = autograd


def check_loss_finite(val: float) -> None:
    if not np.isfinite(val):
        raise GluonTSDataError(
            "Encountered invalid loss value! Try reducing the learning rate "
            "or try a different likelihood."
        )


def loss_value(loss: mx.metric.Loss) -> float:
    return loss.get_name_value()[0][1]


[docs]class Trainer: r""" A trainer specifies how a network is going to be trained. A trainer is mainly defined by two sets of parameters. The first one determines the number of examples that the network will be trained on (`epochs`, `num_batches_per_epoch`), while the second one specifies how the gradient updates are performed (`learning_rate`, `learning_rate_decay_factor`, `patience`, `minimum_learning_rate`, `clip_gradient` and `weight_decay`). Parameters ---------- ctx epochs Number of epochs that the network will train (default: 100). num_batches_per_epoch Number of batches at each epoch (default: 50). learning_rate Initial learning rate (default: :math:`10^{-3}`). learning_rate_decay_factor Factor (between 0 and 1) by which to decrease the learning rate (default: 0.5). patience The patience to observe before reducing the learning rate, nonnegative integer (default: 10). minimum_learning_rate Lower bound for the learning rate (default: :math:`5\cdot 10^{-5}`). clip_gradient Maximum value of gradient. The gradient is clipped if it is too large (default: 10). weight_decay The weight decay (or L2 regularization) coefficient. Modifies objective by adding a penalty for having large weights (default :math:`10^{-8}`). init Initializer of the weights of the network (default: "xavier"). hybridize If set to True the network will be hybridized before training callbacks A list of `gluonts.mx.trainer.callback.Callback` to control the training. add_default_callbacks bool, True by default. If `True`, LearningRateReduction and ModelAveragingCallbacks are used in addition to the callbacks specified in the callbacks argument. Make sure that you only set this to true if you don't specify one of the default callbacks yourself or there will be "duplicate callbacks". default callbacks: >>> callbacks = [ ... ModelAveraging(avg_strategy=SelectNBestMean(num_models=1)), ... LearningRateReduction( ... base_lr=1e-3, # learning_rate ... decay_factor=0.5, # learning_rate_decay_factor ... patience=10, # patience ... min_lr=5e-5, # minimum_learning_rate ... objective="min", ... ) ... ] """ @validated() def __init__( self, ctx: Optional[mx.Context] = None, epochs: int = 100, batch_size: Optional[int] = None, num_batches_per_epoch: int = 50, learning_rate: float = 1e-3, learning_rate_decay_factor: float = 0.5, patience: int = 10, minimum_learning_rate: float = 5e-5, clip_gradient: float = 10.0, weight_decay: float = 1e-8, init: Union[str, mx.initializer.Initializer] = "xavier", hybridize: bool = True, callbacks: Optional[List[Callback]] = None, add_default_callbacks: bool = True, ) -> None: if batch_size is not None: warnings.warn( "batch_size argument is deprecated", DeprecationWarning, stacklevel=2, ) else: batch_size = 32 assert isinstance(batch_size, int) # TODO param disable_default_callbacks to get backwards compatibility # deprecation warnings, in the future, the following callbacks should # be controlled by altering callbacks: if learning_rate_decay_factor is not None: warnings.warn( 'Trainer argument "learning_rate_decay_factor" is deprecated.' " Use callbacks instead.", DeprecationWarning, ) assert 0 <= learning_rate_decay_factor < 1, ( "The value of `learning_rate_decay_factor` should be in the" " [0, 1) range" ) if patience is not None: warnings.warn( 'Trainer argument "patience" is deprecated. Use callbacks' " instead.", DeprecationWarning, ) assert 0 <= patience, "The value of `patience` should be >= 0" if minimum_learning_rate: warnings.warn( 'Trainer argument "minimum_learning_rate" is deprecated. Use' " callbacks instead.", DeprecationWarning, ) assert ( 0 <= minimum_learning_rate ), "The value of `minimum_learning_rate` should be >= 0" assert ( 0 <= epochs < float("inf") ), "The value of `epochs` should be >= 0" assert 0 < batch_size, "The value of `batch_size` should be > 0" assert ( 0 < num_batches_per_epoch ), "The value of `num_batches_per_epoch` should be > 0" assert ( 0 < learning_rate < float("inf") ), "The value of `learning_rate` should be > 0" assert 0 < clip_gradient, "The value of `clip_gradient` should be > 0" assert 0 <= weight_decay, "The value of `weight_decay` should be => 0" self.epochs = epochs self.batch_size = batch_size self.num_batches_per_epoch = num_batches_per_epoch self.learning_rate = learning_rate self.learning_rate_decay_factor = learning_rate_decay_factor self.patience = patience self.minimum_learning_rate = minimum_learning_rate self.clip_gradient = clip_gradient self.weight_decay = weight_decay self.init = init self.hybridize = hybridize self.ctx = ctx if ctx is not None else get_mxnet_context() self.halt = False # Make sure callbacks is list -- they are assigned to `self.callbacks` # below callbacks = callbacks or [] # TODO the following is done for backwards compatibility. For future # versions, add the default callbacks as default arg if add_default_callbacks: default_callbacks = [ ModelAveraging(avg_strategy=SelectNBestMean(num_models=1)), LearningRateReduction( base_lr=learning_rate, decay_factor=learning_rate_decay_factor, patience=patience, min_lr=minimum_learning_rate, objective="min", ), ] self.callbacks = CallbackList(callbacks + default_callbacks) else: self.callbacks = CallbackList(callbacks)
[docs] def count_model_params(self, net: nn.HybridBlock) -> int: params = net.collect_params() num_params = 0 for p in params: v = params[p] num_params += np.prod(v.shape) return num_params
def __call__( self, net: nn.HybridBlock, train_iter: DataLoader, validation_iter: Optional[DataLoader] = None, ) -> None: # TODO: we may want to return some training information here """ Train a network, given an iterable over training (and optionally validation) batches. Parameters ---------- net Network to be trained. This a Gluon HybridBlock, assumed to produce a tensor of loss values as output. train_iter An iterable over batches to be used for training. Batches are assumed to be dictionaries, whose values are MXNet arrays that correspond to the network inputs. validation_iter Similar to `train_iter` but the batches produced here are used to compute validation metrics. """ is_validation_available = validation_iter is not None logger.info("Start model training") net.initialize(ctx=self.ctx, init=self.init) with tempfile.TemporaryDirectory( prefix="gluonts-trainer-temp-" ) as gluonts_temp, HybridContext( net=net, hybridize=self.hybridize, static_alloc=True, static_shape=True, ): def base_path() -> str: return os.path.join( gluonts_temp, f"{STATE_ARTIFACT_FILE_NAME}_{uuid.uuid4()}", ) best_epoch_info = { "params_path": "{}-{}.params".format(base_path(), "init"), "epoch_no": -1, "score": np.Inf, } optimizer = mx.optimizer.Adam( learning_rate=self.learning_rate, wd=self.weight_decay, clip_gradient=self.clip_gradient, ) trainer = mx.gluon.Trainer( net.collect_params(), optimizer=optimizer, kvstore="device", # FIXME: initialize properly ) first_forward = True def loop( # todo call run epoch epoch_no, batch_iter, num_batches_to_use: Optional[int] = None, is_training: bool = True, ) -> mx.metric.Loss: nonlocal first_forward tic = time.time() epoch_loss = mx.metric.Loss() if is_training: # We should not call this method if we haven't compiled the # network yet. Instead, this callback is called after # network initialization. if not first_forward: self.callbacks.on_train_epoch_start( training_network=net ) else: self.callbacks.on_validation_epoch_start( training_network=net ) batch_iter = itertools.islice(batch_iter, num_batches_to_use) it = tqdm(batch_iter, total=num_batches_to_use) any_batches = False for batch_no, batch in enumerate(it, start=1): any_batches = True # `batch` here is expected to be a dictionary whose fields # should correspond 1-to-1 with the network inputs # see below how `batch.values()` is fed into the network if self.halt: break if first_forward: first_forward = False _ = net(*batch.values()) self.callbacks.on_network_initializing_end( training_network=net ) # Call the batch start callback as the model was not # compiled before self.callbacks.on_train_epoch_start( training_network=net ) with mx.autograd.record(): # we set the mode explicitly as by default mxnet # assumes predict mode and hence dropout layers are # not used if the mode is not explicitly set to # training mode = ( autograd.train_mode if is_training else autograd.predict_mode ) with mode(): output = net(*batch.values()) # network can returns several outputs, the first being # always the loss when having multiple outputs, the # forward returns a list in the case of hybrid and a # tuple otherwise we may wrap network outputs in the # future to avoid this type check if isinstance(output, (list, tuple)): loss = output[0] else: loss = output batch_size = loss.shape[0] if not np.isfinite(ndarray.sum(loss).asscalar()): logger.warning( "Batch [%d] of Epoch[%d] gave NaN loss and it will" " be ignored", batch_no, epoch_no, ) should_continue = True else: if is_training: loss.backward() trainer.step(batch_size) should_continue = ( self.callbacks.on_train_batch_end( training_network=net ) ) else: should_continue = ( self.callbacks.on_validation_batch_end( training_network=net ) ) epoch_loss.update(None, preds=loss) lv = loss_value(epoch_loss) it.set_postfix( ordered_dict={ "epoch": f"{epoch_no + 1}/{self.epochs}", ("" if is_training else "validation_") + "avg_epoch_loss": lv, }, refresh=False, ) # print out parameters of the network at the first pass if batch_no == 1 and epoch_no == 0: net_name = type(net).__name__ num_model_param = self.count_model_params(net) logger.info( f"Number of parameters in {net_name}:" f" {num_model_param}" ) if not should_continue: self.halt = True break it.close() if not any_batches: raise GluonTSDataError( "No training data batch could be constructed; " "this usually indicates that the training dataset " "is empty, or consists of too short series." ) # mark epoch end time and log time cost of current epoch if not self.halt: toc = time.time() logger.info( "Epoch[%d] Elapsed time %.3f seconds", epoch_no, (toc - tic), ) logger.info( "Epoch[%d] Evaluation metric '%s'=%f", epoch_no, ("" if is_training else "validation_") + "epoch_loss", lv, ) return epoch_loss self.callbacks.on_train_start(max_epochs=self.epochs) try: for epoch_no in range(self.epochs): if self.halt: logger.info(f"Epoch[{epoch_no}] Interrupting training") break curr_lr = trainer.learning_rate logger.info( f"Epoch[{epoch_no}] Learning rate is {curr_lr}" ) epoch_loss = loop( epoch_no, train_iter, num_batches_to_use=self.num_batches_per_epoch, ) should_continue = self.callbacks.on_train_epoch_end( epoch_no=epoch_no, epoch_loss=loss_value(epoch_loss), training_network=net, trainer=trainer, ) if is_validation_available: epoch_loss = loop( epoch_no, validation_iter, is_training=False ) should_continue = ( should_continue and self.callbacks.on_validation_epoch_end( epoch_no=epoch_no, epoch_loss=loss_value(epoch_loss), training_network=net, trainer=trainer, ) ) # save model and epoch info bp = base_path() epoch_info = { "params_path": f"{bp}-0000.params", "epoch_no": epoch_no, "score": loss_value(epoch_loss), } net.save_parameters( epoch_info["params_path"] ) # TODO: handle possible exception save_epoch_info(bp, epoch_info) # update best epoch info if loss_value(epoch_loss) < cast( float, best_epoch_info["score"] ): best_epoch_info = epoch_info.copy() should_continue = ( should_continue and self.callbacks.on_epoch_end( epoch_no=epoch_no, epoch_loss=loss_value(epoch_loss), training_network=net, trainer=trainer, best_epoch_info=best_epoch_info, ctx=self.ctx, ) ) if not should_continue: logger.info("Stopping training") break except KeyboardInterrupt: warnings.warn( "Detected KeyboardInterrupt, attempting graceful " "shutdown..." ) # save model and epoch info bp = base_path() epoch_info = { "params_path": f"{bp}-0000.params", "epoch_no": epoch_no, "score": loss_value(epoch_loss), } net.save_parameters(epoch_info["params_path"]) save_epoch_info(bp, epoch_info) self.callbacks.on_train_end( training_network=net, temporary_dir=gluonts_temp, ctx=self.ctx, ) logger.info("End model training")