diff --git a/fog_x/dataset.py b/fog_x/dataset.py index e9a9c27..ed0c98c 100644 --- a/fog_x/dataset.py +++ b/fog_x/dataset.py @@ -381,12 +381,53 @@ def load_rtx_episodes( import tensorflow_datasets as tfds from fog_x.rlds.utils import dataset2path - b = tfds.builder_from_directory(builder_dir=dataset2path(name)) + self._build_rtx_episodes_from_tfds_builder( + b, + split=split, + additional_metadata=additional_metadata, + ) + + def load_rtx_episodes_local( + self, + path: str, + split: str = "all", + additional_metadata: Optional[Dict[str, Any]] = dict(), + ): + """ + Load robot data from Tensorflow Datasets. - ds = b.as_dataset(split=split) + Args: + path (str): Path to the RT-X episodes + split (optional str): the portion of data to load, see [Tensorflow Split API](https://www.tensorflow.org/datasets/splits) + additional_metadata (optional Dict[str, Any]): additional metadata to be associated with the loaded episodes - data_type = b.info.features["steps"] + Example: + ``` + >>> dataset.load_rtx_episodes_local(path="~/Downloads/berkeley_autolab_ur5") + >>> dataset.load_rtx_episodes_local(path="~/Downloads/berkeley_autolab_ur5", split="train[:10]", additional_metadata={"data_collector": "Alice", "custom_tag": "sample"}) + ``` + """ + + # this is only required if rtx format is used + import tensorflow_datasets as tfds + + b = tfds.builder(path) + self._build_rtx_episodes_from_tfds_builder( + b, + split=split, + additional_metadata=additional_metadata, + ) + + def _build_rtx_episodes_from_tfds_builder( + builder, + ): + """ + construct the dataset from the tfds builder + """ + ds = builder.as_dataset(split=split) + + data_type = builder.info.features["steps"] for tf_episode in ds: logger.info(tf_episode)