From e9227fefa704ee6d07deab3759971ce497866d6e Mon Sep 17 00:00:00 2001 From: Alexis DUBURCQ Date: Fri, 17 Jan 2025 21:30:56 +0100 Subject: [PATCH] Fix `get_wrapper_attr` / `set_wrapper_attr`. --- gymnasium/core.py | 25 +++++++++++++------------ tests/test_core.py | 10 ++++++++++ 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/gymnasium/core.py b/gymnasium/core.py index 9dfe63876..1e54fb321 100644 --- a/gymnasium/core.py +++ b/gymnasium/core.py @@ -19,6 +19,8 @@ ActType = TypeVar("ActType") RenderFrame = TypeVar("RenderFrame") +NOT_FOUND = object() + class Env(Generic[ObsType, ActType]): r"""The main Gymnasium class for implementing Reinforcement Learning Agents environments. @@ -415,15 +417,15 @@ def get_wrapper_attr(self, name: str) -> Any: Returns: The variable with name in wrapper or lower environments """ - if hasattr(self, name): - return getattr(self, name) - else: + attr = getattr(self, name, NOT_FOUND) + if attr is NOT_FOUND: try: return self.env.get_wrapper_attr(name) except AttributeError as e: raise AttributeError( f"wrapper {self.class_name()} has no attribute {name!r}" ) from e + return attr def set_wrapper_attr(self, name: str, value: Any): """Sets an attribute on this wrapper or lower environment if `name` is already defined. @@ -432,18 +434,17 @@ def set_wrapper_attr(self, name: str, value: Any): name: The variable name value: The new variable value """ - sub_env = self.env - attr_set = False - - while attr_set is False and isinstance(sub_env, Wrapper): + sub_env = self + while True: if hasattr(sub_env, name): setattr(sub_env, name, value) - attr_set = True - else: + return + if isinstance(sub_env, Wrapper): sub_env = sub_env.env - - if attr_set is False: - setattr(sub_env, name, value) + else: + sub_env = self + break + setattr(sub_env, name, value) def __str__(self): """Returns the wrapper name and the :attr:`env` representation string.""" diff --git a/tests/test_core.py b/tests/test_core.py index 196b64f73..7e7391aad 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -215,6 +215,16 @@ def test_get_set_wrapper_attr(): env.unwrapped._disable_render_order_enforcing assert env.get_wrapper_attr("_disable_render_order_enforcing") is True + # Test with top-most wrapper + env.MY_ATTRIBUTE_1 = True + assert env.get_wrapper_attr("MY_ATTRIBUTE_1") is True + env.set_wrapper_attr("MY_ATTRIBUTE_1", False) + assert env.get_wrapper_attr("MY_ATTRIBUTE_1") is False + + # Test with non-existing attribute + env.set_wrapper_attr("MY_ATTRIBUTE_2", True) + assert getattr(env, "MY_ATTRIBUTE_2") is True + class TestRandomSeeding: @staticmethod