Source code for gluonts.mx.trainer.callback

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from dataclasses import dataclass, field
from typing import List, Union, Dict, Any, Optional
import logging
import math
import time

# Third-party imports
import mxnet.gluon.nn as nn
import mxnet as mx
from mxnet import gluon
from pydantic import BaseModel, PrivateAttr

# First-party imports
from gluonts.core.component import validated
from gluonts.mx.util import copy_parameters

logger = logging.getLogger(__name__)


[docs]class Callback: """ Abstract Callback base class. Callbacks control the training of the GluonTS trainer. To write a custom Callback, you can subclass Callback and overwrite one or more of the hook methods. Hook methods with boolean return value stop the training if False is returned. """
[docs] def on_train_start(self, max_epochs: int) -> None: """ Hook that is called prior to training. This is the very first hook to be called. Parameters ---------- max_epochs The maximum number of epochs that training is running. The actual number of epochs may be fewer if another callback hook stops training early. """
[docs] def on_network_initializing_end( self, training_network: nn.HybridBlock ) -> None: """ Hook that is called prior to training, after the training network has been initialized. This is the first hook where the network is passed. Parameters ---------- training_network The network that is being trained. """
[docs] def on_train_epoch_start(self, training_network: nn.HybridBlock) -> None: """ Hook that is called prior to each training epoch. Parameters ---------- training_network The network that is being trained. """
[docs] def on_validation_epoch_start( self, training_network: nn.HybridBlock ) -> None: """ Hook that is called prior to each validation epoch. This hook is never called if no validation data is available during training. Parameters ---------- training_network The network that is being trained. """
[docs] def on_train_batch_end(self, training_network: nn.HybridBlock) -> bool: """ Hook that is called after each training batch. Parameters ---------- training_network The network that is being trained. Returns ------- bool A boolean whether the training should continue. Defaults to `True`. """ return True
[docs] def on_validation_batch_end( self, training_network: nn.HybridBlock ) -> bool: """ Hook that is called after each validation batch. This hook is never called if no validation data is available during training. Parameters ---------- training_network The network that is being trained. Returns ------- bool A boolean whether the training should continue. Defaults to `True`. """ return True
[docs] def on_train_epoch_end( self, epoch_no: int, epoch_loss: float, training_network: nn.HybridBlock, trainer: gluon.Trainer, ) -> bool: """ Hook that is called after each training epoch. This method returns a boolean whether training should continue. Parameters ---------- epoch_no The current epoch (the first epoch has `epoch_no = 0`). epoch_loss The loss that was recorded in the last epoch. training_network The network that is being trained. trainer The trainer which is running the training. Returns ------- bool A boolean whether the training should continue. Defaults to `True`. """ return True
[docs] def on_validation_epoch_end( self, epoch_no: int, epoch_loss: float, training_network: nn.HybridBlock, trainer: gluon.Trainer, ) -> bool: """ Hook that is called after each validation epoch. Similar to `on_train_epoch_end`, this method returns a boolean whether training should continue. Note that it is always called after `on_train_epoch_end` within a single epoch. If `on_train_epoch_end` returned `False`, this method will not be called. Parameters ---------- epoch_no The current epoch (the first epoch has `epoch_no = 0`). epoch_loss The validation loss that was recorded in the last epoch. training_network The network that is being trained. trainer The trainer which is running the training. Returns ------- bool A boolean whether the training should continue. Defaults to `True`. """ return True
[docs] def on_epoch_end( self, epoch_no: int, epoch_loss: float, training_network: nn.HybridBlock, trainer: gluon.Trainer, best_epoch_info: Dict[str, Any], ctx: mx.Context, ) -> bool: """ Hook that is called after every epoch. As `on_train_epoch_end` and `on_validation_epoch_end`, it returns a boolean whether training should continue. This hook is always called after `on_train_epoch_end` and `on_validation_epoch_end`. It is called regardless of these hooks' return values. Parameters ---------- epoch_no The current epoch (the first epoch has `epoch_no = 0`). epoch_loss The validation loss that was recorded in the last epoch if validation data was provided. The training loss otherwise. training_network The network that is being trained. trainer The trainer which is running the training. best_epoch_info Aggregate information about the best epoch. Contains keys `params_path`, `epoch_no` and `score`. The score is the best validation loss if validation data is provided or the best training loss otherwise. ctx The MXNet context used. Returns ------- bool A boolean whether the training should continue. Defaults to `True`. """ return True
[docs] def on_train_end( self, training_network: nn.HybridBlock, temporary_dir: str, ctx: mx.context.Context = None, ) -> None: """ Hook that is called after training is finished. This is the last hook to be called. Parameters ---------- training_network The network that was trained. temporary_dir The directory where model parameters are logged throughout training. ctx An MXNet context used. """
[docs]class CallbackList(Callback): """ Used to chain a list of callbacks to one Callback. Boolean hook methods are logically joined with AND, meaning that if at least one callback method returns False, the training is stopped. Attributes ---------- callbacks A list of gluonts.mx.trainer.callback.Callback's. """ @validated() def __init__(self, callbacks: List[Callback]): self.callbacks = callbacks def _exec( self, function_name: str, *args: Any, **kwargs: Any ) -> Union[List[bool], List[None]]: return [ getattr(callback, function_name)(*args, **kwargs) for callback in self.callbacks ]
[docs] def on_train_start(self, *args: Any, **kwargs: Any) -> None: self._exec("on_train_start", *args, **kwargs)
[docs] def on_network_initializing_end(self, *args: Any, **kwargs: Any) -> None: self._exec("on_network_initializing_end", *args, **kwargs)
[docs] def on_train_epoch_start(self, *args: Any, **kwargs: Any) -> None: self._exec("on_train_epoch_start", *args, **kwargs)
[docs] def on_validation_epoch_start(self, *args: Any, **kwargs: Any) -> None: self._exec("on_validation_epoch_start", *args, **kwargs)
[docs] def on_train_batch_end(self, *args: Any, **kwargs: Any) -> bool: return all(self._exec("on_train_batch_end", *args, **kwargs))
[docs] def on_validation_batch_end(self, *args: Any, **kwargs: Any) -> bool: return all(self._exec("on_validation_batch_end", *args, **kwargs))
[docs] def on_train_epoch_end(self, *args: Any, **kwargs: Any) -> bool: return all(self._exec("on_train_epoch_end", *args, **kwargs))
[docs] def on_validation_epoch_end(self, *args: Any, **kwargs: Any) -> bool: return all(self._exec("on_validation_epoch_end", *args, **kwargs))
[docs] def on_epoch_end(self, *args: Any, **kwargs: Any) -> bool: return all(self._exec("on_epoch_end", *args, **kwargs))
[docs] def on_train_end(self, *args: Any, **kwargs: Any) -> None: self._exec("on_train_end", *args, **kwargs)
[docs]class TrainingHistory(Callback): @validated() def __init__(self): self.loss_history = [] self.validation_loss_history = []
[docs] def on_train_epoch_end( self, epoch_no: int, epoch_loss: float, training_network: nn.HybridBlock, trainer: gluon.Trainer, ) -> bool: self.loss_history.append(epoch_loss) return True
[docs] def on_validation_epoch_end( self, epoch_no: int, epoch_loss: float, training_network: nn.HybridBlock, trainer: gluon.Trainer, ) -> bool: self.validation_loss_history.append(epoch_loss) return True
[docs]class TerminateOnNaN(Callback):
[docs] def on_train_epoch_end( self, epoch_no: int, epoch_loss: float, training_network: nn.HybridBlock, trainer: gluon.Trainer, ) -> bool: if math.isnan(epoch_loss): logging.warning( "TerminateOnNaN Callback initiated stop of training at epoch" f" {epoch_no}." ) return False return True
[docs]class WarmStart(Callback): @validated() def __init__(self, predictor): self.predictor = predictor
[docs] def on_network_initializing_end( self, training_network: nn.HybridBlock ) -> None: copy_parameters(self.predictor.prediction_net, training_network)
@dataclass class _Timer: duration: float _start: Optional[float] = field(init=False, default=None) def start(self): assert self._start is None self._start = time.time() def remaining(self) -> float: assert self._start is not None return max(self._start - time.time(), 0.0) def is_running(self) -> bool: return self.remaining() > 0
[docs]class TrainingTimeLimit(BaseModel, Callback): """Limit time spent for training. This is useful when ensuring that training for a given model doesn't exceed a budget, for example when doing AutoML. If `stop_within_epoch` is set to true, training can be stopped after each batch, otherwise it stops after the end of the epoch. """ time_limit: float stop_within_epoch: bool = False _timer: _Timer = PrivateAttr() def __init__(self, **data): super().__init__(**data) self._timer = _Timer(self.time_limit)
[docs] def on_train_start(self, max_epochs: int) -> None: self._timer.start()
[docs] def on_train_batch_end(self, training_network: nn.HybridBlock) -> bool: if self.stop_within_epoch: return self._timer.is_running() return True
[docs] def on_epoch_end( self, epoch_no: int, epoch_loss: float, training_network: nn.HybridBlock, trainer: gluon.Trainer, best_epoch_info: Dict[str, Any], ctx: mx.Context, ) -> bool: return self._timer.is_running()