Source code for gluonts.dataset.arrow.enc

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from dataclasses import dataclass, field
from functools import partial, singledispatch
from itertools import chain
from pathlib import Path
from typing import List, Set, Optional, Union
from typing_extensions import Literal

from toolz.curried import keyfilter, valmap

import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq

from gluonts.dataset import Dataset, DatasetWriter
from gluonts.itertools import batcher, rows_to_columns


[docs]@dataclass class ArrowEncoder: columns: List[str] ndarray_columns: Set[str] = field(default_factory=set) flatten_arrays: bool = True
[docs] @classmethod def infer(cls, sample: dict, flatten_arrays=True): columns = [] ndarray_columns = set() for name, value in sample.items(): if isinstance(value, np.ndarray): if value.ndim > 1: ndarray_columns.add(name) columns.append(name) return cls( columns=columns, ndarray_columns=ndarray_columns, flatten_arrays=flatten_arrays, )
[docs] def encode(self, entry: dict): result = {} for column in self.columns: value = entry[column] # We need to handle arrays with more than 1 dimension specially. # If we don't, pyarrow complains. As an optimisation, we flatten # the array to 1d and store its shape to gain zero-copy reads # during decoding. if column in self.ndarray_columns: if self.flatten_arrays: result[f"{column}._np_shape"] = list(value.shape) value = value.flatten() else: value = list(value) result[column] = value return result
[docs]def into_arrow_batches(dataset, batch_size=1024, flatten_arrays=True): stream = iter(dataset) # peek 1 first = next(stream) # re-assemble stream = chain([first], stream) encoder = ArrowEncoder.infer(first, flatten_arrays=flatten_arrays) encoded = map(encoder.encode, stream) row_batches = batcher(encoded, batch_size) column_batches = map(rows_to_columns, row_batches) for batch in column_batches: yield pa.record_batch(list(batch.values()), names=list(batch.keys()))
@singledispatch def _encode_py_to_arrow(val): return val @_encode_py_to_arrow.register def _encode_py_pd_period(val: pd.Period): return val.to_timestamp()
[docs]def write_dataset( Writer, dataset, path, metadata=None, batch_size=1024, flatten_arrays=True ): dataset = map(keyfilter(lambda key: key != "source"), dataset) dataset = map(valmap(_encode_py_to_arrow), dataset) batches = into_arrow_batches( dataset, batch_size, flatten_arrays=flatten_arrays ) first = next(batches) schema = first.schema if metadata is not None: schema = schema.with_metadata(metadata) with open(path, "wb") as fobj: writer = Writer(fobj, schema=schema) for batch in chain([first], batches): writer.write_batch(batch) writer.close()
Compression = Union[Literal["lz4"], Literal["zstd"]]
[docs]@dataclass class ArrowWriter(DatasetWriter): stream: bool = False suffix: str = ".feather" compression: Optional[Compression] = None flatten_arrays: bool = True metadata: Optional[dict] = None
[docs] def write_to_file(self, dataset: Dataset, path: Path) -> None: options = pa.ipc.IpcWriteOptions(compression=self.compression) if self.stream: writer = partial(pa.RecordBatchStreamWriter, options=options) else: writer = partial(pa.RecordBatchFileWriter, options=options) write_dataset( writer, dataset, path, self.metadata, flatten_arrays=self.flatten_arrays, )
[docs] def write_to_folder( self, dataset: Dataset, folder: Path, name: Optional[str] = None ) -> None: if name is None: name = "data" self.write_to_file(dataset, (folder / name).with_suffix(self.suffix))
[docs]@dataclass class ParquetWriter(DatasetWriter): suffix: str = ".parquet" flatten_arrays: bool = True metadata: Optional[dict] = None
[docs] def write_to_file(self, dataset: Dataset, path: Path) -> None: write_dataset( pq.ParquetWriter, dataset, path, self.metadata, flatten_arrays=self.flatten_arrays, )
[docs] def write_to_folder( self, dataset: Dataset, folder: Path, name: Optional[str] = None ) -> None: if name is None: name = "data" self.write_to_file(dataset, (folder / name).with_suffix(self.suffix))