Source code for gluonts.dataset.schema.translate

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.


"""

``gluonts.dataset.schema.translate``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

This module provides a ``Translator`` class, which can be used to translate
dictionaries. It is intended to be used with GluonTS datasets, to allow for
more flexibility in the input data::

    tl = Translator.parse(target="demand", feat_dynamic_real="[price]")

"""

import re
from dataclasses import dataclass, InitVar, field
from itertools import chain
from typing import Any, Dict, List, Union, Optional, ClassVar

import numpy as np
from toolz import valfilter, valmap


[docs]class Op: def __call__(self, item): raise NotImplementedError
[docs] def fields(self): raise NotImplementedError
[docs]@dataclass class Get(Op): """ Extracts the field ``name`` from the input. """ name: str def __call__(self, item): return item[self.name]
[docs] def fields(self): return [self.name]
[docs]@dataclass class Method(Op): obj: Op args: list def __call__(self, item): return self.obj(item)(*self.args)
[docs] def fields(self): return self.op.fields()
[docs]@dataclass class GetAttr(Op): """ Invokes ``obj.name``. """ obj: Op name: str def __call__(self, item): return getattr( self.obj(item), self.name, )
[docs] def fields(self): return self.op.fields()
[docs]@dataclass class GetItem(Op): obj: Op dims: List[int] def __call__(self, item): return self.obj(item).__getitem__(self.dims)
[docs] def fields(self): return self.op.fields()
[docs]@dataclass class Stack(Op): objects: List[Op] def __call__(self, item): return np.stack([obj(item) for obj in self.objects])
[docs] def fields(self): return chain.from_iterable(op.fields() for op in self.objects)
[docs]def one_of(s): options = "|".join(map(re.escape, s)) return rf"[{options}]"
[docs]@dataclass class Token: name: str value: str match: Any
[docs]@dataclass class TokenStream: TOKENS: ClassVar[dict] = { "DOT": re.escape("."), "COMMA": re.escape(","), "PAREN_OPEN": one_of("[("), "PARAN_CLOSE": one_of("])"), "NUMBER": r"\-?\d+", "NAME": r"\w+", "WHITESPACE": r"\s+", "INVALID": r".+", } RX: ClassVar[str] = "|".join( f"(?P<{name}>{pattern})" for name, pattern in TOKENS.items() ) tokens: List[Token] idx: InitVar[int] = 0
[docs] @classmethod def from_str(cls, s): stream = cls( [ Token(name, value, match) for match in re.finditer(cls.RX, s) for name, value in valfilter(bool, match.groupdict()).items() if name != "WHITESPACE" ] ) for token in stream: if token.name == "INVALID": raise ValueError(f"Invalid token: {token}") return stream
[docs] def pop(self, ty=None, val=None): token = self.tokens[self.idx] assert check_type(token, ty, val), f"Expected {ty}, got {token}." self.idx += 1 return token
[docs] def pop_if(self, ty=None, val=None): if self.peek(ty, val): return self.pop()
[docs] def peek(self, ty=None, val=None): if self: token = self.tokens[self.idx] if check_type(token, ty, val): return token return None
def __len__(self): return len(self.tokens) - self.idx def __repr__(self): return "".join(token.value for token in self.tokens[self.idx :]) def __iter__(self): yield from self.tokens[self.idx :]
[docs]def check_type(token, ty, val): def matches(tok: Token): return tok.name == ty and (val is None or tok.value == val) if ty is None: return True if isinstance(ty, str): return matches(token) if isinstance(ty, list): return any(map(matches, token)) return False
[docs]@dataclass class Parser: stream: TokenStream
[docs] def parse_number(self): return int(self.stream.pop("NUMBER").value)
[docs] def parse_args(self): args = [] # no args: `f()` if self.stream.peek("PARAN_CLOSE", ")"): return args while True: args.append(self.parse_number()) if self.stream.pop_if("COMMA"): continue if self.stream.pop_if("PARAN_CLOSE", ")"): return args raise ValueError(f"Invalid token {self.stream.peak()}")
[docs] def parse_getitem(self, obj): self.stream.pop("PAREN_OPEN", "[") dims = [self.parse_number()] while self.stream.pop_if("COMMA"): dims.append(self.parse_number()) self.stream.pop("PARAN_CLOSE", "]") if len(dims) == 1: return GetItem(obj, dims[0]) else: return GetItem(obj, tuple(dims))
[docs] def parse_dot(self, obj): self.stream.pop("DOT") name = self.stream.pop("NAME").value return GetAttr(obj, name)
[docs] def parse_invoke(self, obj): self.stream.pop("PAREN_OPEN", "(") args = self.parse_args() self.stream.pop("PARAN_CLOSE", ")") return Method(obj, args)
[docs] def parse_expr(self): if self.stream.peek("PAREN_OPEN", "["): self.stream.pop() expr = [self.parse_expr()] while self.stream.pop_if("COMMA"): expr.append(self.parse_expr()) self.stream.pop("PARAN_CLOSE", "]") obj = Stack(expr) else: token = self.stream.pop("NAME") obj = Get(token.value) while self.stream: if self.stream.peek("DOT"): obj = self.parse_dot(obj) elif self.stream.peek("PAREN_OPEN", "("): obj = self.parse_invoke(obj) elif self.stream.peek("PAREN_OPEN", "["): obj = self.parse_getitem(obj) else: break # raise ValueError(f"Invalid token {self.stream.peek()}") return obj
[docs]def parse(x: Union[str, list]) -> Op: if isinstance(x, list): return Stack(list(map(parse, x))) else: ts = TokenStream.from_str(x) return Parser(ts).parse_expr()
[docs]@dataclass class Translator: """ Simple translation for GluonTS Datasets. A given translator transforms an input dictionary (data-entry) into an output dictionary. Basic usage:: >>> tl = Translator.parse(x="a[0]") >>> data = {"a": [1, 2, 3]} >>> assert tl(data)["x"] == 1 A translator first copies all input fields into a new dictionary, before applying the translations. Thus, an empty `Translator` acts like the identity function for dictionaries: >>> identity = Translator() >>> data = {"a": 1, "b": 2, "c": 3} >>> assert identity(data) == data Using ``Translator.parse(...)```, one can define expressions to be applied to the input data. For example, ``Translator.parse(x="y")`` will write the the value of `y` to the `x` column in the output. These right-hand expressions support indexing (e.g. ``y[1]``), attribute access (e.g. ``x.T``) and method invocation (e.g. ``y.transpose(1, 0)``). """ fields: Dict[str, Op] = field(default_factory=dict) drop: bool = False
[docs] @staticmethod def parse( fields: Optional[dict] = None, drop: bool = False, **kwargs_fields ) -> "Translator": fields_ = {} if fields is not None: fields_.update(fields) fields_.update(kwargs_fields) return Translator(valmap(parse, fields_), drop)
def __call__(self, item): if self.drop: keys = item.keys() - self.get_fields() result = {key: item[key] for key in keys} else: result = dict(item) result.update( {name: field(item) for name, field in self.fields.items()} ) return result
[docs] def get_fields(self): return chain.from_iterable(op.fields() for op in self.fields.values())