Skip to content

Commit

Permalink
Add "display_image" feature to robotics dataset importer builder.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 638110575
  • Loading branch information
The TensorFlow Datasets Authors committed May 29, 2024
1 parent 6bbba45 commit ff98c91
Showing 1 changed file with 35 additions and 0 deletions.
35 changes: 35 additions & 0 deletions tensorflow_datasets/robotics/dataset_importer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ class DatasetImporterBuilder(
'ssot_session_key',
]

images_from_observation_dict = {}


@abc.abstractmethod
def get_description(self):
Expand Down Expand Up @@ -83,9 +85,15 @@ def _info(self) -> tfds.core.DatasetInfo:

tmp = dict(features)

# add all image features from observations to a new featuresdict
self.images_from_observation_dict = self.get_images_from_observation_dict()
if self.images_from_observation_dict:
tmp['display_image'] = self.images_from_observation_dict

for key in self.KEYS_TO_STRIP:
if key in tmp:
del tmp[key]

features = tfds.features.FeaturesDict(tmp)

return tfds.core.DatasetInfo(
Expand Down Expand Up @@ -120,15 +128,28 @@ def _generate_examples(
def converter_fn(example):
# Decode the RLDS Episode and transform it to numpy.
example_out = dict(example)

example_out['steps'] = tf.data.Dataset.from_tensor_slices(
example_out['steps']
).map(decode_fn)

steps = list(iter(example_out['steps'].take(-1)))
example_out['steps'] = steps

example_out = dataset_utils.as_numpy(example_out)
first_step = example_out['steps'][0]
image_feature_dict = {}

for feature_name in self.images_from_observation_dict:
image_feature_dict[feature_name] = first_step['observation'][
feature_name
]

if image_feature_dict:
example_out['display_image'] = image_feature_dict

example_id = example_out['tfds_id'].decode('utf-8')

del example_out['tfds_id']
for key in self.KEYS_TO_STRIP:
if key in example_out:
Expand All @@ -148,3 +169,17 @@ def get_ds_builder(self):
ds_location = self.get_dataset_location()
ds_builder = tfds.builder_from_directory(ds_location)
return ds_builder

def get_images_from_observation_dict(self):
features = self.get_ds_builder().info.features
tmp = dict(features)
images_from_observation = {}
if 'steps' in tmp and 'observation' in tmp['steps']:
observation = tmp['steps']['observation']
for feature_name, feature_data in observation.items():
if isinstance(feature_data, tfds.features.Image):
images_from_observation[feature_name] = feature_data
images_from_observation_dict = tfds.features.FeaturesDict(
images_from_observation
)
return images_from_observation_dict

0 comments on commit ff98c91

Please sign in to comment.