Source code for gluonts.mx.trainer.callback

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

# Standard library imports
from typing import List, Union, Dict, Any
import logging
import math

# Third-party imports
import mxnet.gluon.nn as nn
import mxnet as mx
from mxnet import gluon

# First-party imports
from gluonts.core.component import validated
from gluonts.mx.util import copy_parameters


[docs]class Callback: """ Abstract Callback base class. Callbacks control the training of the GluonTS trainer. To write a custom Callback, you can subclass Callback and overwrite one or more of the hook methods. Hook methods with boolean return value stop the training if False is returned. """
[docs] def on_train_start(self, max_epochs: int) -> None: """ Hook that is called prior to training. This is the very first hook to be called. Parameters ---------- max_epochs The maximum number of epochs that training is running. The actual number of epochs may be fewer if another callback hook stops training early. """
[docs] def on_network_initializing_end( self, training_network: nn.HybridBlock ) -> None: """ Hook that is called prior to training, after the training network has been initialized. This is the first hook where the network is passed. Parameters ---------- training_network The network that is being trained. """
[docs] def on_train_epoch_start(self, training_network: nn.HybridBlock) -> None: """ Hook that is called prior to each training epoch. Parameters ---------- training_network The network that is being trained. """
[docs] def on_validation_epoch_start( self, training_network: nn.HybridBlock ) -> None: """ Hook that is called prior to each validation epoch. This hook is never called if no validation data is available during training. Parameters ---------- training_network The network that is being trained. """
[docs] def on_train_batch_end(self, training_network: nn.HybridBlock) -> None: """ Hook that is called after each training batch. Parameters ---------- training_network The network that is being trained. """
[docs] def on_validation_batch_end( self, training_network: nn.HybridBlock ) -> None: """ Hook that is called after each validation batch. This hook is never called if no validation data is available during training. Parameters ---------- training_network The network that is being trained. """
[docs] def on_train_epoch_end( self, epoch_no: int, epoch_loss: float, training_network: nn.HybridBlock, trainer: gluon.Trainer, ) -> bool: """ Hook that is called after each training epoch. This method returns a boolean whether training should continue. Parameters ---------- epoch_no The current epoch (the first epoch has `epoch_no = 0`). epoch_loss The loss that was recorded in the last epoch. training_network The network that is being trained. trainer The trainer which is running the training. Returns ------- bool A boolean whether the training should continue. Defaults to `True`. """ return True
[docs] def on_validation_epoch_end( self, epoch_no: int, epoch_loss: float, training_network: nn.HybridBlock, trainer: gluon.Trainer, ) -> bool: """ Hook that is called after each validation epoch. Similar to `on_train_epoch_end`, this method returns a boolean whether training should continue. Note that it is always called after `on_train_epoch_end` within a single epoch. If `on_train_epoch_end` returned `False`, this method will not be called. Parameters ---------- epoch_no The current epoch (the first epoch has `epoch_no = 0`). epoch_loss The validation loss that was recorded in the last epoch. training_network The network that is being trained. trainer The trainer which is running the training. Returns ------- bool A boolean whether the training should continue. Defaults to `True`. """ return True
[docs] def on_epoch_end( self, epoch_no: int, epoch_loss: float, training_network: nn.HybridBlock, trainer: gluon.Trainer, best_epoch_info: Dict[str, Any], ctx: mx.Context, ) -> bool: """ Hook that is called after every epoch. As `on_train_epoch_end` and `on_validation_epoch_end`, it returns a boolean whether training should continue. This hook is always called after `on_train_epoch_end` and `on_validation_epoch_end`. It is called regardless of these hooks' return values. Parameters ---------- epoch_no The current epoch (the first epoch has `epoch_no = 0`). epoch_loss The validation loss that was recorded in the last epoch if validation data was provided. The training loss otherwise. training_network The network that is being trained. trainer The trainer which is running the training. best_epoch_info Aggregate information about the best epoch. Contains keys `params_path`, `epoch_no` and `score`. The score is the best validation loss if validation data is provided or the best training loss otherwise. ctx The MXNet context used. Returns ------- bool A boolean whether the training should continue. Defaults to `True`. """ return True
[docs] def on_train_end( self, training_network: nn.HybridBlock, temporary_dir: str, ctx: mx.context.Context = None, ) -> None: """ Hook that is called after training is finished. This is the last hook to be called. Parameters ---------- training_network The network that was trained. temporary_dir The directory where model parameters are logged throughout training. ctx An MXNet context used. """
[docs]class CallbackList(Callback): """ Used to chain a list of callbacks to one Callback. Boolean hook methods are logically joined with AND, meaning that if at least one callback method returns False, the training is stopped. Attributes ---------- callbacks A list of gluonts.mx.trainer.callback.Callback's. """ @validated() def __init__(self, callbacks: List[Callback]): self.callbacks = callbacks def _exec( self, function_name: str, *args: Any, **kwargs: Any ) -> Union[List[bool], List[None]]: return [ getattr(callback, function_name)(*args, **kwargs) for callback in self.callbacks ]
[docs] def on_train_start(self, *args: Any, **kwargs: Any) -> None: self._exec("on_train_start", *args, **kwargs)
[docs] def on_network_initializing_end(self, *args: Any, **kwargs: Any) -> None: self._exec("on_network_initializing_end", *args, **kwargs)
[docs] def on_train_epoch_start(self, *args: Any, **kwargs: Any) -> None: self._exec("on_train_epoch_start", *args, **kwargs)
[docs] def on_validation_epoch_start(self, *args: Any, **kwargs: Any) -> None: self._exec("on_validation_epoch_start", *args, **kwargs)
[docs] def on_train_batch_end(self, *args: Any, **kwargs: Any) -> None: self._exec("on_train_batch_end", *args, **kwargs)
[docs] def on_validation_batch_end(self, *args: Any, **kwargs: Any) -> None: self._exec("on_validation_batch_end", *args, **kwargs)
[docs] def on_train_epoch_end(self, *args: Any, **kwargs: Any) -> bool: return all(self._exec("on_train_epoch_end", *args, **kwargs))
[docs] def on_validation_epoch_end(self, *args: Any, **kwargs: Any) -> bool: return all(self._exec("on_validation_epoch_end", *args, **kwargs))
[docs] def on_epoch_end(self, *args: Any, **kwargs: Any) -> bool: return all(self._exec("on_epoch_end", *args, **kwargs))
[docs] def on_train_end(self, *args: Any, **kwargs: Any) -> None: self._exec("on_train_end", *args, **kwargs)
[docs]class TrainingHistory(Callback): @validated() def __init__(self): self.loss_history = [] self.validation_loss_history = []
[docs] def on_train_epoch_end( self, epoch_no: int, epoch_loss: float, training_network: nn.HybridBlock, trainer: gluon.Trainer, ) -> bool: self.loss_history.append(epoch_loss) return True
[docs] def on_validation_epoch_end( self, epoch_no: int, epoch_loss: float, training_network: nn.HybridBlock, trainer: gluon.Trainer, ) -> bool: self.validation_loss_history.append(epoch_loss) return True
[docs]class TerminateOnNaN(Callback):
[docs] def on_train_epoch_end( self, epoch_no: int, epoch_loss: float, training_network: nn.HybridBlock, trainer: gluon.Trainer, ) -> bool: if math.isnan(epoch_loss): logging.warning( f"TerminateOnNaN Callback initiated stop of training at epoch {epoch_no}." ) return False return True
[docs]class WarmStart(Callback): @validated() def __init__(self, predictor): self.predictor = predictor
[docs] def on_network_initializing_end( self, training_network: nn.HybridBlock ) -> None: copy_parameters(self.predictor.prediction_net, training_network)