Source code for gluonts.transform._base

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import abc
from typing import Callable, Iterable, Iterator, List

from gluonts.core.component import validated
from gluonts.dataset.common import DataEntry, Dataset
from gluonts.env import env


[docs]class Transformation(metaclass=abc.ABCMeta): """ Base class for all Transformations. A Transformation processes works on a stream (iterator) of dictionaries. """ @abc.abstractmethod def __call__( self, data_it: Iterable[DataEntry], is_train: bool ) -> Iterable[DataEntry]: pass
[docs] def chain(self, other: "Transformation") -> "Chain": return Chain([self, other])
def __add__(self, other: "Transformation") -> "Chain": return self.chain(other)
[docs] def apply( self, dataset: Dataset, is_train: bool = True ) -> "TransformedDataset": return TransformedDataset(dataset, self, is_train=is_train)
[docs]class Chain(Transformation): """ Chain multiple transformations together. """ @validated() def __init__(self, trans: List[Transformation]) -> None: self.transformations: List[Transformation] = [] for transformation in trans: # flatten chains if isinstance(transformation, Chain): self.transformations.extend(transformation.transformations) else: self.transformations.append(transformation) def __call__( self, data_it: Iterable[DataEntry], is_train: bool ) -> Iterable[DataEntry]: tmp = data_it for t in self.transformations: tmp = t(tmp, is_train) return tmp
[docs]class TransformedDataset(Dataset): """ A dataset that corresponds to applying a list of transformations to each element in the base_dataset. This only supports SimpleTransformations, which do the same thing at prediction and training time. Parameters ---------- base_dataset Dataset to transform transformations List of transformations to apply """ def __init__( self, base_dataset: Dataset, transformation: Transformation, is_train=True, ) -> None: self.base_dataset = base_dataset self.transformation = transformation self.is_train = is_train def __len__(self): # NOTE this is unsafe when transformations are run with is_train = True # since some transformations may not be deterministic (instance splitter) return sum(1 for _ in self) def __iter__(self) -> Iterator[DataEntry]: yield from self.transformation( self.base_dataset, is_train=self.is_train )
[docs]class Identity(Transformation): def __call__( self, data_it: Iterable[DataEntry], is_train: bool ) -> Iterable[DataEntry]: return data_it
[docs]class MapTransformation(Transformation): """ Base class for Transformations that returns exactly one result per input in the stream. """ def __call__( self, data_it: Iterable[DataEntry], is_train: bool ) -> Iterator: for data_entry in data_it: try: yield self.map_transform(data_entry.copy(), is_train) except Exception as e: raise e
[docs] @abc.abstractmethod def map_transform(self, data: DataEntry, is_train: bool) -> DataEntry: pass
[docs]class SimpleTransformation(MapTransformation): """ Element wise transformations that are the same in train and test mode """
[docs] def map_transform(self, data: DataEntry, is_train: bool) -> DataEntry: return self.transform(data)
[docs] @abc.abstractmethod def transform(self, data: DataEntry) -> DataEntry: pass
[docs]class AdhocTransform(SimpleTransformation): """ Applies a function as a transformation This is called ad-hoc, because it is not serializable. It is OK to use this for experiments and outside of a model pipeline that needs to be serialized. """ def __init__(self, func: Callable[[DataEntry], DataEntry]) -> None: self.func = func
[docs] def transform(self, data: DataEntry) -> DataEntry: return self.func(data.copy())
[docs]class FlatMapTransformation(Transformation): """ Transformations that yield zero or more results per input, but do not combine elements from the input stream. """ @validated() def __init__(self): self.max_idle_transforms = max(env.max_idle_transforms, 100) def __call__( self, data_it: Iterable[DataEntry], is_train: bool ) -> Iterator: num_idle_transforms = 0 for data_entry in data_it: num_idle_transforms += 1 for result in self.flatmap_transform(data_entry.copy(), is_train): num_idle_transforms = 0 yield result if num_idle_transforms > self.max_idle_transforms: raise Exception( f"Reached maximum number of idle transformation calls.\n" f"This means the transformation looped over " f"{self.max_idle_transforms} inputs without returning any " f"output.\nThis occurred in the following transformation:\n" f"{self}" )
[docs] @abc.abstractmethod def flatmap_transform( self, data: DataEntry, is_train: bool ) -> Iterator[DataEntry]: pass
[docs]class FilterTransformation(FlatMapTransformation): def __init__(self, condition: Callable[[DataEntry], bool]) -> None: super().__init__() self.condition = condition
[docs] def flatmap_transform( self, data: DataEntry, is_train: bool ) -> Iterator[DataEntry]: if self.condition(data): yield data