# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
import logging
import re
from typing import Dict, Iterator, NamedTuple, Optional, Tuple
import pandas as pd
import gluonts # noqa
from gluonts.core.serde import load_code
from gluonts.dataset.common import DataEntry, Dataset
from gluonts.dataset.field_names import FieldName
from gluonts.dataset.stat import (
DatasetStatistics,
calculate_dataset_statistics,
)
from gluonts.evaluation import Evaluator
from gluonts.model.estimator import Estimator
from gluonts.model.forecast import Forecast
from gluonts.model.predictor import Predictor
from gluonts.itertools import maybe_len
from gluonts.transform import AdhocTransform
[docs]def make_evaluation_predictions(
dataset: Dataset,
predictor: Predictor,
num_samples: int = 100,
) -> Tuple[Iterator[Forecast], Iterator[pd.Series]]:
"""
Returns predictions for the trailing prediction_length observations of the
given time series, using the given predictor.
The predictor will take as input the given time series without the trailing
prediction_length observations.
Parameters
----------
dataset
Dataset where the evaluation will happen. Only the portion excluding
the prediction_length portion is used when making prediction.
predictor
Model used to draw predictions.
num_samples
Number of samples to draw on the model when evaluating. Only
sampling-based models will use this.
Returns
-------
Tuple[Iterator[Forecast], Iterator[pd.Series]]
A pair of iterators, the first one yielding the forecasts, and the
second one yielding the corresponding ground truth series.
"""
prediction_length = predictor.prediction_length
lead_time = predictor.lead_time
def add_ts_dataframe(
data_iterator: Iterator[DataEntry],
) -> Iterator[DataEntry]:
for data_entry in data_iterator:
data = data_entry.copy()
index = pd.period_range(
start=data[FieldName.START],
periods=data[FieldName.TARGET].shape[-1],
freq=data[FieldName.START].freq,
)
data["ts"] = pd.DataFrame(
index=index, data=data[FieldName.TARGET].transpose()
)
yield data
def ts_iter(dataset: Dataset) -> pd.DataFrame:
for data_entry in add_ts_dataframe(iter(dataset)):
yield data_entry["ts"]
def truncate_target(data):
data = data.copy()
target = data[FieldName.TARGET]
assert (
target.shape[-1] >= prediction_length
) # handles multivariate case (target_dim, history_length)
data[FieldName.TARGET] = target[..., : -prediction_length - lead_time]
return data
# TODO filter out time series with target shorter than prediction length
# TODO or fix the evaluator so it supports missing values instead (all
# TODO the test set may be gone otherwise with such a filtering)
dataset_trunc = AdhocTransform(truncate_target).apply(dataset)
return (
predictor.predict(dataset_trunc, num_samples=num_samples),
ts_iter(dataset),
)
train_dataset_stats_key = "train_dataset_stats"
test_dataset_stats_key = "test_dataset_stats"
estimator_key = "estimator"
agg_metrics_key = "agg_metrics"
[docs]def serialize_message(logger, message: str, variable):
logger.info(f"gluonts[{message}]: {variable}")
[docs]def backtest_metrics(
test_dataset: Dataset,
predictor: Predictor,
evaluator=Evaluator(
quantiles=(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)
),
num_samples: int = 100,
logging_file: Optional[str] = None,
) -> Tuple[dict, pd.DataFrame]:
"""
Parameters
----------
test_dataset
Dataset to use for testing.
predictor
The predictor to test.
evaluator
Evaluator to use.
num_samples
Number of samples to use when generating sample-based forecasts. Only
sampling-based models will use this.
logging_file
If specified, information of the backtest is redirected to this file.
Returns
-------
Tuple[dict, pd.DataFrame]
A tuple of aggregate metrics and per-time-series metrics obtained by
training `forecaster` on `train_dataset` and evaluating the resulting
`evaluator` provided on the `test_dataset`.
"""
if logging_file is not None:
log_formatter = logging.Formatter(
"[%(asctime)s %(levelname)s %(thread)d] %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
)
logger = logging.getLogger(__name__)
handler = logging.FileHandler(logging_file)
handler.setFormatter(log_formatter)
logger.addHandler(handler)
else:
logger = logging.getLogger(__name__)
test_statistics = calculate_dataset_statistics(test_dataset)
serialize_message(logger, test_dataset_stats_key, test_statistics)
forecast_it, ts_it = make_evaluation_predictions(
test_dataset, predictor=predictor, num_samples=num_samples
)
agg_metrics, item_metrics = evaluator(
ts_it, forecast_it, num_series=maybe_len(test_dataset)
)
# we only log aggregate metrics for now as item metrics may be very large
for name, value in agg_metrics.items():
serialize_message(logger, f"metric-{name}", value)
if logging_file is not None:
# Close the file handler to avoid letting the file open.
# https://stackoverflow.com/questions/24816456/python-logging-wont-shutdown
logger.removeHandler(handler)
del logger, handler
return agg_metrics, item_metrics
# TODO does it make sense to have this then?