Skip to content

Commit

Permalink
Support for PdReader kwargs (AmpX-AI#8)
Browse files Browse the repository at this point in the history
Co-authored-by: vojta tuma <vtuma@amp.energy>
  • Loading branch information
tmi and vtuma authored Jun 14, 2022
1 parent 497673d commit f5077ec
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 7 deletions.
38 changes: 31 additions & 7 deletions fsql/deser.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@
opt for the lazy approach (such as in Dask), and don't materialize inside neither the `read_single` nor
`concat` methods.
Existing readers such as PandasReader allow customisation via passing through any kwargs to the underlying
pandas read method.
The user should *not* bake in any specific business logic in here -- a more prefered approach is to
return an object such as data frame as early as possible, and apply any transformations later on.
return an object such as (lazy) data frame as early as possible, and apply any transformations later on.
"""
from __future__ import annotations

import json
import logging
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Iterable
from concurrent.futures import ThreadPoolExecutor
from enum import Enum, auto, unique
Expand Down Expand Up @@ -87,31 +91,51 @@ def read_and_concat(


class PandasReader(DataReader):
"""Wraps various pandas read methods (parquet, json, csv, excel) into a single interface.
Behaviour can be customised via passing any kwargs to the constructor.
"""

def __init__(self, input_format=InputFormat.AUTO, **pdread_kwargs):
super().__init__(input_format=input_format)
self.pdread_user_kwargs = pdread_kwargs
self.pdread_default_kwargs = defaultdict(dict)
self.pdread_default_kwargs[InputFormat.PARQUET] = {
"engine": "fastparquet",
}
self.pdread_default_kwargs[InputFormat.JSON] = {
"lines": "true",
}
self.pdread_default_kwargs[InputFormat.XLSX] = {
"engine": "openpyxl",
}

def read_single(self, partition: Partition, fs: AbstractFileSystem) -> pd.DataFrame:
logger.debug(f"read dataframe for partition {partition}")
input_format = self.detect_format(partition.url)
# TODO allow for user spec of engine and other params, essentially any quark
logger.debug(f"format detected for partition {input_format} <- {partition}")
if input_format is InputFormat.PARQUET:
reader = lambda fd: pd.read_parquet(fd, engine="fastparquet") # noqa: E731
reader = pd.read_parquet
elif input_format is InputFormat.JSON:
reader = lambda fd: pd.read_json(fd, lines=True) # noqa: E731
reader = pd.read_json
elif input_format is InputFormat.CSV:
reader = pd.read_csv
elif input_format is InputFormat.XLSX:
reader = lambda fd: pd.read_excel(fd, engine="openpyxl") # noqa: E731
reader = pd.read_excel
elif input_format is InputFormat.AUTO:
raise ValueError(f"partition had format detected as auto -> invalid state. Partition: {partition}")
else:
assert_exhaustive_enum(input_format)

pdread_kwargs = {**self.pdread_default_kwargs[input_format], **self.pdread_user_kwargs}
logger.debug(f"reader kwargs {pdread_kwargs} for partition {partition}")
try:
with fs.open(partition.url, "rb") as fd:
df = reader(fd)
df = reader(fd, **pdread_kwargs)
except FileNotFoundError as e:
logger.warning(f"file {partition} reading exception {type(e)}, attempting cache invalidation and reread")
fs.invalidate_cache()
with fs.open(partition.url, "rb") as fd:
df = reader(fd)
df = reader(fd, **pdread_kwargs)

for key, value in partition.columns.items():
df[key] = value
Expand Down
37 changes: 37 additions & 0 deletions tests/test_pandasreader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pandas as pd
import pytest
from pandas.testing import assert_frame_equal

from fsql.api import read_partitioned_table
from fsql.deser import InputFormat, PandasReader
from fsql.query import Q_TRUE

df1 = pd.DataFrame(data={"c1": [0, 1], "c2": ["hello", "world"]})


def test_input_format_override(tmp_path):
"""Test that explicitly setting format overrides suffix."""

case1_path = tmp_path / "table1"
case1_path.mkdir(parents=True)
df1.to_csv(case1_path / "f1.json", index=False) # confuse the default by bad suffix

with pytest.raises(ValueError, match="Expected object or value"):
# this test condition is quite brittle! A better match would be desired
failure_result = read_partitioned_table(f"file://{case1_path}/", Q_TRUE)

reader = PandasReader(input_format=InputFormat.CSV)
succ_result = read_partitioned_table(f"file://{case1_path}/", Q_TRUE, data_reader=reader)
assert_frame_equal(df1, succ_result)


def test_parquet_kwargs(tmp_path):
"""Test that a kwarg (`columns`) gets passed through and obeyed."""

case1_path = tmp_path / "table1"
case1_path.mkdir(parents=True)
df1.to_parquet(case1_path / "f1.parquet", index=False)

reader = PandasReader(columns=["c2"])
result = read_partitioned_table(f"file://{case1_path}/", Q_TRUE, data_reader=reader)
assert_frame_equal(df1[["c2"]], result)

0 comments on commit f5077ec

Please sign in to comment.