Skip to content

Commit ea433b4

Browse files
DeepMindalimuldal
authored andcommitted
Add support for multi-array action_specs in viewer.
PiperOrigin-RevId: 230801344
1 parent 314ce15 commit ea433b4

2 files changed

Lines changed: 58 additions & 8 deletions

File tree

dm_control/viewer/runtime.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2018 The dm_control Authors.
1+
# Copyright 2018-2019 The dm_control Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -42,11 +42,16 @@ def _get_default_action(action_spec):
4242
* For unbounded intervals this will be zero.
4343
4444
Args:
45-
action_spec: An instance of `BoundedArraySpec`.
45+
action_spec: An instance of `BoundedArraySpec` or a list or tuple
46+
containing these.
4647
4748
Returns:
48-
A numpy array of actions.
49+
A numpy array of actions if `action_spec` is a single `BoundedArraySpec`, or
50+
a tuple of such arrays if `action_spec` is a list or tuple.
4951
"""
52+
if isinstance(action_spec, (list, tuple)):
53+
return tuple(_get_default_action(spec) for spec in action_spec)
54+
5055
minimum = np.broadcast_to(action_spec.minimum, action_spec.shape)
5156
maximum = np.broadcast_to(action_spec.maximum, action_spec.shape)
5257
left_bounded = np.isfinite(minimum)
@@ -55,7 +60,9 @@ def _get_default_action(action_spec):
5560
condlist=[left_bounded & right_bounded, left_bounded, right_bounded],
5661
choicelist=[0.5 * (minimum + maximum), minimum, maximum],
5762
default=0.)
58-
return action.astype(action_spec.dtype, copy=False)
63+
action = action.astype(action_spec.dtype, copy=False)
64+
action.flags.writeable = False
65+
return action
5966

6067

6168
class State(enum.Enum):
@@ -233,7 +240,7 @@ def _step(self):
233240
if self._policy:
234241
action = self._policy(self._time_step)
235242
else:
236-
action = self._default_action.copy()
243+
action = self._default_action
237244
self._time_step = self._env.step(action)
238245
self._last_action = action
239246
finished = self._time_step.last()

dm_control/viewer/runtime_test.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2018 The dm_control Authors.
1+
# Copyright 2018-2019 The dm_control Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -23,6 +23,7 @@
2323
from dm_control.viewer import runtime
2424
import mock
2525
import numpy as np
26+
from six.moves import zip
2627
from dm_control.rl import environment
2728
from dm_control.rl import specs
2829

@@ -288,8 +289,7 @@ def test_step_without_policy(self):
288289
runtime.__name__ + '._get_default_action') as mock_get_default_action:
289290
this_runtime = runtime.Runtime(environment=self.env, policy=None)
290291
this_runtime._step()
291-
self.env.step.assert_called_once_with(
292-
mock_get_default_action.return_value.copy())
292+
self.env.step.assert_called_once_with(mock_get_default_action.return_value)
293293

294294
def test_stepping_paused(self):
295295
with mock.patch(runtime.__name__ + '.mjlib') as mjlib:
@@ -336,5 +336,48 @@ def raise_exception(*unused_args, **unused_kwargs):
336336
self.assertTrue(finished)
337337

338338

339+
class DefaultActionFromSpecTest(parameterized.TestCase):
340+
341+
def assertNestedArraysEqual(self, expected, actual):
342+
"""Asserts that two potentially nested structures of arrays are equal."""
343+
if isinstance(expected, (list, tuple)):
344+
self.assertIsInstance(actual, (list, tuple))
345+
self.assertLen(actual, len(expected))
346+
for expected_item, actual_item in zip(expected, actual):
347+
self.assertNestedArraysEqual(expected_item, actual_item)
348+
else:
349+
np.testing.assert_array_equal(expected, actual)
350+
351+
_SHAPE = (2,)
352+
_DTYPE = np.float64
353+
_ACTION = np.zeros(_SHAPE)
354+
_ACTION_SPEC = specs.BoundedArraySpec(_SHAPE, np.float64, -1, 1)
355+
356+
@parameterized.named_parameters(
357+
('single_array', _ACTION_SPEC, _ACTION),
358+
('tuple', (_ACTION_SPEC, _ACTION_SPEC), (_ACTION, _ACTION)),
359+
('list', [_ACTION_SPEC, _ACTION_SPEC], (_ACTION, _ACTION)))
360+
def test_action_structure(self, action_spec, expected_action):
361+
self.assertNestedArraysEqual(expected_action,
362+
runtime._get_default_action(action_spec))
363+
364+
@parameterized.named_parameters(
365+
('closed',
366+
specs.BoundedArraySpec(_SHAPE, _DTYPE, minimum=1., maximum=2.),
367+
np.full(_SHAPE, fill_value=1.5, dtype=_DTYPE)),
368+
('left_open',
369+
specs.BoundedArraySpec(_SHAPE, _DTYPE, minimum=-np.inf, maximum=2.),
370+
np.full(_SHAPE, fill_value=2., dtype=_DTYPE)),
371+
('right_open',
372+
specs.BoundedArraySpec(_SHAPE, _DTYPE, minimum=1., maximum=np.inf),
373+
np.full(_SHAPE, fill_value=1., dtype=_DTYPE)),
374+
('unbounded',
375+
specs.BoundedArraySpec(_SHAPE, _DTYPE, minimum=-np.inf, maximum=np.inf),
376+
np.full(_SHAPE, fill_value=0., dtype=_DTYPE)))
377+
def test_action_spec_interval(self, action_spec, expected_action):
378+
self.assertNestedArraysEqual(expected_action,
379+
runtime._get_default_action(action_spec))
380+
381+
339382
if __name__ == '__main__':
340383
absltest.main()

0 commit comments

Comments
 (0)