gluonts.torch.model.predictor module#
- class gluonts.torch.model.predictor.PyTorchPredictor(input_names: List[str], prediction_net: Module, batch_size: int, prediction_length: int, input_transform: Transformation, forecast_generator: ForecastGenerator = gluonts.model.forecast_generator.SampleForecastGenerator(), output_transform: Optional[Callable[[Dict[str, Any], ndarray], ndarray]] = None, lead_time: int = 0, device: Union[str, device] = 'auto')[source]#
Bases:
RepresentablePredictor
- classmethod deserialize(path: Path, device: Optional[Union[device, str]] = None) PyTorchPredictor [source]#
Load a serialized predictor from the given path.
- Parameters:
path – Path to the serialized files predictor.
**kwargs – Optional context/device parameter to be used with the predictor. If nothing is passed will use the GPU if available and CPU otherwise.
- property network: Module#
- predict(dataset: Dataset, num_samples: Optional[int] = None) Iterator[Forecast] [source]#
Compute forecasts for the time series in the provided dataset. This method is not implemented in this abstract class; please use one of the subclasses. :param dataset: The dataset containing the time series to predict.
- Returns:
Iterator over the forecasts, in the same order as the dataset iterable was provided.
- Return type:
Iterator[Forecast]
- to(device: Union[str, device]) PyTorchPredictor [source]#