# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from functools import partial
from typing import Dict, Optional, Tuple, Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Distribution, Beta, constraints
from gluonts.core.component import validated
from gluonts.torch.distributions import DistributionOutput
from gluonts.torch.modules.lambda_layer import LambdaLayer
[docs]class QuantileLayer(nn.Module):
r"""
Implicit Quantile Layer from the paper ``IQN for Distributional
Reinforcement Learning`` (https://arxiv.org/abs/1806.06923) by Dabney et
al.
2018.
"""
def __init__(self, num_output: int, cos_embedding_dim: int = 128):
super().__init__()
self.output_layer = nn.Sequential(
nn.Linear(cos_embedding_dim, cos_embedding_dim),
nn.PReLU(),
nn.Linear(cos_embedding_dim, num_output),
)
self.register_buffer("integers", torch.arange(0, cos_embedding_dim))
[docs] def forward(self, tau: torch.Tensor) -> torch.Tensor: # tau: [B, T]
cos_emb_tau = torch.cos(tau.unsqueeze(-1) * self.integers * torch.pi)
return self.output_layer(cos_emb_tau)
[docs]class ImplicitQuantileModule(nn.Module):
r"""
Implicit Quantile Network from the paper ``IQN for Distributional
Reinforcement Learning`` (https://arxiv.org/abs/1806.06923) by Dabney et
al.
2018.
"""
def __init__(
self,
in_features: int,
args_dim: Dict[str, int],
domain_map: Callable[..., Tuple[torch.Tensor]],
concentration1: float = 1.0,
concentration0: float = 1.0,
output_domain_map=None,
cos_embedding_dim: int = 64,
):
super().__init__()
self.output_domain_map = output_domain_map
self.domain_map = domain_map
self.beta = Beta(
concentration1=concentration1, concentration0=concentration0
)
self.quantile_layer = QuantileLayer(
in_features, cos_embedding_dim=cos_embedding_dim
)
self.output_layer = nn.Sequential(
nn.Linear(in_features, in_features), nn.PReLU()
)
self.proj = nn.ModuleList(
[nn.Linear(in_features, dim) for dim in args_dim.values()]
)
[docs] def forward(self, inputs: torch.Tensor):
if self.training:
taus = self.beta.sample(sample_shape=inputs.shape[:-1]).to(
inputs.device
)
else:
taus = torch.rand(size=inputs.shape[:-1], device=inputs.device)
emb_taus = self.quantile_layer(taus)
emb_inputs = inputs * (1.0 + emb_taus)
emb_outputs = self.output_layer(emb_inputs)
outputs = [proj(emb_outputs).squeeze(-1) for proj in self.proj]
if self.output_domain_map is not None:
outputs = [self.output_domain_map(output) for output in outputs]
return (*self.domain_map(*outputs), taus)
[docs]class ImplicitQuantileNetwork(Distribution):
r"""
Distribution class for the Implicit Quantile from which we can sample or
calculate the quantile loss.
Parameters
----------
outputs
Outputs from the Implicit Quantile Network.
taus
Tensor random numbers from the Beta or Uniform distribution for the
corresponding outputs.
"""
arg_constraints: Dict[str, constraints.Constraint] = {}
def __init__(
self, outputs: torch.Tensor, taus: torch.Tensor, validate_args=None
):
self.taus = taus
self.outputs = outputs
super().__init__(
batch_shape=outputs.shape, validate_args=validate_args
)
[docs] @torch.no_grad()
def sample(self, sample_shape=torch.Size()) -> torch.Tensor:
return self.outputs
[docs] def quantile_loss(self, value: torch.Tensor) -> torch.Tensor:
# penalize by tau for under-predicting
# and by 1-tau for over-predicting
return (self.taus - (value < self.outputs).float()) * (
value - self.outputs
)
[docs]class ImplicitQuantileNetworkOutput(DistributionOutput):
r"""
DistributionOutput class for the IQN from the paper ``Probabilistic Time
Series Forecasting with Implicit Quantile Networks``
(https://arxiv.org/abs/2107.03743) by Gouttes et al. 2021.
Parameters
----------
output_domain
Optional domain mapping of the output. Can be "positive", "unit"
or None.
concentration1
Alpha parameter of the Beta distribution when sampling the taus
during training.
concentration0
Beta parameter of the Beta distribution when sampling the taus
during training.
cos_embedding_dim
The embedding dimension for the taus embedding layer of IQN.
Default is 64.
"""
distr_cls = ImplicitQuantileNetwork
args_dim = {"quantile_function": 1}
@validated()
def __init__(
self,
output_domain: Optional[str] = None,
concentration1: float = 1.0,
concentration0: float = 1.0,
cos_embedding_dim: int = 64,
) -> None:
super().__init__()
self.concentration1 = concentration1
self.concentration0 = concentration0
self.cos_embedding_dim = cos_embedding_dim
if output_domain in ["positive", "unit"]:
output_domain_map_func = {
"positive": F.softplus,
"unit": partial(F.softmax, dim=-1),
}
self.output_domain_map = output_domain_map_func[output_domain]
else:
self.output_domain_map = None
[docs] def get_args_proj(self, in_features: int) -> nn.Module:
return ImplicitQuantileModule(
in_features=in_features,
args_dim=self.args_dim,
output_domain_map=self.output_domain_map,
domain_map=LambdaLayer(self.domain_map),
concentration1=self.concentration1,
concentration0=self.concentration0,
cos_embedding_dim=self.cos_embedding_dim,
)
[docs] @classmethod
def domain_map(cls, *args):
return args
[docs] def distribution(
self, distr_args, loc=0, scale=None
) -> ImplicitQuantileNetwork:
(outputs, taus) = distr_args
if scale is not None:
outputs = outputs * scale
if loc is not None:
outputs = outputs + loc
return self.distr_cls(outputs=outputs, taus=taus)
@property
def event_shape(self):
return ()
[docs] def loss(
self,
target: torch.Tensor,
distr_args: Tuple[torch.Tensor, ...],
loc: Optional[torch.Tensor] = None,
scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
distribution = self.distribution(distr_args, loc=loc, scale=scale)
return distribution.quantile_loss(target)