gluonts.torch.model.estimator module#
- class gluonts.torch.model.estimator.PyTorchLightningEstimator(trainer_kwargs: Dict[str, Any], lead_time: int = 0)[source]#
Bases:
Estimator
An Estimator type with utilities for creating PyTorch-Lightning-based models.
To extend this class, one needs to implement three methods: create_transformation, create_training_network, create_predictor, create_training_data_loader, and create_validation_data_loader.
- create_lightning_module() LightningModule [source]#
Create and return the network used for training (i.e., computing the loss).
- Returns:
The network that computes the loss given input data.
- Return type:
pl.LightningModule
- create_predictor(transformation: Transformation, module) PyTorchPredictor [source]#
Create and return a predictor object.
- Parameters:
transformation – Transformation to be applied to data before it goes into the model.
module – A trained pl.LightningModule object.
- Returns:
A predictor wrapping a nn.Module used for inference.
- Return type:
- create_training_data_loader(data: Dataset, module, **kwargs) Iterable [source]#
Create a data loader for training purposes.
- Parameters:
data – Dataset from which to create the data loader.
module – The pl.LightningModule object that will receive the batches from the data loader.
- Returns:
The data loader, i.e. and iterable over batches of data.
- Return type:
Iterable
- create_transformation() Transformation [source]#
Create and return the transformation needed for training and inference.
- Returns:
The transformation that will be applied entry-wise to datasets, at training and inference time.
- Return type:
- create_validation_data_loader(data: Dataset, module, **kwargs) Iterable [source]#
Create a data loader for validation purposes.
- Parameters:
data – Dataset from which to create the data loader.
module – The pl.LightningModule object that will receive the batches from the data loader.
- Returns:
The data loader, i.e. and iterable over batches of data.
- Return type:
Iterable
- train(training_data: Dataset, validation_data: Optional[Dataset] = None, shuffle_buffer_length: Optional[int] = None, cache_data: bool = False, ckpt_path: Optional[str] = None, **kwargs) PyTorchPredictor [source]#
Train the estimator on the given data.
- Parameters:
training_data – Dataset to train the model on.
validation_data – Dataset to validate the model on during training.
- Returns:
The predictor containing the trained model.
- Return type:
- train_from(predictor: Predictor, training_data: Dataset, validation_data: Optional[Dataset] = None, shuffle_buffer_length: Optional[int] = None, cache_data: bool = False, ckpt_path: Optional[str] = None) PyTorchPredictor [source]#
- train_model(training_data: Dataset, validation_data: Optional[Dataset] = None, from_predictor: Optional[PyTorchPredictor] = None, shuffle_buffer_length: Optional[int] = None, cache_data: bool = False, ckpt_path: Optional[str] = None, **kwargs) TrainOutput [source]#
- class gluonts.torch.model.estimator.TrainOutput(transformation, trained_net, trainer, predictor)[source]#
Bases:
NamedTuple
- predictor: PyTorchPredictor#
Alias for field number 3
- trained_net: Module#
Alias for field number 1
- trainer: Trainer#
Alias for field number 2
- transformation: Transformation#
Alias for field number 0