# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from __future__ import annotations
from dataclasses import dataclass
from functools import partial
from typing import (
Collection,
Optional,
Callable,
Mapping,
Dict,
List,
Iterator,
)
from typing_extensions import Protocol, runtime_checkable, Self
import numpy as np
from .aggregations import Aggregation, Mean, Sum
from .stats import (
error,
absolute_error,
absolute_label,
absolute_percentage_error,
absolute_scaled_error,
coverage,
quantile_loss,
scaled_interval_score,
scaled_quantile_loss,
squared_error,
symmetric_absolute_percentage_error,
num_masked_target_values,
)
[docs]@dataclass
class MetricCollection:
metrics: List[Metric]
[docs] def update(self, data: Mapping[str, np.ndarray]) -> Self:
"""
Update metrics using a single data instance.
"""
for metric in self.metrics:
metric.update(data)
return self
[docs] def update_all(self, stream: Iterator[Mapping[str, np.ndarray]]) -> Self:
"""
Update metrics using a stream of data instances.
"""
for element in stream:
self.update(element)
return self
[docs] def get(self) -> Dict[str, np.ndarray]:
return {metric.name: metric.get() for metric in self.metrics}
[docs]@dataclass
class Metric:
name: str
[docs] def update(self, data: Mapping[str, np.ndarray]) -> Self:
"""
Update metric using a single data instance.
"""
raise NotImplementedError
[docs] def update_all(self, stream: Iterator[Mapping[str, np.ndarray]]) -> Self:
"""
Update metric using a stream of data instances.
"""
for element in stream:
self.update(element)
return self
[docs] def get(self) -> np.ndarray:
raise NotImplementedError
[docs]@dataclass
class DirectMetric(Metric):
"""
A Metric which uses a single function and aggregation strategy.
"""
stat: Callable
aggregate: Aggregation
[docs] def update(self, data: Mapping[str, np.ndarray]) -> Self:
self.aggregate.step(self.stat(data))
return self
[docs] def get(self) -> np.ndarray:
return self.aggregate.get()
[docs]@dataclass
class DerivedMetric(Metric):
"""
A Metric that is computed using other metrics.
A derived metric updates multiple, simpler metrics independently and in the
end combines their results as defined in `post_process`.
"""
metrics: Dict[str, Metric]
post_process: Callable
[docs] def update(self, data: Mapping[str, np.ndarray]) -> Self:
for evaluator in self.metrics.values():
evaluator.update(data)
return self
[docs] def get(self) -> np.ndarray:
return self.post_process(
**{
name: evaluator.get()
for name, evaluator in self.metrics.items()
}
)
[docs]@runtime_checkable
class MetricDefinition(Protocol):
def __call__(self, axis: Optional[int] = None) -> Metric:
raise NotImplementedError
[docs]class BaseMetricDefinition:
def __call__(self, axis):
raise NotImplementedError()
def __add__(self, other) -> MetricDefinitionCollection:
if isinstance(other, MetricDefinitionCollection):
return other + self
return MetricDefinitionCollection([self, other])
[docs] def add(self, *others):
for other in others:
self = self + other
return self
[docs]@dataclass
class MetricDefinitionCollection(BaseMetricDefinition):
metrics: List[BaseMetricDefinition]
def __call__(self, axis: Optional[int] = None) -> MetricCollection:
return MetricCollection([metric(axis=axis) for metric in self.metrics])
def __add__(self, other) -> MetricDefinitionCollection:
if isinstance(other, MetricDefinitionCollection):
return MetricDefinitionCollection([*self.metrics, *other.metrics])
return MetricDefinitionCollection([*self.metrics, other])
[docs]class MeanAbsoluteLabel(BaseMetricDefinition):
def __call__(self, axis: Optional[int] = None) -> DirectMetric:
return DirectMetric(
name="mean_absolute_label",
stat=absolute_label,
aggregate=Mean(axis=axis),
)
mean_absolute_label = MeanAbsoluteLabel()
[docs]class SumAbsoluteLabel(BaseMetricDefinition):
def __call__(self, axis: Optional[int] = None) -> DirectMetric:
return DirectMetric(
name="sum_absolute_label",
stat=absolute_label,
aggregate=Sum(axis=axis),
)
sum_absolute_label = SumAbsoluteLabel()
[docs]class SumNumMaskedTargetValues(BaseMetricDefinition):
def __call__(self, axis: Optional[int] = None) -> DirectMetric:
return DirectMetric(
name="sum_num_masked_target_values",
stat=num_masked_target_values,
aggregate=Sum(axis=axis),
)
sum_num_masked_target_values = SumNumMaskedTargetValues()
[docs]@dataclass
class SumError(BaseMetricDefinition):
forecast_type: str = "0.5"
def __call__(self, axis: Optional[int] = None) -> DirectMetric:
return DirectMetric(
name=f"sum_error[{self.forecast_type}]",
stat=partial(error, forecast_type=self.forecast_type),
aggregate=Sum(axis=axis),
)
sum_error = SumError()
[docs]@dataclass
class SumAbsoluteError(BaseMetricDefinition):
forecast_type: str = "0.5"
def __call__(self, axis: Optional[int] = None) -> DirectMetric:
return DirectMetric(
name=f"sum_absolute_error[{self.forecast_type}]",
stat=partial(absolute_error, forecast_type=self.forecast_type),
aggregate=Sum(axis=axis),
)
sum_absolute_error = SumAbsoluteError()
[docs]@dataclass
class MAE(BaseMetricDefinition):
"""
Mean Absolute Error.
"""
forecast_type: str = "0.5"
def __call__(self, axis: Optional[int] = None) -> DirectMetric:
return DirectMetric(
name=f"MAE[{self.forecast_type}]",
stat=partial(absolute_error, forecast_type=self.forecast_type),
aggregate=Mean(axis=axis),
)
mae = MAE()
[docs]@dataclass
class MSE(BaseMetricDefinition):
"""
Mean Squared Error.
"""
forecast_type: str = "mean"
def __call__(self, axis: Optional[int] = None) -> DirectMetric:
return DirectMetric(
name=f"MSE[{self.forecast_type}]",
stat=partial(squared_error, forecast_type=self.forecast_type),
aggregate=Mean(axis=axis),
)
mse = MSE()
[docs]@dataclass
class SumQuantileLoss(BaseMetricDefinition):
q: float
def __call__(self, axis: Optional[int] = None) -> DirectMetric:
return DirectMetric(
name=f"sum_quantile_loss[{self.q}]",
stat=partial(quantile_loss, q=self.q),
aggregate=Sum(axis=axis),
)
[docs]@dataclass
class Coverage(BaseMetricDefinition):
q: float
def __call__(self, axis: Optional[int] = None) -> DirectMetric:
return DirectMetric(
name=f"coverage[{self.q}]",
stat=partial(coverage, q=self.q),
aggregate=Mean(axis=axis),
)
[docs]@dataclass
class MAPE(BaseMetricDefinition):
"""
Mean Absolute Percentage Error.
"""
forecast_type: str = "0.5"
def __call__(self, axis: Optional[int] = None) -> DirectMetric:
return DirectMetric(
name=f"MAPE[{self.forecast_type}]",
stat=partial(
absolute_percentage_error, forecast_type=self.forecast_type
),
aggregate=Mean(axis=axis),
)
mape = MAPE()
[docs]@dataclass
class SMAPE(BaseMetricDefinition):
"""
Symmetric Mean Absolute Percentage Error.
"""
forecast_type: str = "0.5"
def __call__(self, axis: Optional[int] = None) -> DirectMetric:
return DirectMetric(
name=f"sMAPE[{self.forecast_type}]",
stat=partial(
symmetric_absolute_percentage_error,
forecast_type=self.forecast_type,
),
aggregate=Mean(axis=axis),
)
smape = SMAPE()
[docs]@dataclass
class MSIS(BaseMetricDefinition):
"""
Mean Scaled Interval Score.
"""
alpha: float = 0.05
def __call__(self, axis: Optional[int] = None) -> DirectMetric:
return DirectMetric(
name="MSIS",
stat=partial(scaled_interval_score, alpha=self.alpha),
aggregate=Mean(axis=axis),
)
msis = MSIS()
[docs]@dataclass
class MASE(BaseMetricDefinition):
"""
Mean Absolute Scaled Error.
"""
forecast_type: str = "0.5"
def __call__(self, axis: Optional[int] = None) -> DirectMetric:
return DirectMetric(
name=f"MASE[{self.forecast_type}]",
stat=partial(
absolute_scaled_error, forecast_type=self.forecast_type
),
aggregate=Mean(axis=axis),
)
mase = MASE()
[docs]@dataclass
class MeanScaledQuantileLoss(BaseMetricDefinition):
q: float
def __call__(self, axis: Optional[int] = None) -> DirectMetric:
return DirectMetric(
name=f"mean_scaled_quantile_loss[{self.q}]",
stat=partial(scaled_quantile_loss, q=self.q),
aggregate=Mean(axis=axis),
)
[docs]@dataclass
class ND(BaseMetricDefinition):
"""
Normalized Deviation.
"""
forecast_type: str = "0.5"
[docs] @staticmethod
def normalized_deviation(
sum_absolute_error: np.ndarray, sum_absolute_label: np.ndarray
) -> np.ndarray:
return sum_absolute_error / sum_absolute_label
def __call__(self, axis: Optional[int] = None) -> DerivedMetric:
return DerivedMetric(
name=f"ND[{self.forecast_type}]",
metrics={
"sum_absolute_error": SumAbsoluteError(
forecast_type=self.forecast_type
)(axis=axis),
"sum_absolute_label": sum_absolute_label(axis=axis),
},
post_process=self.normalized_deviation,
)
nd = ND()
[docs]@dataclass
class RMSE(BaseMetricDefinition):
"""
Root Mean Squared Error.
"""
forecast_type: str = "mean"
[docs] @staticmethod
def root_mean_squared_error(mean_squared_error: np.ndarray) -> np.ndarray:
return np.sqrt(mean_squared_error)
def __call__(self, axis: Optional[int] = None) -> DerivedMetric:
return DerivedMetric(
name=f"RMSE[{self.forecast_type}]",
metrics={
"mean_squared_error": MSE(forecast_type=self.forecast_type)(
axis=axis
)
},
post_process=self.root_mean_squared_error,
)
rmse = RMSE()
[docs]@dataclass
class NRMSE(BaseMetricDefinition):
"""
RMSE, normalized by the mean absolute label.
"""
forecast_type: str = "mean"
[docs] @staticmethod
def normalize_root_mean_squared_error(
root_mean_squared_error: np.ndarray, mean_absolute_label: np.ndarray
) -> np.ndarray:
return root_mean_squared_error / mean_absolute_label
def __call__(self, axis: Optional[int] = None) -> DerivedMetric:
return DerivedMetric(
name=f"NRMSE[{self.forecast_type}]",
metrics={
"root_mean_squared_error": RMSE(
forecast_type=self.forecast_type
)(axis=axis),
"mean_absolute_label": mean_absolute_label(axis=axis),
},
post_process=self.normalize_root_mean_squared_error,
)
nrmse = NRMSE()
[docs]@dataclass
class WeightedSumQuantileLoss(BaseMetricDefinition):
q: float
[docs] @staticmethod
def weight_sum_quantile_loss(
sum_quantile_loss: np.ndarray, sum_absolute_label: np.ndarray
) -> np.ndarray:
return sum_quantile_loss / sum_absolute_label
def __call__(self, axis: Optional[int] = None) -> DerivedMetric:
return DerivedMetric(
name=f"weighted_sum_quantile_loss[{self.q}]",
metrics={
"sum_quantile_loss": SumQuantileLoss(q=self.q)(axis=axis),
"sum_absolute_label": sum_absolute_label(axis=axis),
},
post_process=self.weight_sum_quantile_loss,
)
[docs]@dataclass
class MeanSumQuantileLoss(BaseMetricDefinition):
quantile_levels: Collection[float]
[docs] @staticmethod
def mean(**quantile_losses: np.ndarray) -> np.ndarray:
stacked_quantile_losses = np.stack(
[quantile_loss for quantile_loss in quantile_losses.values()],
axis=0,
)
return np.mean(stacked_quantile_losses, axis=0)
def __call__(self, axis: Optional[int] = None) -> DerivedMetric:
return DerivedMetric(
name="mean_sum_quantile_loss",
metrics={
f"quantile_loss[{q}]": SumQuantileLoss(q=q)(axis=axis)
for q in self.quantile_levels
},
post_process=self.mean,
)
[docs]@dataclass
class MeanWeightedSumQuantileLoss(BaseMetricDefinition):
quantile_levels: Collection[float]
[docs] @staticmethod
def mean(**quantile_losses: np.ndarray) -> np.ndarray:
stacked_quantile_losses = np.stack(
[quantile_loss for quantile_loss in quantile_losses.values()],
axis=0,
)
return np.mean(stacked_quantile_losses, axis=0)
def __call__(self, axis: Optional[int] = None) -> DerivedMetric:
return DerivedMetric(
name="mean_weighted_sum_quantile_loss",
metrics={
f"quantile_loss[{q}]": WeightedSumQuantileLoss(q=q)(axis=axis)
for q in self.quantile_levels
},
post_process=self.mean,
)
[docs]@dataclass
class AverageMeanScaledQuantileLoss(BaseMetricDefinition):
quantile_levels: Collection[float]
[docs] @staticmethod
def mean(**quantile_losses: np.ndarray) -> np.ndarray:
stacked_quantile_losses = np.stack(
[quantile_loss for quantile_loss in quantile_losses.values()],
axis=0,
)
return np.mean(stacked_quantile_losses, axis=0)
def __call__(self, axis: Optional[int] = None) -> DerivedMetric:
return DerivedMetric(
name="average_mean_scaled_quantile_loss",
metrics={
f"mean_scaled_quantile_loss[{q}]": MeanScaledQuantileLoss(q=q)(
axis=axis
)
for q in self.quantile_levels
},
post_process=self.mean,
)
[docs]@dataclass
class MAECoverage(BaseMetricDefinition):
quantile_levels: Collection[float]
[docs] @staticmethod
def mean(
quantile_levels: Collection[float], **coverages: np.ndarray
) -> np.ndarray:
intermediate_result = np.stack(
[np.abs(coverages[f"coverage[{q}]"] - q) for q in quantile_levels],
axis=0,
)
return np.mean(intermediate_result, axis=0)
def __call__(self, axis: Optional[int] = None) -> DerivedMetric:
return DerivedMetric(
name="MAE_coverage",
metrics={
f"coverage[{q}]": Coverage(q=q)(axis=axis)
for q in self.quantile_levels
},
post_process=partial(
self.mean, quantile_levels=self.quantile_levels
),
)
[docs]@dataclass
class OWA(BaseMetricDefinition):
"""
Overall Weighted Average.
"""
forecast_type: str = "0.5"
[docs] @staticmethod
def calculate_OWA(
smape: np.ndarray,
smape_naive2: np.ndarray,
mase: np.ndarray,
mase_naive2: np.ndarray,
) -> np.ndarray:
return 0.5 * (smape / smape_naive2 + mase / mase_naive2)
def __call__(self, axis: Optional[int] = None) -> DerivedMetric:
return DerivedMetric(
name=f"OWA[{self.forecast_type}]",
metrics={
"smape": SMAPE(forecast_type=self.forecast_type)(axis=axis),
"smape_naive2": SMAPE(forecast_type="naive_2")(axis=axis),
"mase": MASE(forecast_type=self.forecast_type)(axis=axis),
"mase_naive2": MASE(forecast_type="naive_2")(axis=axis),
},
post_process=self.calculate_OWA,
)
owa = OWA()