2121
2222# Internal dependencies.
2323from absl .testing import absltest
24+ from absl .testing import parameterized
2425from dm_control import composer
2526from dm_control import mjcf
27+ import mock
2628from six .moves import range
2729
2830
29- class TaskWithResetFailures (composer .NullTask ):
31+ class DummyTask (composer .NullTask ):
32+
33+ def __init__ (self ):
34+ null_entity = composer .ModelWrapperEntity (mjcf .RootElement ())
35+ super (DummyTask , self ).__init__ (null_entity )
36+
37+
38+ class DummyTaskWithResetFailures (DummyTask ):
3039
3140 def __init__ (self , num_reset_failures ):
41+ super (DummyTaskWithResetFailures , self ).__init__ ()
3242 self .num_reset_failures = num_reset_failures
3343 self .reset_counter = 0
34- null_entity = composer .ModelWrapperEntity (mjcf .RootElement ())
35- super (TaskWithResetFailures , self ).__init__ (null_entity )
3644
3745 def initialize_episode_mjcf (self , random_state ):
3846 self .reset_counter += 1
@@ -42,19 +50,38 @@ def initialize_episode(self, physics, random_state):
4250 raise composer .EpisodeInitializationError ()
4351
4452
45- class EnvironmentTest (absltest .TestCase ):
53+ class EnvironmentTest (parameterized .TestCase ):
4654
4755 def test_failed_resets (self ):
4856 total_reset_failures = 5
4957 env_reset_attempts = 2
50- task = TaskWithResetFailures (num_reset_failures = total_reset_failures )
58+ task = DummyTaskWithResetFailures (num_reset_failures = total_reset_failures )
5159 env = composer .Environment (task , max_reset_attempts = env_reset_attempts )
5260 for _ in range (total_reset_failures // env_reset_attempts ):
5361 with self .assertRaises (composer .EpisodeInitializationError ):
5462 env .reset ()
5563 env .reset () # should not raise an exception
5664 self .assertEqual (task .reset_counter , total_reset_failures + 1 )
5765
66+ @parameterized .parameters (
67+ dict (name = 'reward_spec' , defined_in_task = True ),
68+ dict (name = 'reward_spec' , defined_in_task = False ),
69+ dict (name = 'discount_spec' , defined_in_task = True ),
70+ dict (name = 'discount_spec' , defined_in_task = False ))
71+ def test_get_spec (self , name , defined_in_task ):
72+ task = DummyTask ()
73+ env = composer .Environment (task )
74+ with mock .patch .object (task , 'get_' + name ) as mock_task_get_spec :
75+ if defined_in_task :
76+ expected_spec = mock .Mock ()
77+ mock_task_get_spec .return_value = expected_spec
78+ else :
79+ expected_spec = getattr (super (composer .Environment , env ), name )()
80+ mock_task_get_spec .return_value = None
81+ spec = getattr (env , name )()
82+ mock_task_get_spec .assert_called_once_with ()
83+ self .assertSameStructure (spec , expected_spec )
84+
5885
5986if __name__ == '__main__' :
6087 absltest .main ()
0 commit comments