diff --git a/petastorm/arrow_reader_worker.py b/petastorm/arrow_reader_worker.py index de1654b0c..f8c7aa4e3 100644 --- a/petastorm/arrow_reader_worker.py +++ b/petastorm/arrow_reader_worker.py @@ -24,6 +24,7 @@ from petastorm.cache import NullCache from petastorm.compat import compat_piece_read, compat_table_columns_gen, compat_column_data +from petastorm.unischema import Unischema, UnischemaField from petastorm.workers_pool import EmptyResultError from petastorm.workers_pool.worker_base import WorkerBase @@ -111,6 +112,17 @@ def __init__(self, worker_id, publish_func, args): def new_results_queue_reader(): return ArrowReaderWorkerResultsQueueReader() + def _init_dataset(self): + if not self._dataset: + self._dataset = pq.ParquetDataset( + self._dataset_path_or_paths, + filesystem=self._filesystem, + validate_schema=False) + if self._dataset.partitions is None: + # When read from parquet file list, the `dataset.partitions` will be None. + # But other petastorm code require at least an empty `ParquetPartitions` object. + self._dataset.partitions = pq.ParquetPartitions() + # pylint: disable=arguments-differ def process(self, piece_index, worker_predicate, shuffle_row_drop_partition): """Main worker function. Loads and returns all rows matching the predicate from a rowgroup @@ -124,17 +136,7 @@ def process(self, piece_index, worker_predicate, shuffle_row_drop_partition): of partitions. :return: """ - - if not self._dataset: - self._dataset = pq.ParquetDataset( - self._dataset_path_or_paths, - filesystem=self._filesystem, - validate_schema=False) - - if self._dataset.partitions is None: - # When read from parquet file list, the `dataset.partitions` will be None. - # But other petastorm code require at least an empty `ParquetPartitions` object. - self._dataset.partitions = pq.ParquetPartitions() + self._init_dataset() piece = self._split_pieces[piece_index] @@ -168,6 +170,39 @@ def process(self, piece_index, worker_predicate, shuffle_row_drop_partition): if all_cols: self.publish_func(all_cols) + def infer_schema_from_a_row(self): + self._init_dataset() + + piece0 = self._split_pieces[0] + pq_file0 = ParquetFile(self._dataset.fs.open(piece0.path)) + + column_names = [field_name for field_name in self._schema.fields] + + piece0_pdf = self._read_piece(piece0, pq_file0, set(column_names)).to_pandas() + + row0_pdf = piece0_pdf.head(n=1) + if self._transform_spec: + row0_pdf = self._transform_spec.func(row0_pdf) + column_names = list(row0_pdf.columns) + + unischema_fields = [] + for field_name in column_names: + field_val = row0_pdf[field_name][0] + + if np.ndim(field_val) == 0: + # scalar value + field_shape = () + field_numpy_type = np.dtype(type(field_val)).type + else: + field_shape = field_val.shape + field_numpy_type = field_val.dtype.type + + # TODO: add type checking,raise error for illegal type (such as np.object_) + unischema_field = UnischemaField(field_name, field_numpy_type, field_shape, None, False) + unischema_fields.append(unischema_field) + + return Unischema('inferred_schema', unischema_fields) + @staticmethod def _check_shape_and_ravel(x, field): if not isinstance(x, np.ndarray): @@ -282,7 +317,7 @@ def _load_rows_with_predicate(self, pq_file, piece, worker_predicate, shuffle_ro return pa.Table.from_pandas(result, preserve_index=False) - def _read_with_shuffle_row_drop(self, piece, pq_file, column_names, shuffle_row_drop_partition): + def _read_piece(self, piece, pq_file, column_names): partition_names = self._dataset.partitions.partition_names # pyarrow would fail if we request a column names that the dataset is partitioned by @@ -296,6 +331,10 @@ def _read_with_shuffle_row_drop(self, piece, pq_file, column_names, shuffle_row_ unasked_for_columns = loaded_column_names - column_names if unasked_for_columns: table = table.drop(unasked_for_columns) + return table + + def _read_with_shuffle_row_drop(self, piece, pq_file, column_names, shuffle_row_drop_partition): + table = self._read_piece(piece, pq_file, column_names) num_rows = len(table) num_partitions = shuffle_row_drop_partition[1] diff --git a/petastorm/reader.py b/petastorm/reader.py index 35b854e9e..b63a735e0 100644 --- a/petastorm/reader.py +++ b/petastorm/reader.py @@ -170,6 +170,7 @@ def make_reader(dataset_url, 'shard_count': shard_count, 'cache': cache, 'transform_spec': transform_spec, + 'infer_schema_from_first_row': False, } try: @@ -196,7 +197,8 @@ def make_batch_reader(dataset_url_or_urls, cache_type='null', cache_location=None, cache_size_limit=None, cache_row_size_estimate=None, cache_extra_settings=None, hdfs_driver='libhdfs3', - transform_spec=None): + transform_spec=None, + infer_schema_from_a_row=False): """ Creates an instance of Reader for reading batches out of a non-Petastorm Parquet store. @@ -245,6 +247,13 @@ def make_batch_reader(dataset_url_or_urls, :param transform_spec: An instance of :class:`~petastorm.transform.TransformSpec` object defining how a record is transformed after it is loaded and decoded. The transformation occurs on a worker thread/process (depends on the ``reader_pool_type`` value). + :param infer_schema_from_a_row: Whether to infer schema from a row data. Only support parquet reader. + If on, before creating the reader, it will first read one row group to infer the full schema information, + and the transform spec (if exists) do not need to specify edit_fields/removed_fields. + Require: for all rows (before applying predicates), all values in each field are non-nullable and have the + same shape. + Turning on this param will address the following two issues: (1) Auto inferring parquet schema from metadata + cannot get shape information. (2) If there's a preprocessing function, we have to specify edit/removed fields. :return: A :class:`Reader` object """ dataset_url_or_urls = normalize_dataset_url_or_urls(dataset_url_or_urls) @@ -291,7 +300,8 @@ def make_batch_reader(dataset_url_or_urls, shard_count=shard_count, cache=cache, transform_spec=transform_spec, - is_batched_reader=True) + is_batched_reader=True, + infer_schema_from_a_row=infer_schema_from_a_row) class Reader(object): @@ -304,7 +314,7 @@ def __init__(self, pyarrow_filesystem, dataset_path, schema_fields=None, shuffle_row_groups=True, shuffle_row_drop_partitions=1, predicate=None, rowgroup_selector=None, reader_pool=None, num_epochs=1, cur_shard=None, shard_count=None, cache=None, worker_class=None, - transform_spec=None, is_batched_reader=False): + transform_spec=None, is_batched_reader=False, infer_schema_from_a_row=False): """Initializes a reader object. :param pyarrow_filesystem: An instance of ``pyarrow.FileSystem`` that will be used. If not specified, @@ -340,13 +350,14 @@ def __init__(self, pyarrow_filesystem, dataset_path, schema_fields=None, to the main data store is either slow or expensive and the local machine has large enough storage to store entire dataset (or a partition of a dataset if shards are used). By default, use the :class:`.NullCache` implementation. + :param infer_schema_from_a_row: Whether to infer schema from a row data. :param worker_class: This is the class that will be instantiated on a different thread/process. It's responsibility is to load and filter the data. """ # 1. Open the parquet storage (dataset) - # 2. Get a list of all groups + # 2. Get a list of all groups and infer schema if needed. # 3. Filter rowgroups # a. predicates # b. row-group selector (our indexing mechanism) @@ -396,14 +407,23 @@ def __init__(self, pyarrow_filesystem, dataset_path, schema_fields=None, fields = schema_fields if isinstance(schema_fields, collections.Iterable) else None storage_schema = stored_schema.create_schema_view(fields) if fields else stored_schema - if transform_spec: - self.schema = transform_schema(storage_schema, transform_spec) - else: - self.schema = storage_schema - # 2. Get a list of all row groups + # 2. Get a list of all row groups and infer schema if needed. row_groups = dataset_metadata.load_row_groups(self.dataset) + if infer_schema_from_a_row: + if worker_class is not ArrowReaderWorker: + raise ValueError('infer_schema_from_first_row only support ArrowReaderWorker.') + worker0 = ArrowReaderWorker(0, None, (pyarrow_filesystem, dataset_path, storage_schema, self.ngram, + row_groups, NullCache(), transform_spec, None)) + self.schema = worker0.infer_schema_from_a_row() + logger.info('Inferred schema from first row: %s', str(self.schema)) + else: + if transform_spec: + self.schema = transform_schema(storage_schema, transform_spec) + else: + self.schema = storage_schema + # 3. Filter rowgroups filtered_row_group_indexes, worker_predicate = self._filter_row_groups(self.dataset, row_groups, predicate, rowgroup_selector, cur_shard, diff --git a/petastorm/spark/spark_dataset_converter.py b/petastorm/spark/spark_dataset_converter.py index 92f641860..e633b089e 100644 --- a/petastorm/spark/spark_dataset_converter.py +++ b/petastorm/spark/spark_dataset_converter.py @@ -28,6 +28,7 @@ from petastorm import make_batch_reader from petastorm.fs_utils import FilesystemResolver +from petastorm.transform import TransformSpec DEFAULT_ROW_GROUP_SIZE_BYTES = 32 * 1024 * 1024 @@ -179,6 +180,7 @@ def make_tf_dataset( prefetch=None, num_epochs=None, workers_count=None, + preprocess_fn=None, **petastorm_reader_kwargs ): """ @@ -188,6 +190,12 @@ def make_tf_dataset( 1) Open a petastorm reader on the materialized dataset dir. 2) Create a tensorflow dataset based on the reader created in (1) + The generated dataset each element will be a batch of namedtuples. + If without specifying `preprocess_fn`, each namedtuple in result dataset will match the + schema of the original spark dataframe columns, otherwise will match the columns of the + output pandas dataframe of `preprocess_fn`. The fields order will keep the same with + original spark dataframe columns or the output pandas dataframe of `preprocess_fn`. + :param batch_size: The number of items to return per batch. Default None. If None, current implementation will set batch size to be 32, in future, None value will denotes auto tuned best value for batch size. @@ -201,8 +209,14 @@ def make_tf_dataset( None denotes auto tune best value (current implementation when auto tune, it will always use 4 workers, but it may be improved in future) Default value None. + :param preprocess_fn: Preprocessing function. Input is pandas dataframe of + a rowgroup data and output should be the transformed pandas dataframe. + the column order of the input pandas dataframe is undefined, but the output + pandas dataframe column order will determine the result tensorflow dataset's + element fields order. :param petastorm_reader_kwargs: arguments for `petastorm.make_batch_reader()`, - exclude these arguments: "dataset_url", "num_epochs", "workers_count". + exclude these arguments: "dataset_url_or_urls", "num_epochs", "workers_count", + "transform_spec", "infer_schema_from_first_row" :return: a context manager for a `tf.data.Dataset` object. when exit the returned context manager, the reader @@ -216,6 +230,18 @@ def make_tf_dataset( workers_count = 4 petastorm_reader_kwargs['workers_count'] = workers_count + if 'dataset_url_or_urls' in petastorm_reader_kwargs: + raise ValueError('User cannot set dataset_url_or_urls argument.') + + if 'transform_spec' in petastorm_reader_kwargs or \ + 'infer_schema_from_a_row' in petastorm_reader_kwargs: + raise ValueError('User cannot set transform_spec and infer_schema_from_a_row ' + 'arguments, use `preprocess_fn` argument instead.') + + petastorm_reader_kwargs['infer_schema_from_a_row'] = True + if preprocess_fn: + petastorm_reader_kwargs['transform_spec'] = TransformSpec(preprocess_fn) + hvd_rank, hvd_size = _get_horovod_rank_and_size() cur_shard = petastorm_reader_kwargs.get('cur_shard') shard_count = petastorm_reader_kwargs.get('shard_count') diff --git a/petastorm/tf_utils.py b/petastorm/tf_utils.py index ec3a8eed5..3b5dfa4c0 100644 --- a/petastorm/tf_utils.py +++ b/petastorm/tf_utils.py @@ -38,6 +38,7 @@ np.string_: tf.string, np.unicode_: tf.string, np.str_: tf.string, + np.bytes_: tf.string, np.bool_: tf.bool, Decimal: tf.string, np.datetime64: tf.int64,