# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import List, Optional, Tuple
import mxnet as mx
import numpy as np
from mxnet import nd
from gluonts.core.component import validated
from gluonts.model.tpp import distribution
from gluonts.model.tpp.distribution.base import TPPDistributionOutput
from gluonts.mx import Tensor
from gluonts.mx.distribution import CategoricalOutput
# noinspection PyAbstractClass
class DeepTPPNetworkBase(mx.gluon.HybridBlock):
"""
Temporal point process model based on a recurrent neural network.
Parameters
----------
num_marks
Number of discrete marks (correlated processes), that are available
in the data.
interval_length
The length of the total time interval that is in the prediction
range. Note that in contrast to discrete-time models in the rest
of GluonTS, the network is trained to predict an interval, in
continuous time.
time_distr_output
Output distribution for the inter-arrival times. Available
distributions can be found in gluonts.model.tpp.distribution.
embedding_dim
Dimension of vector embeddings of marks (used only as input).
num_hidden_dimensions
Number of hidden units in the RNN.
output_scale
Positive scaling applied to the inter-event times. You should provide
this argument if the average inter-arrival time is much larger than 1.
apply_log_to_rnn_inputs
Apply logarithm to inter-event times that are fed into the RNN.
"""
@validated()
def __init__(
self,
num_marks: int,
interval_length: float,
time_distr_output: TPPDistributionOutput = distribution.WeibullOutput(), # noqa: E501
embedding_dim: int = 5,
num_hidden_dimensions: int = 10,
output_scale: Optional[Tensor] = None,
apply_log_to_rnn_inputs: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.num_marks = num_marks
self.interval_length = interval_length
self.rnn_hidden_size = num_hidden_dimensions
self.output_scale = output_scale
self.apply_log_to_rnn_inputs = apply_log_to_rnn_inputs
with self.name_scope():
self.embedding = mx.gluon.nn.Embedding(
input_dim=num_marks, output_dim=embedding_dim
)
self.rnn = mx.gluon.rnn.GRU(
num_hidden_dimensions,
input_size=embedding_dim + 1,
layout="NTC",
)
# Conditional distribution over the inter-arrival times
self.time_distr_output = time_distr_output
self.time_distr_args_proj = self.time_distr_output.get_args_proj()
# Conditional distribution over the marks
if num_marks > 1:
self.mark_distr_output = CategoricalOutput(num_marks)
self.mark_distr_args_proj = (
self.mark_distr_output.get_args_proj()
)
def hybridize(self, active=True, **kwargs):
if active:
raise NotImplementedError(
"DeepTPP blocks do not support hybridization"
)
[docs]class DeepTPPTrainingNetwork(DeepTPPNetworkBase):
# noinspection PyMethodOverriding,PyPep8Naming,PyIncorrectDocstring
[docs] def hybrid_forward(
self,
F,
target: Tensor,
valid_length: Tensor,
**kwargs,
) -> Tensor:
"""
Computes the negative log likelihood loss for the given sequences.
As the model is trained on past (resp. future) or context
(resp. prediction) "intervals" as opposed to fixed-length "sequences",
the number of data points available varies across observations. To
account for this, data is made available to the training network as a
"ragged" tensor. The number of valid entries in each sequence is
provided in a separate variable, :code:`xxx_valid_length`.
Parameters
----------
F
MXNet backend.
target
Tensor with observations.
Shape: (batch_size, past_max_sequence_length, target_dim).
valid_length
The `valid_length` or number of valid entries in the past_target
Tensor. Shape: (batch_size,)
Returns
-------
Tensor
Loss tensor. Shape: (batch_size,).
"""
if F is mx.sym:
raise ValueError(
"The DeepTPP model currently doesn't support hybridization."
)
batch_size = target.shape[0]
# IMPORTANT: We add an additional zero at the end of each sequence
# It will be used to store the time until the end of the interval
target = F.concat(target, F.zeros((batch_size, 1, 2)), dim=1)
# (N, T + 1, 2)
ia_times, marks = F.split(
target, num_outputs=2, axis=-1
) # inter-arrival times, marks
marks = marks.squeeze(axis=-1) # (N, T + 1)
valid_length = valid_length.reshape(-1).astype(
ia_times.dtype
) # make sure shape is (batch_size,)
if self.apply_log_to_rnn_inputs:
ia_times_input = ia_times.clip(1e-8, np.inf).log()
else:
ia_times_input = ia_times
rnn_input = F.concat(ia_times_input, self.embedding(marks), dim=-1)
rnn_output = self.rnn(rnn_input) # (N, T + 1, H)
rnn_init_state = F.zeros([batch_size, 1, self.rnn_hidden_size])
history_emb = F.slice_axis(
F.concat(rnn_init_state, rnn_output, dim=1),
axis=1,
begin=0,
end=-1,
) # (N, T + 1, H)
# Augment ia_times by adding the time remaining until interval_length
# Afterwards, each row of ia_times will sum up to interval_length
ia_times = ia_times.squeeze(axis=-1) # (N, T + 1)
time_remaining = self.interval_length - ia_times.sum(-1) # (N)
# Equivalent to ia_times[F.arange(N), valid_length] = time_remaining
indices = F.stack(F.arange(batch_size), valid_length)
time_remaining_tensor = F.scatter_nd(
time_remaining, indices, ia_times.shape
)
ia_times_aug = ia_times + time_remaining_tensor
time_distr_args = self.time_distr_args_proj(history_emb)
time_distr = self.time_distr_output.distribution(
time_distr_args, scale=self.output_scale
)
log_intensity = time_distr.log_intensity(ia_times_aug) # (N, T + 1)
log_survival = time_distr.log_survival(ia_times_aug) # (N, T + 1)
if self.num_marks > 1:
mark_distr_args = self.mark_distr_args_proj(history_emb)
mark_distr = self.mark_distr_output.distribution(mark_distr_args)
log_intensity = log_intensity + mark_distr.log_prob(marks)
def _mask(x, sequence_length):
return F.SequenceMask(
data=x,
sequence_length=sequence_length,
axis=1,
use_sequence_length=True,
)
log_likelihood = F.sum(
(
_mask(log_intensity, valid_length)
+ _mask(log_survival, valid_length + 1)
),
axis=-1,
) # (N)
return -log_likelihood
[docs]class DeepTPPPredictionNetwork(DeepTPPNetworkBase):
@validated()
def __init__(
self,
prediction_interval_length: float,
num_parallel_samples: int = 100,
*args,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self.num_parallel_samples = num_parallel_samples
self.prediction_interval_length = prediction_interval_length
# noinspection PyMethodOverriding,PyPep8Naming,PyIncorrectDocstring
[docs] def hybrid_forward(
self,
F,
past_target: Tensor,
past_valid_length: Tensor,
) -> Tuple[Tensor, Tensor]:
"""
Draw forward samples from the model. At each step, we sample an inter-
event time and feed it into the RNN to obtain the parameters for the
next distribution over the inter-event time.
Parameters
----------
F
MXNet backend.
past_target
Tensor with past observations.
Shape: (batch_size, context_length, target_dim). Has to comply
with :code:`self.context_interval_length`.
past_valid_length
The `valid_length` or number of valid entries in the past_target
Tensor. Shape: (batch_size,)
Returns
-------
sampled_target: Tensor
Predicted inter-event times and marks.
Shape: (samples, batch_size, max_prediction_length, target_dim).
sampled_valid_length: Tensor
The number of valid entries in the time axis of each sample.
Shape (samples, batch_size)
"""
# Variable-length generation (while t < t_max) is a potential problem
if F is mx.sym:
raise ValueError(
"The DeepTPP model currently doesn't support hybridization."
)
assert (
past_target.shape[-1] == 2
), "TPP data should have two target_dim, interarrival times and marks"
batch_size = past_target.shape[0]
# condition the prediction network on the past events
past_ia_times, past_marks = F.split(
past_target, num_outputs=2, axis=-1
)
past_valid_length = past_valid_length.reshape(-1).astype(
past_ia_times.dtype
)
if self.apply_log_to_rnn_inputs:
past_ia_times_input = past_ia_times.clip(1e-8, np.inf).log()
else:
past_ia_times_input = past_ia_times
rnn_input = F.concat(
past_ia_times_input,
self.embedding(past_marks.squeeze(axis=-1)),
dim=-1,
)
rnn_output = self.rnn(rnn_input) # (N, T, H)
rnn_init_state = F.zeros([batch_size, 1, self.rnn_hidden_size])
past_history_emb = F.concat(
rnn_init_state, rnn_output, dim=1
) # (N, T + 1, H)
# Select the history embedding after the last event in the past
indices = F.stack(F.arange(batch_size), past_valid_length)
history_emb = F.gather_nd(past_history_emb, indices) # (N, H)
num_total_samples = self.num_parallel_samples * batch_size
history_emb = history_emb.expand_dims(0).repeat(
self.num_parallel_samples, axis=0
) # (S, N, H)
history_emb = history_emb.reshape(
[num_total_samples, self.rnn_hidden_size]
) # (S * N, H)
sampled_ia_times_list: List[nd.NDArray] = []
sampled_marks_list: List[nd.NDArray] = []
arrival_times = F.zeros([num_total_samples])
# Time from the last observed event until the past interval end
past_time_elapsed = past_ia_times.squeeze(axis=-1).sum(-1)
past_time_remaining = self.interval_length - past_time_elapsed # (N)
past_time_remaining_repeat = (
past_time_remaining.expand_dims(0)
.repeat(self.num_parallel_samples, axis=0)
.reshape([num_total_samples])
) # (S * N)
first_step = True
while F.sum(arrival_times < self.prediction_interval_length) > 0:
# Sample the next inter-arrival time
time_distr_args = self.time_distr_args_proj(history_emb)
time_distr = self.time_distr_output.distribution(
time_distr_args,
scale=self.output_scale,
)
if first_step:
# Time from the last event until the next event
next_ia_times = time_distr.sample(
lower_bound=past_time_remaining_repeat
)
# Time from the prediction interval start until the next event
clipped_ia_times = next_ia_times - past_time_remaining_repeat
sampled_ia_times_list.append(clipped_ia_times)
arrival_times = arrival_times + clipped_ia_times
first_step = False
else:
next_ia_times = time_distr.sample()
sampled_ia_times_list.append(next_ia_times)
arrival_times = arrival_times + next_ia_times
# Sample the next marks
if self.num_marks > 1:
mark_distr_args = self.mark_distr_args_proj(history_emb)
next_marks = self.mark_distr_output.distribution(
mark_distr_args
).sample()
else:
next_marks = F.zeros([num_total_samples])
sampled_marks_list.append(next_marks)
# Pass the generated ia_times & marks into the RNN to obtain
# the next history embedding
if self.apply_log_to_rnn_inputs:
next_ia_times_input = next_ia_times.clip(1e-8, np.inf).log()
else:
next_ia_times_input = next_ia_times
rnn_input = F.concat(
next_ia_times_input.expand_dims(-1),
self.embedding(next_marks),
dim=-1,
).expand_dims(1)
history_emb = self.rnn(rnn_input).squeeze(axis=1) # (S * N, C)
sampled_ia_times = F.stack(*sampled_ia_times_list, axis=-1)
sampled_marks = F.stack(*sampled_marks_list, axis=-1).astype("float32")
sampled_valid_length = F.sum(
F.cumsum(sampled_ia_times, axis=1)
< self.prediction_interval_length,
axis=-1,
)
def _mask(x, sequence_length):
return F.SequenceMask(
data=x,
sequence_length=sequence_length,
axis=1,
use_sequence_length=True,
)
sampled_ia_times = _mask(sampled_ia_times, sampled_valid_length)
sampled_marks = _mask(sampled_marks, sampled_valid_length)
sampled_ia_times = sampled_ia_times.reshape(
[self.num_parallel_samples, batch_size, -1]
)
sampled_marks = sampled_marks.reshape(
[self.num_parallel_samples, batch_size, -1]
)
sampled_valid_length = sampled_valid_length.reshape(
[self.num_parallel_samples, batch_size]
)
sampled_target = F.stack(sampled_ia_times, sampled_marks, axis=-1)
return sampled_target, sampled_valid_length