Skip to content

Commit

Permalink
Merge pull request #178 from ethanluoyc:fix-wrappers-chex-177
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 420277750
Change-Id: Idf5745935bdadfe4e13035533698f48b661849c0
  • Loading branch information
mwhoffman committed Jan 8, 2022
2 parents 286bf85 + f29cd4e commit 702bd51
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions acme/wrappers/concatenate_observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from typing import Sequence, Optional

from acme import types
from acme.jax import utils
from acme.wrappers import base
import dm_env
import numpy as np
Expand All @@ -40,6 +39,11 @@ def _concat(values: types.NestedArray) -> np.ndarray:
return np.concatenate(leaves)


def _zeros_like(nest, dtype=None):
"""Generate a nested NumPy array according to spec."""
return tree.map_structure(lambda x: np.zeros(x.shape, dtype or x.dtype), nest)


class ConcatObservationWrapper(base.EnvironmentWrapper):
"""Wrapper that concatenates observation fields.
Expand All @@ -62,7 +66,7 @@ def __init__(self, environment: dm_env.Environment,
name_filter = list(observation_spec.keys())
self._obs_names = [x for x in name_filter if x in observation_spec.keys()]

dummy_obs = utils.zeros_like(observation_spec)
dummy_obs = _zeros_like(observation_spec)
dummy_obs = self._convert_observation(dummy_obs)
self._observation_spec = dm_env.specs.BoundedArray(
shape=dummy_obs.shape,
Expand Down

0 comments on commit 702bd51

Please sign in to comment.