# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from __future__ import annotations
import dataclasses
from operator import itemgetter
from typing import Collection, Union, Optional, Dict, List, Tuple
import numpy as np
from toolz import first
from gluonts import maybe
from gluonts.itertools import pluck_attr
from ._base import Pad, TimeBase
from ._period import period, Period, Periods
from ._util import AxisView, pad_axis, _replace
[docs]@dataclasses.dataclass(eq=False)
class TimeSeries(TimeBase):
values: np.ndarray
index: Optional[Periods] = None
name: Optional[str] = None
tdim: int = -1
metadata: Optional[dict] = None
_pad: Pad = Pad()
def __post_init__(self):
assert maybe.map_or(self.index, len, len(self)) == len(self), (
"Index has incorrect length. "
f"Expected: {len(self)}, got {len(self.index)}."
)
def __eq__(self, other):
return self.values == other
@property
def shape(self) -> Tuple[int, ...]:
return self.values.shape
def __array__(self):
return self.values
[docs] def to_numpy(self):
return self.values
def __len__(self):
return self.values.shape[self.tdim]
def _slice_tdim(self, idx):
if isinstance(idx, int):
return AxisView(self.values, self.tdim)[idx]
start, stop, step = idx.indices(len(self))
assert step == 1
return _replace(
self,
values=AxisView(self.values, self.tdim)[idx],
index=maybe.map(self.index, itemgetter(idx)),
_pad=self._pad.extend(-start, stop - len(self)),
)
[docs] def pad(self, value, left: int = 0, right: int = 0) -> TimeSeries:
assert left >= 0 and right >= 0
values = pad_axis(
self.values,
axis=self.tdim,
left=left,
right=right,
value=value,
)
index = self.index
if self.index is not None:
index = self.index.prepend(left).extend(right)
return _replace(
self,
values=values,
index=index,
_pad=self._pad.extend(left, right),
)
@staticmethod
def _batch(xs: List[TimeSeries]) -> BatchTimeSeries:
for series in xs:
assert type(series) == TimeSeries
pluck = pluck_attr(xs)
tdims = set(pluck("tdim"))
assert len(tdims) == 1
tdim = first(tdims)
if tdim >= 0:
# We insert a new axis at the front, so if tdim is counting from
# the left (tdim is positive) we need to shift by one to the right.
tdim += 1
values = np.stack(pluck("values"))
return BatchTimeSeries(
values=values,
tdim=tdim,
index=pluck("index"),
name=pluck("name"),
metadata=pluck("metadata"),
_pad=pluck("_pad"),
)
[docs] def plot(self):
import matplotlib.pyplot as plt
if self.index is None:
plt.plot(self.values)
else:
plt.plot(self.index, self.values)
[docs]@dataclasses.dataclass
class BatchTimeSeries(TimeBase):
values: np.ndarray
index: List[Optional[Periods]]
name: List[Optional[str]]
tdim: int
metadata: List[Optional[dict]]
_pad: List[Pad]
def _slice_tdim(self, idx):
if isinstance(idx, int):
return AxisView(self.values, self.tdim)[idx]
start, stop, step = idx.indices(len(self))
assert step == 1
def calc_pad(pad):
pad_left = max(0, pad.left - start)
pad_right = max(0, pad.right + stop + 1 - len(self) - pad_left)
return Pad(pad_left, pad_right)
return _replace(
self,
values=AxisView(self.values, self.tdim)[idx],
index=[maybe.map(index, itemgetter(idx)) for index in self.index],
_pad=list(map(calc_pad, self._pad)),
)
@property
def batch_size(self):
return len(self.values)
def __len__(self):
return self.values.shape[self.tdim]
@property
def shape(self) -> Tuple[int, ...]:
return self.values.shape
def __array__(self):
return self.values
[docs] def items(self):
return TimeSeriesItems(self)
[docs] def pad(self, value, left: int = 0, right: int = 0) -> TimeSeries:
assert left >= 0 and right >= 0
values = pad_axis(
self.values,
axis=self.tdim,
left=left,
right=right,
value=value,
)
def extend_index(index):
return index.prepend(left).extend(right)
return _replace(
self,
values=values,
index=[maybe.map(index_, extend_index) for index_ in self.index],
_pad=[pad.extend(left, right) for pad in self._pad],
)
[docs] def like(self, values: np.ndarray, name: Optional[str] = None):
return _replace(self, values=values, name=name)
@dataclasses.dataclass(repr=False)
class TimeSeriesItems:
data: BatchTimeSeries
def __len__(self):
return self.data.batch_size
def __getitem__(self, idx):
tdim = self.data.tdim
if isinstance(idx, int):
cls = TimeSeries
if tdim > 0:
tdim -= 1
else:
cls = BatchTimeSeries
return cls(
values=self.data.values[idx],
index=self.data.index[idx],
name=self.data.name[idx],
metadata=self.data.metadata[idx],
_pad=self.data._pad[idx],
tdim=tdim,
)
[docs]def time_series(
values: Collection,
*,
index: Optional[Periods] = None,
start: Optional[Union[Period, str]] = None,
freq: Optional[str] = None,
tdim: int = -1,
name: Optional[str] = None,
metadata: Optional[Dict] = None,
):
"""
Create a ``zebras.TimeSeries`` object that represents a time series.
Parameters
----------
values
A sequence (e.g., list, numpy arrays) representing the values
of the time series.
index, optional
A ``zebras.Periods`` object representing timestamps.
Must have the same length as the `values`, by default None
start, optional
The start time represented by a string (e.g., "2023-01-01"),
or a ``zebras.Period`` object. An index will be constructed using
this start time and the specificed frequency, by default None
freq, optional
The frequency of the period, e.g, "H" for hourly, by default None
tdim, optional
The time dimension in `values`, by default -1
name: optional
A description for the time series. This will be the column names when
returned from a ``TimeFrame``.
metadata, optional
A dictionary of metadata associated with the time series, by default None
Returns
-------
A ``zebras.TimeSeries`` object.
"""
values = np.array(values)
ts = TimeSeries(
values,
index=index,
tdim=tdim,
name=name,
metadata=metadata,
)
if ts.index is None and start is not None:
if freq is not None:
start = period(start, freq)
else:
assert isinstance(start, Period)
return ts.with_index(start.periods(len(ts)))
return ts