# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
import logging
from typing import Optional, Tuple, Iterator
import numpy as np
import pandas as pd
from gluonts.dataset.common import DataEntry, Dataset
from gluonts.dataset.field_names import FieldName
from gluonts.dataset.split import split
from gluonts.dataset.stat import calculate_dataset_statistics
from gluonts.dataset.util import period_index
from gluonts.evaluation import Evaluator
from gluonts.model.forecast import Forecast
from gluonts.model.predictor import Predictor
from gluonts.itertools import maybe_len
def _to_dataframe(input_label: Tuple[DataEntry, DataEntry]) -> pd.DataFrame:
"""
Turn a pair of consecutive (in time) data entries into a dataframe.
"""
start = input_label[0][FieldName.START]
targets = [entry[FieldName.TARGET] for entry in input_label]
full_target = np.concatenate(targets, axis=-1)
index = period_index(
{FieldName.START: start, FieldName.TARGET: full_target}
)
return pd.DataFrame(full_target.transpose(), index=index)
[docs]def make_evaluation_predictions(
dataset: Dataset,
predictor: Predictor,
num_samples: int = 100,
) -> Tuple[Iterator[Forecast], Iterator[pd.Series]]:
"""
Returns predictions for the trailing prediction_length observations of the
given time series, using the given predictor.
The predictor will take as input the given time series without the trailing
prediction_length observations.
Parameters
----------
dataset
Dataset where the evaluation will happen. Only the portion excluding
the prediction_length portion is used when making prediction.
predictor
Model used to draw predictions.
num_samples
Number of samples to draw on the model when evaluating. Only
sampling-based models will use this.
Returns
-------
Tuple[Iterator[Forecast], Iterator[pd.Series]]
A pair of iterators, the first one yielding the forecasts, and the
second one yielding the corresponding ground truth series.
"""
window_length = predictor.prediction_length + predictor.lead_time
_, test_template = split(dataset, offset=-window_length)
test_data = test_template.generate_instances(window_length)
return (
predictor.predict(test_data.input, num_samples=num_samples),
map(_to_dataframe, test_data),
)
train_dataset_stats_key = "train_dataset_stats"
test_dataset_stats_key = "test_dataset_stats"
estimator_key = "estimator"
agg_metrics_key = "agg_metrics"
[docs]def serialize_message(logger, message: str, variable):
logger.info(f"gluonts[{message}]: {variable}")
[docs]def backtest_metrics(
test_dataset: Dataset,
predictor: Predictor,
evaluator=Evaluator(
quantiles=(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)
),
num_samples: int = 100,
logging_file: Optional[str] = None,
) -> Tuple[dict, pd.DataFrame]:
"""
Parameters
----------
test_dataset
Dataset to use for testing.
predictor
The predictor to test.
evaluator
Evaluator to use.
num_samples
Number of samples to use when generating sample-based forecasts. Only
sampling-based models will use this.
logging_file
If specified, information of the backtest is redirected to this file.
Returns
-------
Tuple[dict, pd.DataFrame]
A tuple of aggregate metrics and metrics per time series obtained by
training `forecaster` on `train_dataset` and evaluating the resulting
`evaluator` provided on the `test_dataset`.
"""
if logging_file is not None:
log_formatter = logging.Formatter(
"[%(asctime)s %(levelname)s %(thread)d] %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
)
logger = logging.getLogger(__name__)
handler = logging.FileHandler(logging_file)
handler.setFormatter(log_formatter)
logger.addHandler(handler)
else:
logger = logging.getLogger(__name__)
test_statistics = calculate_dataset_statistics(test_dataset)
serialize_message(logger, test_dataset_stats_key, test_statistics)
forecast_it, ts_it = make_evaluation_predictions(
test_dataset, predictor=predictor, num_samples=num_samples
)
agg_metrics, item_metrics = evaluator(
ts_it, forecast_it, num_series=maybe_len(test_dataset)
)
# we only log aggregate metrics for now as item metrics may be very large
for name, value in agg_metrics.items():
serialize_message(logger, f"metric-{name}", value)
if logging_file is not None:
# Close the file handler to avoid letting the file open.
# https://stackoverflow.com/questions/24816456/python-logging-wont-shutdown
logger.removeHandler(handler)
del logger, handler
return agg_metrics, item_metrics