gluonts.torch.model.predictor module#

class gluonts.torch.model.predictor.PyTorchPredictor(input_names: List[str], prediction_net: Module, batch_size: int, prediction_length: int, input_transform: Transformation, forecast_generator: ForecastGenerator = gluonts.model.forecast_generator.SampleForecastGenerator(), output_transform: Optional[Callable[[Dict[str, Any], ndarray], ndarray]] = None, lead_time: int = 0, device: Union[str, device] = 'auto')[source]#

Bases: RepresentablePredictor

classmethod deserialize(path: Path, device: Optional[Union[device, str]] = None) PyTorchPredictor[source]#

Load a serialized predictor from the given path.

Parameters:
  • path – Path to the serialized files predictor.

  • **kwargs – Optional context/device parameter to be used with the predictor. If nothing is passed will use the GPU if available and CPU otherwise.

property network: Module#
predict(dataset: Dataset, num_samples: Optional[int] = None) Iterator[Forecast][source]#

Compute forecasts for the time series in the provided dataset. This method is not implemented in this abstract class; please use one of the subclasses. :param dataset: The dataset containing the time series to predict.

Returns:

Iterator over the forecasts, in the same order as the dataset iterable was provided.

Return type:

Iterator[Forecast]

serialize(path: Path) None[source]#
to(device: Union[str, device]) PyTorchPredictor[source]#