|
1 | | -# Copyright 2018 The dm_control Authors. |
| 1 | +# Copyright 2018-2019 The dm_control Authors. |
2 | 2 | # |
3 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); |
4 | 4 | # you may not use this file except in compliance with the License. |
|
23 | 23 | from dm_control.viewer import runtime |
24 | 24 | import mock |
25 | 25 | import numpy as np |
| 26 | +from six.moves import zip |
26 | 27 | from dm_control.rl import environment |
27 | 28 | from dm_control.rl import specs |
28 | 29 |
|
@@ -288,8 +289,7 @@ def test_step_without_policy(self): |
288 | 289 | runtime.__name__ + '._get_default_action') as mock_get_default_action: |
289 | 290 | this_runtime = runtime.Runtime(environment=self.env, policy=None) |
290 | 291 | this_runtime._step() |
291 | | - self.env.step.assert_called_once_with( |
292 | | - mock_get_default_action.return_value.copy()) |
| 292 | + self.env.step.assert_called_once_with(mock_get_default_action.return_value) |
293 | 293 |
|
294 | 294 | def test_stepping_paused(self): |
295 | 295 | with mock.patch(runtime.__name__ + '.mjlib') as mjlib: |
@@ -336,5 +336,48 @@ def raise_exception(*unused_args, **unused_kwargs): |
336 | 336 | self.assertTrue(finished) |
337 | 337 |
|
338 | 338 |
|
| 339 | +class DefaultActionFromSpecTest(parameterized.TestCase): |
| 340 | + |
| 341 | + def assertNestedArraysEqual(self, expected, actual): |
| 342 | + """Asserts that two potentially nested structures of arrays are equal.""" |
| 343 | + if isinstance(expected, (list, tuple)): |
| 344 | + self.assertIsInstance(actual, (list, tuple)) |
| 345 | + self.assertLen(actual, len(expected)) |
| 346 | + for expected_item, actual_item in zip(expected, actual): |
| 347 | + self.assertNestedArraysEqual(expected_item, actual_item) |
| 348 | + else: |
| 349 | + np.testing.assert_array_equal(expected, actual) |
| 350 | + |
| 351 | + _SHAPE = (2,) |
| 352 | + _DTYPE = np.float64 |
| 353 | + _ACTION = np.zeros(_SHAPE) |
| 354 | + _ACTION_SPEC = specs.BoundedArraySpec(_SHAPE, np.float64, -1, 1) |
| 355 | + |
| 356 | + @parameterized.named_parameters( |
| 357 | + ('single_array', _ACTION_SPEC, _ACTION), |
| 358 | + ('tuple', (_ACTION_SPEC, _ACTION_SPEC), (_ACTION, _ACTION)), |
| 359 | + ('list', [_ACTION_SPEC, _ACTION_SPEC], (_ACTION, _ACTION))) |
| 360 | + def test_action_structure(self, action_spec, expected_action): |
| 361 | + self.assertNestedArraysEqual(expected_action, |
| 362 | + runtime._get_default_action(action_spec)) |
| 363 | + |
| 364 | + @parameterized.named_parameters( |
| 365 | + ('closed', |
| 366 | + specs.BoundedArraySpec(_SHAPE, _DTYPE, minimum=1., maximum=2.), |
| 367 | + np.full(_SHAPE, fill_value=1.5, dtype=_DTYPE)), |
| 368 | + ('left_open', |
| 369 | + specs.BoundedArraySpec(_SHAPE, _DTYPE, minimum=-np.inf, maximum=2.), |
| 370 | + np.full(_SHAPE, fill_value=2., dtype=_DTYPE)), |
| 371 | + ('right_open', |
| 372 | + specs.BoundedArraySpec(_SHAPE, _DTYPE, minimum=1., maximum=np.inf), |
| 373 | + np.full(_SHAPE, fill_value=1., dtype=_DTYPE)), |
| 374 | + ('unbounded', |
| 375 | + specs.BoundedArraySpec(_SHAPE, _DTYPE, minimum=-np.inf, maximum=np.inf), |
| 376 | + np.full(_SHAPE, fill_value=0., dtype=_DTYPE))) |
| 377 | + def test_action_spec_interval(self, action_spec, expected_action): |
| 378 | + self.assertNestedArraysEqual(expected_action, |
| 379 | + runtime._get_default_action(action_spec)) |
| 380 | + |
| 381 | + |
339 | 382 | if __name__ == '__main__': |
340 | 383 | absltest.main() |
0 commit comments