Skip to content

Commit

Permalink
Add tests: without tranform_spec and with list of strings with some, …
Browse files Browse the repository at this point in the history
…all elements being None.

Change the type of numpy dtype to np.object instead of np.unicode_
  • Loading branch information
Yevgeni Litvin committed Sep 14, 2022
1 parent 9b2bb69 commit 1fcf22f
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 1 deletion.
2 changes: 1 addition & 1 deletion petastorm/arrow_reader_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def read_next(self, workers_pool, schema, ngram):
column_as_numpy = column_as_pandas

if pa.types.is_string(column.type):
result_dict[column_name] = column_as_numpy.astype(np.unicode_)
result_dict[column_name] = column_as_numpy.astype(np.object)
elif pa.types.is_list(column.type) or pa.types.is_fixed_size_list(column.type):
# Assuming all lists are of the same length, hence we can collate them into a matrix
list_of_lists = column_as_numpy
Expand Down
33 changes: 33 additions & 0 deletions petastorm/tests/test_parquet_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import numpy as np
import pandas as pd
import pyarrow as pa
import pytest
from pyarrow import parquet as pq

Expand Down Expand Up @@ -223,6 +224,38 @@ def fill_id_with_nones(x):
assert sample.id.dtype.type == null_column_dtype


@pytest.mark.parametrize('np_dtype, pa_dtype, null_value',
((np.float32, pa.float32(), np.nan), (np.object, pa.string(), None)))
@pytest.mark.parametrize('reader_factory', _D)
def test_entire_column_of_typed_nulls(reader_factory, np_dtype, pa_dtype, null_value, tmp_path):
path = tmp_path / "dataset"
schema = pa.schema([pa.field('all_nulls', pa_dtype)])
pq.write_table(pa.Table.from_pydict({"all_nulls": [null_value] * 10}, schema=schema), path)

with reader_factory("file:///" + str(path)) as reader:
sample = next(reader)
assert sample.all_nulls.dtype == np_dtype
if np_dtype == np.float32:
assert np.all(np.isnan(sample.all_nulls))
elif np_dtype == np.object:
assert all(v is None for v in sample.all_nulls)
else:
assert False, "Unexpected np_dtype"


@pytest.mark.parametrize('reader_factory', _D)
def test_column_with_list_of_strings_some_are_null(reader_factory, tmp_path):
path = tmp_path / "dataset"
schema = pa.schema([pa.field('some_nulls', pa.list_(pa.string(), -1))])
pq.write_table(pa.Table.from_pydict({"some_nulls": [['a0', 'a1'], ['b0', None], [None, None]]}, schema=schema),
path)

with reader_factory("file:///" + str(path)) as reader:
sample = next(reader)
assert sample.some_nulls.dtype == np.object
np.testing.assert_equal(sample.some_nulls, [['a0', 'a1'], ['b0', None], [None, None]])


@pytest.mark.parametrize('reader_factory', _D)
def test_transform_spec_returns_all_none_values_in_a_list_field(scalar_dataset, reader_factory):
def fill_id_with_nones(x):
Expand Down

0 comments on commit 1fcf22f

Please sign in to comment.