gluonts.torch.model.wavenet.lightning_module module#

class gluonts.torch.model.wavenet.lightning_module.WaveNetLightningModule(model_kwargs: dict, lr: float = 0.001, weight_decay: float = 1e-08)[source]#

Bases: LightningModule

LightningModule wrapper over WaveNet.

Parameters:
  • model_kwargs – Keyword arguments to pass to WaveNet.

  • lr – Learning rate, by default 1e-3

  • optional – Learning rate, by default 1e-3

  • weight_decay – Weight decay, by default 1e-8

  • optional – Weight decay, by default 1e-8

configure_optimizers()[source]#

Returns the optimizer to use.

forward(*args, **kwargs)[source]#

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns:

Your model’s output

training_step(batch, batch_idx: int)[source]#

Execute training step.

validation_step(batch, batch_idx: int)[source]#

Execute validation step.