Source code for gluonts.mx.trainer.learning_rate_scheduler

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from dataclasses import field
from typing import Dict, Any, Optional

from pydantic.dataclasses import dataclass
from typing_extensions import Literal

import numpy as np
import mxnet as mx
from mxnet import gluon
import mxnet.gluon.nn as nn

from gluonts.core.component import validated

from .callback import Callback


[docs]@dataclass class Objective: best: float
[docs] @staticmethod def from_str(s: Literal["min", "max"]) -> "Objective": if s == "min": return Min() else: return Max()
[docs] def update(self, metric: float) -> bool: if self.should_update(metric): self.best = metric return True return False
[docs] def should_update(self, metric: float) -> bool: raise NotImplementedError
[docs]@dataclass class Min(Objective): best: float = np.Inf
[docs] def should_update(self, metric: float) -> bool: return metric < self.best
[docs]@dataclass class Max(Objective): best: float = -np.Inf
[docs] def should_update(self, metric: float) -> bool: return metric > self.best
[docs]@dataclass class Patience: """Simple patience tracker. Given an `Objective`, it will check whether the metric has improved and update its patience count. A better value sets the patience back to zero. In addition, one needs to call ``reset()`` explicitly after the patience was exceeded, otherwise `RuntimError` is raised when trying to invoke `step`. ``Patience`` keeps track of the number of invocations to ``reset``, via ``num_resets``. """ patience: int = field(metadata={"ge": 0}) objective: Objective current_patience: int = field(default=0, init=False) num_resets: int = field(default=0, init=False) exceeded: bool = field(default=False, init=False)
[docs] def reset(self) -> None: self.current_patience = 0 self.exceeded = False self.num_resets += 1
[docs] def step(self, metric_value: float) -> bool: if self.exceeded: raise RuntimeError("Patience alread exceeded.") has_improved = self.objective.update(metric_value) if has_improved: self.current_patience = 0 else: self.current_patience += 1 # this can also trigger in case of improvement when `self.patience = 0` self.exceeded = self.current_patience >= self.patience return self.exceeded
[docs]@dataclass class MetricAttentiveScheduler: """ This scheduler decreases the learning rate based on the value of some validation metric to be optimized (maximized or minimized). The value of such metric is provided by calling the `step` method on the scheduler. A `patience` parameter must be provided, and the scheduler will reduce the learning rate if no improvement in the metric is done before `patience` observations of the metric. Examples: `patience = 0`: learning rate will decrease at every call to `step`, regardless of the metric value `patience = 1`: learning rate is reduced as soon `step` is called with a metric value which does not improve over the best encountered `patience = 10`: learning rate is reduced if no improvement in the metric is recorded in 10 successive calls to `step` Parameters ---------- objective String, can either be `"min"` or `"max"` patience The patience to observe before reducing the learning rate, nonnegative integer. base_lr Initial learning rate to be used. decay_factor Factor (between 0 and 1) by which to decrease the learning rate. min_learning_rate Lower bound for the learning rate, learning rate will never go below `min_learning_rate`. """ patience: Patience learning_rate: float = field(default=0.01, metadata={"gt": 0}) decay_factor: float = field(default=0.5, metadata={"gt": 0, "lt": 1}) min_learning_rate: float = 0.0 max_num_decays: Optional[int] = None def __post_init__(self) -> None: assert self.learning_rate > self.min_learning_rate
[docs] def step(self, metric_value: float) -> bool: """ Inform the scheduler of the new value of the metric that is being optimized. This method should be invoked at regular intervals (e.g. at the end of every epoch, after computing a validation score). Parameters ---------- metric_value Value of the metric that is being optimized. Returns ------- bool value indicating, whether to continue training """ self.patience.step(metric_value) should_continue = True if self.patience.exceeded or not np.isfinite(metric_value): if ( self.learning_rate == self.min_learning_rate or self.max_num_decays is not None and self.max_num_decays <= self.patience.num_resets ): should_continue = False # Even though we ask not to continue, we still reset the patience # because we might still end up continuing training. (Can Happen # in testing). self.patience.reset() self.learning_rate *= self.decay_factor # ensure that we don't go below the minimum learning rate if self.learning_rate < self.min_learning_rate: self.learning_rate = self.min_learning_rate return should_continue
[docs]class LearningRateReduction(Callback): """ This Callback decreases the learning rate based on the value of some validation metric to be optimized (maximized or minimized). The value of such metric is provided by calling the `step` method on the scheduler. A `patience` parameter must be provided, and the scheduler will reduce the learning rate if no improvement in the metric is done before `patience` observations of the metric. Examples: `patience = 0`: learning rate will decrease at every call to `step`, regardless of the metric value `patience = 1`: learning rate is reduced as soon `step` is called with a metric value which does not improve over the best encountered `patience = 10`: learning rate is reduced if no improvement in the metric is recorded in 10 successive calls to `step` Parameters ---------- objective String, can either be `"min"` or `"max"`. patience The patience to observe before reducing the learning rate, nonnegative integer. base_lr Initial learning rate to be used. decay_factor Factor (between 0 and 1) by which to decrease the learning rate. min_lr Lower bound for the learning rate, learning rate will never go below `min_lr`. """ @validated() def __init__( self, objective: Literal["min", "max"], patience: int, base_lr: float = 0.01, decay_factor: float = 0.5, min_lr: float = 0.0, ) -> None: assert ( 0 < decay_factor < 1 ), "The value of `decay_factor` should be in the (0, 1) range" assert patience >= 0, "The value of `patience` should be >= 0" assert ( 0 <= min_lr <= base_lr ), "The value of `min_lr` should be >= 0 and <= base_lr" self.lr_scheduler = MetricAttentiveScheduler( patience=Patience(patience, Objective.from_str(objective)), learning_rate=base_lr, decay_factor=decay_factor, min_learning_rate=min_lr, )
[docs] def on_epoch_end( self, epoch_no: int, epoch_loss: float, training_network: nn.HybridBlock, trainer: gluon.Trainer, best_epoch_info: Dict[str, Any], ctx: mx.Context, ) -> bool: should_continue = self.lr_scheduler.step(metric_value=epoch_loss) trainer.optimizer.set_learning_rate(self.lr_scheduler.learning_rate) return should_continue