Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NaNs in qpos when rolling out muscle-driven arm26.xml with MJX and random actions. #2317

Open
2 tasks done
P-Schumacher opened this issue Dec 28, 2024 · 2 comments
Open
2 tasks done
Assignees
Labels
bug Something isn't working

Comments

@P-Schumacher
Copy link

Intro

Hi!

I am a PhD student at the university of Tuebingen, I use MuJoCo for my research on musculoskeletal control.

Me and a group of other researchers are very excited about the support of muscle actuators in MJX.
We have tried using MJX with a simple musculoskeletal model (A single finger) in conjunction with brax, but we keep running into NaNs while training.

We went back to the provided arm26 model in the MJX Github repo.
While this model runs through in the provided testspeed.py file, it also produces NaNs as soon as we use random actions in the rollout loop.

I have created a standalone script that reproduces this behavior. We can improve stability with smaller timesteps and other settings, but we can almost always recover NaNs by just increasing the number of environments and the length of rollouts in the following script. The string model there is just a 1-to-1 copy of the arm26.xml model, contained in the MuJoCo Github code.

Are we missing something? Is this expected behavior?

CCing the involved researchers: @vikashplus @Balint-H @jamesheald

My setup

OS: Ubuntu 24.04.1 LTS
Python: 3.10.16
mujoco: 3.2.6
mujoco-mjx: 3.2.6
jax: 0.4.38
jax-cuda12-pjrt: 0.4.38
jax-cuda12-plugin: 0.4.38
jaxlib: 0.4.38
NVIDIA-SMI 550.120
Driver Version: 550.120
CUDA Version: 12.4 

Hardware: Nvidia Geforce RTX 4080

What's happening? What did you expect?

The script produces NaNs in the model qpos after several rollout steps with random actions.

I would expect the provided parameters for this simple model to be stable under these conditions. We have tried to improve stability, but can almost always reproduce NaNs when running this script for long enough.

Steps for reproduction

Run the script I provided with python test_script.py
On my system it produced NaNs in the qpos.

Minimal model for reproduction

I used the provided arm26.xml model.

minimal XML
<!-- Copyright 2021 DeepMind Technologies Limited

     Licensed under the Apache License, Version 2.0 (the "License");
     you may not use this file except in compliance with the License.
     You may obtain a copy of the License at

         http://www.apache.org/licenses/LICENSE-2.0

     Unless required by applicable law or agreed to in writing, software
     distributed under the License is distributed on an "AS IS" BASIS,
     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     See the License for the specific language governing permissions and
     limitations under the License.
-->

<mujoco model="2-link 6-muscle arm">
  <option timestep="0.005" iterations="50" solver="Newton" tolerance="1e-10"/>

  <visual>
    <rgba haze=".3 .3 .3 1"/>
  </visual>

  <default>
    <joint type="hinge" pos="0 0 0" axis="0 0 1" limited="true" range="0 120" damping="0.1"/>
    <muscle ctrllimited="true" ctrlrange="0 1"/>
  </default>

  <asset>
    <texture type="skybox" builtin="gradient" rgb1="0.6 0.6 0.6" rgb2="0 0 0" width="512" height="512"/>

    <texture name="texplane" type="2d" builtin="checker" rgb1=".25 .25 .25" rgb2=".3 .3 .3" width="512" height="512" mark="cross" markrgb=".8 .8 .8"/>

    <material name="matplane" reflectance="0.3" texture="texplane" texrepeat="1 1" texuniform="true"/>
  </asset>

  <worldbody>
    <geom name="floor" pos="0 0 -0.5" size="0 0 1" type="plane" material="matplane"/>

    <light directional="true" diffuse=".8 .8 .8" specular=".2 .2 .2" pos="0 0 5" dir="0 0 -1"/>

    <site name="s0" pos="-0.15 0 0" size="0.02"/>
    <site name="x0" pos="0 -0.15 0" size="0.02" rgba="0 .7 0 1" group="1"/>

    <body pos="0 0 0">
      <geom name="upper arm" type="capsule" size="0.045" fromto="0 0 0  0.5 0 0" rgba=".5 .1 .1 1"/>
      <joint name="shoulder"/>
      <geom name="shoulder" type="cylinder" pos="0 0 0" size=".1 .05" rgba=".5 .1 .8 .5" mass="0" group="1"/>

      <site name="s1" pos="0.15 0.06 0" size="0.02"/>
      <site name="s2" pos="0.15 -0.06 0" size="0.02"/>
      <site name="s3" pos="0.4 0.06 0" size="0.02"/>
      <site name="s4" pos="0.4 -0.06 0" size="0.02"/>
      <site name="s5" pos="0.25 0.1 0" size="0.02"/>
      <site name="s6" pos="0.25 -0.1 0" size="0.02"/>
      <site name="x1" pos="0.5 -0.15 0" size="0.02" rgba="0 .7 0 1" group="1"/>

      <body pos="0.5 0 0">
        <geom name="forearm" type="capsule" size="0.035" fromto="0 0 0  0.5 0 0" rgba=".5 .1 .1 1"/>
        <joint name="elbow"/>
        <geom name="elbow" type="cylinder" pos="0 0 0" size=".08 .05" rgba=".5 .1 .8 .5" mass="0" group="1"/>

        <site name="s7" pos="0.11 0.05 0" size="0.02"/>
        <site name="s8" pos="0.11 -0.05 0" size="0.02"/>
      </body>
    </body>
  </worldbody>

  <tendon>
    <spatial name="SF" width="0.01">
      <site site="s0"/>
      <geom geom="shoulder"/>
      <site site="s1"/>
    </spatial>

    <spatial name="SE" width="0.01">
      <site site="s0"/>
      <geom geom="shoulder" sidesite="x0"/>
      <site site="s2"/>
    </spatial>

    <spatial name="EF" width="0.01">
      <site site="s3"/>
      <geom geom="elbow"/>
      <site site="s7"/>
    </spatial>

    <spatial name="EE" width="0.01">
      <site site="s4"/>
      <geom geom="elbow" sidesite="x1"/>
      <site site="s8"/>
    </spatial>

    <spatial name="BF" width="0.009" rgba=".4 .6 .4 1">
      <site site="s0"/>
      <geom geom="shoulder"/>
      <site site="s5"/>
      <geom geom="elbow"/>
      <site site="s7"/>
    </spatial>

    <spatial name="BE" width="0.009" rgba=".4 .6 .4 1">
      <site site="s0"/>
      <geom geom="shoulder" sidesite="x0"/>
      <site site="s6"/>
      <geom geom="elbow" sidesite="x1"/>
      <site site="s8"/>
    </spatial>
  </tendon>

  <actuator>
    <muscle name="SF" tendon="SF"/>
    <muscle name="SE" tendon="SE"/>
    <muscle name="EF" tendon="EF"/>
    <muscle name="EE" tendon="EE"/>
    <muscle name="BF" tendon="BF"/>
    <muscle name="BE" tendon="BE"/>
  </actuator>
