Source code for gluonts.zebras._time_frame

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from __future__ import annotations

import dataclasses
from operator import itemgetter
from typing import cast, Optional, List, Union, Collection, Any, Mapping

import numpy as np
from toolz import first, valmap, dissoc, merge, itemmap, take

from gluonts import maybe
from gluonts.itertools import (
    pluck_attr,
    columns_to_rows,
    rows_to_columns,
    select,
    join_items,
)

from ._base import Pad, TimeBase
from ._freq import Freq
from ._period import Periods, Period, period
from ._repr import html_table
from ._util import AxisView, pad_axis, _replace


[docs]@dataclasses.dataclass class TimeFrame(TimeBase): columns: dict index: Optional[Periods] static: dict length: int tdims: dict metadata: Optional[dict] = None default_tdim: int = -1 _pad: Pad = Pad() def __post_init__(self): for column in self.columns: self.tdims.setdefault(column, self.default_tdim) for column in self.columns: assert len(self._time_view(column)) == self.length, ( f"Column {column!r} has incorrect length in time dimension. " f"Expected: {len(self)}, got {len(self._time_view(column))}." ) assert maybe.map_or(self.index, len, self.length) == self.length, ( f"Index has incorrect length. " f"Expected: {len(self)}, got {len(self.index)}." )
[docs] def eq_shape(self, other: TimeFrame) -> bool: if ( len(self) != len(other) or self.index != other.index or self.tdims != other.tdims or self.columns.keys() != other.columns.keys() or self.static.keys() != other.static.keys() ): return False for _, left, right in join_items(self.columns, other.columns, "left"): if left.shape != right.shape: return False for _, left, right in join_items(self.static, other.static, "left"): if left.shape != right.shape: return False return True
[docs] def eq_to(self, other: TimeFrame) -> bool: # not considered: Pad, metadata, default_tdim if not self.eq_shape(other): return False for _, left, right in join_items(self.columns, other.columns, "left"): if not np.array_equal(left, right): return False for _, left, right in join_items(self.static, other.static, "left"): if not np.array_equal(left, right): return False return True
def _time_view(self, column): """ View of column with respect to time. """ return AxisView(self.columns[column], self.tdims[column]) def _slice_tdim(self, idx): start, stop, step = idx.indices(len(self)) assert step == 1 pad_left = max(0, self._pad.left - start) pad_right = max(0, self._pad.right + stop + 1 - len(self) - pad_left) return _replace( self, columns={ column: self._time_view(column)[idx] for column in self.columns }, index=maybe.map(self.index, itemgetter(idx)), length=stop - start, _pad=Pad(pad_left, pad_right), ) def __getitem__(self, idx: Union[slice, int, str]): if isinstance(idx, (slice, int)): return TimeBase.__getitem__(self, idx) assert isinstance(idx, str) return TimeSeries( self.columns[idx], index=self.index, tdim=self.tdims[idx], metadata=self.metadata, name=idx, _pad=self._pad, )
[docs] def pad(self, value, left: int = 0, right: int = 0) -> TimeFrame: # Return `self` if no padding is needed. if left == 0 and right == 0: return self assert left >= 0 and right >= 0 columns = { column: pad_axis( self.columns[column], axis=self.tdims[column], left=left, right=right, value=value, ) for column in self.columns } length = self.length + left + right pad_left = left + self._pad.left pad_right = right + self._pad.right index = self.index if self.index is not None: index = self.index.prepend(left).extend(right) return _replace( self, columns=columns, index=index, length=length, _pad=Pad(pad_left, pad_right), )
[docs] def astype(self, type, columns=None) -> TimeFrame: if columns is None: columns = self.columns return _replace( self, columns=valmap( lambda col: col.astype(type), select(columns, self.columns) ), )
def __repr__(self) -> str: columns = ", ".join(self.columns) return f"TimeFrame<size={len(self)}, columns=[{columns}]>" def _table_columns(self): columns = {} if self.index is not None: index = pluck_attr(self.index, "data") if len(self) > 10: index = [ *index[:5], f"[ ... {len(self) - 10} ... ]", *index[-5:], ] columns[""] = index def move_axis(data, name): return np.moveaxis(data, self.tdims[name], 0) if len(self) > 10: head = self.head(5) tail = self.tail(5) columns.update( { col: [ *(move_axis(head[col], col)), f"[ ... {len(self) - 10} ... ]", *(move_axis(tail[col], col)), ] for col in self.columns } ) else: columns.update( { name: move_axis(values, name) for name, values in self.columns.items() } ) return columns def _repr_html_(self): columns = self._table_columns() html = [ html_table(columns), f"{len(self)} rows × {len(self.columns)} columns", ] if self.static: html.extend( [ "<h3>Static Data</h3>", html_table( {name: [val] for name, val in self.static.items()} ), ] ) return "\n".join(html)
[docs] @classmethod def from_pandas(cls, df): """ Turn ``pandas.DataFrame`` into ``TimeFrame``. """ import pandas as pd try: index = Periods.from_pandas(df.index) except Exception: index = None return cls( columns=valmap(pd.Series.to_numpy, dict(df.items())), index=index, static=None, length=len(df), tdims={name: -1 for name in df.columns}, )
[docs] def set(self, name, value, tdim=None): assert name not in self.static tdim = maybe.unwrap_or(tdim, self.default_tdim) return _replace( self, columns=merge(self.columns, {name: value}), tdims=merge(self.tdims, {name: tdim}), )
[docs] def set_static(self, name, value): assert name not in self.columns return _replace(self, static=merge(self.static, {name: value}))
[docs] def set_like(self, ref: str, column, value, tdim=None): assert ref in self.columns return self.set(column, value, tdim)
[docs] def remove(self, column): return _replace( self, columns=dissoc(self.columns, column), tdims=dissoc(self.tdims, column), )
[docs] def remove_static(self, name): return _replace(self, static=dissoc(self.static, name))
[docs] def like(self, columns=None, static=None): columns = maybe.unwrap_or(columns, {}) static = maybe.unwrap_or(static, {}) return _replace(self, columns=columns, static=static)
[docs] def rename(self, mapping=None, **kwargs): """ Rename ``columns`` of ``TimeFrame``. The keys in ``mapping`` denote the target column names, i.e. ``rename({"target": "source"})``. For convenience one can use keyword parameters (`.rename(target="source")). """ if mapping is None: mapping = {} mapping.update(kwargs) columns = dissoc(self.columns, *mapping.values()) tdims = dissoc(self.tdims, *mapping.values()) for target, source in mapping.items(): columns[target] = self.columns[source] tdims[target] = self.tdims[source] return _replace(self, columns=columns, tdims=tdims)
[docs] def rename_static(self, mapping=None, **kwargs): """ Rename ``static`` fields of ``TimeFrame``. The keys in ``mapping`` denote the target column names, i.e. ``rename({"target": "source"})``. For convenience one can use keyword parameters (`.rename(target="source")). """ if mapping is None: mapping = {} mapping.update(kwargs) static = dissoc(self.static, *mapping.values()) for target, source in mapping.items(): static[target] = self.static[source] return _replace(self, static=static)
[docs] def stack( self, select: List[str], into: str, drop: bool = True, ) -> TimeFrame: # Ensure all tdims are the same. # TODO: Can we make that work for different tdims? There might be a # problem with what the resulting dimensions are. all_tdims = set(self.tdims[column] for column in select) assert len(all_tdims) == 1 tdim = first(all_tdims) if drop: columns = dissoc(self.columns, *select) tdims = dissoc(self.tdims, *select) else: columns = dict(self.columns) tdims = dict(self.tdims) columns[into] = np.vstack([self.columns[column] for column in select]) tdims[into] = tdim return _replace(self, columns=columns, tdims=tdims)
[docs] def as_dict(self, prefix=None, static=True): result = dict(self.columns) if prefix is not None: result = {prefix + key: value for key, value in result.items()} if static: result.update(self.static) return result
[docs] def rolsplit( self, index, *, distance: int = 1, past_length: Optional[int] = None, future_length: Optional[int] = None, n: Optional[int] = None, pad_value=0.0, ): """ Create rolling split of past/future pairs. Parameters ---------- index Starting index that denominates the cut off point from which splits are generated. distance The distance by which pairs are shifted. Defaults to ``1``. To avoid overlapping examples, ``distance`` has to be set to be at least ``past_length``. future_length, optional Optionally enforce future length. Note that ``rolsplit`` will never pad values in the future range. past_length, optional If provided, all pairs past will have ``past_length``, padded with ``pad_value`` if needed. n, optional If provided, limits the number of pairs to ``n``. pad_value Value to pad past if needed, defaults to ``0.0``. Returns ------- A stream of ``zebras.SplitFrame`` objects. """ if not isinstance(index, (int, np.integer)): # If `index` is provided as timestamp we turn it into an integer. index = self.index_of(index) elif index < 0: # Ensure index is >= 0; (turn negative values into positive ones) index = len(self) + index for split_index in take( n, range(index, len(self) + 1 - distance, distance), ): yield self.split( split_index, past_length, future_length, pad_value )
[docs] def split( self, index, past_length=None, future_length=None, pad_value=0.0, ): if not isinstance(index, (int, np.integer)): # If `index` is provided as timestamp we turn it into an integer. index = self.index_of(index) elif index < 0: # Ensure index is >= 0; (turn negative values into positive ones) index = len(self) + index if not 0 <= index <= len(self): raise ValueError( "Split index out of bounds. Use `.resize(...)` or `.pad(...)` " "to ensure `TimeFrame` is long enough." ) # If past_length is not provided, it will equal to `index`, since # `len(tf.split(5).past) == 5` past_length: int = maybe.unwrap_or(past_length, index) # Same logic applies to future_length, except that we deduct from the # right. (We can't use past_length, since it can be unequal to index). future_length: int = maybe.unwrap_or(future_length, len(self) - index) if self.index is None: new_index = None else: start = self.index.start + (index - past_length) new_index = start.periods(past_length + future_length) pad_left = max(0, past_length - index) pad_right = max(0, future_length - (len(self) - index)) self = self.pad(pad_value, pad_left, pad_right) # We need to shift the split index to the right, if we padded values # on the left. index += pad_left def split_item(item): name, data = item tdim = self.tdims[name] past, future = np.split(data, [index], tdim) past = AxisView(past, tdim)[-past_length:] future = AxisView(future, tdim)[:future_length] return name, (past, future) past, future = columns_to_rows(itemmap(split_item, self.columns)) return SplitFrame( _past=past, _future=future, index=new_index, static=self.static, past_length=past_length, future_length=future_length, tdims=self.tdims, metadata=self.metadata, _pad=self._pad, )
[docs] def apply(self, fn, columns=None): if columns is None: columns = self.columns.keys() return _replace(self, columns=valmap(fn, self.columns))
def __len__(self) -> int: return self.length @staticmethod def _batch(xs: List[TimeFrame]) -> BatchTimeFrame: # TODO: Check ref = xs[0] pluck = pluck_attr(xs) tdims = valmap( lambda tdim: tdim + 1 if tdim >= 0 else tdim, ref.tdims, ) return BatchTimeFrame( columns=rows_to_columns(pluck("columns"), np.stack), # type: ignore index=pluck("index"), static=rows_to_columns(pluck("static"), np.stack), # type: ignore length=ref.length, tdims=tdims, metadata=pluck("metadata"), _pad=pluck("_pad"), )
[docs]@dataclasses.dataclass class BatchTimeFrame: columns: dict index: Collection[Optional[Periods]] static: dict length: int tdims: dict metadata: Collection[Optional[dict]] _pad: Collection[Pad] @property def batch_size(self): return len(self.index) def __len__(self): return self.length
[docs] def like(self, columns=None, static=None, tdims=None): columns = maybe.unwrap_or(columns, {}) static = maybe.unwrap_or(static, {}) tdims = maybe.unwrap_or(tdims, {}) for name in columns: tdims.setdefault(name, -1) return _replace( self, columns=columns, index=self.index, static=static, tdims=tdims )
[docs] def items(self): return BatchTimeFrameItems(self)
def __getitem__(self, name): return BatchTimeSeries( self.columns[name], index=self.index, tdim=self.tdims[name], metadata=self.metadata, name=[name] * self.batch_size, _pad=self._pad, )
[docs] def as_dict(self, prefix=None, static=True): return TimeFrame.as_dict(self, prefix, static)
@dataclasses.dataclass(repr=False) class BatchTimeFrameItems: data: BatchTimeFrame def __len__(self): return self.data.batch_size def __getitem__(self, idx): tdims = self.data.tdims if isinstance(idx, int): cls = TimeFrame tdims = valmap(lambda tdim: tdim - 1 if tdim >= 0 else tdim, tdims) else: cls = BatchTimeFrame return cls( columns=valmap(itemgetter(idx), self.data.columns), index=self.data.index[idx], static=valmap(itemgetter(idx), self.data.static), tdims=tdims, metadata=self.data.metadata[idx], _pad=self.data._pad[idx], length=self.data.length, )
[docs]def time_frame( columns: Optional[Mapping[str, Collection]] = None, *, index: Optional[Periods] = None, start: Optional[Union[Period, str]] = None, freq: Optional[Union[str, Freq]] = None, static: Optional[Mapping[str, Any]] = None, tdims: Optional[Mapping[str, int]] = None, length: Optional[int] = None, default_tdim: int = -1, metadata: Optional[Mapping] = None, ): """ Create a ``zebras.TimeFrame`` object that represents one or more time series. Parameters ---------- columns, optional A dictionary where keys are strings representing column names and values are sequences (e.g., list, numpy arrays). All columns must have the same length in the time dimension, by default None index, optional A ``zebras.Periods`` object representing timestamps. Must have the same length as the columns, by default None start, optional The start time represented by a string (e.g., "2023-01-01"), or a ``zebras.Period`` object. An index will be constructed using this start time and the specified frequency, by default None freq, optional The frequency of the period, e.g, "H" for hourly, by default None static, optional A dictionary of static-in-time features, by default None tdims, optional A dictionary specifying the time dimension for each column. The keys should match those in `columns`. If unspecified for a column, the `default_tdim` is used, by default None length, optional The length (in time) of the TimeFrame, by default None default_tdim, optional The default time dimension, by default -1 metadata, optional A dictionary of metadata associated with the TimeFrame, by default None Returns ------- A ``zebras.TimeFrame`` object. """ assert ( index is None or start is None ), "Both index and start cannot be specified." columns = maybe.unwrap_or_else(columns, dict) tdims = maybe.unwrap_or_else(tdims, dict) static = maybe.unwrap_or_else(static, dict) columns = valmap(np.array, columns) static = valmap(np.array, static) if length is None: if index is not None: length = len(index) elif columns: column = first(columns) length = columns[column].shape[tdims.get(column, default_tdim)] else: length = 0 tf = TimeFrame( columns=columns, index=index, static=static, tdims=cast(dict, tdims), length=length, default_tdim=default_tdim, metadata=cast(Optional[dict], metadata), ) if tf.index is None and start is not None: if freq is not None: start = period(start, freq) else: assert isinstance(start, Period) return tf.with_index(start.periods(len(tf))) return tf
# We defer these imports to avoid circular imports. from ._time_series import BatchTimeSeries, TimeSeries # noqa from ._split_frame import SplitFrame # noqa