gluonts.nursery.gmm_tpp.gmm_base module#
- class gluonts.nursery.gmm_tpp.gmm_base.GMMModel(*data_template, num_clusters, log_prior_=None, mu_=None, kR_=None, lr_mult=None, ctx=None, hybridize=None)[source]#
Bases:
mxnet.gluon.block.HybridBlock
log p(x) >= E_{q(z|x)}[ log p(x|z) + log p(z) - log q(z|x) ] p(z) = unif p(x|z) = N(x; mu_z, (kR_z’kR_z)^-1) q(z|x) = softmax(-0.5*(kR_z(x-mu_z))^2 +logdet(kR_z) -0.5d*log(2*pi)) mu_z = E_{q(z|x)} x cov_z = E_{q(z|x)} x^2 - (E_{q(z|x)} x)^2 kR_z = inv(choL(cov_z))
Shapes === x: (batch_size, input_dim) log_marg: (batch_size,) qz: (batch_size, num_clusters)
- class gluonts.nursery.gmm_tpp.gmm_base.GMMTrainer(model, pseudo_count=0.1, jitter=1e-06)[source]#
Bases:
object
trainer based on M-step summary statistics can add mini-batch statistics for a full-batch update.