Source code for gluonts.torch.model.predictor

# Copyright 2018, Inc. or its affiliates. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
# or in the "license" file accompanying this file. This file is distributed
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from pathlib import Path
from typing import Iterator, List, Optional

import numpy as np
import torch
import torch.nn as nn

from gluonts.core.serde import dump_json, load_json
from gluonts.dataset.common import Dataset
from gluonts.dataset.loader import InferenceDataLoader
from gluonts.model.forecast import Forecast
from gluonts.model.forecast_generator import (
from gluonts.model.predictor import OutputTransform, Predictor
from gluonts.torch.batchify import batchify
from gluonts.torch.component import equals
from gluonts.transform import Transformation

def _(prediction_net: nn.Module, inputs: torch.Tensor) -> np.ndarray:
    return prediction_net(*inputs).cpu().numpy()

[docs]class PyTorchPredictor(Predictor): def __init__( self, input_names: List[str], prediction_net: nn.Module, batch_size: int, prediction_length: int, freq: str, input_transform: Transformation, forecast_generator: ForecastGenerator = SampleForecastGenerator(), output_transform: Optional[OutputTransform] = None, lead_time: int = 0, device=torch.device("cpu"), ) -> None: super().__init__(prediction_length, freq=freq, lead_time=lead_time) self.input_names = input_names self.prediction_net = self.batch_size = batch_size self.input_transform = input_transform self.forecast_generator = forecast_generator self.output_transform = output_transform self.device = device
[docs] def to(self, device) -> "PyTorchPredictor": self.prediction_net = self.device = device return self
[docs] def predict( self, dataset: Dataset, num_samples: Optional[int] = None ) -> Iterator[Forecast]: inference_data_loader = InferenceDataLoader( dataset, transform=self.input_transform, batch_size=self.batch_size, stack_fn=lambda data: batchify(data, self.device), ) self.prediction_net.eval() with torch.no_grad(): yield from self.forecast_generator( inference_data_loader=inference_data_loader, prediction_net=self.prediction_net, input_names=self.input_names, freq=self.freq, output_transform=self.output_transform, num_samples=num_samples, )
def __eq__(self, that): if type(self) != type(that): return False # TODO: also consider equality of the pipelines # if not equals(self.input_transform, that.input_transform): # return False return equals( self.prediction_net.state_dict(), that.prediction_net.state_dict(), )
[docs] def serialize(self, path: Path) -> None: super().serialize(path) # serialize network with (path / "prediction_net.json").open("w") as fp: print(dump_json(self.prediction_net), file=fp) self.prediction_net.state_dict(), path / "prediction_net_state" ) # serialize transformation chain with (path / "input_transform.json").open("w") as fp: print(dump_json(self.input_transform), file=fp) # FIXME: also needs to serialize the output_transform # serialize all remaining constructor parameters with (path / "parameters.json").open("w") as fp: parameters = dict( batch_size=self.batch_size, prediction_length=self.prediction_length, freq=self.freq, lead_time=self.lead_time, forecast_generator=self.forecast_generator, input_names=self.input_names, ) print(dump_json(parameters), file=fp)
[docs] @classmethod def deserialize( cls, path: Path, device: Optional[torch.device] = None ) -> "PyTorchPredictor": # deserialize constructor parameters with (path / "parameters.json").open("r") as fp: parameters = load_json( # deserialize transformation chain with (path / "input_transform.json").open("r") as fp: transformation = load_json( # deserialize network with (path / "prediction_net.json").open("r") as fp: prediction_net = load_json( prediction_net.load_state_dict( torch.load(path / "prediction_net_state", map_location=device) ) parameters["device"] = device return PyTorchPredictor( input_transform=transformation, prediction_net=prediction_net, **parameters, )