Skip to content

Commit 4b13923

Browse files
committed
Deprecate composer.Environment.step_spec, add reward_spec and discount_spec
PiperOrigin-RevId: 239807451
1 parent f2cf9ca commit 4b13923

2 files changed

Lines changed: 68 additions & 5 deletions

File tree

dm_control/composer/environment.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,11 @@ def _reset_attempt(self):
331331
discount=None,
332332
observation=self._observation_updater.get_observation())
333333

334+
# TODO(b/129061424): Remove this method.
334335
def step_spec(self):
336+
"""DEPRECATED: please use `reward_spec` and `discount_spec` instead."""
337+
warnings.warn('`step_spec` is deprecated, please use `reward_spec` and '
338+
'`discount_spec` instead.', DeprecationWarning)
335339
if (self._task.get_reward_spec() is None or
336340
self._task.get_discount_spec() is None):
337341
raise NotImplementedError
@@ -400,6 +404,38 @@ def action_spec(self):
400404
"""Returns the action specification for this environment."""
401405
return self._task.action_spec(self._physics_proxy)
402406

407+
def reward_spec(self):
408+
"""Describes the reward returned by this environment.
409+
410+
This will be the output of `self.task.reward_spec()` if it is not None,
411+
otherwise it will be the default spec returned by
412+
`environment.Base.reward_spec()`.
413+
414+
Returns:
415+
An `ArraySpec`, or a nested dict, list or tuple of `ArraySpec`s.
416+
"""
417+
task_reward_spec = self._task.get_reward_spec()
418+
if task_reward_spec is not None:
419+
return task_reward_spec
420+
else:
421+
return super(Environment, self).reward_spec()
422+
423+
def discount_spec(self):
424+
"""Describes the discount returned by this environment.
425+
426+
This will be the output of `self.task.discount_spec()` if it is not None,
427+
otherwise it will be the default spec returned by
428+
`environment.Base.discount_spec()`.
429+
430+
Returns:
431+
An `ArraySpec`, or a nested dict, list or tuple of `ArraySpec`s.
432+
"""
433+
task_discount_spec = self._task.get_discount_spec()
434+
if task_discount_spec is not None:
435+
return task_discount_spec
436+
else:
437+
return super(Environment, self).discount_spec()
438+
403439
def observation_spec(self):
404440
"""Returns the observation specification for this environment.
405441

dm_control/composer/environment_test.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,26 @@
2121

2222
# Internal dependencies.
2323
from absl.testing import absltest
24+
from absl.testing import parameterized
2425
from dm_control import composer
2526
from dm_control import mjcf
27+
import mock
2628
from six.moves import range
2729

2830

29-
class TaskWithResetFailures(composer.NullTask):
31+
class DummyTask(composer.NullTask):
32+
33+
def __init__(self):
34+
null_entity = composer.ModelWrapperEntity(mjcf.RootElement())
35+
super(DummyTask, self).__init__(null_entity)
36+
37+
38+
class DummyTaskWithResetFailures(DummyTask):
3039

3140
def __init__(self, num_reset_failures):
41+
super(DummyTaskWithResetFailures, self).__init__()
3242
self.num_reset_failures = num_reset_failures
3343
self.reset_counter = 0
34-
null_entity = composer.ModelWrapperEntity(mjcf.RootElement())
35-
super(TaskWithResetFailures, self).__init__(null_entity)
3644

3745
def initialize_episode_mjcf(self, random_state):
3846
self.reset_counter += 1
@@ -42,19 +50,38 @@ def initialize_episode(self, physics, random_state):
4250
raise composer.EpisodeInitializationError()
4351

4452

45-
class EnvironmentTest(absltest.TestCase):
53+
class EnvironmentTest(parameterized.TestCase):
4654

4755
def test_failed_resets(self):
4856
total_reset_failures = 5
4957
env_reset_attempts = 2
50-
task = TaskWithResetFailures(num_reset_failures=total_reset_failures)
58+
task = DummyTaskWithResetFailures(num_reset_failures=total_reset_failures)
5159
env = composer.Environment(task, max_reset_attempts=env_reset_attempts)
5260
for _ in range(total_reset_failures // env_reset_attempts):
5361
with self.assertRaises(composer.EpisodeInitializationError):
5462
env.reset()
5563
env.reset() # should not raise an exception
5664
self.assertEqual(task.reset_counter, total_reset_failures + 1)
5765

66+
@parameterized.parameters(
67+
dict(name='reward_spec', defined_in_task=True),
68+
dict(name='reward_spec', defined_in_task=False),
69+
dict(name='discount_spec', defined_in_task=True),
70+
dict(name='discount_spec', defined_in_task=False))
71+
def test_get_spec(self, name, defined_in_task):
72+
task = DummyTask()
73+
env = composer.Environment(task)
74+
with mock.patch.object(task, 'get_' + name) as mock_task_get_spec:
75+
if defined_in_task:
76+
expected_spec = mock.Mock()
77+
mock_task_get_spec.return_value = expected_spec
78+
else:
79+
expected_spec = getattr(super(composer.Environment, env), name)()
80+
mock_task_get_spec.return_value = None
81+
spec = getattr(env, name)()
82+
mock_task_get_spec.assert_called_once_with()
83+
self.assertSameStructure(spec, expected_spec)
84+
5885

5986
if __name__ == '__main__':
6087
absltest.main()

0 commit comments

Comments
 (0)