Source code for gluonts.nursery.spliced_binned_pareto.gaussian_model
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
import torch
import torch.nn.functional as F
from torch import nn
from torch.distributions.normal import Normal
[docs]class GaussianModel(nn.Module):
r"""
Model to learn a univariate Gaussian distribution.
Arguments
----------
mu: Mean of the Gaussian distribution
sigma: Standard deviation of the Gaussian distribution
device: The torch.device to use, typically cpu or gpu id
"""
def __init__(self, mu, sigma, device=None):
super().__init__()
if device is not None:
self.device = device
mu = mu.to(device)
sigma = sigma.to(device)
self.mu = mu
self.sigma = sigma
self.distr = Normal(self.mu, self.sigma)
[docs] def to_device(self, device):
"""
Moves members to a specified torch.device.
"""
self.device = device
[docs] def forward(self, x):
"""
Takes input x as new distribution parameters.
"""
# If mini-batching
if len(x.shape) > 1:
self.mu_batch = x[:, 0]
self.sigma_batch = F.softplus(x[:, 1])
# If not mini-batching
else:
self.mu = x[0]
self.distr = Normal(self.mu, self.sigma)
return self.distr
[docs] def log_prob(self, x):
x = x.view(x.shape.numel())
if x.shape[0] == 1:
return self.distr.log_prob(x[0]).view(1)
log_like_arr = torch.ones_like(x)
for i in range(len(x)):
self.mu = self.mu_batch[i]
self.distr = Normal(self.mu, self.sigma)
lpxx = self.distr.log_prob(x[i]).view(1)
log_like_arr[i] = lpxx
lpx = log_like_arr
return lpx