gluonts.torch.model.tft.lightning_module module#
- class gluonts.torch.model.tft.lightning_module.TemporalFusionTransformerLightningModule(model_kwargs: dict, lr: float = 0.001, patience: int = 10, weight_decay: float = 0.0)[source]#
Bases:
LightningModule
A
pl.LightningModule
class that can be used to train aTemporalFusionTransformerModel
with PyTorch Lightning.This is a thin layer around a (wrapped)
TemporalFusionTransformerModel
object, that exposes the methods to evaluate training and validation loss.- Parameters:
model_kwargs – Keyword arguments to construct the
TemporalFusionTransformerModel
to be trained.lr – Learning rate.
weight_decay – Weight decay regularization parameter.
patience – Patience parameter for learning rate scheduler.