3131from six .moves import zip
3232
3333
34+ _DOMAINS_AND_TASKS = [
35+ dict (domain = domain , task = task ) for domain , task in suite .ALL_TASKS
36+ ]
37+
38+
3439def uniform_random_policy (action_spec , random = None ):
3540 lower_bounds = action_spec .minimum
3641 upper_bounds = action_spec .maximum
@@ -64,7 +69,7 @@ def make_trajectory(domain, task, seed, **trajectory_kwargs):
6469 return step_environment (env , policy , ** trajectory_kwargs )
6570
6671
67- class DomainTest (parameterized .TestCase ):
72+ class SuiteTest (parameterized .TestCase ):
6873 """Tests run on all the tasks registered."""
6974
7075 def test_constants (self ):
@@ -107,7 +112,7 @@ def _validate_control_range(self, lower_bounds, upper_bounds):
107112 for b in upper_bounds :
108113 self .assertEqual (b , 1.0 )
109114
110- @parameterized .parameters (* suite . ALL_TASKS )
115+ @parameterized .parameters (_DOMAINS_AND_TASKS )
111116 def test_components_have_names (self , domain , task ):
112117 env = suite .load (domain , task )
113118 model = env .physics .model
@@ -138,15 +143,15 @@ def test_components_have_names(self, domain, task):
138143 msg = 'Model {!r} contains unnamed {!r} with ID {}.'
139144 .format (model .name , object_type , idx ))
140145
141- @parameterized .parameters (* suite . ALL_TASKS )
146+ @parameterized .parameters (_DOMAINS_AND_TASKS )
142147 def test_model_has_at_least_2_cameras (self , domain , task ):
143148 env = suite .load (domain , task )
144149 model = env .physics .model
145150 self .assertGreaterEqual (model .ncam , 2 ,
146151 'Model {!r} should have at least 2 cameras, has {}.'
147152 .format (model .name , model .ncam ))
148153
149- @parameterized .parameters (* suite . ALL_TASKS )
154+ @parameterized .parameters (_DOMAINS_AND_TASKS )
150155 def test_task_conforms_to_spec (self , domain , task ):
151156 """Tests that the environment timesteps conform to specifications."""
152157 is_benchmark = (domain , task ) in suite .BENCHMARKING
@@ -167,7 +172,7 @@ def test_task_conforms_to_spec(self, domain, task):
167172 if is_benchmark :
168173 self ._validate_reward_range (time_step )
169174
170- @parameterized .parameters (* suite . ALL_TASKS )
175+ @parameterized .parameters (_DOMAINS_AND_TASKS )
171176 def test_environment_is_deterministic (self , domain , task ):
172177 """Tests that identical seeds and actions produce identical trajectories."""
173178 seed = 0
@@ -227,15 +232,15 @@ def test_visualize_reward(self, domain, task):
227232 mock_get_reward .assert_called_with (env .physics )
228233 self .assertCorrectColors (env .physics , reward = mock_get_reward .return_value )
229234
230- @parameterized .parameters (* suite . ALL_TASKS )
235+ @parameterized .parameters (_DOMAINS_AND_TASKS )
231236 def test_task_supports_environment_kwargs (self , domain , task ):
232237 env = suite .load (domain , task ,
233238 environment_kwargs = dict (flat_observation = True ))
234239 # Check that the kwargs are actually passed through to the environment.
235240 self .assertSetEqual (set (env .observation_spec ()),
236241 {control .FLAT_OBSERVATION_KEY })
237242
238- @parameterized .parameters (* suite . ALL_TASKS )
243+ @parameterized .parameters (_DOMAINS_AND_TASKS )
239244 def test_observation_arrays_dont_share_memory (self , domain , task ):
240245 env = suite .load (domain , task )
241246 first_timestep = env .reset ()
@@ -247,7 +252,7 @@ def test_observation_arrays_dont_share_memory(self, domain, task):
247252 np .may_share_memory (first_array , second_array ),
248253 msg = 'Consecutive observations of {!r} may share memory.' .format (name ))
249254
250- @parameterized .parameters (* suite . ALL_TASKS )
255+ @parameterized .parameters (_DOMAINS_AND_TASKS )
251256 def test_observations_dont_contain_constant_elements (self , domain , task ):
252257 env = suite .load (domain , task )
253258 trajectory = make_trajectory (domain = domain , task = task , seed = 0 ,
@@ -278,7 +283,7 @@ def test_observations_dont_contain_constant_elements(self, domain, task):
278283 .format ('\n ' .join (':\t ' .join ([name , str (is_constant )])
279284 for (name , is_constant ) in failures )))
280285
281- @parameterized .parameters (* suite . ALL_TASKS )
286+ @parameterized .parameters (_DOMAINS_AND_TASKS )
282287 def test_initial_state_is_randomized (self , domain , task ):
283288 env = suite .load (domain , task , task_kwargs = {'random' : 42 })
284289 obs1 = env .reset ().observation
0 commit comments