# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from dataclasses import dataclass, field
from typing import List, Optional, Type
import numpy as np
from statsforecast.models import (
ADIDA,
AutoARIMA,
AutoCES,
AutoETS,
AutoTheta,
CrostonClassic,
CrostonOptimized,
CrostonSBA,
DynamicOptimizedTheta,
DynamicTheta,
HistoricAverage,
Holt,
HoltWinters,
IMAPA,
MSTL,
Naive,
OptimizedTheta,
RandomWalkWithDrift,
SeasonalExponentialSmoothing,
SeasonalExponentialSmoothingOptimized,
SeasonalNaive,
SeasonalWindowAverage,
SimpleExponentialSmoothing,
SimpleExponentialSmoothingOptimized,
TSB,
Theta,
WindowAverage,
)
from gluonts.core.component import validated
from gluonts.dataset import DataEntry
from gluonts.dataset.util import forecast_start
from gluonts.model.predictor import RepresentablePredictor
from gluonts.model.forecast import QuantileForecast
[docs]@dataclass
class ModelConfig:
quantile_levels: Optional[List[float]] = None
forecast_keys: List[str] = field(init=False)
statsforecast_keys: List[str] = field(init=False)
intervals: Optional[List[int]] = field(init=False)
def __post_init__(self):
self.forecast_keys = ["mean"]
self.statsforecast_keys = ["mean"]
if self.quantile_levels is None:
self.intervals = None
return
intervals = set()
for quantile_level in self.quantile_levels:
interval = round(
200 * (max(quantile_level, 1 - quantile_level) - 0.5)
)
intervals.add(interval)
side = "hi" if quantile_level > 0.5 else "lo"
self.forecast_keys.append(str(quantile_level))
self.statsforecast_keys.append(f"{side}-{interval}")
self.intervals = sorted(intervals)
[docs]class StatsForecastPredictor(RepresentablePredictor):
"""
A predictor type that wraps models from the `statsforecast`_ package.
This class is used via subclassing and setting the ``ModelType`` class
attribute to specify the ``statsforecast`` model type to use.
.. _statsforecast: https://github.com/Nixtla/statsforecast
Parameters
----------
prediction_length
Prediction length for the model to use.
quantile_levels
Optional list of quantile levels that we want predictions for.
Note: this is only supported by specific types of models, such as
``AutoARIMA``. By default this is ``None``, giving only the mean
prediction.
**model_params
Keyword arguments to be passed to the model type for construction.
The specific arguments accepted or required depend on the
``ModelType``; please refer to the documentation of ``statsforecast``
for details.
"""
ModelType: Type
@validated()
def __init__(
self,
prediction_length: int,
quantile_levels: Optional[List[float]] = None,
**model_params,
) -> None:
super().__init__(prediction_length=prediction_length)
self.model = self.ModelType(**model_params)
self.config = ModelConfig(quantile_levels=quantile_levels)
[docs] def predict_item(self, entry: DataEntry) -> QuantileForecast:
# TODO use also exogenous features
kwargs = {}
if self.config.intervals is not None:
kwargs["level"] = self.config.intervals
prediction = self.model.forecast(
y=entry["target"],
h=self.prediction_length,
**kwargs,
)
forecast_arrays = [
prediction[k] for k in self.config.statsforecast_keys
]
return QuantileForecast(
forecast_arrays=np.stack(forecast_arrays, axis=0),
forecast_keys=self.config.forecast_keys,
start_date=forecast_start(entry),
item_id=entry.get("item_id"),
info=entry.get("info"),
)
[docs]class ADIDAPredictor(StatsForecastPredictor):
"""
A predictor wrapping the ``ADIDA`` model from `statsforecast`_.
See :class:`StatsForecastPredictor` for the list of arguments.
.. _statsforecast: https://github.com/Nixtla/statsforecast
"""
ModelType = ADIDA
[docs]class AutoARIMAPredictor(StatsForecastPredictor):
"""
A predictor wrapping the ``AutoARIMA`` model from `statsforecast`_.
See :class:`StatsForecastPredictor` for the list of arguments.
.. _statsforecast: https://github.com/Nixtla/statsforecast
"""
ModelType = AutoARIMA
[docs]class AutoCESPredictor(StatsForecastPredictor):
"""
A predictor wrapping the ``AutoCES`` model from `statsforecast`_.
See :class:`StatsForecastPredictor` for the list of arguments.
.. _statsforecast: https://github.com/Nixtla/statsforecast
"""
ModelType = AutoCES
[docs]class AutoETSPredictor(StatsForecastPredictor):
"""
A predictor wrapping the ``AutoETS`` model from `statsforecast`_.
See :class:`StatsForecastPredictor` for the list of arguments.
.. _statsforecast: https://github.com/Nixtla/statsforecast
"""
ModelType = AutoETS
[docs]class AutoThetaPredictor(StatsForecastPredictor):
"""
A predictor wrapping the ``AutoTheta`` model from `statsforecast`_.
See :class:`StatsForecastPredictor` for the list of arguments.
.. _statsforecast: https://github.com/Nixtla/statsforecast
"""
ModelType = AutoTheta
[docs]class CrostonClassicPredictor(StatsForecastPredictor):
"""
A predictor wrapping the ``CrostonClassic`` model from `statsforecast`_.
See :class:`StatsForecastPredictor` for the list of arguments.
.. _statsforecast: https://github.com/Nixtla/statsforecast
"""
ModelType = CrostonClassic
[docs]class CrostonOptimizedPredictor(StatsForecastPredictor):
"""
A predictor wrapping the ``CrostonOptimized`` model from `statsforecast`_.
See :class:`StatsForecastPredictor` for the list of arguments.
.. _statsforecast: https://github.com/Nixtla/statsforecast
"""
ModelType = CrostonOptimized
[docs]class CrostonSBAPredictor(StatsForecastPredictor):
"""
A predictor wrapping the ``CrostonSBA`` model from `statsforecast`_.
See :class:`StatsForecastPredictor` for the list of arguments.
.. _statsforecast: https://github.com/Nixtla/statsforecast
"""
ModelType = CrostonSBA
[docs]class IMAPAPredictor(StatsForecastPredictor):
"""
A predictor wrapping the ``IMAPA`` model from `statsforecast`_.
See :class:`StatsForecastPredictor` for the list of arguments.
.. _statsforecast: https://github.com/Nixtla/statsforecast
"""
ModelType = IMAPA
[docs]class DynamicOptimizedThetaPredictor(StatsForecastPredictor):
"""
A predictor wrapping the ``DynamicOptimizedTheta`` model from
`statsforecast`_.
See :class:`StatsForecastPredictor` for the list of arguments.
.. _statsforecast: https://github.com/Nixtla/statsforecast
"""
ModelType = DynamicOptimizedTheta
[docs]class DynamicThetaPredictor(StatsForecastPredictor):
"""
A predictor wrapping the ``DynamicTheta`` model from `statsforecast`_.
See :class:`StatsForecastPredictor` for the list of arguments.
.. _statsforecast: https://github.com/Nixtla/statsforecast
"""
ModelType = DynamicTheta
[docs]class HistoricAveragePredictor(StatsForecastPredictor):
"""
A predictor wrapping the ``HistoricAverage`` model from `statsforecast`_.
See :class:`StatsForecastPredictor` for the list of arguments.
.. _statsforecast: https://github.com/Nixtla/statsforecast
"""
ModelType = HistoricAverage
[docs]class HoltPredictor(StatsForecastPredictor):
"""
A predictor wrapping the ``Holt`` model from `statsforecast`_.
See :class:`StatsForecastPredictor` for the list of arguments.
.. _statsforecast: https://github.com/Nixtla/statsforecast
"""
ModelType = Holt
[docs]class HoltWintersPredictor(StatsForecastPredictor):
"""
A predictor wrapping the ``HoltWinters`` model from `statsforecast`_.
See :class:`StatsForecastPredictor` for the list of arguments.
.. _statsforecast: https://github.com/Nixtla/statsforecast
"""
ModelType = HoltWinters
[docs]class MSTLPredictor(StatsForecastPredictor):
"""
A predictor wrapping the ``MSTL`` model from `statsforecast`_.
See :class:`StatsForecastPredictor` for the list of arguments.
.. _statsforecast: https://github.com/Nixtla/statsforecast
"""
ModelType = MSTL
[docs]class NaivePredictor(StatsForecastPredictor):
"""
A predictor wrapping the ``Naive`` model from `statsforecast`_.
See :class:`StatsForecastPredictor` for the list of arguments.
.. _statsforecast: https://github.com/Nixtla/statsforecast
"""
ModelType = Naive
[docs]class OptimizedThetaPredictor(StatsForecastPredictor):
"""
A predictor wrapping the ``OptimizedTheta`` model from `statsforecast`_.
See :class:`StatsForecastPredictor` for the list of arguments.
.. _statsforecast: https://github.com/Nixtla/statsforecast
"""
ModelType = OptimizedTheta
[docs]class RandomWalkWithDriftPredictor(StatsForecastPredictor):
"""
A predictor wrapping the ``RandomWalkWithDrift`` model from
`statsforecast`_.
See :class:`StatsForecastPredictor` for the list of arguments.
.. _statsforecast: https://github.com/Nixtla/statsforecast
"""
ModelType = RandomWalkWithDrift
[docs]class SeasonalExponentialSmoothingPredictor(StatsForecastPredictor):
"""
A predictor wrapping the ``SeasonalExponentialSmoothing`` model from
`statsforecast`_.
See :class:`StatsForecastPredictor` for the list of arguments.
.. _statsforecast: https://github.com/Nixtla/statsforecast
"""
ModelType = SeasonalExponentialSmoothing
[docs]class SeasonalExponentialSmoothingOptimizedPredictor(StatsForecastPredictor):
"""
A predictor wrapping the ``SeasonalExponentialSmoothingOptimized`` model
from `statsforecast`_.
See :class:`StatsForecastPredictor` for the list of arguments.
.. _statsforecast: https://github.com/Nixtla/statsforecast
"""
ModelType = SeasonalExponentialSmoothingOptimized
[docs]class SeasonalNaivePredictor(StatsForecastPredictor):
"""
A predictor wrapping the ``SeasonalNaive`` model from `statsforecast`_.
See :class:`StatsForecastPredictor` for the list of arguments.
.. _statsforecast: https://github.com/Nixtla/statsforecast
"""
ModelType = SeasonalNaive
[docs]class SeasonalWindowAveragePredictor(StatsForecastPredictor):
"""
A predictor wrapping the ``SeasonalWindowAverage`` model from
`statsforecast`_.
See :class:`StatsForecastPredictor` for the list of arguments.
.. _statsforecast: https://github.com/Nixtla/statsforecast
"""
ModelType = SeasonalWindowAverage
[docs]class SimpleExponentialSmoothingPredictor(StatsForecastPredictor):
"""
A predictor wrapping the ``SimpleExponentialSmoothing`` model from
`statsforecast`_.
See :class:`StatsForecastPredictor` for the list of arguments.
.. _statsforecast: https://github.com/Nixtla/statsforecast
"""
ModelType = SimpleExponentialSmoothing
[docs]class SimpleExponentialSmoothingOptimizedPredictor(StatsForecastPredictor):
"""
A predictor wrapping the ``SimpleExponentialSmoothingOptimized`` model from
`statsforecast`_.
See :class:`StatsForecastPredictor` for the list of arguments.
.. _statsforecast: https://github.com/Nixtla/statsforecast
"""
ModelType = SimpleExponentialSmoothingOptimized
[docs]class TSBPredictor(StatsForecastPredictor):
"""
A predictor wrapping the ``TSB`` model from `statsforecast`_.
See :class:`StatsForecastPredictor` for the list of arguments.
.. _statsforecast: https://github.com/Nixtla/statsforecast
"""
ModelType = TSB
[docs]class ThetaPredictor(StatsForecastPredictor):
"""
A predictor wrapping the ``Theta`` model from `statsforecast`_.
See :class:`StatsForecastPredictor` for the list of arguments.
.. _statsforecast: https://github.com/Nixtla/statsforecast
"""
ModelType = Theta
[docs]class WindowAveragePredictor(StatsForecastPredictor):
"""
A predictor wrapping the ``WindowAverage`` model from `statsforecast`_.
See :class:`StatsForecastPredictor` for the list of arguments.
.. _statsforecast: https://github.com/Nixtla/statsforecast
"""
ModelType = WindowAverage