Source code for gluonts.itertools

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from __future__ import annotations

import itertools
import math
import pickle
import random
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import (
    Callable,
    Collection,
    Dict,
    Iterable,
    Iterator,
    List,
    Optional,
    TypeVar,
    Sequence,
    Tuple,
    NamedTuple,
)
from typing_extensions import Protocol, runtime_checkable
from numpy.random import RandomState
from toolz import curry

import numpy as np


[docs]@runtime_checkable class SizedIterable(Protocol): def __len__(self): ... def __iter__(self): ...
T = TypeVar("T") # key / value K = TypeVar("K") V = TypeVar("V")
[docs]def maybe_len(obj) -> Optional[int]: try: return len(obj) except (NotImplementedError, AttributeError): return None
[docs]def prod(xs): """ Compute the product of the elements of an iterable object. """ p = 1 for x in xs: p *= x return p
[docs]@dataclass class Cyclic: """ Like `itertools.cycle`, but does not store the data. """ iterable: SizedIterable def __iter__(self): at_least_one = False while True: for el in self.iterable: at_least_one = True yield el if not at_least_one: break
[docs] def stream(self): """ Return a continuous stream of self that has no fixed start. When re-iterating ``Cyclic`` it will yield elements from the start of the passed ``iterable``. However, this is not always desired; e.g. in training we want to treat training data as an infinite stream of values and not start at the beginning of the dataset for each epoch. >>> from toolz import take >>> c = Cyclic([1, 2, 3, 4]) >>> assert list(take(5, c)) == [1, 2, 3, 4, 1] >>> assert list(take(5, c)) == [1, 2, 3, 4, 1] >>> s = Cyclic([1, 2, 3, 4]).stream() >>> assert list(take(5, s)) == [1, 2, 3, 4, 1] >>> assert list(take(5, s)) == [2, 3, 4, 1, 2] """ return iter(self)
def __len__(self) -> int: return len(self.iterable)
[docs]def batcher(iterable: Iterable[T], batch_size: int) -> Iterator[List[T]]: """ Groups elements from `iterable` into batches of size `batch_size`. >>> list(batcher("ABCDEFG", 3)) [['A', 'B', 'C'], ['D', 'E', 'F'], ['G']] Unlike the grouper proposed in the documentation of itertools, `batcher` doesn't fill up missing values. """ it: Iterator[T] = iter(iterable) def get_batch(): return list(itertools.islice(it, batch_size)) # has an empty list so that we have a 2D array for sure return iter(get_batch, [])
[docs]@dataclass class Chain: """ Chain multiple iterables into a single one. This is a thin wrapper around ``itertools.chain``. """ iterables: Collection[SizedIterable] def __iter__(self): yield from itertools.chain.from_iterable(self.iterables) def __len__(self) -> int: return sum(map(len, self.iterables))
class _SubIndex(NamedTuple): """ A nested index with two layers. """ item: int local: int
[docs]@dataclass class Fuse: """ Fuse collections together to act as single collections. >>> a = [0, 1, 2] >>> b = [3, 4, 5] >>> fused = Fuse([a, b]) >>> assert len(a) + len(b) == len(fused) >>> list(fused[2:5]) [2, 3, 4] >>> list(fused[3:]) [3, 4, 5] This is similar to ``Chain``, but also allows to select directly into the data. While ``Chain`` creates a single ``Iterable`` out of a set of ``Iterable``s, ``Fuse`` creates a single ``Collection`` out of a set of ``Collection``s. """ collections: List[Sequence] _lengths: List[int] = field(default_factory=list) def __post_init__(self): if not self._lengths: self._lengths = list(map(len, self.collections)) self._length = sum(self._lengths) self._offsets = np.cumsum(self._lengths) def __len__(self): return self._length def _get_range(self, start: _SubIndex, stop: _SubIndex) -> "Fuse": first = self.collections[start.item] if start.item == stop.item: return Fuse([first[start.local : stop.local]]) items = [] first = first[start.local :] if len(first) > 0: items.append(first) for item_index in range(start.item + 1, stop.item): items.append(self.collections[item_index]) items.append(self.collections[stop.item][: stop.local]) return Fuse(items) def _location_for(self, idx, side="right") -> _SubIndex: """ Map global index to pair of index to collection and index within that collection. >>> fuse = Fuse([[0, 0], [1, 1]]) >>> fuse._location_for(0) _SubIndex(item=0, local=0) >>> fuse._location_for(1) _SubIndex(item=0, local=1) >>> fuse._location_for(2) _SubIndex(item=1, local=0) >>> fuse._location_for(3) _SubIndex(item=1, local=1) """ if idx == 0 or not self: return _SubIndex(0, 0) # When the index is out of bounds, we fall back to the last element if idx >= len(self): return _SubIndex( len(self.collections) - 1, len(self.collections[-1]), ) part_no = np.searchsorted(self._offsets, idx, side) if part_no == 0: local_idx = idx else: local_idx = idx - self._offsets[part_no - 1] return _SubIndex(int(part_no), int(local_idx)) def __getitem__(self, idx): if isinstance(idx, slice): start, stop, step = idx.indices(len(self)) assert step is None or step == 1 start_index = self._location_for(start) stop_index = self._location_for(stop) return self._get_range(start_index, stop_index) subindex = self._location_for(idx) return self.collections[subindex.item][subindex.local] def __iter__(self): yield from itertools.chain.from_iterable(self.collections) def __repr__(self): return f"Fuse<size={len(self)}>"
[docs]def split(xs: Sequence, indices: List[int]) -> List[Sequence]: """ Split ``xs`` into subsets given ``indices``. >>> split("abcdef", [1, 3]) ['a', 'bc', 'def'] This is similar to ``numpy.split`` when passing a list of indices, but this version does not convert the underlying data into arrays. """ return [ xs[start:stop] for start, stop in zip( itertools.chain([None], indices), itertools.chain(indices, [None]), ) ]
[docs]def split_into(xs: Sequence, n: int) -> Sequence: """ Split ``xs`` into ``n`` parts of similar size. >>> split_into("abcd", 2) ['ab', 'cd'] >>> split_into("abcd", 3) ['ab', 'c', 'd'] """ bucket_size, remainder = divmod(len(xs), n) # We need one fewer than `n`, since these become split positions. relative_splits = np.full(n - 1, bucket_size) # e.g. 10 by 3 -> 4, 3, 3 relative_splits[:remainder] += 1 return split(xs, np.cumsum(relative_splits))
[docs]@dataclass class Cached: """ An iterable wrapper, which caches values in a list while iterated. The primary use-case for this is to avoid re-computing the elements of the sequence, in case the inner iterable does it on demand. This should be used to wrap deterministic iterables, i.e. iterables where the data generation process is not random, and that yield the same elements when iterated multiple times. """ iterable: SizedIterable provider: Iterable = field(init=False) consumed: list = field(default_factory=list, init=False) def __post_init__(self): # ensure we only iterate once over the iterable self.provider = iter(self.iterable) def __iter__(self): # Yield already provided values first yield from self.consumed # Now yield remaining elements. for element in self.provider: self.consumed.append(element) yield element def __len__(self) -> int: return len(self.iterable)
[docs]@dataclass class PickleCached: """ A caching wrapper for ``iterable`` using ``pickle`` to store cached values on disk. See :class:`Cached` for more information. """ iterable: SizedIterable cached: bool = False _path: Path = field( default_factory=lambda: Path( tempfile.NamedTemporaryFile(delete=False).name ) ) def __iter__(self): if not self.cached: with open(self._path, "wb") as tmpfile: for batch in batcher(self.iterable, 16): pickle.dump(batch, tmpfile) yield from batch self.cached = True else: with open(self._path, "rb") as tmpfile: while True: try: yield from pickle.load(tmpfile) except EOFError: return def __del__(self): self._path.unlink()
[docs]@dataclass class PseudoShuffled: """ Yield items from a given iterable in a pseudo-shuffled order. """ iterable: SizedIterable shuffle_buffer_length: int def __iter__(self): shuffle_buffer = [] for element in self.iterable: shuffle_buffer.append(element) if len(shuffle_buffer) >= self.shuffle_buffer_length: yield shuffle_buffer.pop(random.randrange(len(shuffle_buffer))) while shuffle_buffer: yield shuffle_buffer.pop(random.randrange(len(shuffle_buffer))) def __len__(self) -> int: return len(self.iterable)
# can't make this a dataclass because of pytorch-lightning assumptions
[docs]class IterableSlice: """ An iterable version of `itertools.islice`, i.e. one that can be iterated over multiple times: >>> isl = IterableSlice(iter([1, 2, 3, 4, 5]), 3) >>> list(isl) [1, 2, 3] >>> list(isl) [4, 5] >>> list(isl) [] This needs to be a class to support re-entry iteration. """ def __init__(self, iterable, length): self.iterable = iterable self.length = length def __iter__(self): yield from itertools.islice(self.iterable, self.length)
[docs]class SizedIterableSlice(IterableSlice): """ Same as ``IterableSlice`` but also supports `len()`: >>> isl = SizedIterableSlice([1, 2, 3, 4, 5], 3) >>> len(isl) 3 """ def __len__(self): # NOTE: This works correctly only when self.iterable supports `len()`. total_len = len(self.iterable) return min(total_len, self.length)
[docs]@dataclass class Map: fn: Callable iterable: SizedIterable def __iter__(self): return map(self.fn, self.iterable) def __len__(self): return len(self.iterable)
[docs]@dataclass class StarMap: fn: Callable iterable: SizedIterable def __iter__(self): return itertools.starmap(self.fn, self.iterable) def __len__(self): return len(self.iterable)
[docs]@dataclass class Filter: fn: Callable iterable: SizedIterable def __iter__(self): return filter(self.fn, self.iterable)
[docs]@dataclass class RandomYield: """ Given a list of Iterables `iterables`, generate samples from them. When `probabilities` is given, sample iteratbles according to it. When `probabilities` is not given, sample iterables uniformly. When one iterable exhausts, the sampling probabilities for it will be set to 0. >>> from toolz import take >>> a = [1, 2, 3] >>> b = [4, 5, 6] >>> it = iter(RandomYield([a, b], probabilities=[1, 0])) >>> list(take(5, it)) [1, 2, 3] >>> a = [1, 2, 3] >>> b = [4, 5, 6] >>> it = iter(RandomYield([Cyclic(a), b], probabilities=[1, 0])) >>> list(take(5, it)) [1, 2, 3, 1, 2] """ iterables: List[Iterable] probabilities: Optional[List[float]] = None random_state: RandomState = field(default_factory=RandomState) def __post_init__(self): if not self.probabilities: self.probabilities = [1.0 / len(self.iterables)] * len( self.iterables ) def __iter__(self): iterators = [iter(it) for it in self.iterables] probs = list(self.probabilities) while True: idx = self.random_state.choice(range(len(iterators)), p=probs) try: yield next(iterators[idx]) except StopIteration: probs[idx] = 0 if sum(probs) == 0: return probs = [prob / sum(probs) for prob in probs]
[docs]def rows_to_columns( rows: Sequence[Dict[K, V]], wrap: Callable[[Sequence[V]], Sequence[V]] = lambda x: x, ) -> Dict[K, Sequence[V]]: """ Transpose rows of dicts, to one dict containing columns. >>> rows_to_columns([{'a': 1, 'b': 2}, {'a': 3, 'b': 4}]) {'a': [1, 3], 'b': [2, 4]} This can also be understood as stacking the values of each dict onto each other. """ if not rows: return {} column_names = rows[0].keys() return { column_name: wrap([row[column_name] for row in rows]) for column_name in column_names }
[docs]def columns_to_rows(columns: Dict[K, Sequence[V]]) -> List[Dict[K, V]]: """ Transpose column-orientation to row-orientation. >>> columns_to_rows({'a': [1, 3], 'b': [2, 4]}) [{'a': 1, 'b': 2}, {'a': 3, 'b': 4}] """ if not columns: return [] column_names = columns.keys() return [ dict(zip(column_names, values)) for values in zip(*columns.values()) ]
[docs]def roundrobin(*iterables): """ `roundrobin('ABC', 'D', 'EF') --> A D E B F C` Taken from: https://docs.python.org/3/library/itertools.html#recipes. """ # Recipe credited to George Sakkis num_active = len(iterables) nexts = itertools.cycle(iter(it).__next__ for it in iterables) while num_active: try: for next in nexts: yield next() except StopIteration: # Remove the iterator we just exhausted from the cycle. num_active -= 1 nexts = itertools.cycle(itertools.islice(nexts, num_active))
[docs]def partition( it: Iterable[T], fn: Callable[[T], bool] ) -> Tuple[List[T], List[T]]: """ Partition `it` into two lists given predicate `fn`. This is similar to the recipe defined in Python's `itertools` docs, however this method consumes the iterator directly and returns lists instead of iterators. """ left, right = [], [] for val in it: if fn(val): left.append(val) else: right.append(val) return left, right
[docs]def select(keys, source: dict, ignore_missing: bool = False) -> dict: """ Select subset of `source` dictionaries. >>> d = {"a": 1, "b": 2, "c": 3} >>> select(["a", "b"], d) {'a': 1, 'b': 2} """ result = {} for key in keys: try: result[key] = source[key] except KeyError: if not ignore_missing: raise return result
[docs]def trim_nans(xs, trim="fb"): """ Trim the leading and/or trailing `NaNs` from a 1-D array or sequence. Like ``np.trim_zeros`` but for `NaNs`. """ trim = trim.lower() start = None end = None if "f" in trim: for start, val in enumerate(xs): if not math.isnan(val): break if "b" in trim: for end in range(len(xs), -1, -1): if not math.isnan(xs[end - 1]): break return xs[start:end]
[docs]def inverse(dct: Dict[K, V]) -> Dict[V, K]: """ Inverse a dictionary; keys become values and values become keys. """ return {value: key for key, value in dct.items()}
_no_default = object()
[docs]@curry def pluck_attr(seq, name, default=_no_default): """ Get attribute ``name`` from elements in ``seq``. """ if default is _no_default: return [getattr(el, name) for el in seq] return [getattr(el, name, default) for el in seq]
[docs]def power_set(iterable): """ Generate all possible subsets of the given iterable, as tuples. >>> list(power_set(["a", "b"])) [(), ('a',), ('b',), ('a', 'b')] Adapted from https://docs.python.org/3/library/itertools.html#itertools-recipes """ return itertools.chain.from_iterable( itertools.combinations(iterable, r) for r in range(len(iterable) + 1) )
[docs]def join_items(left, right, how="outer", default=None): """ Iterate over joined dictionary items. Yields triples of `key`, `left_value`, `right_value`. Similar to SQL join statements the join behaviour is controlled by ``how``: * ``outer`` (default): use keys from left and right * ``inner``: only use keys which appear in both left and right * ``strict``: like ``inner``, but throws error if keys mismatch * ``left``: use only keys from ``left`` * ``right``: use only keys from ``right`` If a key is not present in either input, ``default`` is chosen instead. """ if how == "outer": keys = {**left, **right} elif how == "strict": assert left.keys() == right.keys() keys = left.keys() elif how == "inner": keys = left.keys() & right.keys() elif how == "left": keys = left.keys() elif how == "right": keys = right.keys() else: raise ValueError(f"Unknown how={how}.") for key in keys: yield key, left.get(key, default), right.get(key, default)
[docs]def replace(values: Sequence[T], idx: int, value: T) -> Sequence[T]: """ Replace value at index ``idx`` with ``value``. Like ``setitem``, but for tuples. >>> replace((1, 2, 3, 4), -1, 99) (1, 2, 3, 99) """ xs = list(values) xs[idx] = value return type(values)(xs) # type: ignore
[docs]def chop( at: int, take: int, ) -> slice: """ Create slice using an index ``at`` and amount ``take``. >>> x = [0, 1, 2, 3, 4] >>> x[chop(at=1, take=2)] [1, 2] >>> >>> x[chop(at=-2, take=2)] [3, 4] >>> x[chop(at=3, take=-2)] [1, 2] """ if at < 0 and take + at <= 0: return slice(at, None) if take < 0: return slice(at + take, at) return slice(at, at + take)