Skip to content

Commit

Permalink
add a new feature "display_image"
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 638014851
  • Loading branch information
The TensorFlow Datasets Authors committed May 28, 2024
1 parent 6bbba45 commit 25714a7
Showing 1 changed file with 35 additions and 0 deletions.
35 changes: 35 additions & 0 deletions tensorflow_datasets/robotics/dataset_importer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,22 @@ def _info(self) -> tfds.core.DatasetInfo:

tmp = dict(features)

# add all image features from observations to a new featuresdict
image_from_observation = {}
if 'steps' in tmp and 'observation' in tmp['steps']:
observation = tmp['steps']['observation']
for feature_name, feature_data in observation.items():
if isinstance(feature_data, tfds.features.Image):
image_from_observation[feature_name] = feature_data
image_from_observation_dict = tfds.features.FeaturesDict(
image_from_observation
)
tmp['display_image'] = image_from_observation_dict

for key in self.KEYS_TO_STRIP:
if key in tmp:
del tmp[key]

features = tfds.features.FeaturesDict(tmp)

return tfds.core.DatasetInfo(
Expand Down Expand Up @@ -115,20 +128,42 @@ def _generate_examples(
split = split_info.name
read_config = read_config_lib.ReadConfig(add_tfds_id=True)

features = self.get_ds_builder().info.features
tmp = dict(features)
image_name_from_observation_set = set()
if 'steps' in tmp and 'observation' in tmp['steps']:
observation = tmp['steps']['observation']
for feature_name, feature_data in observation.items():
if isinstance(feature_data, tfds.features.Image):
image_name_from_observation_set.add(feature_name)

decode_fn = builder.info.features['steps'].feature.decode_example

def converter_fn(example):
# Decode the RLDS Episode and transform it to numpy.
example_out = dict(example)

example_out['steps'] = tf.data.Dataset.from_tensor_slices(
example_out['steps']
).map(decode_fn)

steps = list(iter(example_out['steps'].take(-1)))
example_out['steps'] = steps

example_out = dataset_utils.as_numpy(example_out)
new_step = example_out['steps']
first_step = new_step[0]
image_feature_dict = {}

for feature_name, feature_data in first_step['observation'].items():
# all the features are arrays now we cannot distinguish image/others
if feature_name in image_name_from_observation_set:
image_feature_dict[feature_name] = feature_data
if image_feature_dict:
example_out['display_image'] = image_feature_dict

example_id = example_out['tfds_id'].decode('utf-8')

del example_out['tfds_id']
for key in self.KEYS_TO_STRIP:
if key in example_out:
Expand Down

0 comments on commit 25714a7

Please sign in to comment.