Source code for gluonts.torch.model.predictor

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from pathlib import Path
from typing import Iterator, List, Optional, Union

import numpy as np
import torch
import torch.nn as nn

from gluonts.core.component import validated
from gluonts.dataset.common import Dataset
from gluonts.dataset.loader import InferenceDataLoader
from gluonts.model.forecast import Forecast
from gluonts.model.forecast_generator import (
    ForecastGenerator,
    SampleForecastGenerator,
    to_numpy,
)
from gluonts.model.predictor import OutputTransform, RepresentablePredictor
from gluonts.torch.batchify import batchify
from gluonts.torch.util import resolve_device
from gluonts.transform import SelectFields, Transformation


@to_numpy.register(torch.Tensor)
def _(x: torch.Tensor) -> np.ndarray:
    return x.cpu().detach().numpy()


[docs]class PyTorchPredictor(RepresentablePredictor): @validated() def __init__( self, input_names: List[str], prediction_net: nn.Module, batch_size: int, prediction_length: int, input_transform: Transformation, forecast_generator: ForecastGenerator = SampleForecastGenerator(), output_transform: Optional[OutputTransform] = None, lead_time: int = 0, device: Union[str, torch.device] = "auto", ) -> None: super().__init__(prediction_length, lead_time=lead_time) self.input_names = input_names self.batch_size = batch_size self.input_transform = input_transform self.forecast_generator = forecast_generator self.output_transform = output_transform self.device = resolve_device(device) self.prediction_net = prediction_net.to(self.device) self.required_fields = ["forecast_start", "item_id", "info"]
[docs] def to(self, device: Union[str, torch.device]) -> "PyTorchPredictor": self.device = resolve_device(device) self.prediction_net = self.prediction_net.to(self.device) return self
@property def network(self) -> nn.Module: return self.prediction_net
[docs] def predict( # type: ignore self, dataset: Dataset, num_samples: Optional[int] = None ) -> Iterator[Forecast]: inference_data_loader = InferenceDataLoader( dataset, transform=self.input_transform + SelectFields( self.input_names + self.required_fields, allow_missing=True ), batch_size=self.batch_size, stack_fn=lambda data: batchify(data, self.device), ) self.prediction_net.eval() with torch.no_grad(): yield from self.forecast_generator( inference_data_loader=inference_data_loader, prediction_net=self.prediction_net, input_names=self.input_names, output_transform=self.output_transform, num_samples=num_samples, )
[docs] def serialize(self, path: Path) -> None: super().serialize(path) torch.save( self.prediction_net.state_dict(), path / "prediction-net-state.pt" )
[docs] @classmethod def deserialize( # type: ignore cls, path: Path, device: Optional[Union[str, torch.device]] = None ) -> "PyTorchPredictor": predictor = super().deserialize(path) assert isinstance(predictor, cls) if device is not None: device = resolve_device(device) predictor.to(device) state_dict = torch.load( path / "prediction-net-state.pt", map_location=device, ) predictor.prediction_net.load_state_dict(state_dict) return predictor