# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import Optional, Tuple
import mxnet as mx
import numpy as np
from gluonts.core.component import validated
from gluonts.mx import Tensor
from gluonts.mx.linalg_util import jitter_cholesky
from gluonts.mx.util import _broadcast_param, make_nd_diag
from . import Distribution, Gaussian, MultivariateGaussian
from .distribution import getF
[docs]class ParameterBounds:
@validated()
def __init__(self, lower, upper) -> None:
assert (
lower <= upper
), "lower bound should be smaller or equal to upper bound"
self.lower = lower
self.upper = upper
def _safe_split(x, num_outputs, axis, squeeze_axis, *args, **kwargs):
"""
A type-stable wrapper around mx.nd.split.
Currently mx.nd.split behaves weirdly if num_outputs=1:
a = mx.nd.ones(shape=(1, 1, 2))
l = a.split(axis=1, num_outputs=1, squeeze_axis=False)
type(l) # mx.NDArray
l.shape # (1, 1, 2)
Compare that with the case num_outputs=2:
a = mx.nd.ones(shape=(1, 2, 2))
l = a.split(axis=1, num_outputs=2, squeeze_axis=False)
type(l) # list
len(l) # 2
l[0].shape # (1, 1, 2)
This wrapper makes the behavior consistent by always returning a list
of length num_outputs, whose elements will have one less axis than x
in case x.shape[axis]==num_outputs and squeeze_axis==True, and the same
number of axes as x otherwise.
"""
if num_outputs > 1:
return x.split(
axis=axis,
num_outputs=num_outputs,
squeeze_axis=squeeze_axis,
*args,
**kwargs,
)
return [x.squeeze(axis=axis)] if squeeze_axis else [x]
[docs]class LDS(Distribution):
r"""
Implements Linear Dynamical System (LDS) as a distribution.
The LDS is given by
.. math::
z_t = A_t l_{t-1} + b_t + \epsilon_t \\
l_t = C_t l_{t-1} + g_t \nu
where
.. math::
\epsilon_t = N(0, S_v) \\
\nu = N(0, 1)
:math:`A_t`, :math:`C_t` and :math:`g_t` are the emission, transition and
innovation coefficients respectively. The residual terms are denoted
by :math:`b_t`.
The target :math:`z_t` can be :math:`d`-dimensional in which case
.. math::
A_t \in R^{d \times h}, b_t \in R^{d}, C_t \in R^{h \times h}, g_t \in R^{h}
where :math:`h` is dimension of the latent state.
Parameters
----------
emission_coeff
Tensor of shape (batch_size, seq_length, obs_dim, latent_dim)
transition_coeff
Tensor of shape (batch_size, seq_length, latent_dim, latent_dim)
innovation_coeff
Tensor of shape (batch_size, seq_length, latent_dim)
noise_std
Tensor of shape (batch_size, seq_length, obs_dim)
residuals
Tensor of shape (batch_size, seq_length, obs_dim)
prior_mean
Tensor of shape (batch_size, latent_dim)
prior_cov
Tensor of shape (batch_size, latent_dim, latent_dim)
latent_dim
Dimension of the latent state
output_dim
Dimension of the output
seq_length
Sequence length
F
""" # noqa: E501
@validated()
def __init__(
self,
emission_coeff: Tensor,
transition_coeff: Tensor,
innovation_coeff: Tensor,
noise_std: Tensor,
residuals: Tensor,
prior_mean: Tensor,
prior_cov: Tensor,
latent_dim: int,
output_dim: int,
seq_length: int,
) -> None:
self.latent_dim = latent_dim
self.output_dim = output_dim
self.seq_length = seq_length
# Split coefficients along time axis for easy access
# emission_coef[t]: (batch_size, obs_dim, latent_dim)
self.emission_coeff = _safe_split(
emission_coeff,
axis=1,
num_outputs=self.seq_length,
squeeze_axis=True,
)
# innovation_coef[t]: (batch_size, latent_dim)
self.innovation_coeff = _safe_split(
innovation_coeff,
axis=1,
num_outputs=self.seq_length,
squeeze_axis=False,
)
# transition_coeff: (batch_size, latent_dim, latent_dim)
self.transition_coeff = _safe_split(
transition_coeff,
axis=1,
num_outputs=self.seq_length,
squeeze_axis=True,
)
# noise_std[t]: (batch_size, obs_dim)
self.noise_std = _safe_split(
noise_std, axis=1, num_outputs=self.seq_length, squeeze_axis=True
)
# residuals[t]: (batch_size, obs_dim)
self.residuals = _safe_split(
residuals, axis=1, num_outputs=self.seq_length, squeeze_axis=True
)
self.prior_mean = prior_mean
self.prior_cov = prior_cov
@property
def F(self):
return getF(self.prior_mean)
@property
def batch_shape(self) -> Tuple:
return self.emission_coeff[0].shape[:1] + (self.seq_length,)
@property
def event_shape(self) -> Tuple:
return (self.output_dim,)
@property
def event_dim(self) -> int:
return 2
[docs] def log_prob(
self,
x: Tensor,
scale: Optional[Tensor] = None,
observed: Optional[Tensor] = None,
):
"""
Compute the log probability of observations.
This method also returns the final state of the system.
Parameters
----------
x
Observations, shape (batch_size, seq_length, output_dim)
scale
Scale of each sequence in x, shape (batch_size, output_dim)
observed
Flag tensor indicating which observations are genuine (1.0) and
which are missing (0.0)
Returns
-------
Tensor
Log probabilities, shape (batch_size, seq_length)
Tensor
Final mean, shape (batch_size, latent_dim)
Tensor
Final covariance, shape (batch_size, latent_dim, latent_dim)
"""
if scale is not None:
x = self.F.broadcast_div(x, scale.expand_dims(axis=1))
# TODO: Based on form of the prior decide to do either filtering
# or residual-sum-of-squares
log_p, final_mean, final_cov = self.kalman_filter(x, observed)
if scale is not None:
F = self.F
# log_abs_det_jac: sum over all output dimensions.
ladj = -F.sum(F.log(F.abs(scale)), axis=-1, keepdims=True)
# Sum `ladj` over all time steps.
log_p = F.broadcast_add(log_p, ladj)
return log_p, final_mean, final_cov
[docs] def kalman_filter(
self, targets: Tensor, observed: Tensor
) -> Tuple[Tensor, ...]:
"""
Performs Kalman filtering given observations.
Parameters
----------
targets
Observations, shape (batch_size, seq_length, output_dim)
observed
Flag tensor indicating which observations are genuine (1.0) and
which are missing (0.0)
Returns
-------
Tensor
Log probabilities, shape (batch_size, seq_length)
Tensor
Mean of p(l_T | l_{T-1}), where T is seq_length, with shape
(batch_size, latent_dim)
Tensor
Covariance of p(l_T | l_{T-1}), where T is seq_length, with shape
(batch_size, latent_dim, latent_dim)
"""
F = self.F
# targets[t]: (batch_size, obs_dim)
targets = _safe_split(
targets, axis=1, num_outputs=self.seq_length, squeeze_axis=True
)
log_p_seq = []
mean = self.prior_mean
cov = self.prior_cov
observed = (
_safe_split(
observed,
axis=1,
num_outputs=self.seq_length,
squeeze_axis=True,
)
if observed is not None
else None
)
for t in range(self.seq_length):
# Compute the filtered distribution
# p(l_t | z_1, ..., z_{t + 1})
# and log - probability
# log p(z_t | z_0, z_{t - 1})
filtered_mean, filtered_cov, log_p = kalman_filter_step(
F,
target=targets[t],
prior_mean=mean,
prior_cov=cov,
emission_coeff=self.emission_coeff[t],
residual=self.residuals[t],
noise_std=self.noise_std[t],
latent_dim=self.latent_dim,
output_dim=self.output_dim,
)
log_p_seq.append(log_p.expand_dims(axis=1))
# Mean of p(l_{t+1} | l_t)
mean = F.linalg_gemm2(
self.transition_coeff[t],
(
filtered_mean.expand_dims(axis=-1)
if observed is None
else F.where(
observed[t], x=filtered_mean, y=mean
).expand_dims(axis=-1)
),
).squeeze(axis=-1)
# Covariance of p(l_{t+1} | l_t)
cov = F.linalg_gemm2(
self.transition_coeff[t],
F.linalg_gemm2(
(
filtered_cov
if observed is None
else F.where(observed[t], x=filtered_cov, y=cov)
),
self.transition_coeff[t],
transpose_b=True,
),
) + F.linalg_gemm2(
self.innovation_coeff[t],
self.innovation_coeff[t],
transpose_a=True,
)
# Return sequence of log likelihoods, as well as
# final mean and covariance of p(l_T | l_{T-1} where T is seq_length
return F.concat(*log_p_seq, dim=1), mean, cov
[docs] def sample(
self, num_samples: Optional[int] = None, scale: Optional[Tensor] = None
) -> Tensor:
r"""
Generates samples from the LDS: p(z_1, z_2, \ldots, z_{`seq_length`}).
Parameters
----------
num_samples
Number of samples to generate
scale
Scale of each sequence in x, shape (batch_size, output_dim)
Returns
-------
Tensor
Samples, shape (num_samples, batch_size, seq_length, output_dim)
"""
F = self.F
# Note on shapes: here we work with tensors of the following shape
# in each time step t: (num_samples, batch_size, dim, dim),
# where dim can be obs_dim or latent_dim or a constant 1 to facilitate
# generalized matrix multiplication (gemm2)
# Sample observation noise for all time steps
# noise_std: (batch_size, seq_length, obs_dim, 1)
noise_std = F.stack(*self.noise_std, axis=1).expand_dims(axis=-1)
# samples_eps_obs[t]: (num_samples, batch_size, obs_dim, 1)
samples_eps_obs = _safe_split(
Gaussian(noise_std.zeros_like(), noise_std).sample(num_samples),
axis=-3,
num_outputs=self.seq_length,
squeeze_axis=True,
)
# Sample standard normal for all time steps
# samples_eps_std_normal[t]: (num_samples, batch_size, obs_dim, 1)
samples_std_normal = _safe_split(
Gaussian(noise_std.zeros_like(), noise_std.ones_like()).sample(
num_samples
),
axis=-3,
num_outputs=self.seq_length,
squeeze_axis=True,
)
# Sample the prior state.
# samples_lat_state: (num_samples, batch_size, latent_dim, 1)
# The prior covariance is observed to be slightly negative definite
# whenever there is excessive zero padding at the beginning of the time
# series. We add positive tolerance to the diagonal to avoid numerical
# issues. Note that `jitter_cholesky` adds positive tolerance only if
# the decomposition without jitter fails.
state = MultivariateGaussian(
self.prior_mean,
jitter_cholesky(
F, self.prior_cov, self.latent_dim, float_type=np.float32
),
)
samples_lat_state = state.sample(num_samples).expand_dims(axis=-1)
samples_seq = []
for t in range(self.seq_length):
# Expand all coefficients to include samples in axis 0
# emission_coeff_t: (num_samples, batch_size, obs_dim, latent_dim)
# transition_coeff_t:
# (num_samples, batch_size, latent_dim, latent_dim)
# innovation_coeff_t: (num_samples, batch_size, 1, latent_dim)
emission_coeff_t, transition_coeff_t, innovation_coeff_t = (
(
_broadcast_param(coeff, axes=[0], sizes=[num_samples])
if num_samples is not None
else coeff
)
for coeff in [
self.emission_coeff[t],
self.transition_coeff[t],
self.innovation_coeff[t],
]
)
# Expand residuals as well
# residual_t: (num_samples, batch_size, obs_dim, 1)
residual_t = (
_broadcast_param(
self.residuals[t].expand_dims(axis=-1),
axes=[0],
sizes=[num_samples],
)
if num_samples is not None
else self.residuals[t].expand_dims(axis=-1)
)
# (num_samples, batch_size, 1, obs_dim)
samples_t = (
F.linalg_gemm2(emission_coeff_t, samples_lat_state)
+ residual_t
+ samples_eps_obs[t]
)
samples_t = (
samples_t.swapaxes(dim1=2, dim2=3)
if num_samples is not None
else samples_t.swapaxes(dim1=1, dim2=2)
)
samples_seq.append(samples_t)
# sample next state: (num_samples, batch_size, latent_dim, 1)
samples_lat_state = F.linalg_gemm2(
transition_coeff_t, samples_lat_state
) + F.linalg_gemm2(
innovation_coeff_t, samples_std_normal[t], transpose_a=True
)
# (num_samples, batch_size, seq_length, obs_dim)
samples = F.concat(*samples_seq, dim=-2)
return (
samples
if scale is None
else F.broadcast_mul(
samples,
(
scale.expand_dims(axis=1).expand_dims(axis=0)
if num_samples is not None
else scale.expand_dims(axis=1)
),
)
)
[docs] def sample_marginals(
self, num_samples: Optional[int] = None, scale: Optional[Tensor] = None
) -> Tensor:
r"""
Generates samples from the marginals p(z_t),
t = 1, \ldots, `seq_length`.
Parameters
----------
num_samples
Number of samples to generate
scale
Scale of each sequence in x, shape (batch_size, output_dim)
Returns
-------
Tensor
Samples, shape (num_samples, batch_size, seq_length, output_dim)
"""
F = self.F
state_mean = self.prior_mean.expand_dims(axis=-1)
state_cov = self.prior_cov
output_mean_seq = []
output_cov_seq = []
for t in range(self.seq_length):
# compute and store observation mean at time t
output_mean = F.linalg_gemm2(
self.emission_coeff[t], state_mean
) + self.residuals[t].expand_dims(axis=-1)
output_mean_seq.append(output_mean)
# compute and store observation cov at time t
output_cov = F.linalg_gemm2(
self.emission_coeff[t],
F.linalg_gemm2(
state_cov, self.emission_coeff[t], transpose_b=True
),
) + make_nd_diag(
F=F, x=self.noise_std[t] * self.noise_std[t], d=self.output_dim
)
output_cov_seq.append(output_cov.expand_dims(axis=1))
state_mean = F.linalg_gemm2(self.transition_coeff[t], state_mean)
state_cov = F.linalg_gemm2(
self.transition_coeff[t],
F.linalg_gemm2(
state_cov, self.transition_coeff[t], transpose_b=True
),
) + F.linalg_gemm2(
self.innovation_coeff[t],
self.innovation_coeff[t],
transpose_a=True,
)
output_mean = F.concat(*output_mean_seq, dim=1)
output_cov = F.concat(*output_cov_seq, dim=1)
L = F.linalg_potrf(output_cov)
output_distribution = MultivariateGaussian(output_mean, L)
samples = output_distribution.sample(num_samples=num_samples)
return (
samples
if scale is None
else F.broadcast_mul(samples, scale.expand_dims(axis=1))
)
[docs]class LDSArgsProj(mx.gluon.HybridBlock):
def __init__(
self,
output_dim: int,
noise_std_bounds: ParameterBounds,
innovation_bounds: ParameterBounds,
) -> None:
super().__init__()
self.output_dim = output_dim
self.dense_noise_std = mx.gluon.nn.Dense(
units=1, flatten=False, activation="sigmoid"
)
self.dense_innovation = mx.gluon.nn.Dense(
units=1, flatten=False, activation="sigmoid"
)
self.dense_residual = mx.gluon.nn.Dense(
units=output_dim, flatten=False
)
self.innovation_bounds = innovation_bounds
self.noise_std_bounds = noise_std_bounds
[docs] def hybrid_forward(self, F, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
noise_std = (
self.dense_noise_std(x)
* (self.noise_std_bounds.upper - self.noise_std_bounds.lower)
+ self.noise_std_bounds.lower
)
innovation = (
self.dense_innovation(x)
* (self.innovation_bounds.upper - self.innovation_bounds.lower)
+ self.innovation_bounds.lower
)
residual = self.dense_residual(x)
return noise_std, innovation, residual
[docs]def kalman_filter_step(
F,
target: Tensor,
prior_mean: Tensor,
prior_cov: Tensor,
emission_coeff: Tensor,
residual: Tensor,
noise_std: Tensor,
latent_dim: int,
output_dim: int,
):
"""
One step of the Kalman filter.
This function computes the filtered state (mean and covariance) given the
linear system coefficients the prior state (mean and variance),
as well as observations.
Parameters
----------
F
target
Observations of the system output, shape (batch_size, output_dim)
prior_mean
Prior mean of the latent state, shape (batch_size, latent_dim)
prior_cov
Prior covariance of the latent state, shape
(batch_size, latent_dim, latent_dim)
emission_coeff
Emission coefficient, shape (batch_size, output_dim, latent_dim)
residual
Residual component, shape (batch_size, output_dim)
noise_std
Standard deviation of the output noise, shape (batch_size, output_dim)
latent_dim
Dimension of the latent state vector
Returns
-------
Tensor
Filtered_mean, shape (batch_size, latent_dim)
Tensor
Filtered_covariance, shape (batch_size, latent_dim, latent_dim)
Tensor
Log probability, shape (batch_size, )
"""
# output_mean: mean of the target (batch_size, obs_dim)
output_mean = F.linalg_gemm2(
emission_coeff, prior_mean.expand_dims(axis=-1)
).squeeze(axis=-1)
# noise covariance
noise_cov = make_nd_diag(F=F, x=noise_std * noise_std, d=output_dim)
S_hh_x_A_tr = F.linalg_gemm2(prior_cov, emission_coeff, transpose_b=True)
# covariance of the target
output_cov = F.linalg_gemm2(emission_coeff, S_hh_x_A_tr) + noise_cov
# compute the Cholesky decomposition output_cov = LL^T
L_output_cov = F.linalg_potrf(output_cov)
# Compute Kalman gain matrix K:
# K = S_hh X with X = A^T output_cov^{-1}
# We have X = A^T output_cov^{-1} => X output_cov = A^T => X LL^T = A^T
# We can thus obtain X by solving two linear systems involving L
kalman_gain = F.linalg_trsm(
L_output_cov,
F.linalg_trsm(
L_output_cov, S_hh_x_A_tr, rightside=True, transpose=True
),
rightside=True,
)
# compute the error
target_minus_residual = target - residual
delta = target_minus_residual - output_mean
# filtered estimates
filtered_mean = prior_mean.expand_dims(axis=-1) + F.linalg_gemm2(
kalman_gain, delta.expand_dims(axis=-1)
)
filtered_mean = filtered_mean.squeeze(axis=-1)
# Joseph's symmetrized update for covariance:
ImKA = F.broadcast_sub(
F.eye(latent_dim), F.linalg_gemm2(kalman_gain, emission_coeff)
)
filtered_cov = F.linalg_gemm2(
ImKA, F.linalg_gemm2(prior_cov, ImKA, transpose_b=True)
) + F.linalg_gemm2(
kalman_gain, F.linalg_gemm2(noise_cov, kalman_gain, transpose_b=True)
)
# likelihood term: (batch_size,)
log_p = MultivariateGaussian(output_mean, L_output_cov).log_prob(
target_minus_residual
)
return filtered_mean, filtered_cov, log_p