Source code for gluonts.mx.trainer.model_iteration_averaging

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import Any, Dict, List, Optional
import logging

import mxnet as mx
import mxnet.gluon.nn as nn
from mxnet import gluon

from gluonts.core.component import validated

from .callback import Callback


[docs]class IterationAveragingStrategy: r""" The model averaging is based on paper "Stochastic Gradient Descent for Non-smooth Optimization: Convergence Results and Optimal Averaging Schemes", (http://proceedings.mlr.press/v28/shamir13.pdf), which implements polynomial-decay averaging, parameterized by eta. When eta = 0, it is equivalent to simple average over all iterations with same weights. """ averaged_model: Optional[Dict[str, mx.nd.NDArray]] cached_model: Optional[Dict[str, mx.nd.NDArray]] average_counter: int averaging_started: bool @validated() def __init__(self, eta: float = 0): r""" Parameters ---------- eta Parameter of polynomial-decay averaging. """ self.eta = eta # Dict that maintains the averaged model parameters. self.averaged_model = None # Temporarily save the current model, so that the averaged model can be # used for validation. self.cached_model = None # The number of models accumulated in the average. self.average_counter = 0 # Indicate whether the model averaging has started. self.averaging_started = False
[docs] def update_average_trigger( self, metric: Any = None, epoch: int = 0, **kwargs ): r""" Parameters ---------- metric The criteria to trigger averaging. epoch The epoch to start averaging. Returns ------- """ raise NotImplementedError()
[docs] def apply(self, model: nn.HybridBlock) -> Optional[Dict]: r""" Parameters ---------- model The model of the current iteration. Returns ------- The averaged model, None if the averaging hasn't started. """ if self.averaging_started: self.update_average(model) return self.averaged_model
[docs] def update_average(self, model: nn.HybridBlock): r""" Parameters ---------- model The model to update the average. """ self.average_counter += 1 if self.averaged_model is None: self.averaged_model = { k: v.list_data()[0].copy() for k, v in model.collect_params().items() } else: alpha = (self.eta + 1.0) / (self.eta + self.average_counter) # moving average for name, param_avg in self.averaged_model.items(): param_avg[:] += alpha * ( model.collect_params()[name].list_data()[0] - param_avg )
[docs] def load_averaged_model(self, model: nn.HybridBlock): r""" When validating/evaluating the averaged model in the half way of training, use load_averaged_model first to load the averaged model and overwrite the current model, do the evaluation, and then use load_cached_model to load the current model back. Parameters ---------- model The model that the averaged model is loaded to. """ if self.averaged_model is not None: # cache the current model if self.cached_model is None: self.cached_model = { k: v.list_data()[0].copy() for k, v in model.collect_params().items() } else: for name, param_cached in self.cached_model.items(): param_cached[:] = model.collect_params()[name].list_data()[ 0 ] # load the averaged model for name, param_avg in self.averaged_model.items(): model.collect_params()[name].set_data(param_avg)
[docs] def load_cached_model(self, model: nn.HybridBlock): r""" Parameters ---------- model The model that the cached model is loaded to. """ if self.cached_model is not None: # load the cached model for name, param_cached in self.cached_model.items(): model.collect_params()[name].set_data(param_cached)
[docs]class NTA(IterationAveragingStrategy): r""" Implement Non-monotonically Triggered AvSGD (NTA). This method is based on paper "Regularizing and Optimizing LSTM Language Models", (https://openreview.net/pdf?id=SyyGPP0TZ), and an implementation is available in Salesforce GitHub (https://github.com/salesforce/awd-lstm-lm/blob/master/main.py). Note that it mismatches the arxiv (and gluonnlp) version, which is referred to as NTA_V2 below. """ val_logs: List[Any] @validated() def __init__( self, epochs: int, n: int = 5, maximize: bool = False, last_n_trigger: bool = False, eta: float = 0, fallback_alpha: float = 0.05, ): r""" Depending on the choice of metrics, the users may want to minimize or maximize the metrics. Thus, set maximize = True to maximize, otherwise minimize. Parameters ---------- epochs The total number of epochs. n The non-montone interval. maximize Whether to maximize or minimize the validation metric. eta Parameter of polynomial-decay averaging. last_n_trigger If True, use [-n:] in average trigger, otherwise use [:-n]. fallback_alpha Fallback epoch proportion of averaging. """ super().__init__(eta=eta) assert 0 <= fallback_alpha <= 1 self.n = n self.maximize = maximize self.last_n_trigger = last_n_trigger # Historical validation metrics. self.val_logs = [] # The epoch where we fallback to alpha suffix. This solves the edge # case where the averaging is never triggered and without the fallback # the model of the last epoch would be returned. self.fallback_alpha_suffix = epochs * (1.0 - fallback_alpha)
[docs] def update_average_trigger( self, metric: Any = None, epoch: int = 0, **kwargs ): r""" Parameters ---------- metric The criteria to trigger averaging. epoch The epoch to start averaging, not used in NTA Returns ------- """ # If not triggered already due to epoch loss check fallback condition if not self.averaging_started: if epoch >= self.fallback_alpha_suffix: self.averaging_started = True if not self.averaging_started and self.n > 0: min_len = self.n if self.last_n_trigger else (self.n + 1) sliced_val_logs = ( self.val_logs[-self.n :] if self.last_n_trigger else self.val_logs[: -self.n] ) if self.maximize: if len(self.val_logs) >= min_len and metric < max( sliced_val_logs ): self.averaging_started = True else: if len(self.val_logs) >= min_len and metric > min( sliced_val_logs ): self.averaging_started = True self.val_logs.append(metric)
[docs]class Alpha_Suffix(IterationAveragingStrategy): r""" Implement Alpha Suffix model averaging. This method is based on paper "Making Gradient Descent Optimalfor Strongly Convex Stochastic Optimization" (https://arxiv.org/pdf/1109.5647.pdf). """ alpha_suffix: float @validated() def __init__(self, epochs: int, alpha: float = 0.75, eta: float = 0): r""" Taking iteration average for the last epoch*alpha epochs Parameters ---------- epochs The total number of epochs. alpha Proportion of averaging. eta Parameter of polynomial-decay averaging. """ super().__init__(eta=eta) assert 0 <= alpha <= 1 # The epoch where iteration averaging starts. self.alpha_suffix = epochs * (1.0 - alpha)
[docs] def update_average_trigger( self, metric: Any = None, epoch: int = 0, **kwargs ): r""" Parameters ---------- metric The criteria to trigger averaging, not used in Alpha Suffix. epoch The epoch to start averaging. Returns ------- """ if not self.averaging_started: if epoch >= self.alpha_suffix: self.averaging_started = True
[docs]class ModelIterationAveraging(Callback): """ Callback to implement iteration based model averaging strategies. Parameters ---------- avg_strategy IterationAveragingStrategy, one of NTA or Alpha_Suffix from gluonts.mx.trainer.model_iteration_averaging """ @validated() def __init__(self, avg_strategy: IterationAveragingStrategy): self.avg_strategy = avg_strategy
[docs] def on_validation_epoch_start( self, training_network: nn.HybridBlock ) -> None: # use averaged model for validation self.avg_strategy.load_averaged_model(training_network)
[docs] def on_validation_epoch_end( self, epoch_no: int, epoch_loss: float, training_network: nn.HybridBlock, trainer: gluon.Trainer, ) -> bool: self.avg_strategy.load_cached_model(training_network) return True
[docs] def on_train_batch_end(self, training_network: nn.HybridBlock) -> bool: self.avg_strategy.apply(training_network) return True
[docs] def on_epoch_end( self, epoch_no: int, epoch_loss: float, training_network: nn.HybridBlock, trainer: gluon.Trainer, best_epoch_info: Dict[str, Any], ctx: mx.Context, ) -> bool: self.avg_strategy.update_average_trigger( metric=epoch_loss, epoch=epoch_no + 1 ) # once triggered, update the average immediately self.avg_strategy.apply(training_network) return True
[docs] def on_train_end( self, training_network: nn.HybridBlock, temporary_dir: str, ctx: mx.context.Context = None, ) -> None: logging.info("Loading averaged parameters.") self.avg_strategy.load_averaged_model(training_network)