Source code for gluonts.nursery.few_shot_prediction.src.meta.metrics.nd

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import Any, Callable, Optional
import torch
from torchmetrics import Metric


[docs]class NormalizedDeviation(Metric): # pylint: disable=arguments-differ def __init__( self, rescale: bool = False, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): super().__init__( compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, dist_sync_fn=dist_sync_fn, ) self.rescale = rescale self.add_state( "numerator", default=torch.as_tensor(0.0), dist_reduce_fx="sum" ) self.add_state( "denom", default=torch.as_tensor(0.0), dist_reduce_fx="sum" )
[docs] def update( self, y_pred: torch.Tensor, y_true: torch.Tensor, mask: torch.Tensor, scales: Optional[torch.Tensor] = None, ) -> None: if self.rescale: m = scales[:, 0].unsqueeze(1) std = scales[:, 1].unsqueeze(1) y_true = y_true * std + m y_pred = y_pred * std.unsqueeze(2) + m.unsqueeze(2) idx_median = y_pred.size()[-1] // 2 y_pred = y_pred[:, : y_true.size()[1], idx_median] losses = (y_pred - y_true).abs() losses = torch.mul(losses, mask) self.numerator += losses.sum() self.denom += y_true.abs().sum()
[docs] def compute(self) -> torch.Tensor: return self.numerator / self.denom