Skip to content

Commit 78dd720

Browse files
yuvaltassaalimuldal
authored andcommitted
All MuJoCo warnings trigger an invalid state error.
PiperOrigin-RevId: 182929842
1 parent 8e0510b commit 78dd720

2 files changed

Lines changed: 12 additions & 20 deletions

File tree

dm_control/mujoco/engine.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,6 @@
7171
'bottom right': enums.mjtGridPos.mjGRID_BOTTOMRIGHT,
7272
}
7373

74-
_DIVERGENCE_WARNINGS = [
75-
enums.mjtWarning.mjWARN_INERTIA,
76-
enums.mjtWarning.mjWARN_BADQPOS,
77-
enums.mjtWarning.mjWARN_BADQVEL,
78-
enums.mjtWarning.mjWARN_BADQACC,
79-
enums.mjtWarning.mjWARN_BADCTRL,
80-
]
81-
8274
Contexts = collections.namedtuple('Contexts', ['gl', 'mujoco'])
8375
Selected = collections.namedtuple(
8476
'Selected', ['body', 'geom', 'world_position'])
@@ -146,7 +138,7 @@ def step(self):
146138
if self.model.opt.integrator != enums.mjtIntegrator.mjINT_EULER:
147139
mjlib.mj_step1(self.model.ptr, self.data.ptr)
148140

149-
self.check_divergence()
141+
self.check_invalid_state()
150142

151143
def render(self, height=240, width=320, camera_id=-1, overlays=(),
152144
depth=False, scene_option=None):
@@ -250,16 +242,16 @@ def forward(self):
250242
# http://www.mujoco.org/book/programming.html#siForward
251243
mjlib.mj_forward(self.model.ptr, self.data.ptr)
252244

253-
def check_divergence(self):
254-
"""Raises a `base.PhysicsError` if the simulation state is divergent."""
255-
warning_counts = [self.data.warning[i].number for i in _DIVERGENCE_WARNINGS]
245+
def check_invalid_state(self):
246+
"""Raises a `base.PhysicsError` if the simulation state is invalid."""
247+
warning_counts = [self.data.warning[i].number for i in
248+
xrange(enums.mjtWarning.mjNWARNING)]
256249
if any(warning_counts):
257250
warning_names = []
258251
for i in np.where(warning_counts)[0]:
259-
field_idx = _DIVERGENCE_WARNINGS[i]
260-
warning_names.append(enums.mjtWarning._fields[field_idx])
252+
warning_names.append(enums.mjtWarning._fields[i])
261253
raise _control.PhysicsError(
262-
'Physics state has diverged. Warning(s) raised: {}'.format(
254+
'Physics state is invalid. Warning(s) raised: {}'.format(
263255
', '.join(warning_names)))
264256

265257
def __getstate__(self):

dm_control/mujoco/engine_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -236,20 +236,20 @@ def testDivergenceException(self, warning_name):
236236
with self._physics.reset_context():
237237
self._physics.data.warning[warning_enum].number = 1
238238
with self.assertRaisesRegexp(control.PhysicsError, warning_name):
239-
self._physics.check_divergence()
239+
self._physics.check_invalid_state()
240240
self._physics.reset()
241-
self._physics.check_divergence()
241+
self._physics.check_invalid_state()
242242

243243
@parameterized.parameters(float('inf'), float('nan'), 1e15)
244244
def testBadQpos(self, bad_value):
245245
with self._physics.reset_context():
246246
self._physics.data.qpos[0] = bad_value
247247
mjlib.mj_checkPos(self._physics.model.ptr, self._physics.data.ptr)
248248
with self.assertRaises(control.PhysicsError):
249-
self._physics.check_divergence()
249+
self._physics.check_invalid_state()
250250
self._physics.reset()
251251
mjlib.mj_checkPos(self._physics.model.ptr, self._physics.data.ptr)
252-
self._physics.check_divergence()
252+
self._physics.check_invalid_state()
253253

254254
def testNanControl(self):
255255
with self._physics.reset_context():
@@ -258,7 +258,7 @@ def testNanControl(self):
258258
# Apply the controls.
259259
mjlib.mj_step(self._physics.model.ptr, self._physics.data.ptr)
260260
with self.assertRaisesRegexp(control.PhysicsError, 'mjWARN_BADCTRL'):
261-
self._physics.check_divergence()
261+
self._physics.check_invalid_state()
262262

263263
@parameterized.named_parameters(
264264
('_copy', lambda x: x.copy()),

0 commit comments

Comments
 (0)