Skip to content

Commit

Permalink
Merge pull request octo-models#136 from rail-berkeley/dibya_patch_1
Browse files Browse the repository at this point in the history
Make viz lib more efficient for small datasets, fix woopsie
  • Loading branch information
dibyaghosh authored Dec 7, 2023
2 parents 0505e18 + a58d6b3 commit 44bfdfb
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion orca/utils/visualization_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,14 @@ class Visualizer:

def __post_init__(self):
self.action_proprio_stats = self.dataset.dataset_statistics
cardinality = self.dataset.cardinality()
if (
cardinality == tf.data.INFINITE_CARDINALITY
or cardinality == tf.data.UNKNOWN_CARDINALITY
):
self.cardinality = float("inf")
else:
self.cardinality = cardinality.numpy()
self.visualized_trajs = False
self._cached_iterators = {}

Expand Down Expand Up @@ -250,7 +258,8 @@ def raw_evaluations(
return all_traj_info

def get_iterator(self, dataset, n):
if dataset not in self._cached_iterators:
n = min(n, self.cardinality)
if n not in self._cached_iterators:
self._cached_iterators[n] = (
dataset.take(n).repeat().as_numpy_iterator()
if self.freeze_trajs
Expand Down

0 comments on commit 44bfdfb

Please sign in to comment.