Skip to content

Commit

Permalink
Add AutoresetMode to metadata and warning about it if missing
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudo-rnd-thoughts committed Nov 27, 2024
1 parent 606bfaf commit 781e6f5
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 4 deletions.
3 changes: 2 additions & 1 deletion gymnasium/envs/classic_control/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from gymnasium import logger, spaces
from gymnasium.envs.classic_control import utils
from gymnasium.error import DependencyNotInstalled
from gymnasium.vector import VectorEnv
from gymnasium.vector import AutoresetMode, VectorEnv
from gymnasium.vector.utils import batch_space


Expand Down Expand Up @@ -355,6 +355,7 @@ class CartPoleVectorEnv(VectorEnv):
metadata = {
"render_modes": ["rgb_array"],
"render_fps": 50,
"autoreset-mode": AutoresetMode.NEXT_STEP,
}

def __init__(
Expand Down
3 changes: 2 additions & 1 deletion gymnasium/envs/functional_jax_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from gymnasium.envs.registration import EnvSpec
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
from gymnasium.utils import seeding
from gymnasium.vector import AutoresetMode
from gymnasium.vector.utils import batch_space


Expand Down Expand Up @@ -115,7 +116,7 @@ def __init__(
"""Initialize the environment from a FuncEnv."""
super().__init__()
if metadata is None:
metadata = {}
metadata = {"AutoresetMode": AutoresetMode.NEXT_STEP}
self.func_env = func_env
self.num_envs = num_envs

Expand Down
8 changes: 7 additions & 1 deletion gymnasium/envs/phys2d/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
from gymnasium.utils import EzPickle
from gymnasium.vector import AutoresetMode


RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock"] # type: ignore # noqa: F821
Expand Down Expand Up @@ -272,7 +273,12 @@ def __init__(self, render_mode: str | None = None, **kwargs: Any):
class CartPoleJaxVectorEnv(FunctionalJaxVectorEnv, EzPickle):
"""Jax-based implementation of the vectorized CartPole environment."""

metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True}
metadata = {
"render_modes": ["rgb_array"],
"render_fps": 50,
"jax": True,
"AutoresetMode": AutoresetMode.NEXT_STEP,
}

def __init__(
self,
Expand Down
8 changes: 7 additions & 1 deletion gymnasium/envs/phys2d/pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
from gymnasium.utils import EzPickle
from gymnasium.vector import AutoresetMode


RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock", Optional[float]] # type: ignore # noqa: F821
Expand Down Expand Up @@ -225,7 +226,12 @@ def get_default_params(self, **kwargs) -> PendulumParams:
class PendulumJaxEnv(FunctionalJaxEnv, EzPickle):
"""Jax-based pendulum environment using the functional version as base."""

metadata = {"render_modes": ["rgb_array"], "render_fps": 30, "jax": True}
metadata = {
"render_modes": ["rgb_array"],
"render_fps": 30,
"jax": True,
"AutoresetMode": AutoresetMode.NEXT_STEP,
}

def __init__(self, render_mode: str | None = None, **kwargs: Any):
"""Constructor where the kwargs are passed to the base environment to modify the parameters."""
Expand Down
9 changes: 9 additions & 0 deletions gymnasium/envs/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import gymnasium as gym
from gymnasium import Env, Wrapper, error, logger
from gymnasium.logger import warn
from gymnasium.vector import AutoresetMode


if sys.version_info < (3, 10):
Expand Down Expand Up @@ -976,6 +978,13 @@ def create_single_env() -> Env:
copied_id_spec.kwargs["wrappers"] = wrappers
env.unwrapped.spec = copied_id_spec

if "AutoresetMode" not in env.metadata:
warn(f"The VectorEnv ({env}) is missing AutoresetMode metadata.")
elif not isinstance(env.metadata["AutoresetMode"], AutoresetMode):
warn(
f"The VectorEnv ({env}) AutoresetMode metadata is not an instance of AutoresetMode, {type(env.metadata['AutoresetMode'])}."
)

return env


Expand Down
2 changes: 2 additions & 0 deletions gymnasium/envs/tabular/blackjack.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
from gymnasium.utils import EzPickle, seeding
from gymnasium.vector import AutoresetMode
from gymnasium.wrappers import HumanRendering


Expand Down Expand Up @@ -239,6 +240,7 @@ class BlackjackFunctional(
metadata = {
"render_modes": ["rgb_array"],
"render_fps": 4,
"AutoresetMode": AutoresetMode.NEXT_STEP,
}

def transition(
Expand Down
2 changes: 2 additions & 0 deletions gymnasium/envs/tabular/cliffwalking.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
from gymnasium.utils import EzPickle
from gymnasium.vector import AutoresetMode
from gymnasium.wrappers import HumanRendering


Expand Down Expand Up @@ -136,6 +137,7 @@ class CliffWalkingFunctional(
metadata = {
"render_modes": ["rgb_array"],
"render_fps": 4,
"AutoresetMode": AutoresetMode.NEXT_STEP,
}

def transition(
Expand Down
2 changes: 2 additions & 0 deletions gymnasium/vector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from gymnasium.vector.async_vector_env import AsyncVectorEnv
from gymnasium.vector.sync_vector_env import SyncVectorEnv
from gymnasium.vector.vector_env import (
AutoresetMode,
VectorActionWrapper,
VectorEnv,
VectorObservationWrapper,
Expand All @@ -21,4 +22,5 @@
"SyncVectorEnv",
"AsyncVectorEnv",
"utils",
"AutoresetMode",
]

0 comments on commit 781e6f5

Please sign in to comment.