gluonts.torch.util module

class gluonts.torch.util.IterableDataset(iterable)[source]

Bases: torch.utils.data.dataset.IterableDataset

gluonts.torch.util.copy_parameters(net_source: torch.nn.modules.module.Module, net_dest: torch.nn.modules.module.Module, strict: Optional[bool] = True) → None[source]

Copies parameters from one network to another.

Parameters
  • net_source – Input network.

  • net_dest – Output network.

  • strict – whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: True

gluonts.torch.util.get_forward_input_names(module: Type[torch.nn.modules.module.Module])[source]
gluonts.torch.util.weighted_average(x: torch.Tensor, weights: Optional[torch.Tensor] = None, dim=None) → torch.Tensor[source]

Computes the weighted average of a given tensor across a given dim, masking values associated with weight zero, meaning instead of nan * 0 = nan you will get 0 * 0 = 0.

Parameters
  • x – Input tensor, of which the average must be computed.

  • weights – Weights tensor, of the same shape as x.

  • dim – The dim along which to average x

Returns

The tensor with values averaged along the specified dim.

Return type

Tensor