# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
import functools
import itertools
import json
import logging
import multiprocessing as mp
import sys
import traceback
from pathlib import Path
from pydoc import locate
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Callable, Iterator, Optional, Type
import numpy as np
import gluonts
from gluonts.core import fqname_for
from gluonts.core.component import equals, from_hyperparameters, validated
from gluonts.core.serde import dump_json, load_json
from gluonts.dataset.common import DataEntry, Dataset
from gluonts.exceptions import GluonTSException
from gluonts.model.forecast import Forecast
if TYPE_CHECKING: # avoid circular import
from gluonts.model.estimator import Estimator # noqa
OutputTransform = Callable[[DataEntry, np.ndarray], np.ndarray]
[docs]class Predictor:
"""
Abstract class representing predictor objects.
Parameters
----------
prediction_length
Prediction horizon.
"""
__version__: str = gluonts.__version__
def __init__(self, prediction_length: int, lead_time: int = 0) -> None:
assert (
prediction_length > 0
), "The value of `prediction_length` should be > 0"
assert lead_time >= 0, "The value of `lead_time` should be >= 0"
self.prediction_length = prediction_length
self.lead_time = lead_time
[docs] def predict(self, dataset: Dataset, **kwargs) -> Iterator[Forecast]:
"""
Compute forecasts for the time series in the provided dataset.
This method is not implemented in this abstract class; please
use one of the subclasses.
Parameters
----------
dataset
The dataset containing the time series to predict.
Returns
-------
Iterator[Forecast]
Iterator over the forecasts, in the same order as the dataset
iterable was provided.
"""
raise NotImplementedError
[docs] def serialize(self, path: Path) -> None:
# serialize Predictor type
with (path / "type.txt").open("w") as fp:
fp.write(fqname_for(self.__class__))
with (path / "version.json").open("w") as fp:
json.dump(
{"model": self.__version__, "gluonts": gluonts.__version__}, fp
)
[docs] @classmethod
def deserialize(cls, path: Path, **kwargs) -> "Predictor":
"""
Load a serialized predictor from the given path.
Parameters
----------
path
Path to the serialized files predictor.
**kwargs
Optional context/device parameter to be used with the predictor.
If nothing is passed will use the GPU if available and CPU
otherwise.
"""
# deserialize Predictor type
with (path / "type.txt").open("r") as fp:
tpe_str = fp.readline()
tpe = locate(tpe_str)
assert tpe is not None, f"Cannot locate {tpe_str}."
# ensure that predictor_cls is a subtype of Predictor
if not issubclass(tpe, Predictor):
raise OSError(
f"Class {fqname_for(tpe)} is not "
f"a subclass of {fqname_for(Predictor)}"
)
# call deserialize() for the concrete Predictor type
return tpe.deserialize(path, **kwargs)
[docs] @classmethod
def from_hyperparameters(cls, **hyperparameters):
return from_hyperparameters(cls, **hyperparameters)
[docs] @classmethod
def derive_auto_fields(cls, train_iter):
return {}
[docs]class RepresentablePredictor(Predictor):
"""
An abstract predictor that can be subclassed by models that are not based
on Gluon. Subclasses should have @validated() constructors.
(De)serialization and value equality are all implemented on top of the.
@validated() logic.
Parameters
----------
prediction_length
Prediction horizon.
"""
@validated()
def __init__(self, prediction_length: int, lead_time: int = 0) -> None:
super().__init__(
lead_time=lead_time, prediction_length=prediction_length
)
[docs] def predict(self, dataset: Dataset, **kwargs) -> Iterator[Forecast]:
for item in dataset:
yield self.predict_item(item)
[docs] def predict_item(self, item: DataEntry) -> Forecast:
raise NotImplementedError
def __eq__(self, that):
"""
Two RepresentablePredictor instances are considered equal if they have
the same constructor arguments.
"""
return equals(self, that)
[docs] def serialize(self, path: Path) -> None:
# call Predictor.serialize() in order to serialize the class name
super().serialize(path)
with (path / "predictor.json").open("w") as fp:
print(dump_json(self), file=fp)
[docs] @classmethod
def deserialize(cls, path: Path) -> "RepresentablePredictor":
with (path / "predictor.json").open("r") as fp:
return load_json(fp.read())
[docs]class WorkerError:
def __init__(self, msg):
self.msg = msg
def _worker_loop(
predictor_path: Path,
input_queue: mp.Queue,
output_queue: mp.Queue,
worker_id,
**kwargs,
):
"""
Worker loop for multiprocessing Predictor.
Loads the predictor serialized in predictor_path reads inputs from
input_queue and writes forecasts to output_queue
"""
predictor = Predictor.deserialize(predictor_path)
while True:
idx, data_chunk = input_queue.get()
if idx is None:
output_queue.put((None, None, None))
break
try:
result = list(predictor.predict(data_chunk, **kwargs))
except Exception:
we = WorkerError(
"".join(traceback.format_exception(*sys.exc_info()))
)
output_queue.put((we, None, None))
break
output_queue.put((idx, worker_id, result))
[docs]class ParallelizedPredictor(Predictor):
"""
Runs multiple instances (workers) of a predictor in parallel.
Exceptions are propagated from the workers.
Note: That there is currently an issue with tqdm that will cause things
to hang if the ParallelizedPredictor is used with tqdm and an exception
occurs during prediction.
https://github.com/tqdm/tqdm/issues/548
Parameters
----------
base_predictor
A representable predictor that will be used
num_workers
Number of workers (processes) to use. If set to
None, one worker per CPU will be used.
chunk_size
Number of items to pass per call
"""
def __init__(
self,
base_predictor: Predictor,
num_workers: Optional[int] = None,
chunk_size=1,
) -> None:
super().__init__(
lead_time=base_predictor.lead_time,
prediction_length=base_predictor.prediction_length,
)
self._base_predictor = base_predictor
self._num_workers = (
num_workers if num_workers is not None else mp.cpu_count()
)
self._chunk_size = chunk_size
self._num_running_workers = 0
self._input_queues = []
self._output_queue = None
def _grouper(self, iterable, n):
iterator = iter(iterable)
group = tuple(itertools.islice(iterator, n))
while group:
yield group
group = tuple(itertools.islice(iterator, n))
[docs] def terminate(self):
for q in self._input_queues:
q.put((None, None))
for w in self._workers:
w.terminate()
for i, w in enumerate(self._workers):
w.join()
[docs] def predict(self, dataset: Dataset, **kwargs) -> Iterator[Forecast]:
with TemporaryDirectory() as tempdir:
predictor_path = Path(tempdir)
self._base_predictor.serialize(predictor_path)
# TODO: Consider using shared memory for the data transfer.
self._input_queues = [mp.Queue() for _ in range(self._num_workers)]
self._output_queue = mp.Queue()
workers = []
for worker_id, in_q in enumerate(self._input_queues):
worker = mp.Process(
target=_worker_loop,
args=(predictor_path, in_q, self._output_queue, worker_id),
kwargs=kwargs,
)
worker.daemon = True
worker.start()
workers.append(worker)
self._num_running_workers += 1
self._workers = workers
chunked_data = self._grouper(dataset, self._chunk_size)
self._send_idx = 0
self._next_idx = 0
self._data_buffer = {}
worker_ids = list(range(self._num_workers))
def receive():
idx, worker_id, result = self._output_queue.get()
if isinstance(idx, WorkerError):
self._num_running_workers -= 1
self.terminate()
raise Exception(idx.msg)
if idx is not None:
self._data_buffer[idx] = result
return idx, worker_id, result
def get_next_from_buffer():
while self._next_idx in self._data_buffer:
result_batch = self._data_buffer.pop(self._next_idx)
self._next_idx += 1
yield from result_batch
def send(worker_id, chunk):
q = self._input_queues[worker_id]
q.put((self._send_idx, chunk))
self._send_idx += 1
try:
# prime the queues
for wid in worker_ids:
chunk = next(chunked_data)
send(wid, chunk)
while True:
idx, wid, result = receive()
yield from get_next_from_buffer()
chunk = next(chunked_data)
send(wid, chunk)
except StopIteration:
# signal workers end of data
for q in self._input_queues:
q.put((None, None))
# collect any outstanding results
while self._num_running_workers > 0:
idx, worker_id, result = receive()
if idx is None:
self._num_running_workers -= 1
continue
yield from get_next_from_buffer()
assert len(self._data_buffer) == 0
assert self._send_idx == self._next_idx
[docs]class Localizer(Predictor):
"""
A Predictor that uses an estimator to train a local model per time series
and immediately calls this to predict.
Parameters
----------
estimator
The estimator object to train on each dataset entry at prediction time.
"""
def __init__(self, estimator: "Estimator"):
super().__init__(
lead_time=estimator.lead_time,
prediction_length=estimator.prediction_length,
)
self.estimator = estimator
[docs] def predict(self, dataset: Dataset, **kwargs) -> Iterator[Forecast]:
logger = logging.getLogger(__name__)
for i, ts in enumerate(dataset, start=1):
logger.info(f"training for time series {i} / {len(dataset)}")
trained_pred = self.estimator.train([ts])
logger.info(f"predicting for time series {i} / {len(dataset)}")
yield from trained_pred.predict([ts], **kwargs)
[docs]class FallbackPredictor(Predictor):
[docs] @classmethod
def from_predictor(
cls, base: RepresentablePredictor, **overrides
) -> Predictor:
# Create predictor based on an existing predictor.
# This let's us create a MeanPredictor as a fallback on the fly.
return cls.from_hyperparameters(
**getattr(base, "__init_args__"), **overrides
)
[docs]def fallback(fallback_cls: Type[FallbackPredictor]):
def decorator(predict_item):
@functools.wraps(predict_item)
def fallback_predict(self, item: DataEntry) -> Forecast:
try:
return predict_item(self, item)
except GluonTSException:
raise
except Exception:
logging.warning(
f"Base predictor failed with: {traceback.format_exc()}"
)
fallback_predictor = fallback_cls.from_predictor(self)
return fallback_predictor.predict_item(item)
return fallback_predict
return decorator