From b79068a5a72455441b9c007c6c47c76d9266e9ba Mon Sep 17 00:00:00 2001 From: Kaiyuan Eric Chen Date: Tue, 27 Aug 2024 00:25:41 -0700 Subject: [PATCH] octo dataloader working --- fog_x/DLdataset.py | 4 ++-- fog_x/feature.py | 8 ++++++-- fog_x/trajectory.py | 8 +++++++- fog_x/utils.py | 5 +++-- 4 files changed, 18 insertions(+), 7 deletions(-) diff --git a/fog_x/DLdataset.py b/fog_x/DLdataset.py index dec7605..d64f7da 100644 --- a/fog_x/DLdataset.py +++ b/fog_x/DLdataset.py @@ -209,7 +209,6 @@ def _convert_h5_cache_to_tensor(): output_tf_traj[key] = tf.convert_to_tensor(h5_cache[key]) return output_tf_traj output = {"steps" : _convert_h5_cache_to_tensor()} - print(output) yield output @@ -253,7 +252,8 @@ def _convert_h5_cache_to_tensor(): output_signature = {"steps" : tf.nest.map_structure( lambda spec: tf.TensorSpec(shape=spec.shape, dtype=spec.dtype), step_spec )} - + print(output_signature) + dataset = _wrap(tf.data.Dataset.from_generator, False)( generator, output_signature=output_signature diff --git a/fog_x/feature.py b/fog_x/feature.py index 4aa0495..08c3e1b 100644 --- a/fog_x/feature.py +++ b/fog_x/feature.py @@ -140,7 +140,7 @@ def from_str(self, feature_str: str): shape = eval(shape.split("=")[1][:-1]) # strip brackets return FeatureType(dtype=dtype, shape=shape) - def to_tf_feature_type(self): + def to_tf_feature_type(self, first_dim_none=False): """ Convert to tf feature """ @@ -176,7 +176,11 @@ def to_tf_feature_type(self): else: return Scalar(dtype=tf_detype) elif len(self.shape) >= 1: - return Tensor(shape=self.shape, dtype=tf_detype) + if first_dim_none: + tf_shape = [None] + list(self.shape[1:]) + return Tensor(shape=tf_shape, dtype=tf_detype) + else: + return Tensor(shape=self.shape, dtype=tf_detype) else: raise ValueError(f"Unsupported conversion to tf feature: {self}") diff --git a/fog_x/trajectory.py b/fog_x/trajectory.py index 6c35fd0..46618b5 100644 --- a/fog_x/trajectory.py +++ b/fog_x/trajectory.py @@ -558,7 +558,13 @@ def _get_length_of_stream(container, stream): h5_cache = h5py.File(self.cache_file_name, "w") for feature_name, data in np_cache.items(): if data.dtype == object: - continue # TODO + for i in range(len(data)): + if data[i] is not None: + data[i] = str(data[i]) + h5_cache.create_dataset( + feature_name, + data=data + ) else: h5_cache.create_dataset(feature_name, data=data) h5_cache.close() diff --git a/fog_x/utils.py b/fog_x/utils.py index 06cbed3..cdbf925 100644 --- a/fog_x/utils.py +++ b/fog_x/utils.py @@ -14,7 +14,8 @@ def data_to_tf_schema(data: Dict[str, Any]) -> Dict[str, FeatureType]: main_key, sub_key = k.split("/") if main_key not in schema: schema[main_key] = {} - schema[main_key][sub_key] = FeatureType.from_data(v).to_tf_feature_type() + schema[main_key][sub_key] = FeatureType.from_data(v).to_tf_feature_type(first_dim_none=True) + # replace first element of shape with None else: - schema[k] = FeatureType.from_data(v).to_tf_feature_type() + schema[k] = FeatureType.from_data(v).to_tf_feature_type(first_dim_none=True) return schema \ No newline at end of file