# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import Dict, Optional, Tuple
import mxnet as mx
from mxnet.gluon import HybridBlock
from gluonts.mx import Tensor
[docs]def split_heads(F, x: Tensor, dim_per_head: int, heads: int) -> Tensor:
r"""
Returns a tensor with head dimension folded into batch and last dimension
divided by the number of heads.
Parameters
----------
x
Tensor of shape (batch_size, time_length, dim).
dim_per_head
Dimension per head
heads
Number of heads
Returns
-------
Tensor of shape (batch_size * heads, time_length, dim_per_head).
"""
# (batch_size, time_length, heads, dim_per_head)
x = F.reshape(data=x, shape=(0, -1, heads, dim_per_head))
# (batch_size, heads, time_length, dim/heads)
x = F.transpose(data=x, axes=(0, 2, 1, 3))
# (batch_size * heads, time_length, dim/heads)
return F.reshape(data=x, shape=(-3, -1, dim_per_head))
[docs]def dot_attention(
F,
queries: Tensor,
keys: Tensor,
values: Tensor,
mask: Optional[Tensor] = None,
dropout: float = 0.0,
) -> Tensor:
r"""
Parameters
----------
queries
Attention queries of shape (n, lq, d)
keys
Attention keys of shape (n, lk, d)
values
Attention values of shape (n, lk, dv)
mask
Optional mask tensor
dropout
Dropout rate
Returns
-------
'Context' vectors for each query of shape (n, lq, dv)
"""
# (n, lq, lk)
logits = F.batch_dot(lhs=queries, rhs=keys, transpose_b=True)
if mask is not None:
logits = F.broadcast_add(logits, mask)
probs = F.softmax(logits, axis=-1)
probs = F.Dropout(probs, p=dropout) if dropout > 0.0 else probs
# (n, lq, lk) x (n, lk, dv) -> (n, lq, dv)
return F.batch_dot(lhs=probs, rhs=values)
[docs]def combine_heads(F, x: Tensor, dim_per_head: int, heads: int) -> Tensor:
r"""
Parameters
----------
x
Tensor of shape (batch_size * heads, time_length, dim_per_head)
dim_per_head
Dimension per head
heads
Number of heads
Returns
-------
Tensor of shape (batch_size, time_length, dim)
"""
# (batch_size, heads, time_length, dim_per_head)
x = F.reshape(data=x, shape=(-4, -1, heads, 0, dim_per_head))
# (batch_size, time_length, heads, dim_per_head)
x = F.transpose(x, axes=(0, 2, 1, 3))
# (batch_size, time_length, dim)
return F.reshape(x, shape=(-1, 0, dim_per_head * heads))
[docs]class LayerNormalization(HybridBlock):
"""
Implements layer normalization as proposed in [BKH16]_.
"""
def __init__(
self,
scale_init: str = "ones",
shift_init: str = "zeros",
eps: float = 1e-06,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.scale_init = scale_init
self.shift_init = shift_init
with self.name_scope():
self.lnorm = mx.gluon.nn.LayerNorm(
axis=-1,
gamma_initializer=self.scale_init,
beta_initializer=self.shift_init,
epsilon=eps,
)
[docs] def hybrid_forward(self, F, data: Tensor) -> Tensor:
r"""
Normalizes hidden units of data as follows:
data = scale * (data - mean) / sqrt(var + eps) + shift
Normalization is performed over the last dimension of the input data.
Parameters
----------
data
Data to normalize of shape (d0, ..., dn, num_hidden)
Returns
-------
Normalized inputs of shape: (d0, ..., dn, num_hidden)
"""
return self.lnorm(data)
[docs]class MultiHeadAttentionBase(HybridBlock):
"""
Base class for Multi-head attention.
Parameters
----------
att_dim_in
Attention dimension (number of hidden units)
heads
Number of attention heads
att_dim_out
Output dimension (number of output units)
dropout
Dropout rate on attention scores
"""
def __init__(
self,
att_dim_in: int = 32,
heads: int = 8,
att_dim_out: int = 32,
dropout: float = 0.0,
**kwargs,
) -> None:
super().__init__(**kwargs)
assert (
att_dim_in % heads == 0
), "Number of heads {} must divide attention att_dim_in {}".format(
heads, att_dim_in
)
self.att_dim_in = att_dim_in
self.heads = heads
self.att_dim_out = att_dim_out
self.dropout = dropout
self.dim_per_head = self.att_dim_in // self.heads
with self.name_scope():
self.dense_att = mx.gluon.nn.Dense(
units=self.att_dim_out, flatten=False
)
def _attend(
self,
F,
queries: Tensor,
keys: Tensor,
values: Tensor,
mask: Optional[Tensor] = None,
) -> Tensor:
r"""
Returns context vectors of multi-head dot attention.
Parameters
----------
queries
Queries tensor of shape (batch_size, query_max_length, dim)
keys
Keys tensor of shape (batch_size, memory_max_length, dim)
values
Values tensor of shape (batch_size, memory_max_length, dim)
mask
Returns
-------
Context vectors of shape (batch_size, query_max_length, att_dim_out)
"""
# scale by 1/sqrt(dim_per_head)
queries = queries * (self.dim_per_head**-0.5)
# (batch_size * heads, length, dim/heads)
queries = split_heads(F, queries, self.dim_per_head, self.heads)
keys = split_heads(F, keys, self.dim_per_head, self.heads)
values = split_heads(F, values, self.dim_per_head, self.heads)
# (batch_size * heads, query_max_length, dim_per_head)
contexts = dot_attention(
F, queries, keys, values, mask=mask, dropout=self.dropout
)
# (batch_size, query_max_length, input_dim)
contexts = combine_heads(F, contexts, self.dim_per_head, self.heads)
# contexts: (batch_size, query_max_length, output_dim)
contexts = self.dense_att(contexts)
return contexts
[docs]class MultiHeadSelfAttention(MultiHeadAttentionBase):
r"""
Multi-head self-attention. Independent linear projections of inputs serve
as queries, keys, and values for the attention.
Parameters
----------
att_dim_in
Attention dimension (number of hidden units)
heads
Number of attention heads
att_dim_out
Output dimension (number of output units)
dropout
Dropout rate on attention scores
"""
def __init__(
self,
att_dim_in: int = 32,
heads: int = 8,
att_dim_out: int = 32,
dropout: float = 0.0,
**kwargs,
) -> None:
super().__init__(att_dim_in, heads, att_dim_out, dropout, **kwargs)
with self.name_scope():
self.dense_pre_satt = mx.gluon.nn.Dense(
units=self.att_dim_in * 3, flatten=False
)
[docs] def hybrid_forward(
self,
F,
inputs: Tensor,
mask: Optional[Tensor] = None,
cache: Optional[Dict[str, Optional[Tensor]]] = None,
) -> Tuple[Tensor, Optional[Dict]]:
r"""
Computes multi-head attention on a set of inputs, serving as queries,
keys, and values. If sequence lengths are provided, they will be used
to mask the attention scores. May also use a cache of previously
computed inputs.
Parameters
----------
inputs
Input data of shape (batch_size, max_length, att_dim_in)
mask
Optional tensor to mask attention scores
cache
Optional dictionary of previously computed keys and values
Returns
-------
Tensor
A tensor of shape (batch_size, max_length, att_dim_out)
"""
# Q = K = V -> Q * W_q, K * W_k, V * W_v
# combined: (batch_size, max_length, att_dim_in * 3)
combined = self.dense_pre_satt(inputs)
# split into queries, keys and values
# (batch_size, max_length, att_dim_in)
queries, keys, values = F.split(data=combined, num_outputs=3, axis=2)
if cache is not None:
# append new keys and values to cache, update the cache
keys = cache["k"] = (
keys
if "k" not in cache.keys()
else F.concat(cache["k"], keys, dim=1)
)
values = cache["v"] = (
values
if "v" not in cache.keys()
else F.concat(cache["v"], values, dim=1)
)
return self._attend(F, queries, keys, values, mask), cache
[docs]class MultiHeadAttention(MultiHeadAttentionBase):
r"""
Multi-head attention layer for queries independent from keys/values.
Parameters
----------
att_dim_in
Attention dimension (number of hidden units)
heads
Number of attention heads
att_dim_out
Output dimension (number of output units)
dropout
Dropout rate on attention scores
"""
def __init__(
self,
att_dim_in: int = 32,
heads: int = 8,
att_dim_out: int = 32,
dropout: float = 0.0,
**kwargs,
) -> None:
super().__init__(att_dim_in, heads, att_dim_out, dropout, **kwargs)
with self.name_scope():
self.dense_pre_att_q = mx.gluon.nn.Dense(
units=self.att_dim_in, flatten=False
)
self.dense_pre_att_k = mx.gluon.nn.Dense(
units=self.att_dim_in, flatten=False
)
self.dense_pre_att_v = mx.gluon.nn.Dense(
units=self.att_dim_in, flatten=False
)
[docs] def hybrid_forward(
self, F, queries: Tensor, memory: Tensor, mask: Optional[Tensor] = None
) -> Tensor:
r"""
Computes multi-head attention for queries given a memory tensor. If
sequence lengths are provided, they will be used to mask the attention
scores. A mask tensor may also be used to mask the attention scores.
Returns a tensor of shape (batch_size, max_length, att_dim_out).
Parameters
----------
queries
Queries tensor of shape (batch_size, query_max_length, att_dim_in)
memory
Memory tensor to attend to of shape (batch_size, memory_max_length,
att_dim_in)
mask
Optional tensor to mask attention scores
Returns
-------
Tensor of shape (batch_size, query_seq_len, att_dim_out)
"""
# Q -> Q * W_q
# K = V -> K * W_k, V * W_v
# (batch, query_max_length, att_dim_in)
queries = self.dense_pre_att_q(queries)
# (batch, memory_max_length, att_dim_in)
keys = self.dense_pre_att_k(memory)
# (batch, memory_max_length, att_dim_in)
values = self.dense_pre_att_v(memory)
return self._attend(F, queries, keys, values, mask=mask)