gluonts.mx.block.dropout module

class gluonts.mx.block.dropout.RNNZoneoutCell(base_cell: mxnet.gluon.rnn.rnn_cell.RecurrentCell, zoneout_outputs: float = 0.0, zoneout_states: float = 0.0)[source]

Bases: mxnet.gluon.rnn.rnn_cell.ModifierCell

Applies Zoneout on base cell. The implementation follows [KMK16]. Compared to mx.gluon.rnn.ZoneoutCell, this implementation uses the same mask for output and states[0], since for RNN cells, states[0] is the same as output, except for ResidualCell, where states[0] = input + ouptut

Parameters
  • base_cell – The cell on which to perform variational dropout.

  • zoneout_outputs – The dropout rate for outputs. Won’t apply dropout if it equals 0.

  • zoneout_states – The dropout rate for state inputs on the first state channel. Won’t apply dropout if it equals 0.

hybrid_forward(F, inputs: Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol], states: Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol]) → Tuple[Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol], Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol]][source]

Overrides to construct symbolic graph for this Block.

Parameters
  • x (Symbol or NDArray) – The first input tensor.

  • *args (list of Symbol or list of NDArray) – Additional input tensors.

reset()[source]

Reset before re-using the cell for another graph.

class gluonts.mx.block.dropout.VariationalZoneoutCell(base_cell: mxnet.gluon.rnn.rnn_cell.RecurrentCell, zoneout_outputs: float = 0.0, zoneout_states: float = 0.0)[source]

Bases: mxnet.gluon.rnn.rnn_cell.ModifierCell

Applies Variational Zoneout on base cell. The implementation follows [GG16]. Variational zoneout uses the same mask across time-steps. It can be applied to RNN outputs, and states. The masks for them are not shared.

The mask is initialized when stepping forward for the first time and will remain the same until .reset() is called. Thus, if using the cell and stepping manually without calling .unroll(), the .reset() should be called after each sequence.

Parameters
  • base_cell – The cell on which to perform variational dropout.

  • zoneout_outputs – The dropout rate for outputs. Won’t apply dropout if it equals 0.

  • zoneout_states – The dropout rate for state inputs on the first state channel. Won’t apply dropout if it equals 0.

hybrid_forward(F, inputs: Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol], states: Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol]) → Tuple[Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol], Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol]][source]

Overrides to construct symbolic graph for this Block.

Parameters
  • x (Symbol or NDArray) – The first input tensor.

  • *args (list of Symbol or list of NDArray) – Additional input tensors.

reset()[source]

Reset before re-using the cell for another graph.