diff --git a/fsql/deser.py b/fsql/deser.py index e1742b6..35ad8dd 100644 --- a/fsql/deser.py +++ b/fsql/deser.py @@ -17,14 +17,18 @@ opt for the lazy approach (such as in Dask), and don't materialize inside neither the `read_single` nor `concat` methods. +Existing readers such as PandasReader allow customisation via passing through any kwargs to the underlying +pandas read method. + The user should *not* bake in any specific business logic in here -- a more prefered approach is to -return an object such as data frame as early as possible, and apply any transformations later on. +return an object such as (lazy) data frame as early as possible, and apply any transformations later on. """ from __future__ import annotations import json import logging from abc import ABC, abstractmethod +from collections import defaultdict from collections.abc import Iterable from concurrent.futures import ThreadPoolExecutor from enum import Enum, auto, unique @@ -87,31 +91,51 @@ def read_and_concat( class PandasReader(DataReader): + """Wraps various pandas read methods (parquet, json, csv, excel) into a single interface. + Behaviour can be customised via passing any kwargs to the constructor. + """ + + def __init__(self, input_format=InputFormat.AUTO, **pdread_kwargs): + super().__init__(input_format=input_format) + self.pdread_user_kwargs = pdread_kwargs + self.pdread_default_kwargs = defaultdict(dict) + self.pdread_default_kwargs[InputFormat.PARQUET] = { + "engine": "fastparquet", + } + self.pdread_default_kwargs[InputFormat.JSON] = { + "lines": "true", + } + self.pdread_default_kwargs[InputFormat.XLSX] = { + "engine": "openpyxl", + } + def read_single(self, partition: Partition, fs: AbstractFileSystem) -> pd.DataFrame: logger.debug(f"read dataframe for partition {partition}") input_format = self.detect_format(partition.url) - # TODO allow for user spec of engine and other params, essentially any quark + logger.debug(f"format detected for partition {input_format} <- {partition}") if input_format is InputFormat.PARQUET: - reader = lambda fd: pd.read_parquet(fd, engine="fastparquet") # noqa: E731 + reader = pd.read_parquet elif input_format is InputFormat.JSON: - reader = lambda fd: pd.read_json(fd, lines=True) # noqa: E731 + reader = pd.read_json elif input_format is InputFormat.CSV: reader = pd.read_csv elif input_format is InputFormat.XLSX: - reader = lambda fd: pd.read_excel(fd, engine="openpyxl") # noqa: E731 + reader = pd.read_excel elif input_format is InputFormat.AUTO: raise ValueError(f"partition had format detected as auto -> invalid state. Partition: {partition}") else: assert_exhaustive_enum(input_format) + pdread_kwargs = {**self.pdread_default_kwargs[input_format], **self.pdread_user_kwargs} + logger.debug(f"reader kwargs {pdread_kwargs} for partition {partition}") try: with fs.open(partition.url, "rb") as fd: - df = reader(fd) + df = reader(fd, **pdread_kwargs) except FileNotFoundError as e: logger.warning(f"file {partition} reading exception {type(e)}, attempting cache invalidation and reread") fs.invalidate_cache() with fs.open(partition.url, "rb") as fd: - df = reader(fd) + df = reader(fd, **pdread_kwargs) for key, value in partition.columns.items(): df[key] = value diff --git a/tests/test_pandasreader.py b/tests/test_pandasreader.py new file mode 100644 index 0000000..d1f0f68 --- /dev/null +++ b/tests/test_pandasreader.py @@ -0,0 +1,37 @@ +import pandas as pd +import pytest +from pandas.testing import assert_frame_equal + +from fsql.api import read_partitioned_table +from fsql.deser import InputFormat, PandasReader +from fsql.query import Q_TRUE + +df1 = pd.DataFrame(data={"c1": [0, 1], "c2": ["hello", "world"]}) + + +def test_input_format_override(tmp_path): + """Test that explicitly setting format overrides suffix.""" + + case1_path = tmp_path / "table1" + case1_path.mkdir(parents=True) + df1.to_csv(case1_path / "f1.json", index=False) # confuse the default by bad suffix + + with pytest.raises(ValueError, match="Expected object or value"): + # this test condition is quite brittle! A better match would be desired + failure_result = read_partitioned_table(f"file://{case1_path}/", Q_TRUE) + + reader = PandasReader(input_format=InputFormat.CSV) + succ_result = read_partitioned_table(f"file://{case1_path}/", Q_TRUE, data_reader=reader) + assert_frame_equal(df1, succ_result) + + +def test_parquet_kwargs(tmp_path): + """Test that a kwarg (`columns`) gets passed through and obeyed.""" + + case1_path = tmp_path / "table1" + case1_path.mkdir(parents=True) + df1.to_parquet(case1_path / "f1.parquet", index=False) + + reader = PandasReader(columns=["c2"]) + result = read_partitioned_table(f"file://{case1_path}/", Q_TRUE, data_reader=reader) + assert_frame_equal(df1[["c2"]], result)