# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import Any, Dict, List, Optional
import logging
import mxnet as mx
import mxnet.gluon.nn as nn
from mxnet import gluon
from gluonts.core.component import validated
from .callback import Callback
[docs]class IterationAveragingStrategy:
r"""
The model averaging is based on paper
"Stochastic Gradient Descent for Non-smooth Optimization: Convergence
Results and Optimal Averaging Schemes",
(http://proceedings.mlr.press/v28/shamir13.pdf), which implements
polynomial-decay averaging, parameterized by eta. When eta = 0, it is
equivalent to simple average over all iterations with same weights.
"""
averaged_model: Optional[Dict[str, mx.nd.NDArray]]
cached_model: Optional[Dict[str, mx.nd.NDArray]]
average_counter: int
averaging_started: bool
@validated()
def __init__(self, eta: float = 0):
r"""
Parameters
----------
eta
Parameter of polynomial-decay averaging.
"""
self.eta = eta
# Dict that maintains the averaged model parameters.
self.averaged_model = None
# Temporarily save the current model, so that the averaged model can be
# used for validation.
self.cached_model = None
# The number of models accumulated in the average.
self.average_counter = 0
# Indicate whether the model averaging has started.
self.averaging_started = False
[docs] def update_average_trigger(
self, metric: Any = None, epoch: int = 0, **kwargs
):
r"""
Parameters
----------
metric
The criteria to trigger averaging.
epoch
The epoch to start averaging.
Returns
-------
"""
raise NotImplementedError()
[docs] def apply(self, model: nn.HybridBlock) -> Optional[Dict]:
r"""
Parameters
----------
model
The model of the current iteration.
Returns
-------
The averaged model, None if the averaging hasn't started.
"""
if self.averaging_started:
self.update_average(model)
return self.averaged_model
[docs] def update_average(self, model: nn.HybridBlock):
r"""
Parameters
----------
model
The model to update the average.
"""
self.average_counter += 1
if self.averaged_model is None:
self.averaged_model = {
k: v.list_data()[0].copy()
for k, v in model.collect_params().items()
}
else:
alpha = (self.eta + 1.0) / (self.eta + self.average_counter)
# moving average
for name, param_avg in self.averaged_model.items():
param_avg[:] += alpha * (
model.collect_params()[name].list_data()[0] - param_avg
)
[docs] def load_averaged_model(self, model: nn.HybridBlock):
r"""
When validating/evaluating the averaged model in the half way of
training, use load_averaged_model first to load the averaged model and
overwrite the current model, do the evaluation, and then use
load_cached_model to load the current model back.
Parameters
----------
model
The model that the averaged model is loaded to.
"""
if self.averaged_model is not None:
# cache the current model
if self.cached_model is None:
self.cached_model = {
k: v.list_data()[0].copy()
for k, v in model.collect_params().items()
}
else:
for name, param_cached in self.cached_model.items():
param_cached[:] = model.collect_params()[name].list_data()[
0
]
# load the averaged model
for name, param_avg in self.averaged_model.items():
model.collect_params()[name].set_data(param_avg)
[docs] def load_cached_model(self, model: nn.HybridBlock):
r"""
Parameters
----------
model
The model that the cached model is loaded to.
"""
if self.cached_model is not None:
# load the cached model
for name, param_cached in self.cached_model.items():
model.collect_params()[name].set_data(param_cached)
[docs]class NTA(IterationAveragingStrategy):
r"""
Implement Non-monotonically Triggered AvSGD (NTA).
This method is based on paper "Regularizing and Optimizing LSTM Language
Models", (https://openreview.net/pdf?id=SyyGPP0TZ), and an implementation
is available in Salesforce GitHub
(https://github.com/salesforce/awd-lstm-lm/blob/master/main.py). Note that
it mismatches the arxiv (and gluonnlp) version, which is referred to as
NTA_V2 below.
"""
val_logs: List[Any]
@validated()
def __init__(
self,
epochs: int,
n: int = 5,
maximize: bool = False,
last_n_trigger: bool = False,
eta: float = 0,
fallback_alpha: float = 0.05,
):
r"""
Depending on the choice of metrics, the users may want to minimize or
maximize the metrics. Thus, set maximize = True to maximize, otherwise
minimize.
Parameters
----------
epochs
The total number of epochs.
n
The non-montone interval.
maximize
Whether to maximize or minimize the validation metric.
eta
Parameter of polynomial-decay averaging.
last_n_trigger
If True, use [-n:] in average trigger, otherwise use [:-n].
fallback_alpha
Fallback epoch proportion of averaging.
"""
super().__init__(eta=eta)
assert 0 <= fallback_alpha <= 1
self.n = n
self.maximize = maximize
self.last_n_trigger = last_n_trigger
# Historical validation metrics.
self.val_logs = []
# The epoch where we fallback to alpha suffix. This solves the edge
# case where the averaging is never triggered and without the fallback
# the model of the last epoch would be returned.
self.fallback_alpha_suffix = epochs * (1.0 - fallback_alpha)
[docs] def update_average_trigger(
self, metric: Any = None, epoch: int = 0, **kwargs
):
r"""
Parameters
----------
metric
The criteria to trigger averaging.
epoch
The epoch to start averaging, not used in NTA
Returns
-------
"""
# If not triggered already due to epoch loss check fallback condition
if not self.averaging_started:
if epoch >= self.fallback_alpha_suffix:
self.averaging_started = True
if not self.averaging_started and self.n > 0:
min_len = self.n if self.last_n_trigger else (self.n + 1)
sliced_val_logs = (
self.val_logs[-self.n :]
if self.last_n_trigger
else self.val_logs[: -self.n]
)
if self.maximize:
if len(self.val_logs) >= min_len and metric < max(
sliced_val_logs
):
self.averaging_started = True
else:
if len(self.val_logs) >= min_len and metric > min(
sliced_val_logs
):
self.averaging_started = True
self.val_logs.append(metric)
[docs]class Alpha_Suffix(IterationAveragingStrategy):
r"""
Implement Alpha Suffix model averaging.
This method is based on paper "Making Gradient Descent Optimalfor Strongly
Convex Stochastic Optimization" (https://arxiv.org/pdf/1109.5647.pdf).
"""
alpha_suffix: float
@validated()
def __init__(self, epochs: int, alpha: float = 0.75, eta: float = 0):
r"""
Taking iteration average for the last epoch*alpha epochs
Parameters
----------
epochs
The total number of epochs.
alpha
Proportion of averaging.
eta
Parameter of polynomial-decay averaging.
"""
super().__init__(eta=eta)
assert 0 <= alpha <= 1
# The epoch where iteration averaging starts.
self.alpha_suffix = epochs * (1.0 - alpha)
[docs] def update_average_trigger(
self, metric: Any = None, epoch: int = 0, **kwargs
):
r"""
Parameters
----------
metric
The criteria to trigger averaging, not used in Alpha Suffix.
epoch
The epoch to start averaging.
Returns
-------
"""
if not self.averaging_started:
if epoch >= self.alpha_suffix:
self.averaging_started = True
[docs]class ModelIterationAveraging(Callback):
"""
Callback to implement iteration based model averaging strategies.
Parameters
----------
avg_strategy
IterationAveragingStrategy, one of NTA or Alpha_Suffix from
gluonts.mx.trainer.model_iteration_averaging
"""
@validated()
def __init__(self, avg_strategy: IterationAveragingStrategy):
self.avg_strategy = avg_strategy
[docs] def on_validation_epoch_start(
self, training_network: nn.HybridBlock
) -> None:
# use averaged model for validation
self.avg_strategy.load_averaged_model(training_network)
[docs] def on_validation_epoch_end(
self,
epoch_no: int,
epoch_loss: float,
training_network: nn.HybridBlock,
trainer: gluon.Trainer,
) -> bool:
self.avg_strategy.load_cached_model(training_network)
return True
[docs] def on_train_batch_end(self, training_network: nn.HybridBlock) -> bool:
self.avg_strategy.apply(training_network)
return True
[docs] def on_epoch_end(
self,
epoch_no: int,
epoch_loss: float,
training_network: nn.HybridBlock,
trainer: gluon.Trainer,
best_epoch_info: Dict[str, Any],
ctx: mx.Context,
) -> bool:
self.avg_strategy.update_average_trigger(
metric=epoch_loss, epoch=epoch_no + 1
)
# once triggered, update the average immediately
self.avg_strategy.apply(training_network)
return True
[docs] def on_train_end(
self,
training_network: nn.HybridBlock,
temporary_dir: str,
ctx: mx.context.Context = None,
) -> None:
logging.info("Loading averaged parameters.")
self.avg_strategy.load_averaged_model(training_network)