# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""
``gluonts.dataset.schema.translate``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
This module provides a ``Translator`` class, which can be used to translate
dictionaries. It is intended to be used with GluonTS datasets, to allow for
more flexibility in the input data::
tl = Translator.parse(target="demand", feat_dynamic_real="[price]")
"""
import re
from dataclasses import dataclass, InitVar, field
from itertools import chain
from typing import Any, Dict, List, Union, Optional, ClassVar
import numpy as np
from toolz import valfilter, valmap
[docs]class Op:
def __call__(self, item):
raise NotImplementedError
[docs] def fields(self):
raise NotImplementedError
[docs]@dataclass
class Get(Op):
"""
Extracts the field ``name`` from the input.
"""
name: str
def __call__(self, item):
return item[self.name]
[docs] def fields(self):
return [self.name]
[docs]@dataclass
class Method(Op):
obj: Op
args: list
def __call__(self, item):
return self.obj(item)(*self.args)
[docs] def fields(self):
return self.op.fields()
[docs]@dataclass
class GetAttr(Op):
"""
Invokes ``obj.name``.
"""
obj: Op
name: str
def __call__(self, item):
return getattr(
self.obj(item),
self.name,
)
[docs] def fields(self):
return self.op.fields()
[docs]@dataclass
class GetItem(Op):
obj: Op
dims: List[int]
def __call__(self, item):
return self.obj(item).__getitem__(self.dims)
[docs] def fields(self):
return self.op.fields()
[docs]@dataclass
class Stack(Op):
objects: List[Op]
def __call__(self, item):
return np.stack([obj(item) for obj in self.objects])
[docs] def fields(self):
return chain.from_iterable(op.fields() for op in self.objects)
[docs]def one_of(s):
options = "|".join(map(re.escape, s))
return rf"[{options}]"
[docs]@dataclass
class Token:
name: str
value: str
match: Any
[docs]@dataclass
class TokenStream:
TOKENS: ClassVar[dict] = {
"DOT": re.escape("."),
"COMMA": re.escape(","),
"PAREN_OPEN": one_of("[("),
"PARAN_CLOSE": one_of("])"),
"NUMBER": r"\-?\d+",
"NAME": r"\w+",
"WHITESPACE": r"\s+",
"INVALID": r".+",
}
RX: ClassVar[str] = "|".join(
f"(?P<{name}>{pattern})" for name, pattern in TOKENS.items()
)
tokens: List[Token]
idx: InitVar[int] = 0
[docs] @classmethod
def from_str(cls, s):
stream = cls(
[
Token(name, value, match)
for match in re.finditer(cls.RX, s)
for name, value in valfilter(bool, match.groupdict()).items()
if name != "WHITESPACE"
]
)
for token in stream:
if token.name == "INVALID":
raise ValueError(f"Invalid token: {token}")
return stream
[docs] def pop(self, ty=None, val=None):
token = self.tokens[self.idx]
assert check_type(token, ty, val), f"Expected {ty}, got {token}."
self.idx += 1
return token
[docs] def pop_if(self, ty=None, val=None):
if self.peek(ty, val):
return self.pop()
[docs] def peek(self, ty=None, val=None):
if self:
token = self.tokens[self.idx]
if check_type(token, ty, val):
return token
return None
def __len__(self):
return len(self.tokens) - self.idx
def __repr__(self):
return "".join(token.value for token in self.tokens[self.idx :])
def __iter__(self):
yield from self.tokens[self.idx :]
[docs]def check_type(token, ty, val):
def matches(tok: Token):
return tok.name == ty and (val is None or tok.value == val)
if ty is None:
return True
if isinstance(ty, str):
return matches(token)
if isinstance(ty, list):
return any(map(matches, token))
return False
[docs]@dataclass
class Parser:
stream: TokenStream
[docs] def parse_number(self):
return int(self.stream.pop("NUMBER").value)
[docs] def parse_args(self):
args = []
# no args: `f()`
if self.stream.peek("PARAN_CLOSE", ")"):
return args
while True:
args.append(self.parse_number())
if self.stream.pop_if("COMMA"):
continue
if self.stream.pop_if("PARAN_CLOSE", ")"):
return args
raise ValueError(f"Invalid token {self.stream.peak()}")
[docs] def parse_getitem(self, obj):
self.stream.pop("PAREN_OPEN", "[")
dims = [self.parse_number()]
while self.stream.pop_if("COMMA"):
dims.append(self.parse_number())
self.stream.pop("PARAN_CLOSE", "]")
if len(dims) == 1:
return GetItem(obj, dims[0])
else:
return GetItem(obj, tuple(dims))
[docs] def parse_dot(self, obj):
self.stream.pop("DOT")
name = self.stream.pop("NAME").value
return GetAttr(obj, name)
[docs] def parse_invoke(self, obj):
self.stream.pop("PAREN_OPEN", "(")
args = self.parse_args()
self.stream.pop("PARAN_CLOSE", ")")
return Method(obj, args)
[docs] def parse_expr(self):
if self.stream.peek("PAREN_OPEN", "["):
self.stream.pop()
expr = [self.parse_expr()]
while self.stream.pop_if("COMMA"):
expr.append(self.parse_expr())
self.stream.pop("PARAN_CLOSE", "]")
obj = Stack(expr)
else:
token = self.stream.pop("NAME")
obj = Get(token.value)
while self.stream:
if self.stream.peek("DOT"):
obj = self.parse_dot(obj)
elif self.stream.peek("PAREN_OPEN", "("):
obj = self.parse_invoke(obj)
elif self.stream.peek("PAREN_OPEN", "["):
obj = self.parse_getitem(obj)
else:
break
# raise ValueError(f"Invalid token {self.stream.peek()}")
return obj
[docs]def parse(x: Union[str, list]) -> Op:
if isinstance(x, list):
return Stack(list(map(parse, x)))
else:
ts = TokenStream.from_str(x)
return Parser(ts).parse_expr()
[docs]@dataclass
class Translator:
"""
Simple translation for GluonTS Datasets.
A given translator transforms an input dictionary (data-entry) into an
output dictionary.
Basic usage::
>>> tl = Translator.parse(x="a[0]")
>>> data = {"a": [1, 2, 3]}
>>> assert tl(data)["x"] == 1
A translator first copies all input fields into a new dictionary, before
applying the translations. Thus, an empty `Translator` acts like the
identity function for dictionaries:
>>> identity = Translator()
>>> data = {"a": 1, "b": 2, "c": 3}
>>> assert identity(data) == data
Using ``Translator.parse(...)```, one can define expressions to be applied
to the input data. For example, ``Translator.parse(x="y")`` will write the
the value of `y` to the `x` column in the output.
These right-hand expressions support indexing (e.g. ``y[1]``), attribute
access (e.g. ``x.T``) and method invocation (e.g. ``y.transpose(1, 0)``).
"""
fields: Dict[str, Op] = field(default_factory=dict)
drop: bool = False
[docs] @staticmethod
def parse(
fields: Optional[dict] = None, drop: bool = False, **kwargs_fields
) -> "Translator":
fields_ = {}
if fields is not None:
fields_.update(fields)
fields_.update(kwargs_fields)
return Translator(valmap(parse, fields_), drop)
def __call__(self, item):
if self.drop:
keys = item.keys() - self.get_fields()
result = {key: item[key] for key in keys}
else:
result = dict(item)
result.update(
{name: field(item) for name, field in self.fields.items()}
)
return result
[docs] def get_fields(self):
return chain.from_iterable(op.fields() for op in self.fields.values())