Source code for gluonts.nursery.pipeline.action

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from dataclasses import dataclass
from typing import Any, Dict, List, Type

from gluonts import itertools as it


[docs]class Action:
[docs] def apply(self, stream): raise NotImplementedError
[docs] def apply_one(self, data): return list(self.apply([data]))
[docs] def apply_schema(self, schema): raise NotImplementedError
[docs] def bind(self, schema): return Bind(self, schema, self.apply_schema(schema))
[docs] def requires(self): raise NotImplementedError
def __add__(self, other): return other.__radd__(self) def __radd__(self, other): if isinstance(other, Pipeline): return Pipeline(other.actions + [self]) return Pipeline([other, self])
[docs]@dataclass class Bind(Action): action: Action input_schema: "Schema" output_schema: "Schema"
[docs] def apply(self, stream): return self.action.apply(stream)
[docs]@dataclass class Pipeline(Action): actions: List[Action]
[docs] def apply(self, stream): for action in self.actions: stream = action.apply(stream) return stream
[docs] def apply_schema(self, schema): for action in self.actions: schema = action.apply_schema(schema) return schema
def __radd__(self, other): if isinstance(other, Pipeline): return Pipeline(other.actions + self.actions) return Pipeline([other] + self.actions)
[docs]class Filter(Action):
[docs] def filter(self, data): raise NotImplementedError
[docs] def apply_schema(self, schema): return schema
[docs] def apply(self, data): return it.Filter(self.filter, data)
[docs]class Map(Action): def __call__(self, data: dict) -> dict: raise NotImplementedError
[docs] def apply(self, stream): return it.Map(self, stream)
def __radd__(self, other): if isinstance(other, Map): if isinstance(other, MapPipeline): return MapPipeline(other.actions + [self]) return MapPipeline([other, self]) return Action.__radd__(self, other)
[docs]@dataclass class MapPipeline(Pipeline, Map): def __call__(self, data): for result in self.apply_one(data): return result def __radd__(self, other): if isinstance(other, MapPipeline): return MapPipeline(other.actions + self.actions) elif isinstance(other, Map): return MapPipeline([other] + self.actions) else: return Pipeline.__radd__(self, other)
[docs]class Identity(Map): def __add__(self, other): return other def __radd__(self, other): return other def __call__(self, data): return data
[docs] def apply_schema(self, schema): return schema
[docs]class Copy(Map): def __call__(self, data): return dict(data)
[docs] def apply_schema(self, schema): return schema
[docs]@dataclass class Set(Map): name: str value: Any def __call__(self, data): data[self.name] = self.value return data
[docs] def apply_schema(self, schema): return schema.add(self.name, type(self.value))
[docs]@dataclass class SetDefault(Map): name: str value: Any def __call__(self, data): data.setdefault(self.name, self.value) return data
[docs] def apply_schema(self, schema): return schema.union(self.name, type(self.value))
[docs]@dataclass class Update(Map): fields: Dict[str, Type] def __call__(self, data): data.update(self.fields) return data
[docs] def apply_schema(self, schema): for name, value in self.fields.items(): schema = schema.add(name, type(value)) return schema
[docs]@dataclass class UpdateDefault(Map): fields: Dict[str, Type] def __call__(self, data): result = dict(self.fields) result.update(data) return result
[docs] def apply_schema(self, schema): for name, value in self.fields.items(): schema = schema.union(name, type(value)) return schema
[docs]@dataclass class Remove(Map): name: str def __call__(self, data): del data[self.name] return data
[docs] def apply_schema(self, schema): return schema.remove(self.name)
[docs]@dataclass class Move(Map): source: str target: str def __call__(self, data): data[self.target] = data.pop(self.source) return data
[docs] def apply_schema(self, schema): schema, ty = schema.pop(self.source) return schema.set(self.target, ty)