The feature, motivation and pitch
Description
MJX's JAX backend produces incorrect batch dimension ordering when mjx.step is called under nested jax.vmap. The contact solver's dot_general operations transpose the batch dimensions, causing an XLA HLO verification error.
This affects evolutionary computation (EC) workloads where the outer vmap iterates over a population of candidates and the inner vmap parallelizes multiple environment instances per candidate.
Contact-free environments (e.g., CartpoleSwingup with nconmax=0) work correctly. The issue only manifests in contact-rich environments (e.g., HopperHop with nconmax=50000), suggesting it's specific to the contact solver's batching rules.
Minimal Reproduction
import jax
import mujoco
from mujoco import mjx
from mujoco_playground import registry
# Load a contact-rich environment
env = registry.load("HopperHop")
# Force JAX backend (WARP has a separate issue with nested vmap)
mjx_model = mjx.put_model(env._mj_model, impl='jax')
env._mjx_model = mjx_model
# Single vmap — works ✅
def single_step(key):
state = env.reset(key)
return env.step(state, jax.numpy.zeros(env.action_size))
keys = jax.random.split(jax.random.PRNGKey(0), 512)
jax.vmap(single_step)(keys) # OK
# Nested vmap — crashes ❌
def inner_step(key):
state = env.reset(key)
return env.step(state, jax.numpy.zeros(env.action_size))
def outer_step(key):
inner_keys = jax.random.split(key, 4)
return jax.vmap(inner_step)(inner_keys)
outer_keys = jax.random.split(jax.random.PRNGKey(0), 512)
jax.vmap(outer_step)(outer_keys) # CRASHES
Error
jax.errors.JaxRuntimeError: INTERNAL: during context [hlo verifier]:
Expected instruction to have shape equal to f32[2,512,4,6,3],
actual shape is f32[4,2,512,6,3]:
%dot.55 = f32[4,2,512,6,3]{4,3,1,0,2} dot(%reshape.15280, %reshape.15285),
lhs_batch_dims={0,1,2}, lhs_contracting_dims={4},
rhs_batch_dims={0,1,2}, rhs_contracting_dims={3},
metadata={op_name="jit(step)/vmap()/while/body/closed_call/vmap(forward)/vmap()/dot_general"}
The outer vmap dimension (512) and inner vmap dimension (4) appear in swapped positions in the dot_general output shape.
Expected Behavior
Nested jax.vmap should compose correctly, producing shapes f32[512, 4, ...] (outer dim first, inner dim second).
Environment
- mujoco: 3.6.0
- mujoco-mjx: 3.6.0
- mujoco-playground: 0.2.0
- JAX: 0.9.2
- GPU: NVIDIA RTX 4090, CUDA 12.9
Notes
- WARP backend has a separate issue: its FFI dimension validation rejects nested vmap entirely (
ValueError: Leaf node leading dim (512) does not match nconmax (50000)). This is expected since WARP kernels have fixed batch semantics.
- JAX backend should theoretically support nested vmap but has this batch transposition bug in the contact solver path.
- Workaround: Use single-level vmap with sequential episode evaluation (
num_envs=1 in EC frameworks).
- Contact-free envs work fine under nested vmap — the issue is specific to the contact solver's
dot_general batching rules.
Alternatives
No response
Additional context
No response
The feature, motivation and pitch
Description
MJX's JAX backend produces incorrect batch dimension ordering when
mjx.stepis called under nestedjax.vmap. The contact solver'sdot_generaloperations transpose the batch dimensions, causing an XLA HLO verification error.This affects evolutionary computation (EC) workloads where the outer
vmapiterates over a population of candidates and the innervmapparallelizes multiple environment instances per candidate.Contact-free environments (e.g., CartpoleSwingup with
nconmax=0) work correctly. The issue only manifests in contact-rich environments (e.g., HopperHop withnconmax=50000), suggesting it's specific to the contact solver's batching rules.Minimal Reproduction
Error
The outer vmap dimension (512) and inner vmap dimension (4) appear in swapped positions in the
dot_generaloutput shape.Expected Behavior
Nested
jax.vmapshould compose correctly, producing shapesf32[512, 4, ...](outer dim first, inner dim second).Environment
Notes
ValueError: Leaf node leading dim (512) does not match nconmax (50000)). This is expected since WARP kernels have fixed batch semantics.num_envs=1in EC frameworks).dot_generalbatching rules.Alternatives
No response
Additional context
No response