gluonts.torch.model.tft.lightning_module module#
- class gluonts.torch.model.tft.lightning_module.TemporalFusionTransformerLightningModule(model: gluonts.torch.model.tft.module.TemporalFusionTransformerModel, lr: float = 0.001, patience: int = 10, weight_decay: float = 0.0)[source]#
Bases:
pytorch_lightning.core.module.LightningModuleA
pl.LightningModuleclass that can be used to train aTemporalFusionTransformerModelwith PyTorch Lightning.This is a thin layer around a (wrapped)
TemporalFusionTransformerModelobject, that exposes the methods to evaluate training and validation loss.- Parameters
model –
TemporalFusionTransformerModelto be trained.lr – Learning rate, default:
1e-3.weight_decay – Weight decay regularization parameter, default:
1e-8.patience – Patience parameter for learning rate scheduler, default:
10.