</mujoco>

Code required for reproduction

import mujoco
from mujoco import mjx
import jax
from jax import numpy as jp


model_str = """
<!-- Copyright 2021 DeepMind Technologies Limited

     Licensed under the Apache License, Version 2.0 (the "License");
     you may not use this file except in compliance with the License.
     You may obtain a copy of the License at

         http://www.apache.org/licenses/LICENSE-2.0

     Unless required by applicable law or agreed to in writing, software
     distributed under the License is distributed on an "AS IS" BASIS,
     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     See the License for the specific language governing permissions and
     limitations under the License.
-->

<mujoco model="2-link 6-muscle arm">
  <option timestep="0.005" iterations="50" solver="Newton" tolerance="1e-10"/>

  <visual>
    <rgba haze=".3 .3 .3 1"/>
  </visual>

  <default>
    <joint type="hinge" pos="0 0 0" axis="0 0 1" limited="true" range="0 120" damping="0.1"/>
    <muscle ctrllimited="true" ctrlrange="0 1"/>
  </default>

  <asset>
    <texture type="skybox" builtin="gradient" rgb1="0.6 0.6 0.6" rgb2="0 0 0" width="512" height="512"/>

    <texture name="texplane" type="2d" builtin="checker" rgb1=".25 .25 .25" rgb2=".3 .3 .3" width="512" height="512" mark="cross" markrgb=".8 .8 .8"/>

    <material name="matplane" reflectance="0.3" texture="texplane" texrepeat="1 1" texuniform="true"/>
  </asset>

  <worldbody>
    <geom name="floor" pos="0 0 -0.5" size="0 0 1" type="plane" material="matplane"/>

    <light directional="true" diffuse=".8 .8 .8" specular=".2 .2 .2" pos="0 0 5" dir="0 0 -1"/>

    <site name="s0" pos="-0.15 0 0" size="0.02"/>
    <site name="x0" pos="0 -0.15 0" size="0.02" rgba="0 .7 0 1" group="1"/>

    <body pos="0 0 0">
      <geom name="upper arm" type="capsule" size="0.045" fromto="0 0 0  0.5 0 0" rgba=".5 .1 .1 1"/>
      <joint name="shoulder"/>
      <geom name="shoulder" type="cylinder" pos="0 0 0" size=".1 .05" rgba=".5 .1 .8 .5" mass="0" group="1"/>

      <site name="s1" pos="0.15 0.06 0" size="0.02"/>
      <site name="s2" pos="0.15 -0.06 0" size="0.02"/>
      <site name="s3" pos="0.4 0.06 0" size="0.02"/>
      <site name="s4" pos="0.4 -0.06 0" size="0.02"/>
      <site name="s5" pos="0.25 0.1 0" size="0.02"/>
      <site name="s6" pos="0.25 -0.1 0" size="0.02"/>
      <site name="x1" pos="0.5 -0.15 0" size="0.02" rgba="0 .7 0 1" group="1"/>

      <body pos="0.5 0 0">
        <geom name="forearm" type="capsule" size="0.035" fromto="0 0 0  0.5 0 0" rgba=".5 .1 .1 1"/>
        <joint name="elbow"/>
        <geom name="elbow" type="cylinder" pos="0 0 0" size=".08 .05" rgba=".5 .1 .8 .5" mass="0" group="1"/>

        <site name="s7" pos="0.11 0.05 0" size="0.02"/>
        <site name="s8" pos="0.11 -0.05 0" size="0.02"/>
      </body>
    </body>
  </worldbody>

  <tendon>
    <spatial name="SF" width="0.01">
      <site site="s0"/>
      <geom geom="shoulder"/>
      <site site="s1"/>
    </spatial>

    <spatial name="SE" width="0.01">
      <site site="s0"/>
      <geom geom="shoulder" sidesite="x0"/>
      <site site="s2"/>
    </spatial>

    <spatial name="EF" width="0.01">
      <site site="s3"/>
      <geom geom="elbow"/>
      <site site="s7"/>
    </spatial>

    <spatial name="EE" width="0.01">
      <site site="s4"/>
      <geom geom="elbow" sidesite="x1"/>
      <site site="s8"/>
    </spatial>

    <spatial name="BF" width="0.009" rgba=".4 .6 .4 1">
      <site site="s0"/>
      <geom geom="shoulder"/>
      <site site="s5"/>
      <geom geom="elbow"/>
      <site site="s7"/>
    </spatial>

    <spatial name="BE" width="0.009" rgba=".4 .6 .4 1">
      <site site="s0"/>
      <geom geom="shoulder" sidesite="x0"/>
      <site site="s6"/>
      <geom geom="elbow" sidesite="x1"/>
      <site site="s8"/>
    </spatial>
  </tendon>

  <actuator>
    <muscle name="SF" tendon="SF"/>
    <muscle name="SE" tendon="SE"/>
    <muscle name="EF" tendon="EF"/>
    <muscle name="EE" tendon="EE"/>
    <muscle name="BF" tendon="BF"/>
    <muscle name="BE" tendon="BE"/>
  </actuator>
</mujoco>
"""


