Skip to content

Commit

Permalink
Update metadata to autoreset_mode and add tests on it
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudo-rnd-thoughts committed Nov 28, 2024
1 parent 781e6f5 commit d031c42
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 11 deletions.
2 changes: 1 addition & 1 deletion gymnasium/envs/classic_control/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ class CartPoleVectorEnv(VectorEnv):
metadata = {
"render_modes": ["rgb_array"],
"render_fps": 50,
"autoreset-mode": AutoresetMode.NEXT_STEP,
"autoreset_mode": AutoresetMode.NEXT_STEP,
}

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion gymnasium/envs/functional_jax_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def __init__(
"""Initialize the environment from a FuncEnv."""
super().__init__()
if metadata is None:
metadata = {"AutoresetMode": AutoresetMode.NEXT_STEP}
metadata = {"autoreset_mode": AutoresetMode.NEXT_STEP}
self.func_env = func_env
self.num_envs = num_envs

Expand Down
2 changes: 1 addition & 1 deletion gymnasium/envs/phys2d/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ class CartPoleJaxVectorEnv(FunctionalJaxVectorEnv, EzPickle):
"render_modes": ["rgb_array"],
"render_fps": 50,
"jax": True,
"AutoresetMode": AutoresetMode.NEXT_STEP,
"autoreset_mode": AutoresetMode.NEXT_STEP,
}

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion gymnasium/envs/phys2d/pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ class PendulumJaxEnv(FunctionalJaxEnv, EzPickle):
"render_modes": ["rgb_array"],
"render_fps": 30,
"jax": True,
"AutoresetMode": AutoresetMode.NEXT_STEP,
"autoreset_mode": AutoresetMode.NEXT_STEP,
}

def __init__(self, render_mode: str | None = None, **kwargs: Any):
Expand Down
10 changes: 6 additions & 4 deletions gymnasium/envs/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,11 +978,13 @@ def create_single_env() -> Env:
copied_id_spec.kwargs["wrappers"] = wrappers
env.unwrapped.spec = copied_id_spec

if "AutoresetMode" not in env.metadata:
warn(f"The VectorEnv ({env}) is missing AutoresetMode metadata.")
elif not isinstance(env.metadata["AutoresetMode"], AutoresetMode):
if "autoreset_mode" not in env.metadata:
warn(
f"The VectorEnv ({env}) AutoresetMode metadata is not an instance of AutoresetMode, {type(env.metadata['AutoresetMode'])}."
f"The VectorEnv ({env}) is missing AutoresetMode metadata, metadata={env.metadata}"
)
elif not isinstance(env.metadata["autoreset_mode"], AutoresetMode):
warn(
f"The VectorEnv ({env}) metadata['autoreset_mode'] is not an instance of AutoresetMode, {type(env.metadata['autoreset_mode'])}."
)

return env
Expand Down
2 changes: 1 addition & 1 deletion gymnasium/envs/tabular/blackjack.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ class BlackjackFunctional(
metadata = {
"render_modes": ["rgb_array"],
"render_fps": 4,
"AutoresetMode": AutoresetMode.NEXT_STEP,
"autoreseet-mode": AutoresetMode.NEXT_STEP,
}

def transition(
Expand Down
2 changes: 1 addition & 1 deletion gymnasium/envs/tabular/cliffwalking.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ class CliffWalkingFunctional(
metadata = {
"render_modes": ["rgb_array"],
"render_fps": 4,
"AutoresetMode": AutoresetMode.NEXT_STEP,
"autoreset_mode": AutoresetMode.NEXT_STEP,
}

def transition(
Expand Down
44 changes: 43 additions & 1 deletion tests/envs/registration/test_make_vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

import multiprocessing
import re
import warnings

import pytest

import gymnasium as gym
from gymnasium import VectorizeMode, error, wrappers
from gymnasium.envs.classic_control import CartPoleEnv
from gymnasium.envs.classic_control.cartpole import CartPoleVectorEnv
from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv
from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv, VectorEnv
from gymnasium.wrappers import TimeLimit, TransformObservation
from tests.wrappers.utils import has_wrapper

Expand Down Expand Up @@ -282,3 +283,44 @@ def test_make_vec_with_spec_additional_wrappers():

del gym.registry["TestEnv-v0"]
del gym.registry["TestEnv-v1"]


class MissingMetadataVecEnv(VectorEnv):
metadata = {"render_fps": 1}

def __init__(self, num_envs: int):
self.num_envs = num_envs


class IncorrectMetadataVecEnv(VectorEnv):
metadata = {"autoreset_mode": "next_step"}

def __init__(self, num_envs: int):
self.num_envs = num_envs


def test_missing_autoreset_mode_metadata():
gym.register("MissingMetadataVecEnv-v0", vector_entry_point=MissingMetadataVecEnv)
gym.register(
"IncorrectMetadataVecEnv-v0", vector_entry_point=IncorrectMetadataVecEnv
)

with warnings.catch_warnings():
with pytest.warns(
UserWarning,
match=re.escape(
"The VectorEnv (MissingMetadataVecEnv(MissingMetadataVecEnv-v0, num_envs=1)) is missing AutoresetMode metadata, metadata={'render_fps': 1}"
),
):
gym.make_vec("MissingMetadataVecEnv-v0")

with pytest.warns(
UserWarning,
match=re.escape(
"The VectorEnv (IncorrectMetadataVecEnv(IncorrectMetadataVecEnv-v0, num_envs=1)) metadata['autoreset_mode'] is not an instance of AutoresetMode, <class 'str'>."
),
):
gym.make_vec("IncorrectMetadataVecEnv-v0")

gym.registry.pop("MissingMetadataVecEnv-v0")
gym.registry.pop("IncorrectMetadataVecEnv-v0")

0 comments on commit d031c42

Please sign in to comment.