# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
import logging
from typing import Callable, Iterable, Optional
import numpy as np
from gluonts.dataset import DataBatch, Dataset
from gluonts.itertools import (
Cyclic,
IterableSlice,
PseudoShuffled,
batcher,
rows_to_columns,
)
from gluonts.pydantic import BaseModel
from gluonts.transform import (
AdhocTransform,
Identity,
SelectFields,
Transformation,
Valmap,
)
logger = logging.getLogger(__name__)
DataLoader = Iterable[DataBatch]
# TODO: the following are for backward compatibility
# and could eventually be removed
[docs]class Batch(Transformation, BaseModel):
batch_size: int
def __call__(self, data, is_train):
yield from batcher(data, self.batch_size)
[docs]class Stack(Transformation, BaseModel):
def __call__(self, data, is_train):
for batch in data:
yield rows_to_columns(batch, np.array)
[docs]def as_stacked_batches(
dataset: Dataset,
*,
batch_size: int,
output_type: Optional[Callable] = None,
num_batches_per_epoch: Optional[int] = None,
shuffle_buffer_length: Optional[int] = None,
field_names: Optional[list] = None,
):
"""
Prepare data in batches to be passed to a network.
Input data is collected into batches of size ``batch_size`` and then
columns are stacked on top of each other. In addition, the result is
wrapped in ``output_type`` if provided.
If ``num_batches_per_epoch`` is provided, only those number of batches are
effectively returned. This is especially useful for training when
providing a cyclic dataset.
To pseudo shuffle data, ``shuffle_buffer_length`` can be set to collect
inputs into a buffer first, from which we then randomly sample.
Setting ``field_names`` will only consider those columns in the input data
and discard all other values.
"""
if shuffle_buffer_length:
dataset = PseudoShuffled(dataset, shuffle_buffer_length)
transform: Transformation = Identity()
if field_names is not None:
transform += SelectFields(field_names)
transform += Batch(batch_size=batch_size)
transform += Stack()
if output_type is not None:
transform += Valmap(output_type)
# Note: is_train needs to be provided but does not have an effect
transformed_dataset = transform.apply(dataset, is_train=True)
return IterableSlice(transformed_dataset, num_batches_per_epoch)
[docs]def TrainDataLoader(
dataset: Dataset,
*,
transform: Transformation = Identity(),
batch_size: int,
stack_fn: Callable,
num_batches_per_epoch: Optional[int] = None,
shuffle_buffer_length: Optional[int] = None,
):
"""
Construct an iterator of batches for training purposes.
This function wraps around ``DataLoader`` to offer training-specific
behaviour and options, as follows:
1. The provided dataset is iterated cyclically, so that one can go over
it multiple times in a single epoch. 2. A transformation must be
provided, that is lazily applied as the dataset is being iterated;
this is useful e.g. to slice random instances of fixed length out of
each time series in the dataset. 3. The resulting batches can be
iterated in a pseudo-shuffled order.
The returned object is a stateful iterator, whose length is either
``num_batches_per_epoch`` (if not ``None``) or infinite (otherwise).
Parameters
----------
dataset
Data to iterate over.
transform
Transformation to be lazily applied as data is being iterated.
The transformation is applied in "training mode" (``is_train=True``).
batch_size
Number of entries to include in a batch.
stack_fn
Function to use to stack data entries into batches.
This can be used to set a specific array type or computing device
the arrays should end up onto (CPU, GPU).
num_batches_per_epoch
Length of the iterator. If ``None``, then the iterator is endless.
shuffle_buffer_length
Size of the buffer used for shuffling. Default: None, in which case no
shuffling occurs.
Returns
-------
Iterator[DataBatch]
An iterator of batches.
"""
dataset: Dataset = Cyclic(dataset)
if shuffle_buffer_length:
dataset = PseudoShuffled(dataset, shuffle_buffer_length)
transform += Batch(batch_size=batch_size) + AdhocTransform(stack_fn)
transformed_dataset = transform.apply(dataset, is_train=True)
batches = iter(transformed_dataset)
return IterableSlice(batches, num_batches_per_epoch)
[docs]def ValidationDataLoader(
dataset: Dataset,
*,
transform: Transformation = Identity(),
batch_size: int,
stack_fn: Callable,
):
"""
Construct an iterator of batches for validation purposes.
Parameters
----------
dataset
Data to iterate over.
transform
Transformation to be lazily applied as data is being iterated.
The transformation is applied in "training mode" (``is_train=True``).
batch_size
Number of entries to include in a batch.
stack_fn
Function to use to stack data entries into batches.
This can be used to set a specific array type or computing device
the arrays should end up onto (CPU, GPU).
Returns
-------
Iterable[DataBatch]
An iterable sequence of batches.
"""
transform += Batch(batch_size=batch_size) + AdhocTransform(stack_fn)
return transform.apply(dataset, is_train=True)
[docs]def InferenceDataLoader(
dataset: Dataset,
*,
transform: Transformation = Identity(),
batch_size: int,
stack_fn: Callable,
):
"""
Construct an iterator of batches for inference purposes.
Parameters
----------
dataset
Data to iterate over.
transform
Transformation to be lazily applied as data is being iterated.
The transformation is applied in "inference mode" (``is_train=False``).
batch_size
Number of entries to include in a batch.
stack_fn
Function to use to stack data entries into batches.
This can be used to set a specific array type or computing device
the arrays should end up onto (CPU, GPU).
Returns
-------
Iterable[DataBatch]
An iterable sequence of batches.
"""
transform += Batch(batch_size=batch_size) + AdhocTransform(stack_fn)
return transform.apply(dataset, is_train=False)