Skip to content

Commit

Permalink
Parallelize encoding of a single row
Browse files Browse the repository at this point in the history
When writing data into a petastorm dataset. Before a pyspark sql.Row
object is created, fields containing data that is not natively supported
by Parqyet format, such as numpy arrays, are serialized into byte
arrays. Images maybe compressed using png or jpeg compression.

Serializing fields on a thread pool speeds up this process in some
cases (e.g. a row contains multiple images).
  • Loading branch information
Yevgeni Litvin committed Apr 20, 2020
1 parent 83a02df commit 367736c
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 14 deletions.
23 changes: 22 additions & 1 deletion petastorm/tests/test_unischema.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from petastorm.codecs import ScalarCodec, NdarrayCodec
from petastorm.unischema import Unischema, UnischemaField, dict_to_spark_row, \
insert_explicit_nulls, match_unischema_fields, _new_gt_255_compatible_namedtuple, _fullmatch
insert_explicit_nulls, match_unischema_fields, _new_gt_255_compatible_namedtuple, _fullmatch, encode_row

try:
from unittest import mock
Expand Down Expand Up @@ -107,6 +107,27 @@ def test_as_spark_schema_unspecified_codec_type_unknown_scalar_type_raises():
TestSchema.as_spark_schema()


def test_encode_row():
"""Test various validations done on data types when converting a dictionary to a spark row"""
TestSchema = Unischema('TestSchema', [
UnischemaField('string_field', np.string_, (), ScalarCodec(StringType()), False),
UnischemaField('int8_matrix', np.int8, (2, 2), NdarrayCodec(), False),
])

row = {'string_field': 'abc', 'int8_matrix': np.asarray([[1, 2], [3, 4]], dtype=np.int8)}
encoded_row = encode_row(TestSchema, row)
assert set(row.keys()) == set(encoded_row)
assert isinstance(encoded_row['int8_matrix'], bytearray)

extra_field_row = {'string_field': 'abc', 'int8_matrix': [[1, 2], [3, 4]], 'bogus': 'value'}
with pytest.raises(ValueError, match='.*not found.*bogus.*'):
encode_row(TestSchema, extra_field_row)

extra_field_row = {'string_field': 'abc'}
with pytest.raises(ValueError, match='int8_matrix is not found'):
encode_row(TestSchema, extra_field_row)


def test_dict_to_spark_row_field_validation_scalar_types():
"""Test various validations done on data types when converting a dictionary to a spark row"""
TestSchema = Unischema('TestSchema', [
Expand Down
59 changes: 46 additions & 13 deletions petastorm/unischema.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import sys
import warnings
from collections import namedtuple, OrderedDict
from concurrent.futures import ThreadPoolExecutor
from decimal import Decimal

import numpy as np
Expand Down Expand Up @@ -369,38 +370,70 @@ def dict_to_spark_row(unischema, row_dict):
# Lazy loading pyspark to avoid creating pyspark dependency on data reading code path
# (currently works only with make_batch_reader)
import pyspark
encoded_dict = encode_row(unischema, row_dict)

field_list = list(unischema.fields.keys())
# generate a value list which match the schema column order.
value_list = [encoded_dict[name] for name in field_list]
# create a row by value list
row = pyspark.Row(*value_list)
# set row fields
row.__fields__ = field_list

return row


def encode_row(unischema, row_dict, threads_count=10):
"""Verifies that the data confirms with unischema definition types and encodes the data using the codec specified
by the unischema.
:param unischema: an instance of Unischema object
:param row_dict: a dictionary where the keys match name of fields in the unischema.
:return: a dictionary of encoded fields
"""

# Lazy loading pyspark to avoid creating pyspark dependency on data reading code path
# (currently works only with make_batch_reader)

assert isinstance(unischema, Unischema)
# Add null fields. Be careful not to mutate the input dictionary - that would be an unexpected side effect
copy_row_dict = copy.copy(row_dict)
insert_explicit_nulls(unischema, copy_row_dict)

if set(copy_row_dict.keys()) != set(unischema.fields.keys()):
raise ValueError('Dictionary fields \n{}\n do not match schema fields \n{}'.format(
'\n'.join(sorted(copy_row_dict.keys())), '\n'.join(unischema.fields.keys())))
input_field_names = set(copy_row_dict.keys())
unischema_field_names = set(unischema.fields.keys())

encoded_dict = {}
unknown_field_names = input_field_names - unischema_field_names

if unknown_field_names:
raise ValueError('Following fields of row_dict are not found in '
'unischema: {}'.format(', '.join(sorted(unknown_field_names))))

executor = ThreadPoolExecutor(threads_count)
encoded_dict = dict()
futures_dict = dict()
for field_name, value in copy_row_dict.items():
schema_field = unischema.fields[field_name]
if value is None:
if not schema_field.nullable:
raise ValueError('Field {} is not "nullable", but got passes a None value')
if schema_field.codec:
encoded_dict[field_name] = schema_field.codec.encode(schema_field, value) if value is not None else None
if value is None:
encoded_dict[field_name] = None
else:
futures_dict[field_name] = executor.submit(
lambda _schema_field, _value: _schema_field.codec.encode(_schema_field, _value), schema_field,
value)
else:
if isinstance(value, (np.generic,)):
encoded_dict[field_name] = value.tolist()
else:
encoded_dict[field_name] = value

field_list = list(unischema.fields.keys())
# generate a value list which match the schema column order.
value_list = [encoded_dict[name] for name in field_list]
# create a row by value list
row = pyspark.Row(*value_list)
# set row fields
row.__fields__ = field_list
return row
for k, v in futures_dict.items():
encoded_dict[k] = v.result()

return encoded_dict


def insert_explicit_nulls(unischema, row_dict):
Expand Down

0 comments on commit 367736c

Please sign in to comment.