# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
import abc
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, Optional, Union, List
import numpy as np
from toolz import take, drop, first
import pyarrow as pa
import pyarrow.parquet as pq
from .dec import ArrowDecoder
[docs]class File:
SUFFIXES = {".parquet", ".arrow", ".feather"}
[docs] @staticmethod
def infer(
path: Path,
) -> Union["ArrowFile", "ArrowStreamFile", "ParquetFile"]:
"""
Return `ArrowFile`, `ArrowStreamFile` or `ParquetFile` by inspecting
provided path.
Arrow's `random-access` format starts with `ARROW1`, so we peek the
provided file for it.
"""
with open(path, "rb") as in_file:
peek = in_file.read(6)
if peek == b"ARROW1":
return ArrowFile(path)
elif peek.startswith(b"PAR1"):
return ParquetFile(path)
else:
return ArrowStreamFile(path)
@abc.abstractmethod
def __iter__(self): ...
@abc.abstractmethod
def __len__(self): ...
[docs]@dataclass
class ArrowFile(File):
path: Path
reader: pa.RecordBatchFileReader = field(init=False)
decoder: ArrowDecoder = field(init=False)
_batch_offsets: Optional[np.ndarray] = field(
default=None, init=False, repr=False
)
_start: int = 0
_take: Optional[int] = None
@property
def batch_offsets(self):
if self._batch_offsets is None:
self._batch_offsets = np.cumsum(
list(map(len, self.iter_batches()))
)
return self._batch_offsets
def __post_init__(self):
self.reader = pa.RecordBatchFileReader(self.path)
self.decoder = ArrowDecoder.from_schema(self.schema)
[docs] def location_for(self, idx):
if idx == 0:
return 0, 0
batch_no = np.searchsorted(self.batch_offsets, idx)
if batch_no == 0:
batch_idx = idx
else:
batch_idx = idx - self.batch_offsets[batch_no - 1]
return batch_no, batch_idx
@property
def schema(self):
return self.reader.schema
[docs] def iter_batches(self):
for batch_no in range(self.reader.num_record_batches):
yield self.reader.get_batch(batch_no)
def __len__(self):
if self._take is not None:
return self._take
if len(self.batch_offsets) > 0:
return self.batch_offsets[-1] - self._start
# empty file
return 0
def __iter__(self):
def iter_values():
# yield from starting batch
batch_no, batch_idx = self.location_for(self._start)
sub_batch = self.reader.get_batch(batch_no)[batch_idx:]
yield from self.decoder.decode_batch(sub_batch)
for batch_no_ in range(
batch_no + 1, self.reader.num_record_batches
):
yield from self.decoder.decode_batch(
self.reader.get_batch(batch_no_)
)
yield from take(self._take, iter_values())
def __getitem__(self, idx):
if isinstance(idx, slice):
assert idx.step is None or idx.step == 1
# normalize index
start, stop, _step = idx.indices(len(self))
idx = slice(start + self._start, stop + self._start)
return ArrowFile(
self.path,
_start=idx.start,
_take=max(0, idx.stop - idx.start),
)
if self._start is not None:
idx += self._start
batch_no, batch_idx = self.location_for(idx)
return self.decoder.decode(self.reader.get_batch(batch_no), batch_idx)
[docs]@dataclass
class ArrowStreamFile(File):
path: Path
_decoder: Optional[ArrowDecoder] = field(default=None, init=False)
_start: int = 0
_take: Optional[int] = None
def __iter__(self):
def iter_values():
with open(self.path, "rb") as infile:
reader = pa.RecordBatchStreamReader(infile)
if self._decoder is None:
self._decoder = ArrowDecoder.from_schema(reader.schema)
while True:
try:
batch = reader.read_next_batch()
except StopIteration:
return
yield from self._decoder.decode_batch(batch)
yield from take(self._take, drop(self._start, iter_values()))
def __len__(self):
return sum(1 for _ in self)
def __getitem__(self, idx):
if isinstance(idx, slice):
assert idx.step is None or idx.step == 1
# normalize index
start, stop, _step = idx.indices(len(self))
idx = slice(start + self._start, stop + self._start)
return ArrowStreamFile(
self.path,
_start=idx.start,
_take=max(0, idx.stop - idx.start),
)
return first(self[idx:])
[docs]@dataclass
class ParquetFile(File):
path: Path
reader: pq.ParquetFile = field(init=False)
_start: int = 0
_take: Optional[int] = None
# Note: accumulated
_row_group_sizes: List[int] = field(default_factory=list)
def __post_init__(self):
self.reader = pq.ParquetFile(self.path)
self.decoder = ArrowDecoder.from_schema(self.reader.schema_arrow)
if not self._row_group_sizes:
self._row_group_sizes = np.cumsum(
[
self.reader.metadata.row_group(row_group).num_rows
for row_group in range(self.reader.metadata.num_row_groups)
]
)
[docs] def location_for(self, idx):
if idx == 0:
return 0, 0
row_group = np.searchsorted(self._row_group_sizes, idx)
if row_group == 0:
row_index = idx
else:
row_index = idx - self._row_group_sizes[row_group - 1]
return row_group, row_index
def __iter__(self):
def iter_values():
row_group, row_index = self.location_for(self._start)
table = self.reader.read_row_group(row_group)
yield from self.decoder.decode_batch(table[row_index:])
for row_group_ in range(row_group + 1, len(self._row_group_sizes)):
table = self.reader.read_row_group(row_group_)
yield from self.decoder.decode_batch(table)
yield from take(self._take, iter_values())
def __getitem__(self, idx):
if isinstance(idx, slice):
assert idx.step is None or idx.step == 1
# normalize index
start, stop, _step = idx.indices(len(self))
idx = slice(start + self._start, stop + self._start)
return ParquetFile(
self.path,
_start=idx.start,
_take=max(0, idx.stop - idx.start),
_row_group_sizes=self._row_group_sizes,
)
return first(self[idx:])
def __len__(self):
if self._take is not None:
return self._take
# One would think that pq.ParquetFile had a nicer way to get its length
return max(0, self.reader.metadata.num_rows - self._start)