# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from mxnet.gluon import HybridBlock, rnn
from gluonts.core.component import validated
from gluonts.mx import Tensor
[docs]class RNN(HybridBlock):
"""
Defines an RNN block.
Parameters
----------
mode
type of the RNN. Can be either: rnn_relu (RNN with relu activation),
rnn_tanh, (RNN with tanh activation), lstm or gru.
num_hidden
number of units per hidden layer.
num_layers
number of hidden layers.
bidirectional
toggle use of bi-directional RNN as encoder.
"""
@validated()
def __init__(
self,
mode: str,
num_hidden: int,
num_layers: int,
bidirectional: bool = False,
**kwargs,
):
super().__init__(**kwargs)
with self.name_scope():
if mode == "rnn_relu":
self.rnn = rnn.RNN(
num_hidden,
num_layers,
bidirectional=bidirectional,
activation="relu",
layout="NTC",
)
elif mode == "rnn_tanh":
self.rnn = rnn.RNN(
num_hidden,
num_layers,
bidirectional=bidirectional,
layout="NTC",
)
elif mode == "lstm":
self.rnn = rnn.LSTM(
num_hidden,
num_layers,
bidirectional=bidirectional,
layout="NTC",
)
elif mode == "gru":
self.rnn = rnn.GRU(
num_hidden,
num_layers,
bidirectional=bidirectional,
layout="NTC",
)
else:
raise ValueError(
"Invalid mode %s. Options are rnn_relu, rnn_tanh, lstm,"
" and gru " % mode
)
[docs] def hybrid_forward(self, F, inputs: Tensor) -> Tensor: # NTC in, NTC out
"""
Parameters
----------
F
A module that can either refer to the Symbol API or the NDArray
API in MXNet.
inputs
input tensor with shape (batch_size, num_timesteps, num_dimensions)
Returns
-------
Tensor
rnn output with shape (batch_size, num_timesteps, num_dimensions)
"""
return self.rnn(inputs)