Skip to content

Commit e066d77

Browse files
yuvaltassaalimuldal
authored andcommitted
Add reset-to-keyframe functionality to Physics.reset()
PiperOrigin-RevId: 309392225 Change-Id: Id3425c8c057a6ac6d29e2bb8c6b31b1b409b2a04
1 parent 0331cd6 commit e066d77

4 files changed

Lines changed: 41 additions & 4 deletions

File tree

dm_control/mujoco/engine.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@
8383
_RENDER_FLAG_OVERRIDES_NOT_SUPPORTED_FOR_DEPTH_OR_SEGMENTATION = (
8484
'`render_flag_overrides` are not supported for depth or segmentation '
8585
'rendering.')
86+
_KEYFRAME_ID_OUT_OF_RANGE = (
87+
'`keyframe_id` must be between 0 and {max_valid} inclusive, got: {actual}.')
8688

8789

8890
class Physics(_control.Physics):
@@ -267,9 +269,22 @@ def copy(self, share_model=False):
267269
new_obj._reload_from_data(new_data) # pylint: disable=protected-access
268270
return new_obj
269271

270-
def reset(self):
271-
"""Resets internal variables of the physics simulation."""
272-
mjlib.mj_resetData(self.model.ptr, self.data.ptr)
272+
def reset(self, keyframe_id=None):
273+
"""Resets internal variables of the simulation, possibly to a keyframe.
274+
275+
Args:
276+
keyframe_id: Optional int. If specified, the keyframe (saved state) to
277+
which to set the state.
278+
279+
"""
280+
if keyframe_id is None:
281+
mjlib.mj_resetData(self.model.ptr, self.data.ptr)
282+
else:
283+
if not 0 <= keyframe_id < self.model.nkey:
284+
raise ValueError(_KEYFRAME_ID_OUT_OF_RANGE.format(
285+
max_valid=self.model.nkey-1, actual=keyframe_id))
286+
mjlib.mj_resetDataKeyframe(self.model.ptr, self.data.ptr, keyframe_id)
287+
273288
# Disable actuation since we don't yet have meaningful control inputs.
274289
with self.model.disable('actuation'):
275290
self.forward()

dm_control/mujoco/engine_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,22 @@ def testNamedIndexing(self):
311311
def testReload(self):
312312
self._physics.reload_from_xml_path(MODEL_PATH)
313313

314+
def testReset(self):
315+
self._physics.reset()
316+
self.assertEqual(self._physics.data.qpos[1], 0)
317+
keyframe_id = 0
318+
self._physics.reset(keyframe_id=keyframe_id)
319+
self.assertEqual(self._physics.data.qpos[1],
320+
self._physics.model.key_qpos[keyframe_id, 1])
321+
out_of_range = [-1, 3]
322+
max_valid = self._physics.model.nkey - 1
323+
for actual in out_of_range:
324+
with self.assertRaisesWithLiteralMatch(
325+
ValueError,
326+
engine._KEYFRAME_ID_OUT_OF_RANGE.format(
327+
max_valid=max_valid, actual=actual)):
328+
self._physics.reset(keyframe_id=actual)
329+
314330
def testLoadAndReloadFromStringWithAssets(self):
315331
physics = engine.Physics.from_xml_string(
316332
MODEL_WITH_ASSETS, assets=ASSETS)

dm_control/mujoco/testing/assets/cartpole.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ Actuators (name/actuator/parameter):
1414
<mujoco model='test_cartpole'>
1515
<compiler inertiafromgeom='true' coordinate='local'/>
1616

17+
<size nkey="1"/>
18+
1719
<custom>
1820
<numeric name="control_timestep" data="0.04" />
1921
<numeric name="three_numbers" data="1.0 2.0 3.0" />
@@ -66,4 +68,8 @@ Actuators (name/actuator/parameter):
6668
<touch name="collision" site="cart sensor"/>
6769
</sensor>
6870

71+
<keyframe>
72+
<key name="hanging_down" qpos="0 1.57"/>
73+
</keyframe>
74+
6975
</mujoco>

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def find_data_files(package_dir, patterns):
166166

167167
setup(
168168
name='dm_control',
169-
version='0.0.309253177',
169+
version='0.0.309392225',
170170
description='Continuous control environments and MuJoCo Python bindings.',
171171
author='DeepMind',
172172
license='Apache License, Version 2.0',

0 commit comments

Comments
 (0)