Skip to content

Commit 409a057

Browse files
yuvaltassacopybara-github
authored andcommitted
Suppress exceptions due to physics warnings when initializing observation updaters
In order to initialize the buffers used by the observation updater we need the shape and dtype of the array returned by each observation callable. We usually get this by calling the observation callable and inspecting the array it returns. For `MJCFFeature` observations, this may involve calling `physics.forward` if the feature depends on a field that requires recalculation due to other changes to the physics state. However, at the point when this happens there is no guarantee that the physics state will be valid, since it has not yet been fully initialized by `initialize_episode`. Because of this, the call to `physics.forward` might raise a `PhysicsError`. In practice it does not matter whether or not the physics state is valid, since we only care about the shape and dtype of the observation array, not its contents. We therefore suppress `PhysicsError`s originating from this initial call to the observation callable. PiperOrigin-RevId: 283497708 Change-Id: I0c7b1bf1e968dadfe8101fbcbe072b1687f5a4de
1 parent 4cb1c61 commit 409a057

5 files changed

Lines changed: 68 additions & 8 deletions

File tree

dm_control/composer/observation/fake_physics.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from __future__ import division
2020
from __future__ import print_function
2121

22+
import contextlib
23+
2224
from dm_control.composer.observation import observable
2325
from dm_control.rl import control
2426
import numpy as np
@@ -68,3 +70,7 @@ def reset(self):
6870

6971
def after_reset(self):
7072
pass
73+
74+
@contextlib.contextmanager
75+
def suppress_physics_errors(self):
76+
yield

dm_control/composer/observation/updater.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,16 @@ def __init__(self, observable, physics, random_state,
4646

4747
obs_spec = self.observable.array_spec
4848
if obs_spec is None:
49-
# We take an observation here to determine the shape and size.
49+
# We take an observation to determine the shape and dtype of the array.
5050
# This occurs outside of an episode and doesn't affect environment
51-
# behavior.
52-
obs_value = np.array(self.observation_callable())
53-
obs_spec = specs.Array(shape=obs_value.shape, dtype=obs_value.dtype)
51+
# behavior. At this point the physics state is not guaranteed to be valid,
52+
# so we might get a `PhysicsError` if the observation callable calls
53+
# `physics.forward`. We suppress such errors since they do not matter as
54+
# far as the shape and dtype of the observation are concerned.
55+
with physics.suppress_physics_errors():
56+
obs_array = self.observation_callable()
57+
obs_array = np.asarray(obs_array)
58+
obs_spec = specs.Array(shape=obs_array.shape, dtype=obs_array.dtype)
5459
self.buffer = obs_buffer.Buffer(
5560
buffer_size=(observable.buffer_size or DEFAULT_BUFFER_SIZE),
5661
shape=obs_spec.shape, dtype=obs_spec.dtype,

dm_control/composer/observation/updater_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ def testNestedSpecsAndValues(self, list_or_tuple):
6565
observables[1]['five'].enabled = True
6666

6767
observation_updater = updater.Updater(observables)
68-
observation_updater.reset(physics=None, random_state=None)
68+
observation_updater.reset(physics=fake_physics.FakePhysics(),
69+
random_state=None)
6970

7071
def make_spec(obs):
7172
array = np.array(obs.observation_callable(None, None)())

dm_control/mujoco/engine.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,19 @@ def __init__(self, data):
120120
Args:
121121
data: Instance of `wrapper.MjData`.
122122
"""
123+
self._warnings_cause_exception = True
123124
self._reload_from_data(data)
124125

126+
@contextlib.contextmanager
127+
def suppress_physics_errors(self):
128+
"""Physics warnings will be logged rather than raise exceptions."""
129+
prev_state = self._warnings_cause_exception
130+
self._warnings_cause_exception = False
131+
try:
132+
yield
133+
finally:
134+
self._warnings_cause_exception = prev_state
135+
125136
def enable_profiling(self):
126137
"""Enables Mujoco timing profiling."""
127138
wrapper.enable_timer(True)
@@ -263,15 +274,28 @@ def forward(self):
263274

264275
@contextlib.contextmanager
265276
def check_invalid_state(self):
266-
"""Raises a `base.PhysicsError` if the simulation state is invalid."""
277+
"""Checks whether the physics state is invalid at exit.
278+
279+
Yields:
280+
None
281+
282+
Raises:
283+
PhysicsError: if the simulation state is invalid at exit, unless this
284+
context is nested inside a `suppress_physics_errors` context, in which
285+
case a warning will be logged instead.
286+
"""
267287
# `np.copyto(dst, src)` is marginally faster than `dst[:] = src`.
268288
np.copyto(self._warnings_before, self._warnings)
269289
yield
270290
np.greater(self._warnings, self._warnings_before, out=self._new_warnings)
271291
if any(self._new_warnings):
272292
warning_names = np.compress(self._new_warnings, enums.mjtWarning._fields)
273-
raise _control.PhysicsError(
274-
_INVALID_PHYSICS_STATE.format(warning_names=', '.join(warning_names)))
293+
message = _INVALID_PHYSICS_STATE.format(
294+
warning_names=', '.join(warning_names))
295+
if self._warnings_cause_exception:
296+
raise _control.PhysicsError(message)
297+
else:
298+
logging.warn(message)
275299

276300
def __getstate__(self):
277301
return self.data # All state is assumed to reside within `self.data`.

dm_control/mujoco/engine_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,30 @@ def testNanControl(self):
346346
self._physics.data.ctrl[0] = float('nan')
347347
self._physics.step()
348348

349+
def testSuppressPhysicsError(self):
350+
bad_value = float('nan')
351+
message = engine._INVALID_PHYSICS_STATE.format(
352+
warning_names='mjWARN_BADCTRL')
353+
354+
def assert_physics_error():
355+
self._physics.data.ctrl[0] = bad_value
356+
with self.assertRaisesWithLiteralMatch(control.PhysicsError, message):
357+
self._physics.forward()
358+
359+
def assert_warning():
360+
self._physics.data.ctrl[0] = bad_value
361+
with mock.patch.object(engine.logging, 'warn') as mock_warn:
362+
self._physics.forward()
363+
mock_warn.assert_called_once_with(message)
364+
365+
assert_physics_error()
366+
with self._physics.suppress_physics_errors():
367+
assert_warning()
368+
with self._physics.suppress_physics_errors():
369+
assert_warning()
370+
assert_warning()
371+
assert_physics_error()
372+
349373
@parameterized.named_parameters(
350374
('_copy', lambda x: x.copy()),
351375
('_pickle_and_unpickle', lambda x: cPickle.loads(cPickle.dumps(x))),

0 commit comments

Comments
 (0)