Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallelize encoding of a single row #546

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion petastorm/tests/test_unischema.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@

from petastorm.codecs import ScalarCodec, NdarrayCodec
from petastorm.unischema import Unischema, UnischemaField, dict_to_spark_row, \
insert_explicit_nulls, match_unischema_fields, _new_gt_255_compatible_namedtuple, _fullmatch
insert_explicit_nulls, match_unischema_fields, _new_gt_255_compatible_namedtuple, _fullmatch, encode_row

from concurrent.futures import ThreadPoolExecutor

try:
from unittest import mock
Expand Down Expand Up @@ -107,6 +109,28 @@ def test_as_spark_schema_unspecified_codec_type_unknown_scalar_type_raises():
TestSchema.as_spark_schema()


@pytest.mark.parametrize("pool_executor", [None, ThreadPoolExecutor(2)])
def test_encode_row(pool_executor):
"""Test various validations done on data types when converting a dictionary to a spark row"""
TestSchema = Unischema('TestSchema', [
UnischemaField('string_field', np.string_, (), ScalarCodec(StringType()), False),
UnischemaField('int8_matrix', np.int8, (2, 2), NdarrayCodec(), False),
])

row = {'string_field': 'abc', 'int8_matrix': np.asarray([[1, 2], [3, 4]], dtype=np.int8)}
encoded_row = encode_row(TestSchema, row, pool_executor)
assert set(row.keys()) == set(encoded_row)
assert isinstance(encoded_row['int8_matrix'], bytearray)

extra_field_row = {'string_field': 'abc', 'int8_matrix': [[1, 2], [3, 4]], 'bogus': 'value'}
with pytest.raises(ValueError, match='.*not found.*bogus.*'):
encode_row(TestSchema, extra_field_row, pool_executor)

extra_field_row = {'string_field': 'abc'}
with pytest.raises(ValueError, match='int8_matrix is not found'):
encode_row(TestSchema, extra_field_row, pool_executor)


def test_dict_to_spark_row_field_validation_scalar_types():
"""Test various validations done on data types when converting a dictionary to a spark row"""
TestSchema = Unischema('TestSchema', [
Expand Down
62 changes: 48 additions & 14 deletions petastorm/unischema.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def from_arrow_schema(cls, parquet_dataset, omit_unsupported_fields=False):
return Unischema('inferred_schema', unischema_fields)


def dict_to_spark_row(unischema, row_dict):
def dict_to_spark_row(unischema, row_dict, pool_executor=None):
"""Converts a single row into a spark Row object.

Verifies that the data confirms with unischema definition types and encodes the data using the codec specified
Expand All @@ -363,44 +363,78 @@ def dict_to_spark_row(unischema, row_dict):

:param unischema: an instance of Unischema object
:param row_dict: a dictionary where the keys match name of fields in the unischema.
:param pool_executor: if not None, encoding of row fields will be performed using the pool_executor
:return: a single pyspark.Row object
"""

# Lazy loading pyspark to avoid creating pyspark dependency on data reading code path
# (currently works only with make_batch_reader)
import pyspark
encoded_dict = encode_row(unischema, row_dict, pool_executor)

field_list = list(unischema.fields.keys())
# generate a value list which match the schema column order.
value_list = [encoded_dict[name] for name in field_list]
# create a row by value list
row = pyspark.Row(*value_list)
# set row fields
row.__fields__ = field_list

return row


def encode_row(unischema, row_dict, pool_executor=None):
"""Verifies that the data confirms with unischema definition types and encodes the data using the codec specified
by the unischema.

:param unischema: an instance of Unischema object
:param row_dict: a dictionary where the keys match name of fields in the unischema.
:param pool_executor: if not None, encoding of row fields will be performed using the pool_executor
:return: a dictionary of encoded fields
"""

# Lazy loading pyspark to avoid creating pyspark dependency on data reading code path
# (currently works only with make_batch_reader)

assert isinstance(unischema, Unischema)
# Add null fields. Be careful not to mutate the input dictionary - that would be an unexpected side effect
copy_row_dict = copy.copy(row_dict)
insert_explicit_nulls(unischema, copy_row_dict)

if set(copy_row_dict.keys()) != set(unischema.fields.keys()):
raise ValueError('Dictionary fields \n{}\n do not match schema fields \n{}'.format(
'\n'.join(sorted(copy_row_dict.keys())), '\n'.join(unischema.fields.keys())))
input_field_names = set(copy_row_dict.keys())
unischema_field_names = set(unischema.fields.keys())

unknown_field_names = input_field_names - unischema_field_names

encoded_dict = {}
if unknown_field_names:
raise ValueError('Following fields of row_dict are not found in '
'unischema: {}'.format(', '.join(sorted(unknown_field_names))))

encoded_dict = dict()
futures_dict = dict()
for field_name, value in copy_row_dict.items():
schema_field = unischema.fields[field_name]
if value is None:
if not schema_field.nullable:
raise ValueError('Field {} is not "nullable", but got passes a None value')
if schema_field.codec:
encoded_dict[field_name] = schema_field.codec.encode(schema_field, value) if value is not None else None
if value is None:
encoded_dict[field_name] = None
else:
if pool_executor:
futures_dict[field_name] = pool_executor.submit(schema_field.codec.encode, schema_field, value)
else:
encoded_dict[field_name] = schema_field.codec.encode(schema_field, value)
else:
if isinstance(value, (np.generic,)):
encoded_dict[field_name] = value.tolist()
else:
encoded_dict[field_name] = value

field_list = list(unischema.fields.keys())
# generate a value list which match the schema column order.
value_list = [encoded_dict[name] for name in field_list]
# create a row by value list
row = pyspark.Row(*value_list)
# set row fields
row.__fields__ = field_list
return row
for k, v in futures_dict.items():
encoded_dict[k] = v.result()

return encoded_dict


def insert_explicit_nulls(unischema, row_dict):
Expand Down