# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import Callable, Union
import numpy as np
from gluonts.core.component import validated
from gluonts.dataset.common import DataEntry
from gluonts.dataset.util import forecast_start
from gluonts.dataset.field_names import FieldName
from gluonts.model.forecast import Forecast, SampleForecast
from gluonts.model.predictor import RepresentablePredictor
from gluonts.transform.feature import (
LastValueImputation,
MissingValueImputation,
)
[docs]class SeasonalAggregatePredictor(RepresentablePredictor):
"""
Seasonal aggegate forecaster.
For each time series :math:`y`, this predictor produces a forecast
:math:`\\tilde{y}(T+k) = f\big(y(T+k-h), y(T+k-2h), ...,
y(T+k-mh)\big)`, where :math:`T` is the forecast time,
:math:`k = 0, ...,` `prediction_length - 1`, :math:`m =`num_seasons`,
:math:`h =`season_length` and :math:`f =`agg_fun`.
If `prediction_length > season_length` :math:\times `num_seasons`, then the
seasonal aggregate is repeated multiple times. If a time series is shorter
than season_length` :math:\times `num_seasons`, then the `agg_fun` is
applied to the full time series.
Parameters
----------
prediction_length
Number of time points to predict.
season_length
Seasonality used to make predictions. If this is an integer, then a
fixed sesasonlity is applied; if this is a function, then it will be
called on each given entry's ``freq`` attribute of the ``"start"``
field, and the returned seasonality will be used.
num_seasons
Number of seasons to aggregate.
agg_fun
Aggregate function.
imputation_method
The imputation method to use in case of missing values.
Defaults to :py:class:`LastValueImputation` which replaces each missing
value with the last value that was not missing.
"""
@validated()
def __init__(
self,
prediction_length: int,
season_length: Union[int, Callable],
num_seasons: int,
agg_fun: Callable = np.nanmean,
imputation_method: MissingValueImputation = LastValueImputation(),
) -> None:
super().__init__(prediction_length=prediction_length)
assert (
not isinstance(season_length, int) or season_length > 0
), "The value of `season_length` should be > 0"
assert (
isinstance(num_seasons, int) and num_seasons > 0
), "The value of `num_seasons` should be > 0"
self.prediction_length = prediction_length
self.season_length = season_length
self.num_seasons = num_seasons
self.agg_fun = agg_fun
self.imputation_method = imputation_method
[docs] def predict_item(self, item: DataEntry) -> Forecast:
if isinstance(self.season_length, int):
season_length = self.season_length
else:
season_length = self.season_length(item["start"].freq)
target = np.asarray(item[FieldName.TARGET], np.float32)
len_ts = len(target)
forecast_start_time = forecast_start(item)
assert (
len_ts >= 1
), "all time series should have at least one data point"
if np.isnan(target).any():
target = target.copy()
target = self.imputation_method(target)
if len_ts >= season_length * self.num_seasons:
# `indices` here is a 2D array where each row collects indices
# from one of the past seasons. The first row is identical to the
# one in `seasonal_naive` and the subsequent rows are similar
# except that the indices are taken from a different past season.
indices = [
[
len_ts - (j + 1) * season_length + k % season_length
for k in range(self.prediction_length)
]
for j in range(self.num_seasons)
]
samples = self.agg_fun(target[indices], axis=0).reshape(
(1, self.prediction_length)
)
else:
samples = np.full(
shape=(1, self.prediction_length),
fill_value=self.agg_fun(target),
)
return SampleForecast(
samples=samples,
start_date=forecast_start_time,
item_id=item.get("item_id", None),
info=item.get("info", None),
)