Source code for gluonts.model.r_forecast._univariate_predictor

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import Dict, Optional, List, Union

import numpy as np
import pandas as pd

from gluonts.core.component import validated
from gluonts.model.forecast import SampleForecast, QuantileForecast
from gluonts.model.r_forecast import RBasePredictor


R_FILE_PREFIX = "univariate"

UNIVARIATE_SAMPLE_FORECAST_METHODS = ["ets", "arima"]
UNIVARIATE_QUANTILE_FORECAST_METHODS = ["tbats", "thetaf", "stlar"]
UNIVARIATE_POINT_FORECAST_METHODS = ["croston", "mlp"]
SUPPORTED_UNIVARIATE_METHODS = (
    UNIVARIATE_SAMPLE_FORECAST_METHODS
    + UNIVARIATE_QUANTILE_FORECAST_METHODS
    + UNIVARIATE_POINT_FORECAST_METHODS
)


[docs]class RForecastPredictor(RBasePredictor): """ Wrapper for calling the `R forecast package <http://pkg.robjhyndman.com/forecast/>`_. In order to use it you need to install R and run:: pip install 'rpy2>=2.9.*,<3.*' R -e 'install.packages(c("forecast", "nnfor"), repos="https://cloud.r-project.org")' # noqa Parameters ---------- freq The granularity of the time series (e.g. '1H') prediction_length Number of time points to be predicted. method_name The method from rforecast to be used one of "ets", "arima", "tbats", "croston", "mlp", "thetaf". period The period to be used (this is called `frequency` in the R forecast package), result to a tentative reasonable default if not specified (for instance 24 for hourly freq '1H') trunc_length Maximum history length to feed to the model (some models become slow with very long series). params Parameters to be used when calling the forecast method default. Note that, as `output_type`, only 'samples' and `quantiles` (depending on the underlying R method) are supported currently. """ @validated() def __init__( self, freq: str, prediction_length: int, method_name: str = "ets", period: int = None, trunc_length: Optional[int] = None, params: Optional[Dict] = None, ) -> None: super().__init__( freq=freq, prediction_length=prediction_length, period=period, trunc_length=trunc_length, r_file_prefix=R_FILE_PREFIX, ) assert method_name in SUPPORTED_UNIVARIATE_METHODS, ( f"method {method_name} is not supported please " f"use one of {SUPPORTED_UNIVARIATE_METHODS}" ) self.method_name = method_name self._r_method = self._robjects.r[method_name] self.params = { "prediction_length": self.prediction_length, "output_types": ["samples"], "frequency": self.period, } if params is not None: self.params.update(params) def _get_r_forecast(self, data: Dict, params: Dict) -> Dict: make_ts = self._stats_pkg.ts r_params = self._robjects.vectors.ListVector(params) vec = self._robjects.FloatVector(data["target"]) ts = make_ts(vec, frequency=self.period) forecast = self._r_method(ts, r_params) forecast_dict = dict( zip(forecast.names, map(self._unlist, list(forecast))) ) if "quantiles" in forecast_dict or "upper_quantiles" in forecast_dict: def from_interval_to_level(interval: int, side: str): if side == "upper": level = 50 + interval / 2 elif side == "lower": level = 50 - interval / 2 else: raise ValueError return level / 100 # Post-processing quantiles on then Python side for the # convenience of asserting and debugging. upper_quantiles = [ str(from_interval_to_level(interval, side="upper")) for interval in params["intervals"] ] lower_quantiles = [ str(from_interval_to_level(interval, side="lower")) for interval in params["intervals"] ] # Median forecasts would be available at two places: # Lower 0 and Higher 0 (0-prediction interval) forecast_dict["quantiles"] = dict( zip( lower_quantiles + upper_quantiles[1:], forecast_dict["lower_quantiles"] + forecast_dict["upper_quantiles"][1:], ) ) # `QuantileForecast` allows "mean" as the key; # we store them as well since they can differ from median. forecast_dict["quantiles"].update( {"mean": forecast_dict.pop("mean")} ) return forecast_dict def _preprocess_data(self, data: Dict) -> Dict: if self.trunc_length: shift_by = max(data["target"].shape[0] - self.trunc_length, 0) data["start"] = data["start"] + shift_by data["target"] = data["target"][-self.trunc_length :] return data def _override_params( self, params: Dict, num_samples: int, intervals: Optional[List] = None ) -> Dict: params["num_samples"] = num_samples if self.method_name in UNIVARIATE_POINT_FORECAST_METHODS: params["output_types"] = ["mean"] elif self.method_name in UNIVARIATE_QUANTILE_FORECAST_METHODS: params["output_types"] = ["quantiles", "mean"] if intervals is None: # This corresponds to quantiles: 0.05 to 0.95 in steps of 0.05. params["intervals"] = list(range(0, 100, 10)) else: params["intervals"] = np.sort(intervals).tolist() return params def _warning_message(self) -> None: if self.method_name in UNIVARIATE_POINT_FORECAST_METHODS: print( "Overriding `output_types` to `mean` since" f" {self.method_name} is a point forecast method." ) elif self.method_name in UNIVARIATE_QUANTILE_FORECAST_METHODS: print( "Overriding `output_types` to `quantiles` since " f"{self.method_name} is a quantile forecast method." ) def _forecast_dict_to_obj( self, forecast_dict: Dict, num_samples: int, forecast_start_date: pd.Timestamp, item_id: Optional[str], info: Dict, ) -> Union[QuantileForecast, SampleForecast]: if self.method_name in UNIVARIATE_QUANTILE_FORECAST_METHODS: quantile_forecasts_dict = forecast_dict["quantiles"] return QuantileForecast( forecast_arrays=np.array( list(quantile_forecasts_dict.values()) ), forecast_keys=list(quantile_forecasts_dict.keys()), start_date=forecast_start_date, item_id=item_id, ) else: if self.method_name in UNIVARIATE_POINT_FORECAST_METHODS: # Handling special cases outside of R is better, # since it is more visible and is easier to change. # Repeat mean forecasts `num_samples` times. samples = np.reshape( forecast_dict["mean"] * num_samples, (num_samples, self.prediction_length), ) else: samples = np.array(forecast_dict["samples"]) expected_shape = ( num_samples, self.prediction_length, ) assert ( samples.shape == expected_shape ), f"Expected shape {expected_shape} but found {samples.shape}" return SampleForecast( samples, start_date=forecast_start_date, info=info, item_id=item_id, )