gluonts.torch.model.predictor module

class gluonts.torch.model.predictor.PyTorchPredictor(input_names: List[str], prediction_net: torch.nn.modules.module.Module, batch_size: int, prediction_length: int, freq: str, input_transform: gluonts.transform._base.Transformation, forecast_generator: gluonts.model.forecast_generator.ForecastGenerator = gluonts.model.forecast_generator.SampleForecastGenerator(), output_transform: Optional[Callable[[Dict[str, Any], numpy.ndarray], numpy.ndarray]] = None, lead_time: int = 0, device=device(type='cpu'))[source]

Bases: gluonts.model.predictor.Predictor

classmethod deserialize(path: pathlib.Path, device: Optional[torch.device] = None) → gluonts.torch.model.predictor.PyTorchPredictor[source]

Load a serialized predictor from the given path

  • path – Path to the serialized files predictor.

  • **kwargs – Optional context/device parameter to be used with the predictor. If nothing is passed will use the GPU if available and CPU otherwise.

predict(dataset: gluonts.dataset.common.Dataset, num_samples: Optional[int] = None) → Iterator[gluonts.model.forecast.Forecast][source]

Compute forecasts for the time series in the provided dataset. This method is not implemented in this abstract class; please use one of the subclasses. :param dataset: The dataset containing the time series to predict.


Iterator over the forecasts, in the same order as the dataset iterable was provided.

Return type


serialize(path: pathlib.Path) → None[source]
to(device) → gluonts.torch.model.predictor.PyTorchPredictor[source]