gluonts.torch.model.tft.lightning_module module#

class gluonts.torch.model.tft.lightning_module.TemporalFusionTransformerLightningModule(model: gluonts.torch.model.tft.module.TemporalFusionTransformerModel, lr: float = 0.001, patience: int = 10, weight_decay: float = 0.0)[source]#

Bases: pytorch_lightning.core.module.LightningModule

A pl.LightningModule class that can be used to train a TemporalFusionTransformerModel with PyTorch Lightning.

This is a thin layer around a (wrapped) TemporalFusionTransformerModel object, that exposes the methods to evaluate training and validation loss.

Parameters
  • modelTemporalFusionTransformerModel to be trained.

  • lr – Learning rate, default: 1e-3.

  • weight_decay – Weight decay regularization parameter, default: 1e-8.

  • patience – Patience parameter for learning rate scheduler, default: 10.

configure_optimizers()[source]#

Returns the optimizer to use.

forward(*args, **kwargs)[source]#

Same as torch.nn.Module.forward().

Parameters
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns

Your model’s output

training_step(batch, batch_idx: int)[source]#

Execute training step.

validation_step(batch, batch_idx: int)[source]#

Execute validation step.