-
Notifications
You must be signed in to change notification settings - Fork 284
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Auto infer schema (including fields shape) from the first row #512
base: master
Are you sure you want to change the base?
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,7 @@ | |
|
||
from petastorm import make_batch_reader | ||
from petastorm.fs_utils import FilesystemResolver | ||
from petastorm.transform import TransformSpec | ||
|
||
DEFAULT_ROW_GROUP_SIZE_BYTES = 32 * 1024 * 1024 | ||
|
||
|
@@ -179,6 +180,7 @@ def make_tf_dataset( | |
prefetch=None, | ||
num_epochs=None, | ||
workers_count=None, | ||
preprocess_fn=None, | ||
**petastorm_reader_kwargs | ||
): | ||
""" | ||
|
@@ -188,6 +190,12 @@ def make_tf_dataset( | |
1) Open a petastorm reader on the materialized dataset dir. | ||
2) Create a tensorflow dataset based on the reader created in (1) | ||
|
||
The generated dataset each element will be a batch of namedtuples. | ||
If without specifying `preprocess_fn`, each namedtuple in result dataset will match the | ||
schema of the original spark dataframe columns, otherwise will match the columns of the | ||
output pandas dataframe of `preprocess_fn`. The fields order will keep the same with | ||
original spark dataframe columns or the output pandas dataframe of `preprocess_fn`. | ||
|
||
:param batch_size: The number of items to return per batch. Default None. | ||
If None, current implementation will set batch size to be 32, in future, | ||
None value will denotes auto tuned best value for batch size. | ||
|
@@ -201,8 +209,14 @@ def make_tf_dataset( | |
None denotes auto tune best value (current implementation when auto tune, | ||
it will always use 4 workers, but it may be improved in future) | ||
Default value None. | ||
:param preprocess_fn: Preprocessing function. Input is pandas dataframe of | ||
a rowgroup data and output should be the transformed pandas dataframe. | ||
the column order of the input pandas dataframe is undefined, but the output | ||
pandas dataframe column order will determine the result tensorflow dataset's | ||
element fields order. | ||
:param petastorm_reader_kwargs: arguments for `petastorm.make_batch_reader()`, | ||
exclude these arguments: "dataset_url", "num_epochs", "workers_count". | ||
exclude these arguments: "dataset_url_or_urls", "num_epochs", "workers_count", | ||
"transform_spec", "infer_schema_from_first_row" | ||
|
||
:return: a context manager for a `tf.data.Dataset` object. | ||
when exit the returned context manager, the reader | ||
|
@@ -216,6 +230,18 @@ def make_tf_dataset( | |
workers_count = 4 | ||
petastorm_reader_kwargs['workers_count'] = workers_count | ||
|
||
if 'dataset_url_or_urls' in petastorm_reader_kwargs: | ||
raise ValueError('User cannot set dataset_url_or_urls argument.') | ||
|
||
if 'transform_spec' in petastorm_reader_kwargs or \ | ||
'infer_schema_from_first_row' in petastorm_reader_kwargs: | ||
raise ValueError('User cannot set transform_spec and infer_schema_from_first_row ' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shall we also allow users to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the param |
||
'arguments, use `preprocess_fn` argument instead.') | ||
|
||
petastorm_reader_kwargs['infer_schema_from_first_row'] = True | ||
if preprocess_fn: | ||
petastorm_reader_kwargs['transform_spec'] = TransformSpec(preprocess_fn) | ||
|
||
hvd_rank, hvd_size = _get_horovod_rank_and_size() | ||
cur_shard = petastorm_reader_kwargs.get('cur_shard') | ||
shard_count = petastorm_reader_kwargs.get('shard_count') | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I'm not sure whether the partition[0] necessarily contains the "first" row? Could the partitions be out of order? If so, we may call it
infer_schema_from_a_row
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I read the first row in the index-0 row-groups. But index-0 row-groups may be non-deterministic ? Not sure.
infer_schema_from_a_row
sounds good.