def get_rollout_callable(mjx_model, mjx_data, n_steps):
    """
    Get a callable that runs the environment rollout.
    """

    def collect(mjx_data, key):
        """
        Rollout in the environment.
        """

        def step_with_action(carry, _):
            """
            Perform a step with a random action.
            """
            data, key = carry[0], carry[1]
            key, action_key = jax.random.split(key, 2)
            action = jax.random.uniform(key, shape=[mjx_model.nu,], minval=0, maxval=1.0, dtype=data.ctrl.dtype)
            data = data.replace(ctrl=action)
            data = mjx.step(mjx_model, data)
            return (data, key), None

        carry, _ = jax.lax.scan(step_with_action, (mjx_data, key), (), n_steps)
        return carry[0]

    return collect


if __name__ == "__main__":
    batch_size = 100
    rollout_length = 1000
    model = mujoco.MjModel.from_xml_string(model_str)
    mjx_model = mjx.put_model(model)
    
    # vmap data creation to get parallel environments
    mjx_data = jax.vmap(lambda dummy: mjx.make_data(mjx_model))(jp.arange(batch_size))

    rollout_callable = get_rollout_callable(mjx_model, mjx_data, rollout_length)

    key = jax.random.PRNGKey(0)
    key = jax.random.split(key, batch_size)

    mjx_data = jax.jit(jax.vmap(rollout_callable))(mjx_data, key)
    
    # detect NaNs
    assert not jp.any(jp.isnan(mjx_data.qpos)), f"NaNs detected at: {jp.where(jp.isnan(mjx_data.qpos))}"

Confirmations

@P-Schumacher P-Schumacher added the bug Something isn't working label Dec 28, 2024
@thowell thowell self-assigned this Dec 30, 2024
@Beanpow
Copy link
Contributor

Beanpow commented Jan 8, 2025

def tendon(m: Model, d: Data) -> Data:

I think it maybe due to some bugs in the tendon length calculation. Currently MJX don’t support muscle wrapping on the cylinder. When I remove the geom wrapping in the model string, the script is working.

@yuvaltassa
Copy link
Collaborator

@Beanpow, you're using internal wrapping? External wrapping is supported.

I believe both internal wrapping and a bug fix for the issue described in the OP are coming soon. Is that right, @thowell ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants