diff --git a/acme/wrappers/concatenate_observations.py b/acme/wrappers/concatenate_observations.py index 949b4cf49f..15d5dafd22 100644 --- a/acme/wrappers/concatenate_observations.py +++ b/acme/wrappers/concatenate_observations.py @@ -18,7 +18,6 @@ from typing import Sequence, Optional from acme import types -from acme.jax import utils from acme.wrappers import base import dm_env import numpy as np @@ -40,6 +39,11 @@ def _concat(values: types.NestedArray) -> np.ndarray: return np.concatenate(leaves) +def _zeros_like(nest, dtype=None): + """Generate a nested NumPy array according to spec.""" + return tree.map_structure(lambda x: np.zeros(x.shape, dtype or x.dtype), nest) + + class ConcatObservationWrapper(base.EnvironmentWrapper): """Wrapper that concatenates observation fields. @@ -62,7 +66,7 @@ def __init__(self, environment: dm_env.Environment, name_filter = list(observation_spec.keys()) self._obs_names = [x for x in name_filter if x in observation_spec.keys()] - dummy_obs = utils.zeros_like(observation_spec) + dummy_obs = _zeros_like(observation_spec) dummy_obs = self._convert_observation(dummy_obs) self._observation_spec = dm_env.specs.BoundedArray( shape=dummy_obs.shape,