Skip to content

Commit

Permalink
octo dataloader working
Browse files Browse the repository at this point in the history
  • Loading branch information
KeplerC committed Aug 27, 2024
1 parent 3467b3f commit b79068a
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 7 deletions.
4 changes: 2 additions & 2 deletions fog_x/DLdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ def _convert_h5_cache_to_tensor():
output_tf_traj[key] = tf.convert_to_tensor(h5_cache[key])
return output_tf_traj
output = {"steps" : _convert_h5_cache_to_tensor()}
print(output)

yield output

Expand Down Expand Up @@ -253,7 +252,8 @@ def _convert_h5_cache_to_tensor():
output_signature = {"steps" : tf.nest.map_structure(
lambda spec: tf.TensorSpec(shape=spec.shape, dtype=spec.dtype), step_spec
)}

print(output_signature)

dataset = _wrap(tf.data.Dataset.from_generator, False)(
generator,
output_signature=output_signature
Expand Down
8 changes: 6 additions & 2 deletions fog_x/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def from_str(self, feature_str: str):
shape = eval(shape.split("=")[1][:-1]) # strip brackets
return FeatureType(dtype=dtype, shape=shape)

def to_tf_feature_type(self):
def to_tf_feature_type(self, first_dim_none=False):
"""
Convert to tf feature
"""
Expand Down Expand Up @@ -176,7 +176,11 @@ def to_tf_feature_type(self):
else:
return Scalar(dtype=tf_detype)
elif len(self.shape) >= 1:
return Tensor(shape=self.shape, dtype=tf_detype)
if first_dim_none:
tf_shape = [None] + list(self.shape[1:])
return Tensor(shape=tf_shape, dtype=tf_detype)
else:
return Tensor(shape=self.shape, dtype=tf_detype)
else:
raise ValueError(f"Unsupported conversion to tf feature: {self}")

Expand Down
8 changes: 7 additions & 1 deletion fog_x/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,13 @@ def _get_length_of_stream(container, stream):
h5_cache = h5py.File(self.cache_file_name, "w")
for feature_name, data in np_cache.items():
if data.dtype == object:
continue # TODO
for i in range(len(data)):
if data[i] is not None:
data[i] = str(data[i])
h5_cache.create_dataset(
feature_name,
data=data
)
else:
h5_cache.create_dataset(feature_name, data=data)
h5_cache.close()
Expand Down
5 changes: 3 additions & 2 deletions fog_x/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ def data_to_tf_schema(data: Dict[str, Any]) -> Dict[str, FeatureType]:
main_key, sub_key = k.split("/")
if main_key not in schema:
schema[main_key] = {}
schema[main_key][sub_key] = FeatureType.from_data(v).to_tf_feature_type()
schema[main_key][sub_key] = FeatureType.from_data(v).to_tf_feature_type(first_dim_none=True)
# replace first element of shape with None
else:
schema[k] = FeatureType.from_data(v).to_tf_feature_type()
schema[k] = FeatureType.from_data(v).to_tf_feature_type(first_dim_none=True)
return schema

0 comments on commit b79068a

Please sign in to comment.