gluonts.nursery.gmm_tpp.gmm_base module

class gluonts.nursery.gmm_tpp.gmm_base.GMMModel(*data_template, num_clusters, log_prior_=None, mu_=None, kR_=None, lr_mult=None, ctx=None, hybridize=None)[source]

Bases: mxnet.gluon.block.HybridBlock

log p(x) >= E_{q(z|x)}[ log p(x|z) + log p(z) - log q(z|x) ] p(z) = unif p(x|z) = N(x; mu_z, (kR_z’kR_z)^-1) q(z|x) = softmax(-0.5*(kR_z(x-mu_z))^2 +logdet(kR_z) -0.5d*log(2*pi)) mu_z = E_{q(z|x)} x cov_z = E_{q(z|x)} x^2 - (E_{q(z|x)} x)^2 kR_z = inv(choL(cov_z))

Shapes === x: (batch_size, input_dim) log_marg: (batch_size,) qz: (batch_size, num_clusters)

hybrid_forward(F, x, log_prior_, mu_, kR_)[source]

E-step computes log_marginal and q(z|x)

static m_step(x, qz)[source]

M-step computes summary statistics in numpy

class gluonts.nursery.gmm_tpp.gmm_base.GMMTrainer(model, pseudo_count=0.1, jitter=1e-06)[source]

Bases: object

trainer based on M-step summary statistics can add mini-batch statistics for a full-batch update


similar to enumerate but prepend elapsed time since loop starts

gluonts.nursery.gmm_tpp.gmm_base.infer_lambda(model, *_, xmin, xmax)[source]

infer lambda and intercept based on linear fitting at the base points