From 7e870a80a22516b9fafeb694af4737d6fd05a7f7 Mon Sep 17 00:00:00 2001 From: Gaetan Lepage Date: Tue, 8 Oct 2024 14:40:21 +0200 Subject: [PATCH] Lazily import locomotion envs to prevent ModuleNotFoundError when labmaze is not installed --- shimmy/registration.py | 62 ++++++++++++++++++++++++------------------ 1 file changed, 35 insertions(+), 27 deletions(-) diff --git a/shimmy/registration.py b/shimmy/registration.py index 492ddbd4..5817fe06 100644 --- a/shimmy/registration.py +++ b/shimmy/registration.py @@ -68,6 +68,40 @@ def _make_dm_control_generic_env(env, **render_kwargs): # Register all suite environments import dm_control.suite + def _register_locomotion_envs(): + try: + from dm_control import composer + from dm_control.locomotion.examples import basic_cmu_2019, basic_rodent_2020 + except ImportError: + print("Warning, registration of `dm_control` locomotion envs has failed due to an ImportError") + return + + def _make_dm_control_example_locomotion_env( + env_fn: Callable[[np.random.RandomState | None], composer.Environment], + random_state: np.random.RandomState | None = None, + **render_kwargs, + ): + return DmControlCompatibilityV0(env_fn(random_state), **render_kwargs) + + for locomotion_env, nondeterministic in ( + (basic_cmu_2019.cmu_humanoid_run_walls, False), + (basic_cmu_2019.cmu_humanoid_run_gaps, False), + (basic_cmu_2019.cmu_humanoid_go_to_target, False), + (basic_cmu_2019.cmu_humanoid_maze_forage, True), + (basic_cmu_2019.cmu_humanoid_heterogeneous_forage, True), + (basic_rodent_2020.rodent_escape_bowl, False), + (basic_rodent_2020.rodent_run_gaps, False), + (basic_rodent_2020.rodent_maze_forage, True), + (basic_rodent_2020.rodent_two_touch, True), + # (cmu_2020_tracking.cmu_humanoid_tracking, False), + ): + register( + f"dm_control/{locomotion_env.__name__.title().replace('_', '')}-v0", + partial(_make_dm_control_example_locomotion_env, env_fn=locomotion_env), + nondeterministic=nondeterministic, + ) + + def _make_dm_control_suite_env( domain_name: str, task_name: str, @@ -98,33 +132,7 @@ def _make_dm_control_suite_env( # Register all example locomotion environments # Listed in https://github.com/deepmind/dm_control/blob/main/dm_control/locomotion/examples/examples_test.py - from dm_control import composer - from dm_control.locomotion.examples import basic_cmu_2019, basic_rodent_2020 - - def _make_dm_control_example_locomotion_env( - env_fn: Callable[[np.random.RandomState | None], composer.Environment], - random_state: np.random.RandomState | None = None, - **render_kwargs, - ): - return DmControlCompatibilityV0(env_fn(random_state), **render_kwargs) - - for locomotion_env, nondeterministic in ( - (basic_cmu_2019.cmu_humanoid_run_walls, False), - (basic_cmu_2019.cmu_humanoid_run_gaps, False), - (basic_cmu_2019.cmu_humanoid_go_to_target, False), - (basic_cmu_2019.cmu_humanoid_maze_forage, True), - (basic_cmu_2019.cmu_humanoid_heterogeneous_forage, True), - (basic_rodent_2020.rodent_escape_bowl, False), - (basic_rodent_2020.rodent_run_gaps, False), - (basic_rodent_2020.rodent_maze_forage, True), - (basic_rodent_2020.rodent_two_touch, True), - # (cmu_2020_tracking.cmu_humanoid_tracking, False), - ): - register( - f"dm_control/{locomotion_env.__name__.title().replace('_', '')}-v0", - partial(_make_dm_control_example_locomotion_env, env_fn=locomotion_env), - nondeterministic=nondeterministic, - ) + _register_locomotion_envs() # Register all manipulation environments import dm_control.manipulation