# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
import json
import logging
from typing import Any, Optional, Type, Union
import gluonts
from gluonts.core import fqname_for
from gluonts.dataset.common import Dataset
from gluonts.evaluation import Evaluator, backtest
from gluonts.model.estimator import Estimator, IncrementallyTrainable
from gluonts.model.forecast import Quantile
from gluonts.model.predictor import Predictor
from gluonts.itertools import maybe_len
from gluonts.transform import FilterTransformation
from .env import TrainEnv
from .util import invoke_with
logger = logging.getLogger(__name__)
[docs]def log_metric(metric: str, value: Any) -> None:
logger.info(f"gluonts[{metric}]: {value}")
[docs]def log_version(forecaster_type):
name = fqname_for(forecaster_type)
version = forecaster_type.__version__
logger.info(f"Using gluonts v{gluonts.__version__}")
logger.info(f"Using forecaster {name} v{version}")
[docs]def run_train_and_test(
env: TrainEnv, forecaster_type: Type[Union[Estimator, Predictor]]
) -> None:
log_version(forecaster_type)
logger.info(
"Using the following data channels: %s", ", ".join(env.datasets)
)
forecaster = forecaster_type.from_inputs(
env.datasets["train"], **env.hyperparameters
)
logger.info(
"The forecaster can be reconstructed with the following expression: "
f"{forecaster}"
)
if isinstance(forecaster, Predictor):
predictor = forecaster
else:
predictor = run_train(
forecaster=forecaster,
train_dataset=env.datasets["train"],
validation_dataset=env.datasets.get("validation"),
hyperparameters=env.hyperparameters,
from_predictor=env.datasets.get("model"),
)
predictor.serialize(env.path.model)
if "test" in env.datasets:
run_test(env, predictor, env.datasets["test"], env.hyperparameters)
[docs]def run_train(
forecaster: Estimator,
train_dataset: Dataset,
hyperparameters: dict,
validation_dataset: Optional[Dataset],
from_predictor: Optional[Predictor],
) -> Predictor:
num_workers = (
int(hyperparameters["num_workers"])
if "num_workers" in hyperparameters
else None
)
shuffle_buffer_length = (
int(hyperparameters["shuffle_buffer_length"])
if "shuffle_buffer_length" in hyperparameters
else None
)
num_prefetch = (
int(hyperparameters["num_prefetch"])
if "num_prefetch" in hyperparameters
else None
)
if from_predictor is not None:
assert isinstance(forecaster, IncrementallyTrainable), (
"The model provided does not implement the "
"IncrementallyTrainable protocol"
)
return invoke_with(
forecaster.train_from,
from_predictor,
training_data=train_dataset,
validation_data=validation_dataset,
num_workers=num_workers,
num_prefetch=num_prefetch,
shuffle_buffer_length=shuffle_buffer_length,
)
return invoke_with(
forecaster.train,
training_data=train_dataset,
validation_data=validation_dataset,
num_workers=num_workers,
num_prefetch=num_prefetch,
shuffle_buffer_length=shuffle_buffer_length,
)
[docs]def run_test(
env: TrainEnv,
predictor: Predictor,
test_dataset: Dataset,
hyperparameters: dict,
) -> None:
len_original = maybe_len(test_dataset)
test_dataset = FilterTransformation(
lambda x: x["target"].shape[-1] > predictor.prediction_length
).apply(test_dataset)
len_filtered = len(test_dataset)
if len_original is not None and len_original > len_filtered:
logger.warning(
"Not all time series in the test-channel have "
"enough data to be used for evaluation. Proceeding with "
f"{len_filtered}/{len_original} "
f"(~{int(len_filtered / len_original * 100)}%) items."
)
forecast_it, ts_it = backtest.make_evaluation_predictions(
dataset=test_dataset, predictor=predictor, num_samples=100
)
test_quantiles = (
[
Quantile.parse(quantile).name
for quantile in hyperparameters["test_quantiles"]
]
if "test_quantiles" in hyperparameters
else None
)
if test_quantiles is not None:
logger.info(f"Using quantiles `{test_quantiles}` for evaluation.")
evaluator = Evaluator(quantiles=test_quantiles)
else:
evaluator = Evaluator()
agg_metrics, item_metrics = evaluator(
ts_iterator=ts_it,
fcst_iterator=forecast_it,
num_series=len(test_dataset),
)
# we only log aggregate metrics for now as item metrics may be very large
for name, score in agg_metrics.items():
logger.info(f"#test_score ({env.current_host}, {name}): {score}")
# store metrics
with open(env.path.model / "agg_metrics.json", "w") as agg_metric_file:
json.dump(agg_metrics, agg_metric_file)
with open(env.path.model / "item_metrics.csv", "w") as item_metrics_file:
item_metrics.to_csv(item_metrics_file, index=False)