Source code for gluonts.model.seasonal_naive._predictor

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import Callable, Union

import numpy as np

from gluonts.core.component import validated
from gluonts.dataset.common import DataEntry
from gluonts.dataset.util import forecast_start
from gluonts.dataset.field_names import FieldName
from gluonts.model.forecast import Forecast, SampleForecast
from gluonts.model.predictor import RepresentablePredictor
from gluonts.transform.feature import (
    LastValueImputation,
    MissingValueImputation,
)


[docs]class SeasonalNaivePredictor(RepresentablePredictor): """ Seasonal naïve forecaster. For each time series :math:`y`, this predictor produces a forecast :math:`\\tilde{y}(T+k) = y(T+k-h)`, where :math:`T` is the forecast time, :math:`k = 0, ...,` `prediction_length - 1`, and :math:`h =` `season_length`. If `prediction_length > season_length`, then the season is repeated multiple times. If a time series is shorter than season_length, then the mean observed value is used as prediction. Parameters ---------- prediction_length Number of time points to predict. season_length Seasonality used to make predictions. If this is an integer, then a fixed sesasonlity is applied; if this is a function, then it will be called on each given entry's ``freq`` attribute of the ``"start"`` field, and will return the seasonality to use. imputation_method The imputation method to use in case of missing values. Defaults to :py:class:`LastValueImputation` which replaces each missing value with the last value that was not missing. """ @validated() def __init__( self, prediction_length: int, season_length: Union[int, Callable], imputation_method: MissingValueImputation = LastValueImputation(), ) -> None: super().__init__(prediction_length=prediction_length) assert ( not isinstance(season_length, int) or season_length > 0 ), "The value of `season_length` should be > 0" self.prediction_length = prediction_length self.season_length = season_length self.imputation_method = imputation_method
[docs] def predict_item(self, item: DataEntry) -> Forecast: if isinstance(self.season_length, int): season_length = self.season_length else: season_length = self.season_length(item["start"].freq) target = np.asarray(item[FieldName.TARGET], np.float32) len_ts = len(target) forecast_start_time = forecast_start(item) assert ( len_ts >= 1 ), "all time series should have at least one data point" if np.isnan(target).any(): target = target.copy() target = self.imputation_method(target) if len_ts >= season_length: indices = [ len_ts - season_length + k % season_length for k in range(self.prediction_length) ] samples = target[indices].reshape((1, self.prediction_length)) else: samples = np.full( shape=(1, self.prediction_length), fill_value=np.nanmean(target), ) return SampleForecast( samples=samples, start_date=forecast_start_time, item_id=item.get("item_id", None), info=item.get("info", None), )