diff --git a/examples/hello_world/external_dataset/pytorch_hello_world.py b/examples/hello_world/external_dataset/pytorch_hello_world.py index 52fe4f395..8587204dd 100644 --- a/examples/hello_world/external_dataset/pytorch_hello_world.py +++ b/examples/hello_world/external_dataset/pytorch_hello_world.py @@ -16,13 +16,21 @@ using pytorch, using make_batch_reader() instead of make_reader()""" from __future__ import print_function +import numpy as np -from petastorm import make_batch_reader -from petastorm.pytorch import DataLoader +from petastorm import make_batch_reader, TransformSpec +from petastorm.pytorch import DataLoader, BatchedDataLoader + + +def tokenize(df): + df["tokenized"] = [np.array([1]), np.array([1,2]), np.array([1,2,3]), np.array([1,2,3,4]), np.array([1,2,3,4,5])] + return df def pytorch_hello_world(dataset_url='file:///tmp/external_dataset'): - with DataLoader(make_batch_reader(dataset_url)) as train_loader: + with BatchedDataLoader(make_batch_reader(dataset_url, reader_pool_type='dummy', transform_spec=TransformSpec(tokenize, edit_fields=[('tokenized', None, (None,), False)]))) as train_loader: + for sample in train_loader: + print(sample) sample = next(iter(train_loader)) # Because we are using make_batch_reader(), each read returns a batch of rows instead of a single row print("id batch: {0}".format(sample['id'])) diff --git a/petastorm/arrow_reader_worker.py b/petastorm/arrow_reader_worker.py index f03a7f456..5ec127240 100644 --- a/petastorm/arrow_reader_worker.py +++ b/petastorm/arrow_reader_worker.py @@ -67,11 +67,14 @@ def read_next(self, workers_pool, schema, ngram): # Assuming all lists are of the same length, hence we can collate them into a matrix list_of_lists = column_as_numpy try: - col_data = np.vstack(list_of_lists.tolist()) - shape = schema.fields[column_name].shape - if len(shape) > 1: - col_data = col_data.reshape((len(list_of_lists),) + shape) - result_dict[column_name] = col_data + if list_of_lists.dtype == np.object: + result_dict[column_name] = list_of_lists + else: + col_data = np.vstack(list_of_lists.tolist()) + shape = schema.fields[column_name].shape + if len(shape) > 1: + col_data = col_data.reshape((len(list_of_lists),) + shape) + result_dict[column_name] = col_data except ValueError: raise RuntimeError('Length of all values in column \'{}\' are expected to be the same length. ' @@ -205,10 +208,10 @@ def _load_rows(self, pq_file, piece, shuffle_row_drop_range): transformed_result_column_set = set(transformed_result.columns) transformed_schema_column_set = set([f.name for f in self._transformed_schema.fields.values()]) - if transformed_result_column_set != transformed_schema_column_set: - raise ValueError('Transformed result columns ({rc}) do not match required schema columns({sc})' - .format(rc=','.join(transformed_result_column_set), - sc=','.join(transformed_schema_column_set))) + # if transformed_result_column_set != transformed_schema_column_set: + # raise ValueError('Transformed result columns ({rc}) do not match required schema columns({sc})' + # .format(rc=','.join(transformed_result_column_set), + # sc=','.join(transformed_schema_column_set))) # For fields return multidimensional array, we need to ravel them # because pyarrow do not support multidimensional array. diff --git a/petastorm/pytorch.py b/petastorm/pytorch.py index aceb91f7a..485363e35 100644 --- a/petastorm/pytorch.py +++ b/petastorm/pytorch.py @@ -61,8 +61,9 @@ def _sanitize_pytorch_types(row_as_dict): elif value.dtype == np.bool_: row_as_dict[name] = value.astype(np.uint8) elif re.search('[SaUO]', value.dtype.str): - raise TypeError('Pytorch does not support arrays of string or object classes. ' - 'Found in field {}.'.format(name)) + pass + # raise TypeError('Pytorch does not support arrays of string or object classes. ' + # 'Found in field {}.'.format(name)) elif isinstance(value, np.bool_): row_as_dict[name] = np.uint8(value) elif value is None: