Skip to content

MJX JAX backend: dot_general batch dimension transposition under nested vmap #3209

@seyeint

Description

@seyeint

The feature, motivation and pitch

Description

MJX's JAX backend produces incorrect batch dimension ordering when mjx.step is called under nested jax.vmap. The contact solver's dot_general operations transpose the batch dimensions, causing an XLA HLO verification error.

This affects evolutionary computation (EC) workloads where the outer vmap iterates over a population of candidates and the inner vmap parallelizes multiple environment instances per candidate.

Contact-free environments (e.g., CartpoleSwingup with nconmax=0) work correctly. The issue only manifests in contact-rich environments (e.g., HopperHop with nconmax=50000), suggesting it's specific to the contact solver's batching rules.

Minimal Reproduction

import jax
import mujoco
from mujoco import mjx
from mujoco_playground import registry

# Load a contact-rich environment
env = registry.load("HopperHop")

# Force JAX backend (WARP has a separate issue with nested vmap)
mjx_model = mjx.put_model(env._mj_model, impl='jax')
env._mjx_model = mjx_model

# Single vmap — works ✅
def single_step(key):
    state = env.reset(key)
    return env.step(state, jax.numpy.zeros(env.action_size))

keys = jax.random.split(jax.random.PRNGKey(0), 512)
jax.vmap(single_step)(keys)  # OK

# Nested vmap — crashes ❌
def inner_step(key):
    state = env.reset(key)
    return env.step(state, jax.numpy.zeros(env.action_size))

def outer_step(key):
    inner_keys = jax.random.split(key, 4)
    return jax.vmap(inner_step)(inner_keys)

outer_keys = jax.random.split(jax.random.PRNGKey(0), 512)
jax.vmap(outer_step)(outer_keys)  # CRASHES

Error

jax.errors.JaxRuntimeError: INTERNAL: during context [hlo verifier]:
Expected instruction to have shape equal to f32[2,512,4,6,3],
actual shape is f32[4,2,512,6,3]:

%dot.55 = f32[4,2,512,6,3]{4,3,1,0,2} dot(%reshape.15280, %reshape.15285),
  lhs_batch_dims={0,1,2}, lhs_contracting_dims={4},
  rhs_batch_dims={0,1,2}, rhs_contracting_dims={3},
  metadata={op_name="jit(step)/vmap()/while/body/closed_call/vmap(forward)/vmap()/dot_general"}

The outer vmap dimension (512) and inner vmap dimension (4) appear in swapped positions in the dot_general output shape.

Expected Behavior

Nested jax.vmap should compose correctly, producing shapes f32[512, 4, ...] (outer dim first, inner dim second).

Environment

  • mujoco: 3.6.0
  • mujoco-mjx: 3.6.0
  • mujoco-playground: 0.2.0
  • JAX: 0.9.2
  • GPU: NVIDIA RTX 4090, CUDA 12.9

Notes

  • WARP backend has a separate issue: its FFI dimension validation rejects nested vmap entirely (ValueError: Leaf node leading dim (512) does not match nconmax (50000)). This is expected since WARP kernels have fixed batch semantics.
  • JAX backend should theoretically support nested vmap but has this batch transposition bug in the contact solver path.
  • Workaround: Use single-level vmap with sequential episode evaluation (num_envs=1 in EC frameworks).
  • Contact-free envs work fine under nested vmap — the issue is specific to the contact solver's dot_general batching rules.

Alternatives

No response

Additional context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions