gluonts.mx.trainer.model_iteration_averaging module

class gluonts.mx.trainer.model_iteration_averaging.Alpha_Suffix(epochs: int, alpha: float = 0.75, eta: float = 0)[source]

Bases: gluonts.mx.trainer.model_iteration_averaging.IterationAveragingStrategy

Implement Alpha Suffix model averaging. This method is based on paper “Making Gradient Descent Optimalfor Strongly Convex Stochastic Optimization” (https://arxiv.org/pdf/1109.5647.pdf).

alpha_suffix: float = None
update_average_trigger(metric: Any = None, epoch: int = 0, **kwargs)[source]
Parameters
  • metric – The criteria to trigger averaging, not used in Alpha Suffix.

  • epoch – The epoch to start averaging.

class gluonts.mx.trainer.model_iteration_averaging.IterationAveragingStrategy(eta: float = 0)[source]

Bases: object

The model averaging is based on paper “Stochastic Gradient Descent for Non-smooth Optimization: Convergence Results and Optimal Averaging Schemes”, (http://proceedings.mlr.press/v28/shamir13.pdf), which implements polynomial-decay averaging, parameterized by eta. When eta = 0, it is equivalent to simple average over all iterations with same weights.

apply(model: mxnet.gluon.block.HybridBlock) → Optional[Dict][source]
Parameters

model – The model of the current iteration.

Returns

Return type

The averaged model, None if the averaging hasn’t started.

average_counter: int = None
averaged_model: Optional[Dict[str, mx.nd.NDArray]] = None
averaging_started: bool = None
cached_model: Optional[Dict[str, mx.nd.NDArray]] = None
load_averaged_model(model: mxnet.gluon.block.HybridBlock)[source]

When validating/evaluating the averaged model in the half way of training, use load_averaged_model first to load the averaged model and overwrite the current model, do the evaluation, and then use load_cached_model to load the current model back.

Parameters

model – The model that the averaged model is loaded to.

load_cached_model(model: mxnet.gluon.block.HybridBlock)[source]
Parameters

model – The model that the cached model is loaded to.

update_average(model: mxnet.gluon.block.HybridBlock)[source]
Parameters

model – The model to update the average.

update_average_trigger(metric: Any = None, epoch: int = 0, **kwargs)[source]
Parameters
  • metric – The criteria to trigger averaging.

  • epoch – The epoch to start averaging.

class gluonts.mx.trainer.model_iteration_averaging.ModelIterationAveraging(avg_strategy: gluonts.mx.trainer.model_iteration_averaging.IterationAveragingStrategy)[source]

Bases: gluonts.mx.trainer.callback.Callback

Callback to implement iteration based model averaging strategies.

Parameters

avg_strategy – IterationAveragingStrategy, one of NTA or Alpha_Suffix from gluonts.mx.trainer.model_iteration_averaging

on_epoch_end(epoch_no: int, epoch_loss: float, training_network: mxnet.gluon.block.HybridBlock, trainer: mxnet.gluon.trainer.Trainer, best_epoch_info: Dict[str, Any], ctx: mxnet.context.Context) → bool[source]

Hook that is called after every epoch. As on_train_epoch_end and on_validation_epoch_end, it returns a boolean whether training should continue. This hook is always called after on_train_epoch_end and on_validation_epoch_end. It is called regardless of these hooks’ return values.

Parameters
  • epoch_no – The current epoch (the first epoch has epoch_no = 0).

  • epoch_loss – The validation loss that was recorded in the last epoch if validation data was provided. The training loss otherwise.

  • training_network – The network that is being trained.

  • trainer – The trainer which is running the training.

  • best_epoch_info – Aggregate information about the best epoch. Contains keys params_path, epoch_no and score. The score is the best validation loss if validation data is provided or the best training loss otherwise.

  • ctx – The MXNet context used.

Returns

A boolean whether the training should continue. Defaults to True.

Return type

bool

on_train_batch_end(training_network: mxnet.gluon.block.HybridBlock) → None[source]

Hook that is called after each training batch.

Parameters

training_network – The network that is being trained.

on_train_end(training_network: mxnet.gluon.block.HybridBlock, temporary_dir: str, ctx: mxnet.context.Context = None) → None[source]

Hook that is called after training is finished. This is the last hook to be called.

Parameters
  • training_network – The network that was trained.

  • temporary_dir – The directory where model parameters are logged throughout training.

  • ctx – An MXNet context used.

on_validation_epoch_end(epoch_no: int, epoch_loss: float, training_network: mxnet.gluon.block.HybridBlock, trainer: mxnet.gluon.trainer.Trainer) → bool[source]

Hook that is called after each validation epoch. Similar to on_train_epoch_end, this method returns a boolean whether training should continue. Note that it is always called after on_train_epoch_end within a single epoch. If on_train_epoch_end returned False, this method will not be called.

Parameters
  • epoch_no – The current epoch (the first epoch has epoch_no = 0).

  • epoch_loss – The validation loss that was recorded in the last epoch.

  • training_network – The network that is being trained.

  • trainer – The trainer which is running the training.

Returns

A boolean whether the training should continue. Defaults to True.

Return type

bool

on_validation_epoch_start(training_network: mxnet.gluon.block.HybridBlock) → None[source]

Hook that is called prior to each validation epoch. This hook is never called if no validation data is available during training.

Parameters

training_network – The network that is being trained.

class gluonts.mx.trainer.model_iteration_averaging.NTA(epochs: int, n: int = 5, maximize: bool = False, last_n_trigger: bool = False, eta: float = 0, fallback_alpha: float = 0.05)[source]

Bases: gluonts.mx.trainer.model_iteration_averaging.IterationAveragingStrategy

Implement Non-monotonically Triggered AvSGD (NTA). This method is based on paper “Regularizing and Optimizing LSTM Language Models”, (https://openreview.net/pdf?id=SyyGPP0TZ), and an implementation is available in Salesforce GitHub (https://github.com/salesforce/awd-lstm-lm/blob/master/main.py). Note that it mismatches the arxiv (and gluonnlp) version, which is referred to as NTA_V2 below.

update_average_trigger(metric: Any = None, epoch: int = 0, **kwargs)[source]
Parameters
  • metric – The criteria to trigger averaging.

  • epoch – The epoch to start averaging, not used in NTA

val_logs: List[Any] = None