# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
import glob
import json
import logging
from typing import Dict, List, Tuple
import mxnet as mx
import mxnet.gluon.nn as nn
import numpy as np
from gluonts.core.component import validated
from .callback import Callback
EPOCH_INFO_STRING = "epoch-info"
[docs]def save_epoch_info(tmp_path: str, epoch_info: dict) -> None:
r"""
Writes the current epoch information into a json file in the model path.
Parameters
----------
tmp_path
Temporary base path to save the epoch info.
epoch_info
Epoch information dictionary containing the parameters path, the epoch
number and the tracking metric value.
Returns
-------
None
"""
with open(f"{tmp_path}-{EPOCH_INFO_STRING}.json", "w") as f:
json.dump(epoch_info, f)
[docs]class AveragingStrategy:
@validated()
def __init__(
self,
num_models: int = 5,
metric: str = "score",
maximize: bool = False,
):
r"""
Parameters
----------
num_models
Number of model checkpoints to average.
metric
Metric which is used to average models.
maximize
Boolean flag to indicate whether the metric should be maximized or
minimized.
"""
self.num_models = num_models
self.metric = metric
self.maximize = maximize
[docs] def apply(self, model_path: str) -> str:
r"""
Averages model parameters of serialized models based on the selected
model strategy and metric.
IMPORTANT: Depending on the metric the user might want to minimize or
maximize. The maximize flag has to be chosen appropriately to reflect
this.
Parameters
----------
model_path
Path to the models directory.
Returns
-------
Path to file with the averaged model.
"""
checkpoints = self.get_checkpoint_information(model_path)
checkpoint_paths, weights = self.select_checkpoints(checkpoints)
average_parms = self.average(checkpoint_paths, weights)
average_parms_path = model_path + "/averaged_model-0000.params"
mx.nd.save(average_parms_path, average_parms)
return average_parms_path
[docs] def select_checkpoints(
self, checkpoints: List[Dict]
) -> Tuple[List[str], List[float]]:
r"""
Selects checkpoints and computes weights for the selected checkpoints.
Parameters
----------
checkpoints
List of checkpoint information dictionaries.
Returns
-------
List of selected checkpoint paths and list of corresponding
weights.
"""
raise NotImplementedError()
[docs] def average(self, param_paths: List[str], weights: List[float]) -> Dict:
r"""
Averages parameters from a list of .params file paths.
Parameters
----------
param_paths
List of paths to parameter files.
weights
List of weights for the parameter average.
Returns
-------
Averaged parameter dictionary.
"""
all_arg_params = []
for path in param_paths:
params = mx.nd.load(path)
all_arg_params.append(params)
avg_params = {}
for k in all_arg_params[0]:
arrays = [p[k] for p in all_arg_params]
avg_params[k] = self.average_arrays(arrays, weights)
return avg_params
[docs] @staticmethod
def average_arrays(
arrays: List[mx.nd.NDArray], weights: List[float]
) -> mx.nd.NDArray:
r"""
Takes a list of arrays of the same shape and computes the element wise
weighted average.
Parameters
----------
arrays
List of NDArrays with the same shape that will be averaged.
weights
List of weights for the parameter average.
Returns
-------
The average of the NDArrays in the same context as arrays[0].
"""
def _assert_shapes(arrays):
shape_set = {array.shape for array in arrays}
assert len(shape_set) == 1, (
"All arrays should be the same shape. Found arrays with these"
" shapes instead :{}".format(shape_set)
)
_assert_shapes(arrays)
if not arrays:
raise ValueError("arrays is empty.")
if len(arrays) == 1:
return arrays[0]
return mx.nd.add_n(*[a * w for a, w in zip(arrays, weights)])
[docs]class SelectNBestSoftmax(AveragingStrategy):
[docs] def select_checkpoints(
self, checkpoints: List[Dict]
) -> Tuple[List[str], List[float]]:
r"""
Selects the checkpoints with the best metric values.
The weights are the softmax of the metric values, i.e.,
w_i = exp(v_i) / sum(exp(v_j)) if maximize=True
w_i = exp(-v_i) / sum(exp(-v_j)) if maximize=False
Parameters
----------
checkpoints
List of checkpoint information dictionaries.
Returns
-------
List of selected checkpoint paths and list of corresponding
weights.
"""
metric_path_tuple = [
(c[self.metric], c["params_path"]) for c in checkpoints
]
top_checkpoints = sorted(metric_path_tuple, reverse=self.maximize)[
: self.num_models
]
# weights of top checkpoints
weights = [
np.exp(c[0]) if self.maximize else np.exp(-c[0])
for c in top_checkpoints
]
weights = [x / sum(weights) for x in weights]
# paths of top checkpoints
checkpoint_paths = [c[1] for c in top_checkpoints]
return checkpoint_paths, weights
[docs]class SelectNBestMean(AveragingStrategy):
[docs] def select_checkpoints(
self, checkpoints: List[Dict]
) -> Tuple[List[str], List[float]]:
r"""
Selects the checkpoints with the best metric values.
The weights are equal for all checkpoints, i.e., w_i = 1/N.
Parameters
----------
checkpoints
List of checkpoint information dictionaries.
Returns
-------
List of selected checkpoint paths and list of corresponding
weights.
"""
metric_path_tuple = [
(c[self.metric], c["params_path"]) for c in checkpoints
]
top_checkpoints = sorted(metric_path_tuple, reverse=self.maximize)[
: self.num_models
]
# weights of top checkpoints
weights = [1 / len(top_checkpoints)] * len(top_checkpoints)
# paths of top checkpoints
checkpoint_paths = [c[1] for c in top_checkpoints]
return checkpoint_paths, weights
[docs]class ModelAveraging(Callback):
"""
Callback to implement model averaging strategies. Selects the checkpoints
with the best loss values and computes the model average or weighted model
average depending on the chosen avg_strategy.
Parameters
----------
avg_strategy
AveragingStrategy, one of SelectNBestSoftmax or SelectNBestMean from
gluonts.mx.trainer.model_averaging.
"""
@validated()
def __init__(self, avg_strategy: AveragingStrategy):
self.avg_strategy = avg_strategy
[docs] def on_train_end(
self,
training_network: nn.HybridBlock,
temporary_dir: str,
ctx: mx.context.Context = None,
) -> None:
logging.info("Computing averaged parameters.")
averaged_params_path = self.avg_strategy.apply(temporary_dir)
logging.info("Loading averaged parameters.")
training_network.load_parameters(averaged_params_path, ctx)