Source code for gluonts.zebras._split_frame

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from __future__ import annotations

import dataclasses
import itertools
from operator import itemgetter
from typing import Optional, List

import numpy as np
from toolz import first, keymap, valmap, dissoc, merge

from gluonts import maybe
from gluonts.itertools import pluck_attr, rows_to_columns

from ._base import Pad
from ._period import Periods, period
from ._repr import html_table
from ._util import _replace


[docs]@dataclasses.dataclass class SplitFrame: _past: dict _future: dict index: Optional[Periods] static: dict past_length: int future_length: int tdims: dict metadata: Optional[dict] = None default_tdim: int = -1 _pad: Pad = Pad() def __post_init__(self): for column in itertools.chain(self._past, self._future): self.tdims.setdefault(column, self.default_tdim) # this triggers checks for past_length and future_length _, _ = self.past, self.future @property def past(self): return TimeFrame( self._past, index=maybe.map( self.index, lambda index: index[: self.past_length] ), static=self.static, length=self.past_length, tdims=self.tdims, metadata=self.metadata, _pad=Pad(self._pad.left, 0), ) @property def future(self): return TimeFrame( self._future, index=maybe.map( self.index, lambda index: index[self.past_length :] ), static=self.static, length=self.future_length, tdims=self.tdims, metadata=self.metadata, _pad=Pad(0, self._pad.right), ) def __getitem__(self, name): if name in self._past: if name in self._future: return np.concatenate( self._past[name], self._future[name], axis=self.tdims[name] ) return self._past[name] return self._future[name] def __len__(self): return self.past_length + self.future_length
[docs] def set(self, name, value, tdim=None): tdim = maybe.unwrap_or(tdim, self.default_tdim) assert value.shape[tdim] == len(self) past, future = np.split( value, [self.past_length], axis=tdim, ) return _replace( past=merge(self.past, {name: past}), future=merge(self.future, {name: future}), tdims=merge(self.tdims, {name: tdim}), )
[docs] def set_like(self, ref: str, column, value, tdim=None): is_past = ref in self._past is_future = ref in self._future if is_past: if is_future: return self.set(column, value, tdim) else: return self.set_past(column, value, tdim) elif is_future: return self.set_future(column, value, tdim) raise KeyError(f"Ref {ref} is neither past nor future")
[docs] def set_past(self, name, value, tdim=None): tdim = maybe.unwrap_or(tdim, self.default_tdim) assert value.shape[tdim] == self.past_length assert self.tdims.get(name, tdim) == tdim return _replace( self, _past=merge(self._past, {name: value}), tdims=merge(self.tdims, {name: tdim}), )
[docs] def set_future(self, name, value, tdim=None): tdim = maybe.unwrap_or(tdim, self.default_tdim) assert value.shape[tdim] == self.future_length assert self.tdims.get(name, tdim) == tdim return _replace( self, _future=merge(self.future, {name: value}), tdims=merge(self.tdims, {name: tdim}), )
[docs] def remove(self, column): return _replace( self, columns=dissoc(self.columns, column), tdims=dissoc(self.tdims, column), )
def _repr_html_(self): past = self.past._table_columns() future = self.future._table_columns() length = max( len(first(past.values())) if past else 0, len(first(future.values())) if future else 0, ) def pad(col): to_pad = length - len(col) return list(col) + [""] * to_pad past = valmap(pad, past) future = valmap(pad, future) past = keymap(lambda key: f"past_{key}" if key else "past", past) future = keymap( lambda key: f"future_{key}" if key else "future", future ) return html_table({**past, "|": ["|"] * length, **future})
[docs] def as_dict(self): past = keymap(lambda key: f"past_{key}", self._past) future = keymap(lambda key: f"future_{key}", self._future) return {**past, **future, **self.static}
[docs] def resize( self, past_length: Optional[int] = None, future_length: Optional[int] = None, pad_value=0.0, ) -> SplitFrame: index = self.index past_length = maybe.unwrap_or(past_length, self.past_length) future_length = maybe.unwrap_or(future_length, self.future_length) if index is not None: # Calculate new start. If current past_length is larger than the # the new one, we shift it to the right, if it's smaller, we need # to go further into the past (shift to the left) start = index[0] + (self.past_length - past_length) index = start.periods(past_length + future_length) return _replace( self, _past=self.past.resize( past_length, pad_value, pad="l", skip="l" ).columns, past_length=maybe.unwrap_or(past_length, self.past_length), _future=self.future.resize( future_length, pad_value, pad="r", skip="r" ).columns, future_length=maybe.unwrap_or(future_length, self.future_length), index=index, )
[docs] def with_index(self, index): return _replace(self, index=index)
@staticmethod def _batch(split_frames: List[SplitFrame]) -> BatchSplitFrame: ref = split_frames[0] pluck = pluck_attr(split_frames) return BatchSplitFrame( _past=rows_to_columns(pluck("_past"), np.stack), # type: ignore _future=rows_to_columns(pluck("_future"), np.stack), # type: ignore index=pluck("index"), static=rows_to_columns(pluck("static"), np.stack), # type: ignore past_length=ref.past_length, future_length=ref.future_length, tdims=ref.tdims, metadata=pluck("metadata"), _pad=pluck("_pad"), )
[docs]@dataclasses.dataclass class BatchSplitFrame: _past: dict _future: dict index: List[Optional[Periods]] static: dict past_length: int future_length: int tdims: dict metadata: List[Optional[dict]] _pad: List[Pad] @property def batch_size(self): return len(self.index) def __len__(self): return self.past_length + self.future_length @property def past(self): return BatchTimeFrame( columns=self._past, static=self.static, length=self.past_length, index=[ maybe.map(index, itemgetter(slice(None, self.past_length))) for index in self.index ], tdims=self.tdims, metadata=self.metadata, _pad=self._pad, ) @property def future(self): return BatchTimeFrame( columns=self._future, static=self.static, length=self.future_length, index=[ maybe.map(index, itemgetter(slice(self.past_length, None))) for index in self.index ], tdims=self.tdims, metadata=self.metadata, _pad=self._pad, )
[docs] def items(self): return BatchSplitFrameItems(self)
[docs] def as_dict(self): return SplitFrame.as_dict(self)
@dataclasses.dataclass(repr=False) class BatchSplitFrameItems: data: BatchSplitFrame def __len__(self): return self.data.batch_size def __getitem__(self, idx): tdims = self.data.tdims if isinstance(idx, int): cls = SplitFrame tdims = valmap(lambda tdim: tdim - 1 if tdim >= 0 else tdim, tdims) else: cls = BatchSplitFrame return cls( _past=valmap(itemgetter(idx), self.data._past), _future=valmap(itemgetter(idx), self.data._future), index=self.data.index[idx], static=valmap(itemgetter(idx), self.data.static), tdims=tdims, metadata=self.data.metadata[idx], _pad=self.data._pad[idx], past_length=self.data.past_length, future_length=self.data.future_length, )
[docs]def split_frame( full=None, *, past=None, future=None, past_length=None, future_length=None, static=None, index=None, start=None, freq=None, tdims=None, metadata=None, default_tdim=-1, ): """ Create a ``zebras.SplitFrame`` where columns can either be `past`, `future` or `full`, which spans both past and future. ``past_length`` and ``future_length`` is derived from the input data if possible or default to ``0`` in case no respective data is available. It is possible to set these values explicitly for enforcing consistency or to provide a length even though no time series spans that range. Parameters ---------- full, optional Time series columns that span past and future. past, optional Time series columns that are past only. future, optional Time series columns that are future only. past_length, optional Set length of the past section, derived from data if not provided. future_length, optional Set length of the future section, derived from data if not provided. static, optional Values that are independent of time. index, optional A ``zebras.Periods`` object representing timestamps. Must have the same length as full range. start, optional The start time represented by a string (e.g., "2023-01-01"), or a ``zebras.Period`` object. An index will be constructed using this start time and the specified frequency freq, optional The frequency to use for constructing the index. tdims, optional A dictionary specifying the time dimension for each column, this applies to past, future and full. default_tdim, optional The default time dimension, by default -1 metadata, optional A dictionary of metadata associated with the TimeFrame, by default None Returns ------- A ``zebras.SplitFrame`` object. """ full = valmap(np.array, maybe.unwrap_or_else(full, dict)) past = valmap(np.array, maybe.unwrap_or_else(past, dict)) future = valmap(np.array, maybe.unwrap_or_else(future, dict)) static = valmap(np.array, maybe.unwrap_or_else(static, dict)) tdims = maybe.unwrap_or_else(tdims, dict) # Resolve `past_length` and `future_length` if not directly set from # provided data. If no data is passed for either field, the value is still # `None` after this. if past_length is None and past: column = first(past) past_length = past[column].shape[tdims.get(column, default_tdim)] if future_length is None and future: column = first(future) future_length = future[column].shape[tdims.get(column, default_tdim)] if full: column = first(full) full_length = full[column].shape[tdims.get(column, default_tdim)] if maybe.or_(past_length, future_length) is None: raise ValueError( "Cannot determine past and future length if only " "`full` is provided." ) # No data and no lengths are passed elif maybe.or_(past_length, future_length) is None: past_length = 0 future_length = 0 full_length = 0 else: # If past_length and/or future_length is provided, but no `full` data # is given, then we first resolve past and future length and then just # calculate full_length later. full_length = None if past_length is None: past_length = maybe.map_or( full_length, lambda fl: fl - future_length, 0, ) elif future_length is None: future_length = maybe.map_or( full_length, lambda fl: fl - past_length, 0, ) full_length = past_length + future_length # create copies to not mutate user provided dicts past = dict(past) future = dict(future) for name, data in full.items(): tdim = tdims.get(name, default_tdim) assert data.shape[tdim] == full_length past_data, future_data = np.split(data, [past_length], axis=tdim) past[name] = past_data future[name] = future_data sf = SplitFrame( _past=past, _future=future, index=index, static=static, tdims=tdims, past_length=past_length, future_length=future_length, default_tdim=default_tdim, metadata=metadata, ) if sf.index is None and start is not None: if freq is not None: start = period(start, freq) return sf.with_index(start.periods(len(sf))) return sf
# We defer these imports to avoid circular imports. from ._time_series import TimeSeries # noqa from ._time_frame import BatchTimeFrame, TimeFrame # noqa