Skip to content

Commit 473e09a

Browse files
committed
Add support for named indexing into arrays with dimension na (e.g. mjData->act)
These dimensions correspond to the subset of actuators that have internal state. PiperOrigin-RevId: 207272312
1 parent a56b451 commit 473e09a

4 files changed

Lines changed: 78 additions & 3 deletions

File tree

dm_control/mujoco/index.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,13 @@ def _get_size_name_to_element_names(model):
228228
assert None not in mocap_body_names
229229
size_name_to_element_names['nmocap'] = mocap_body_names
230230

231+
# Arrays with dimension `na` correspond to stateful actuators. MuJoCo's
232+
# compiler requires that these are always defined after stateless actuators,
233+
# so we only need the final `na` elements in the list of all actuator names.
234+
if model.na:
235+
all_actuator_names = size_name_to_element_names['nu']
236+
size_name_to_element_names['na'] = all_actuator_names[-model.na:]
237+
231238
return size_name_to_element_names
232239

233240

@@ -257,11 +264,10 @@ def make_axis_indexers(model):
257264
"""Returns a dict that maps size names to `Axis` indexers.
258265
259266
Args:
260-
model: An instance of `mjbindings.mjModelWrapper`.
267+
model: An instance of `mjbindings.MjModelWrapper`.
261268
262269
Returns:
263-
A `dict` mapping from a size name (e.g. `'nbody'`) to a `Axis`
264-
instance.
270+
A `dict` mapping from a size name (e.g. `'nbody'`) to an `Axis` instance.
265271
"""
266272

267273
size_name_to_element_names = _get_size_name_to_element_names(model)

dm_control/mujoco/index_test.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@
3333

3434
MODEL = assets.get_contents('cartpole.xml')
3535
MODEL_NO_NAMES = assets.get_contents('cartpole_no_names.xml')
36+
MODEL_3RD_ORDER_ACTUATORS = assets.get_contents(
37+
'model_with_third_order_actuators.xml')
38+
MODEL_INCORRECT_ACTUATOR_ORDER = assets.get_contents(
39+
'model_incorrect_actuator_order.xml')
3640

3741
FIELD_REPR = {
3842
'act': ('FieldIndexer(act):\n'
@@ -145,6 +149,42 @@ def testDataNamedIndexing(self, field_name, key, numeric_key):
145149
# indexing.
146150
np.testing.assert_array_equal(field[numeric_key], indexer[key])
147151

152+
@parameterized.parameters(
153+
# (field name, named index key, expected integer index key)
154+
('act', 'cylinder', 0),
155+
('act_dot', 'general', 1),
156+
('act', ['general', 'cylinder', 'general'], [1, 0, 1]))
157+
def testIndexThirdOrderActuators(self, field_name, key, numeric_key):
158+
model = wrapper.MjModel.from_xml_string(MODEL_3RD_ORDER_ACTUATORS)
159+
data = wrapper.MjData(model)
160+
size_to_axis_indexer = index.make_axis_indexers(model)
161+
data_indexers = index.struct_indexer(data, 'mjdata', size_to_axis_indexer)
162+
163+
indexer = getattr(data_indexers, field_name)
164+
field = getattr(data, field_name)
165+
166+
# Explicit check that the converted key matches the numeric key.
167+
converted_key = indexer._convert_key(key)
168+
self.assertIndexExpressionEqual(numeric_key, converted_key)
169+
170+
# This writes unique values to the underlying buffer to prevent false
171+
# negatives.
172+
field.flat[:] = np.arange(field.size)
173+
174+
# Check that the result of named indexing matches the result of numeric
175+
# indexing.
176+
np.testing.assert_array_equal(field[numeric_key], indexer[key])
177+
178+
def testIncorrectActuatorOrder(self):
179+
# Our indexing of third-order actuators relies on an undocumented
180+
# requirement of MuJoCo's compiler that all third-order actuators come after
181+
# all second-order actuators. This test ensures that the rule still holds
182+
# (e.g. in future versions of MuJoCo).
183+
with self.assertRaisesRegexp(
184+
wrapper.Error,
185+
'2nd-order actuators must come before 3rd-order'):
186+
wrapper.MjModel.from_xml_string(MODEL_INCORRECT_ACTUATOR_ORDER)
187+
148188
@parameterized.parameters(
149189
# (field name, named index key)
150190
('xpos', 'pole'),
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
<mujoco>
2+
<worldbody>
3+
<body>
4+
<geom type="sphere" size="0.1"/>
5+
<joint type="slide" name="slide_joint"/>
6+
</body>
7+
</worldbody>
8+
<actuator>
9+
<!-- 3rd-order actuator preceding a 2nd-order actuator, which is disallowed by the compiler -->
10+
<cylinder name="cylinder" joint="slide_joint"/>
11+
<motor name="motor" joint="slide_joint"/>
12+
</actuator>
13+
</mujoco>
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
<mujoco>
2+
<worldbody>
3+
<body>
4+
<geom type="sphere" size="0.1"/>
5+
<joint type="slide" name="slide_joint"/>
6+
</body>
7+
</worldbody>
8+
<actuator>
9+
<!-- Second-order actuators -->
10+
<motor name="motor" joint="slide_joint"/>
11+
<velocity name="velocity" joint="slide_joint"/>
12+
<!-- Third-order actuators -->
13+
<cylinder name="cylinder" joint="slide_joint"/>
14+
<general name="general" joint="slide_joint" dyntype="integrator" biastype="affine" dynprm="1 0 0"/>
15+
</actuator>
16+
</mujoco>

0 commit comments

Comments
 (0)