gluonts.mx.block.rnn module¶
-
class
gluonts.mx.block.rnn.
RNN
(mode: str, num_hidden: int, num_layers: int, bidirectional: bool = False, **kwargs)[source]¶ Bases:
mxnet.gluon.block.HybridBlock
Defines an RNN block.
- Parameters
mode – type of the RNN. Can be either: rnn_relu (RNN with relu activation), rnn_tanh, (RNN with tanh activation), lstm or gru.
num_hidden – number of units per hidden layer.
num_layers – number of hidden layers.
bidirectional – toggle use of bi-directional RNN as encoder.
-
hybrid_forward
(F, inputs: Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol]) → Union[mxnet.ndarray.ndarray.NDArray, mxnet.symbol.symbol.Symbol][source]¶ - Parameters
F – A module that can either refer to the Symbol API or the NDArray API in MXNet.
inputs – input tensor with shape (batch_size, num_timesteps, num_dimensions)
- Returns
rnn output with shape (batch_size, num_timesteps, num_dimensions)
- Return type
Tensor