Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][ML-10118] Keep petastorm dataset/dataloader schema fields order the same with spark dataframe #511

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion petastorm/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


class TransformSpec(object):
def __init__(self, func=None, edit_fields=None, removed_fields=None):
def __init__(self, func=None, edit_fields=None, removed_fields=None, selected_fields=None):
"""TransformSpec defines a user transformation that is applied to a loaded row on a worker thread/process.

The object defines the function (callable) that perform the transform as well as the
Expand All @@ -34,10 +34,12 @@ def __init__(self, func=None, edit_fields=None, removed_fields=None):
:param edit_fields: Optional. A list of 4-tuples with the following fields:
``(name, numpy_dtype, shape, is_nullable)``
:param removed_fields: Optional[list]. A list of field names that will be removed from the original schema.
:param selected_fields: Optional[list]. A list of field names specify the fields to be selected.
"""
self.func = func
self.edit_fields = edit_fields or []
self.removed_fields = removed_fields or []
self.selected_fields = selected_fields


def transform_schema(schema, transform_spec):
Expand All @@ -61,4 +63,8 @@ def transform_schema(schema, transform_spec):
shape=field_to_edit[2], codec=None, nullable=field_to_edit[3])
fields.append(edited_unischema_field)

if transform_spec.selected_fields is not None:
fields = [f for f in fields if f.name in transform_spec.selected_fields]
fields = sorted(fields, key=lambda f: transform_spec.selected_fields.index(f.name))

return Unischema(schema._name + '_transformed', fields)
10 changes: 6 additions & 4 deletions petastorm/unischema.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,11 @@ def get(parent_schema_name, field_names):
:return: A namedtuple with field names defined by `field_names`
"""
# Cache key is a combination of schema name and all field names
sorted_names = list(sorted(field_names))
key = ' '.join([parent_schema_name] + sorted_names)
field_names = list(field_names)
key = ' '.join([parent_schema_name] + field_names)
if key not in _NamedtupleCache._store:
_NamedtupleCache._store[key] = \
_new_gt_255_compatible_namedtuple('{}_view'.format(parent_schema_name), sorted_names)
_new_gt_255_compatible_namedtuple('{}_view'.format(parent_schema_name), field_names)
return _NamedtupleCache._store[key]


Expand Down Expand Up @@ -185,6 +185,8 @@ def __init__(self, name, fields):
warnings.warn(('Can not create dynamic property {} because it conflicts with an existing property of '
'Unischema').format(f.name))

self._ordered_fields_name = [f.name for f in fields]

def create_schema_view(self, fields):
"""Creates a new instance of the schema using a subset of fields.

Expand Down Expand Up @@ -229,7 +231,7 @@ def create_schema_view(self, fields):
return Unischema('{}_view'.format(self._name), view_fields)

def _get_namedtuple(self):
return _NamedtupleCache.get(self._name, self._fields.keys())
return _NamedtupleCache.get(self._name, self._ordered_fields_name)

def __str__(self):
"""Represent this as the following form:
Expand Down