From 2295490dfcf248736d50f484c4b5c05b0c33ad59 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Mon, 23 Mar 2020 11:37:33 +0800 Subject: [PATCH 1/6] init --- petastorm/reader.py | 76 +++++++++++++++++++--- petastorm/spark/spark_dataset_converter.py | 14 +++- petastorm/tf_utils.py | 1 + petastorm/unischema.py | 6 +- 4 files changed, 84 insertions(+), 13 deletions(-) diff --git a/petastorm/reader.py b/petastorm/reader.py index 35b854e9e..0febcc6d0 100644 --- a/petastorm/reader.py +++ b/petastorm/reader.py @@ -16,8 +16,10 @@ import logging import warnings +import numpy as np import six from pyarrow import parquet as pq +from pyarrow.parquet import ParquetFile from petastorm.arrow_reader_worker import ArrowReaderWorker from petastorm.cache import NullCache @@ -35,6 +37,7 @@ from petastorm.reader_impl.pyarrow_serializer import PyArrowSerializer from petastorm.selectors import RowGroupSelectorBase from petastorm.transform import transform_schema +from petastorm.unischema import Unischema, UnischemaField from petastorm.workers_pool.dummy_pool import DummyPool from petastorm.workers_pool.process_pool import ProcessPool from petastorm.workers_pool.thread_pool import ThreadPool @@ -170,6 +173,7 @@ def make_reader(dataset_url, 'shard_count': shard_count, 'cache': cache, 'transform_spec': transform_spec, + 'infer_schema_from_first_row': False, } try: @@ -196,7 +200,8 @@ def make_batch_reader(dataset_url_or_urls, cache_type='null', cache_location=None, cache_size_limit=None, cache_row_size_estimate=None, cache_extra_settings=None, hdfs_driver='libhdfs3', - transform_spec=None): + transform_spec=None, + infer_schema_from_first_row=False): """ Creates an instance of Reader for reading batches out of a non-Petastorm Parquet store. @@ -245,6 +250,14 @@ def make_batch_reader(dataset_url_or_urls, :param transform_spec: An instance of :class:`~petastorm.transform.TransformSpec` object defining how a record is transformed after it is loaded and decoded. The transformation occurs on a worker thread/process (depends on the ``reader_pool_type`` value). + :param infer_schema_from_first_row: Whether to infer schema from the row data. Only support parquet reader. + If on, before creating the reader, it will first read one row group to infer the full schema information, + and the transform spec (if exists) do not need to specify edit_fields/removed_fields. + Require: for all rows (before applying predicates), all values in each field are non-nullable and have + the same shape. + Turning on this param will address the following two issues: + 1) Auto inferring parquet schema from metadata cannot get shape information. + 2) If there's a preprocessing function, we have to specify edit/removed fields. :return: A :class:`Reader` object """ dataset_url_or_urls = normalize_dataset_url_or_urls(dataset_url_or_urls) @@ -291,7 +304,8 @@ def make_batch_reader(dataset_url_or_urls, shard_count=shard_count, cache=cache, transform_spec=transform_spec, - is_batched_reader=True) + is_batched_reader=True, + infer_schema_from_first_row=infer_schema_from_first_row) class Reader(object): @@ -304,7 +318,7 @@ def __init__(self, pyarrow_filesystem, dataset_path, schema_fields=None, shuffle_row_groups=True, shuffle_row_drop_partitions=1, predicate=None, rowgroup_selector=None, reader_pool=None, num_epochs=1, cur_shard=None, shard_count=None, cache=None, worker_class=None, - transform_spec=None, is_batched_reader=False): + transform_spec=None, is_batched_reader=False, infer_schema_from_first_row=False): """Initializes a reader object. :param pyarrow_filesystem: An instance of ``pyarrow.FileSystem`` that will be used. If not specified, @@ -340,13 +354,21 @@ def __init__(self, pyarrow_filesystem, dataset_path, schema_fields=None, to the main data store is either slow or expensive and the local machine has large enough storage to store entire dataset (or a partition of a dataset if shards are used). By default, use the :class:`.NullCache` implementation. + :param infer_schema_from_first_row: Whether to infer schema from the row data. Only support parquet reader. + If on, before creating the reader, it will first read one row group to infer the full schema information, + and the transform spec (if exists) do not need to specify edit_fields/removed_fields. + Require: for all rows (before applying predicates), all values in each field are non-nullable and have + the same shape. + Turning on this param will address the following two issues: + 1) Auto inferring parquet schema from metadata cannot get shape information. + 2) If there's a preprocessing function, we have to specify edit/removed fields. :param worker_class: This is the class that will be instantiated on a different thread/process. It's responsibility is to load and filter the data. """ # 1. Open the parquet storage (dataset) - # 2. Get a list of all groups + # 2. Get a list of all groups and infer schema if needed. # 3. Filter rowgroups # a. predicates # b. row-group selector (our indexing mechanism) @@ -396,14 +418,22 @@ def __init__(self, pyarrow_filesystem, dataset_path, schema_fields=None, fields = schema_fields if isinstance(schema_fields, collections.Iterable) else None storage_schema = stored_schema.create_schema_view(fields) if fields else stored_schema - if transform_spec: - self.schema = transform_schema(storage_schema, transform_spec) - else: - self.schema = storage_schema - # 2. Get a list of all row groups + # 2. Get a list of all row groups and infer schema if needed. row_groups = dataset_metadata.load_row_groups(self.dataset) + if infer_schema_from_first_row: + if worker_class is not ArrowReaderWorker: + raise ValueError('infer_schema_from_first_row only support ArrowReaderWorker.') + worker0 = ArrowReaderWorker(0, None, (pyarrow_filesystem, dataset_path, storage_schema, self.ngram, + row_groups, NullCache(), None, None)) + self.schema = self._infer_schema_from_first_row(worker0, transform_spec) + else: + if transform_spec: + self.schema = transform_schema(storage_schema, transform_spec) + else: + self.schema = storage_schema + # 3. Filter rowgroups filtered_row_group_indexes, worker_predicate = self._filter_row_groups(self.dataset, row_groups, predicate, rowgroup_selector, cur_shard, @@ -454,6 +484,34 @@ def reset(self): def batched_output(self): return self._results_queue_reader.batched_output + @staticmethod + def _infer_schema_from_first_row(worker0, transform_spec): + piece0 = worker0._split_pieces[0] + pq_file0 = ParquetFile(worker0._dataset.fs.open(piece0.path)) + piece0_pdf = worker0._load_rows(pq_file0, piece0, (0, 1)) + + row0_pdf = piece0_pdf.head(n=1) + if transform_spec: + row0_pdf = transform_spec.func(row0_pdf) + + unischema_fields = [] + for field_name in list(row0_pdf.columns): + field_val = row0_pdf[field_name][0] + + if np.ndim(field_val) == 0: + # scalar value + field_shape = () + field_numpy_type = np.dtype(type(field_val)).type + else: + field_shape = field_val.shape + field_numpy_type = field_val.dtype.type + + # TODO: add type checking,raise error for illegal type (such as np.object_) + unischema_field = UnischemaField(field_name, field_numpy_type, field_shape, None, False) + unischema_fields.append(unischema_field) + + return Unischema('inferred_schema', unischema_fields) + def _filter_row_groups(self, dataset, row_groups, predicate, rowgroup_selector, cur_shard, shard_count): """Calculates which rowgroups will be read during. diff --git a/petastorm/spark/spark_dataset_converter.py b/petastorm/spark/spark_dataset_converter.py index 92f641860..bdc29305a 100644 --- a/petastorm/spark/spark_dataset_converter.py +++ b/petastorm/spark/spark_dataset_converter.py @@ -179,6 +179,7 @@ def make_tf_dataset( prefetch=None, num_epochs=None, workers_count=None, + preprocess_fn=None, **petastorm_reader_kwargs ): """ @@ -201,8 +202,11 @@ def make_tf_dataset( None denotes auto tune best value (current implementation when auto tune, it will always use 4 workers, but it may be improved in future) Default value None. + :param preprocess_fn: Preprocessing function. Input is pandas dataframe of + a rowgroup data and output should be the transformed pandas dataframe. :param petastorm_reader_kwargs: arguments for `petastorm.make_batch_reader()`, - exclude these arguments: "dataset_url", "num_epochs", "workers_count". + exclude these arguments: "dataset_url_or_urls", "num_epochs", "workers_count", + "transform_spec", "infer_schema_from_first_row" :return: a context manager for a `tf.data.Dataset` object. when exit the returned context manager, the reader @@ -216,6 +220,14 @@ def make_tf_dataset( workers_count = 4 petastorm_reader_kwargs['workers_count'] = workers_count + if 'dataset_url_or_urls' in petastorm_reader_kwargs: + raise ValueError('User cannot set dataset_url_or_urls argument.') + + if 'transform_spec' in petastorm_reader_kwargs or \ + 'infer_schema_from_first_row' in petastorm_reader_kwargs: + raise ValueError('User cannot set transform_spec and infer_schema_from_first_row ' + 'arguments, use `preprocess_fn` argument instead.') + hvd_rank, hvd_size = _get_horovod_rank_and_size() cur_shard = petastorm_reader_kwargs.get('cur_shard') shard_count = petastorm_reader_kwargs.get('shard_count') diff --git a/petastorm/tf_utils.py b/petastorm/tf_utils.py index ec3a8eed5..3b5dfa4c0 100644 --- a/petastorm/tf_utils.py +++ b/petastorm/tf_utils.py @@ -38,6 +38,7 @@ np.string_: tf.string, np.unicode_: tf.string, np.str_: tf.string, + np.bytes_: tf.string, np.bool_: tf.bool, Decimal: tf.string, np.datetime64: tf.int64, diff --git a/petastorm/unischema.py b/petastorm/unischema.py index 472abf7c4..88008fd71 100644 --- a/petastorm/unischema.py +++ b/petastorm/unischema.py @@ -172,11 +172,11 @@ def __init__(self, name, fields): """Creates an instance of a Unischema object. :param name: name of the schema - :param fields: a list of ``UnischemaField`` instances describing the fields. The order of the fields is - not important - they are stored sorted by name internally. + :param fields: a list of ``UnischemaField`` instances describing the fields. The order of the fields + will be the order from the fields list. """ self._name = name - self._fields = OrderedDict([(f.name, f) for f in sorted(fields, key=lambda t: t.name)]) + self._fields = OrderedDict([(f.name, f) for f in fields]) # Generates attributes named by the field names as an access syntax sugar. for f in fields: if not hasattr(self, f.name): From 6ded62747121c2ca468641f529f2bafd519be183 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Mon, 23 Mar 2020 11:46:33 +0800 Subject: [PATCH 2/6] update --- petastorm/arrow_reader_worker.py | 58 ++++++++++++++++++---- petastorm/reader.py | 55 ++++---------------- petastorm/spark/spark_dataset_converter.py | 14 ++++++ 3 files changed, 71 insertions(+), 56 deletions(-) diff --git a/petastorm/arrow_reader_worker.py b/petastorm/arrow_reader_worker.py index de1654b0c..bfe9ec3cc 100644 --- a/petastorm/arrow_reader_worker.py +++ b/petastorm/arrow_reader_worker.py @@ -24,6 +24,7 @@ from petastorm.cache import NullCache from petastorm.compat import compat_piece_read, compat_table_columns_gen, compat_column_data +from petastorm.unischema import Unischema, UnischemaField from petastorm.workers_pool import EmptyResultError from petastorm.workers_pool.worker_base import WorkerBase @@ -111,6 +112,18 @@ def __init__(self, worker_id, publish_func, args): def new_results_queue_reader(): return ArrowReaderWorkerResultsQueueReader() + def _init_dataset(self): + if not self._dataset: + self._dataset = pq.ParquetDataset( + self._dataset_path_or_paths, + filesystem=self._filesystem, + validate_schema=False) + + if self._dataset.partitions is None: + # When read from parquet file list, the `dataset.partitions` will be None. + # But other petastorm code require at least an empty `ParquetPartitions` object. + self._dataset.partitions = pq.ParquetPartitions() + # pylint: disable=arguments-differ def process(self, piece_index, worker_predicate, shuffle_row_drop_partition): """Main worker function. Loads and returns all rows matching the predicate from a rowgroup @@ -124,17 +137,7 @@ def process(self, piece_index, worker_predicate, shuffle_row_drop_partition): of partitions. :return: """ - - if not self._dataset: - self._dataset = pq.ParquetDataset( - self._dataset_path_or_paths, - filesystem=self._filesystem, - validate_schema=False) - - if self._dataset.partitions is None: - # When read from parquet file list, the `dataset.partitions` will be None. - # But other petastorm code require at least an empty `ParquetPartitions` object. - self._dataset.partitions = pq.ParquetPartitions() + self._init_dataset() piece = self._split_pieces[piece_index] @@ -168,6 +171,39 @@ def process(self, piece_index, worker_predicate, shuffle_row_drop_partition): if all_cols: self.publish_func(all_cols) + def infer_schema_from_first_row(self): + self._init_dataset() + if self._dataset.partitions.partition_names: + raise ValueError('infer_schema_from_first_row does not support parquet partition column.') + + piece0 = self._split_pieces[0] + pq_file0 = ParquetFile(self._dataset.fs.open(piece0.path)) + + column_names = [field_name for field_name in self._schema.fields] + piece0_pdf = compat_piece_read(piece0, lambda _: pq_file0, columns=column_names).to_pandas() + + row0_pdf = piece0_pdf.head(n=1) + if self._transform_spec: + row0_pdf = self._transform_spec.func(row0_pdf) + + unischema_fields = [] + for field_name in column_names: + field_val = row0_pdf[field_name][0] + + if np.ndim(field_val) == 0: + # scalar value + field_shape = () + field_numpy_type = np.dtype(type(field_val)).type + else: + field_shape = field_val.shape + field_numpy_type = field_val.dtype.type + + # TODO: add type checking,raise error for illegal type (such as np.object_) + unischema_field = UnischemaField(field_name, field_numpy_type, field_shape, None, False) + unischema_fields.append(unischema_field) + + return Unischema('inferred_schema', unischema_fields) + @staticmethod def _check_shape_and_ravel(x, field): if not isinstance(x, np.ndarray): diff --git a/petastorm/reader.py b/petastorm/reader.py index 0febcc6d0..8f14c46f6 100644 --- a/petastorm/reader.py +++ b/petastorm/reader.py @@ -16,10 +16,8 @@ import logging import warnings -import numpy as np import six from pyarrow import parquet as pq -from pyarrow.parquet import ParquetFile from petastorm.arrow_reader_worker import ArrowReaderWorker from petastorm.cache import NullCache @@ -37,7 +35,6 @@ from petastorm.reader_impl.pyarrow_serializer import PyArrowSerializer from petastorm.selectors import RowGroupSelectorBase from petastorm.transform import transform_schema -from petastorm.unischema import Unischema, UnischemaField from petastorm.workers_pool.dummy_pool import DummyPool from petastorm.workers_pool.process_pool import ProcessPool from petastorm.workers_pool.thread_pool import ThreadPool @@ -250,14 +247,16 @@ def make_batch_reader(dataset_url_or_urls, :param transform_spec: An instance of :class:`~petastorm.transform.TransformSpec` object defining how a record is transformed after it is loaded and decoded. The transformation occurs on a worker thread/process (depends on the ``reader_pool_type`` value). - :param infer_schema_from_first_row: Whether to infer schema from the row data. Only support parquet reader. + :param infer_schema_from_first_row: Whether to infer schema from the first row data. Only support parquet reader. If on, before creating the reader, it will first read one row group to infer the full schema information, and the transform spec (if exists) do not need to specify edit_fields/removed_fields. - Require: for all rows (before applying predicates), all values in each field are non-nullable and have + Require: + * for all rows (before applying predicates), all values in each field are non-nullable and have the same shape. + * Do not support parquet partition column. Turning on this param will address the following two issues: - 1) Auto inferring parquet schema from metadata cannot get shape information. - 2) If there's a preprocessing function, we have to specify edit/removed fields. + * Auto inferring parquet schema from metadata cannot get shape information. + * If there's a preprocessing function, we have to specify edit/removed fields. :return: A :class:`Reader` object """ dataset_url_or_urls = normalize_dataset_url_or_urls(dataset_url_or_urls) @@ -354,14 +353,7 @@ def __init__(self, pyarrow_filesystem, dataset_path, schema_fields=None, to the main data store is either slow or expensive and the local machine has large enough storage to store entire dataset (or a partition of a dataset if shards are used). By default, use the :class:`.NullCache` implementation. - :param infer_schema_from_first_row: Whether to infer schema from the row data. Only support parquet reader. - If on, before creating the reader, it will first read one row group to infer the full schema information, - and the transform spec (if exists) do not need to specify edit_fields/removed_fields. - Require: for all rows (before applying predicates), all values in each field are non-nullable and have - the same shape. - Turning on this param will address the following two issues: - 1) Auto inferring parquet schema from metadata cannot get shape information. - 2) If there's a preprocessing function, we have to specify edit/removed fields. + :param infer_schema_from_first_row: Whether to infer schema from the first row data. :param worker_class: This is the class that will be instantiated on a different thread/process. It's responsibility is to load and filter the data. @@ -426,8 +418,9 @@ def __init__(self, pyarrow_filesystem, dataset_path, schema_fields=None, if worker_class is not ArrowReaderWorker: raise ValueError('infer_schema_from_first_row only support ArrowReaderWorker.') worker0 = ArrowReaderWorker(0, None, (pyarrow_filesystem, dataset_path, storage_schema, self.ngram, - row_groups, NullCache(), None, None)) - self.schema = self._infer_schema_from_first_row(worker0, transform_spec) + row_groups, NullCache(), transform_spec, None)) + self.schema = worker0.infer_schema_from_first_row() + logger.info('Inferred schema from first row: %s', str(self.schema)) else: if transform_spec: self.schema = transform_schema(storage_schema, transform_spec) @@ -484,34 +477,6 @@ def reset(self): def batched_output(self): return self._results_queue_reader.batched_output - @staticmethod - def _infer_schema_from_first_row(worker0, transform_spec): - piece0 = worker0._split_pieces[0] - pq_file0 = ParquetFile(worker0._dataset.fs.open(piece0.path)) - piece0_pdf = worker0._load_rows(pq_file0, piece0, (0, 1)) - - row0_pdf = piece0_pdf.head(n=1) - if transform_spec: - row0_pdf = transform_spec.func(row0_pdf) - - unischema_fields = [] - for field_name in list(row0_pdf.columns): - field_val = row0_pdf[field_name][0] - - if np.ndim(field_val) == 0: - # scalar value - field_shape = () - field_numpy_type = np.dtype(type(field_val)).type - else: - field_shape = field_val.shape - field_numpy_type = field_val.dtype.type - - # TODO: add type checking,raise error for illegal type (such as np.object_) - unischema_field = UnischemaField(field_name, field_numpy_type, field_shape, None, False) - unischema_fields.append(unischema_field) - - return Unischema('inferred_schema', unischema_fields) - def _filter_row_groups(self, dataset, row_groups, predicate, rowgroup_selector, cur_shard, shard_count): """Calculates which rowgroups will be read during. diff --git a/petastorm/spark/spark_dataset_converter.py b/petastorm/spark/spark_dataset_converter.py index bdc29305a..5fa271451 100644 --- a/petastorm/spark/spark_dataset_converter.py +++ b/petastorm/spark/spark_dataset_converter.py @@ -28,6 +28,7 @@ from petastorm import make_batch_reader from petastorm.fs_utils import FilesystemResolver +from petastorm.transform import TransformSpec DEFAULT_ROW_GROUP_SIZE_BYTES = 32 * 1024 * 1024 @@ -189,6 +190,12 @@ def make_tf_dataset( 1) Open a petastorm reader on the materialized dataset dir. 2) Create a tensorflow dataset based on the reader created in (1) + The generated dataset each element will be a batch of namedtuples. + If without specifying `preprocess_fn`, each namedtuple in result dataset will match the + schema of the original spark dataframe columns, otherwise will match the columns of the + output pandas dataframe of `preprocess_fn`. The fields order will keep the same with + original spark dataframe columns or the output pandas dataframe of `preprocess_fn`. + :param batch_size: The number of items to return per batch. Default None. If None, current implementation will set batch size to be 32, in future, None value will denotes auto tuned best value for batch size. @@ -204,6 +211,9 @@ def make_tf_dataset( Default value None. :param preprocess_fn: Preprocessing function. Input is pandas dataframe of a rowgroup data and output should be the transformed pandas dataframe. + the column order of the input pandas dataframe is undefined, but the output + pandas dataframe column order will determine the result tensorflow dataset's + element fields order. :param petastorm_reader_kwargs: arguments for `petastorm.make_batch_reader()`, exclude these arguments: "dataset_url_or_urls", "num_epochs", "workers_count", "transform_spec", "infer_schema_from_first_row" @@ -228,6 +238,10 @@ def make_tf_dataset( raise ValueError('User cannot set transform_spec and infer_schema_from_first_row ' 'arguments, use `preprocess_fn` argument instead.') + petastorm_reader_kwargs['infer_schema_from_first_row'] = True + if preprocess_fn: + petastorm_reader_kwargs['transform_spec'] = TransformSpec(preprocess_fn) + hvd_rank, hvd_size = _get_horovod_rank_and_size() cur_shard = petastorm_reader_kwargs.get('cur_shard') shard_count = petastorm_reader_kwargs.get('shard_count') From 105f2ad207415bf45adc45c4e89ca77025847a25 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Mon, 23 Mar 2020 14:19:55 +0800 Subject: [PATCH 3/6] update --- petastorm/arrow_reader_worker.py | 20 +++++++++++--------- petastorm/unischema.py | 6 +++--- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/petastorm/arrow_reader_worker.py b/petastorm/arrow_reader_worker.py index bfe9ec3cc..6fabb174a 100644 --- a/petastorm/arrow_reader_worker.py +++ b/petastorm/arrow_reader_worker.py @@ -118,11 +118,10 @@ def _init_dataset(self): self._dataset_path_or_paths, filesystem=self._filesystem, validate_schema=False) - - if self._dataset.partitions is None: - # When read from parquet file list, the `dataset.partitions` will be None. - # But other petastorm code require at least an empty `ParquetPartitions` object. - self._dataset.partitions = pq.ParquetPartitions() + if self._dataset.partitions is None: + # When read from parquet file list, the `dataset.partitions` will be None. + # But other petastorm code require at least an empty `ParquetPartitions` object. + self._dataset.partitions = pq.ParquetPartitions() # pylint: disable=arguments-differ def process(self, piece_index, worker_predicate, shuffle_row_drop_partition): @@ -173,14 +172,13 @@ def process(self, piece_index, worker_predicate, shuffle_row_drop_partition): def infer_schema_from_first_row(self): self._init_dataset() - if self._dataset.partitions.partition_names: - raise ValueError('infer_schema_from_first_row does not support parquet partition column.') piece0 = self._split_pieces[0] pq_file0 = ParquetFile(self._dataset.fs.open(piece0.path)) column_names = [field_name for field_name in self._schema.fields] - piece0_pdf = compat_piece_read(piece0, lambda _: pq_file0, columns=column_names).to_pandas() + + piece0_pdf = self._read_piece(piece0, pq_file0, set(column_names)).to_pandas() row0_pdf = piece0_pdf.head(n=1) if self._transform_spec: @@ -318,7 +316,7 @@ def _load_rows_with_predicate(self, pq_file, piece, worker_predicate, shuffle_ro return pa.Table.from_pandas(result, preserve_index=False) - def _read_with_shuffle_row_drop(self, piece, pq_file, column_names, shuffle_row_drop_partition): + def _read_piece(self, piece, pq_file, column_names): partition_names = self._dataset.partitions.partition_names # pyarrow would fail if we request a column names that the dataset is partitioned by @@ -332,6 +330,10 @@ def _read_with_shuffle_row_drop(self, piece, pq_file, column_names, shuffle_row_ unasked_for_columns = loaded_column_names - column_names if unasked_for_columns: table = table.drop(unasked_for_columns) + return table + + def _read_with_shuffle_row_drop(self, piece, pq_file, column_names, shuffle_row_drop_partition): + table = self._read_piece(piece, pq_file, column_names) num_rows = len(table) num_partitions = shuffle_row_drop_partition[1] diff --git a/petastorm/unischema.py b/petastorm/unischema.py index 88008fd71..472abf7c4 100644 --- a/petastorm/unischema.py +++ b/petastorm/unischema.py @@ -172,11 +172,11 @@ def __init__(self, name, fields): """Creates an instance of a Unischema object. :param name: name of the schema - :param fields: a list of ``UnischemaField`` instances describing the fields. The order of the fields - will be the order from the fields list. + :param fields: a list of ``UnischemaField`` instances describing the fields. The order of the fields is + not important - they are stored sorted by name internally. """ self._name = name - self._fields = OrderedDict([(f.name, f) for f in fields]) + self._fields = OrderedDict([(f.name, f) for f in sorted(fields, key=lambda t: t.name)]) # Generates attributes named by the field names as an access syntax sugar. for f in fields: if not hasattr(self, f.name): From 5d015026df403dc4ab65aa7c3d8943ec9fe2d0f8 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Mon, 23 Mar 2020 16:50:40 +0800 Subject: [PATCH 4/6] fix doc --- petastorm/reader.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/petastorm/reader.py b/petastorm/reader.py index 8f14c46f6..6efe7366e 100644 --- a/petastorm/reader.py +++ b/petastorm/reader.py @@ -250,13 +250,10 @@ def make_batch_reader(dataset_url_or_urls, :param infer_schema_from_first_row: Whether to infer schema from the first row data. Only support parquet reader. If on, before creating the reader, it will first read one row group to infer the full schema information, and the transform spec (if exists) do not need to specify edit_fields/removed_fields. - Require: - * for all rows (before applying predicates), all values in each field are non-nullable and have - the same shape. - * Do not support parquet partition column. - Turning on this param will address the following two issues: - * Auto inferring parquet schema from metadata cannot get shape information. - * If there's a preprocessing function, we have to specify edit/removed fields. + Require: for all rows (before applying predicates), all values in each field are non-nullable and have the + same shape. + Turning on this param will address the following two issues: (1) Auto inferring parquet schema from metadata + cannot get shape information. (2) If there's a preprocessing function, we have to specify edit/removed fields. :return: A :class:`Reader` object """ dataset_url_or_urls = normalize_dataset_url_or_urls(dataset_url_or_urls) From 8d41e709a2fd2f8ec109c4a325e9d183f55636d4 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 24 Mar 2020 19:54:42 +0800 Subject: [PATCH 5/6] update --- petastorm/arrow_reader_worker.py | 1 + 1 file changed, 1 insertion(+) diff --git a/petastorm/arrow_reader_worker.py b/petastorm/arrow_reader_worker.py index 6fabb174a..2b2ec408a 100644 --- a/petastorm/arrow_reader_worker.py +++ b/petastorm/arrow_reader_worker.py @@ -183,6 +183,7 @@ def infer_schema_from_first_row(self): row0_pdf = piece0_pdf.head(n=1) if self._transform_spec: row0_pdf = self._transform_spec.func(row0_pdf) + column_names = list(row0_pdf.columns) unischema_fields = [] for field_name in column_names: From 529cb83522918d73daac3768f1c7f361f87d7554 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 24 Mar 2020 20:09:46 +0800 Subject: [PATCH 6/6] update --- petastorm/arrow_reader_worker.py | 2 +- petastorm/reader.py | 14 +++++++------- petastorm/spark/spark_dataset_converter.py | 6 +++--- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/petastorm/arrow_reader_worker.py b/petastorm/arrow_reader_worker.py index 2b2ec408a..f8c7aa4e3 100644 --- a/petastorm/arrow_reader_worker.py +++ b/petastorm/arrow_reader_worker.py @@ -170,7 +170,7 @@ def process(self, piece_index, worker_predicate, shuffle_row_drop_partition): if all_cols: self.publish_func(all_cols) - def infer_schema_from_first_row(self): + def infer_schema_from_a_row(self): self._init_dataset() piece0 = self._split_pieces[0] diff --git a/petastorm/reader.py b/petastorm/reader.py index 6efe7366e..b63a735e0 100644 --- a/petastorm/reader.py +++ b/petastorm/reader.py @@ -198,7 +198,7 @@ def make_batch_reader(dataset_url_or_urls, cache_row_size_estimate=None, cache_extra_settings=None, hdfs_driver='libhdfs3', transform_spec=None, - infer_schema_from_first_row=False): + infer_schema_from_a_row=False): """ Creates an instance of Reader for reading batches out of a non-Petastorm Parquet store. @@ -247,7 +247,7 @@ def make_batch_reader(dataset_url_or_urls, :param transform_spec: An instance of :class:`~petastorm.transform.TransformSpec` object defining how a record is transformed after it is loaded and decoded. The transformation occurs on a worker thread/process (depends on the ``reader_pool_type`` value). - :param infer_schema_from_first_row: Whether to infer schema from the first row data. Only support parquet reader. + :param infer_schema_from_a_row: Whether to infer schema from a row data. Only support parquet reader. If on, before creating the reader, it will first read one row group to infer the full schema information, and the transform spec (if exists) do not need to specify edit_fields/removed_fields. Require: for all rows (before applying predicates), all values in each field are non-nullable and have the @@ -301,7 +301,7 @@ def make_batch_reader(dataset_url_or_urls, cache=cache, transform_spec=transform_spec, is_batched_reader=True, - infer_schema_from_first_row=infer_schema_from_first_row) + infer_schema_from_a_row=infer_schema_from_a_row) class Reader(object): @@ -314,7 +314,7 @@ def __init__(self, pyarrow_filesystem, dataset_path, schema_fields=None, shuffle_row_groups=True, shuffle_row_drop_partitions=1, predicate=None, rowgroup_selector=None, reader_pool=None, num_epochs=1, cur_shard=None, shard_count=None, cache=None, worker_class=None, - transform_spec=None, is_batched_reader=False, infer_schema_from_first_row=False): + transform_spec=None, is_batched_reader=False, infer_schema_from_a_row=False): """Initializes a reader object. :param pyarrow_filesystem: An instance of ``pyarrow.FileSystem`` that will be used. If not specified, @@ -350,7 +350,7 @@ def __init__(self, pyarrow_filesystem, dataset_path, schema_fields=None, to the main data store is either slow or expensive and the local machine has large enough storage to store entire dataset (or a partition of a dataset if shards are used). By default, use the :class:`.NullCache` implementation. - :param infer_schema_from_first_row: Whether to infer schema from the first row data. + :param infer_schema_from_a_row: Whether to infer schema from a row data. :param worker_class: This is the class that will be instantiated on a different thread/process. It's responsibility is to load and filter the data. @@ -411,12 +411,12 @@ def __init__(self, pyarrow_filesystem, dataset_path, schema_fields=None, # 2. Get a list of all row groups and infer schema if needed. row_groups = dataset_metadata.load_row_groups(self.dataset) - if infer_schema_from_first_row: + if infer_schema_from_a_row: if worker_class is not ArrowReaderWorker: raise ValueError('infer_schema_from_first_row only support ArrowReaderWorker.') worker0 = ArrowReaderWorker(0, None, (pyarrow_filesystem, dataset_path, storage_schema, self.ngram, row_groups, NullCache(), transform_spec, None)) - self.schema = worker0.infer_schema_from_first_row() + self.schema = worker0.infer_schema_from_a_row() logger.info('Inferred schema from first row: %s', str(self.schema)) else: if transform_spec: diff --git a/petastorm/spark/spark_dataset_converter.py b/petastorm/spark/spark_dataset_converter.py index 5fa271451..e633b089e 100644 --- a/petastorm/spark/spark_dataset_converter.py +++ b/petastorm/spark/spark_dataset_converter.py @@ -234,11 +234,11 @@ def make_tf_dataset( raise ValueError('User cannot set dataset_url_or_urls argument.') if 'transform_spec' in petastorm_reader_kwargs or \ - 'infer_schema_from_first_row' in petastorm_reader_kwargs: - raise ValueError('User cannot set transform_spec and infer_schema_from_first_row ' + 'infer_schema_from_a_row' in petastorm_reader_kwargs: + raise ValueError('User cannot set transform_spec and infer_schema_from_a_row ' 'arguments, use `preprocess_fn` argument instead.') - petastorm_reader_kwargs['infer_schema_from_first_row'] = True + petastorm_reader_kwargs['infer_schema_from_a_row'] = True if preprocess_fn: petastorm_reader_kwargs['transform_spec'] = TransformSpec(preprocess_fn)