# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from functools import partial
from mxnet.gluon import HybridBlock
from gluonts.core.component import validated
from gluonts.dataset.common import Dataset
from gluonts.dataset.loader import (
DataLoader,
TrainDataLoader,
ValidationDataLoader,
)
from gluonts.model.predictor import Predictor
from gluonts.model.tpp import PointProcessGluonPredictor
from gluonts.model.tpp.distribution import TPPDistributionOutput, WeibullOutput
from gluonts.mx.batchify import batchify
from gluonts.mx.model.estimator import GluonEstimator
from gluonts.mx.trainer import Trainer
from gluonts.mx.util import get_hybrid_forward_input_names
from gluonts.transform import (
Chain,
ContinuousTimeInstanceSplitter,
ContinuousTimePointSampler,
ContinuousTimePredictionSampler,
ContinuousTimeUniformSampler,
RenameFields,
SelectFields,
Transformation,
)
from ._network import DeepTPPPredictionNetwork, DeepTPPTrainingNetwork
[docs]class DeepTPPEstimator(GluonEstimator):
r"""
DeepTPP is a multivariate point process model based on an RNN.
After each event :math:`(\tau_i, m_i)`, we feed the inter-arrival time
:math:`\tau_i` and the mark :math:`m_i` into the RNN. The state :math:`h_i`
of the RNN represents the history embedding. We use :math:`h_i` to
parametrize the distribution over the next inter-arrival time
:math:`p(\tau_{i+1} | h_i)` and the distribution over the next mark
:math:`p(m_{i+1} | h_i)`. The distribution over the marks is always
categorical, but different choices are possible for the distribution over
inter-arrival times - see :code:`gluonts.model.tpp.distribution`.
The model is a generalization of the approaches described in [DDT+16]_,
[TWJ19]_ and [SBG20]_.
References
----------
Parameters
----------
prediction_interval_length
The length of the interval (in continuous time) that the estimator will
predict at prediction time.
context_interval_length
The length of intervals (in continuous time) that the estimator will be
trained with.
num_marks
The number of marks (distinct processes), i.e., the cardinality of the
mark set.
time_distr_output
TPPDistributionOutput for the distribution over the inter-arrival
times. See :code:`gluonts.model.tpp.distribution` for possible choices.
embedding_dim
The dimension of vector embeddings for marks (used as input to the
GRU).
trainer
:code:`gluonts.mx.trainer.Trainer` object which will be used to train
the estimator. Note that :code:`Trainer(hybridize=False)` must be set
as :code:`DeepTPPEstimator` currently does not support hybridization.
num_hidden_dimensions
Number of hidden units in the GRU network.
num_parallel_samples
The number of samples returned by the :code:`Predictor` learned.
num_training_instances
The number of training instances to be sampled from each entry in the
data set provided during training.
freq
Similar to the :code:`freq` of discrete-time models, specifies the time
unit by which inter-arrival times are given.
batch_size
The size of the batches to be used training and prediction.
"""
@validated()
def __init__(
self,
prediction_interval_length: float,
context_interval_length: float,
num_marks: int,
time_distr_output: TPPDistributionOutput = WeibullOutput(),
embedding_dim: int = 5,
trainer: Trainer = Trainer(hybridize=False),
num_hidden_dimensions: int = 10,
num_parallel_samples: int = 100,
num_training_instances: int = 100,
freq: str = "H",
batch_size: int = 32,
) -> None:
assert (
not trainer.hybridize
), "DeepTPP currently only supports the non-hybridized training"
super().__init__(trainer=trainer, batch_size=batch_size)
assert (
prediction_interval_length > 0
), "The value of `prediction_interval_length` should be > 0"
assert (
context_interval_length is None or context_interval_length > 0
), "The value of `context_interval_length` should be > 0"
assert (
num_hidden_dimensions > 0
), "The value of `num_hidden_dimensions` should be > 0"
assert (
num_parallel_samples > 0
), "The value of `num_parallel_samples` should be > 0"
assert num_marks > 0, "The value of `num_marks` should be > 0"
assert (
num_training_instances > 0
), "The value of `num_training_instances` should be > 0"
self.num_hidden_dimensions = num_hidden_dimensions
self.prediction_interval_length = prediction_interval_length
self.context_interval_length = (
context_interval_length
if context_interval_length is not None
else prediction_interval_length
)
self.num_marks = num_marks
self.time_distr_output = time_distr_output
self.embedding_dim = embedding_dim
self.num_parallel_samples = num_parallel_samples
self.num_training_instances = num_training_instances
self.freq = freq
def _create_instance_splitter(self, mode: str):
assert mode in ["training", "validation", "test"]
instance_sampler = {
"training": ContinuousTimeUniformSampler(
num_instances=self.num_training_instances,
min_past=self.context_interval_length,
min_future=self.prediction_interval_length,
),
"validation": ContinuousTimePredictionSampler(
allow_empty_interval=True,
min_past=self.context_interval_length,
min_future=self.prediction_interval_length,
),
"test": ContinuousTimePredictionSampler(
min_past=self.context_interval_length,
allow_empty_interval=False,
),
}[mode]
assert isinstance(instance_sampler, ContinuousTimePointSampler)
return Chain(
[
ContinuousTimeInstanceSplitter(
past_interval_length=self.context_interval_length,
future_interval_length=self.prediction_interval_length,
instance_sampler=instance_sampler,
),
RenameFields(
{
"past_target": "target",
"past_valid_length": "valid_length",
}
),
]
)
[docs] def create_training_data_loader(
self,
data: Dataset,
**kwargs,
) -> DataLoader:
input_names = get_hybrid_forward_input_names(DeepTPPTrainingNetwork)
instance_splitter = self._create_instance_splitter("training")
return TrainDataLoader(
dataset=data,
transform=instance_splitter + SelectFields(input_names),
batch_size=self.batch_size,
stack_fn=partial(batchify, ctx=self.trainer.ctx, dtype=self.dtype),
**kwargs,
)
[docs] def create_validation_data_loader(
self,
data: Dataset,
**kwargs,
) -> DataLoader:
input_names = get_hybrid_forward_input_names(DeepTPPTrainingNetwork)
instance_splitter = self._create_instance_splitter("validation")
return ValidationDataLoader(
dataset=data,
transform=instance_splitter + SelectFields(input_names),
batch_size=self.batch_size,
stack_fn=partial(batchify, ctx=self.trainer.ctx, dtype=self.dtype),
)
[docs] def create_training_network(self) -> HybridBlock:
return DeepTPPTrainingNetwork(
num_marks=self.num_marks,
time_distr_output=self.time_distr_output,
interval_length=self.prediction_interval_length,
embedding_dim=self.embedding_dim,
num_hidden_dimensions=self.num_hidden_dimensions,
)
[docs] def create_predictor(
self,
transformation: Transformation,
trained_network: DeepTPPTrainingNetwork,
) -> Predictor:
prediction_splitter = self._create_instance_splitter("test")
prediction_network = DeepTPPPredictionNetwork(
num_marks=self.num_marks,
prediction_interval_length=self.prediction_interval_length,
interval_length=self.context_interval_length,
embedding_dim=self.embedding_dim,
num_hidden_dimensions=self.num_hidden_dimensions,
time_distr_output=trained_network.time_distr_output,
params=trained_network.collect_params(),
num_parallel_samples=self.num_parallel_samples,
)
return PointProcessGluonPredictor(
input_names=["target", "valid_length"],
prediction_net=prediction_network,
batch_size=self.batch_size,
prediction_interval_length=self.prediction_interval_length,
freq=self.freq,
ctx=self.trainer.ctx,
input_transform=transformation + prediction_splitter,
)