2222# Internal dependencies.
2323
2424from absl .testing import absltest
25+ from absl .testing import parameterized
2526
2627from dm_control .mujoco .testing import decorators
2728import mock
@@ -42,7 +43,9 @@ def test_number_of_threads(self, mock_threading):
4243 mock_threading .Thread = mock .MagicMock (side_effect = mock_threads )
4344
4445 test_decorator = decorators .run_threaded (num_threads = num_threads )
45- test_runner = test_decorator (mock .MagicMock ())
46+ tested_method = mock .MagicMock ()
47+ tested_method .__name__ = "foo"
48+ test_runner = test_decorator (tested_method )
4649 test_runner (self )
4750
4851 for thread in mock_threads :
@@ -53,13 +56,39 @@ def test_number_of_iterations(self):
5356 calls_per_thread = 5
5457
5558 tested_method = mock .MagicMock ()
59+ tested_method .__name__ = "foo"
5660 test_decorator = decorators .run_threaded (
5761 num_threads = 1 , calls_per_thread = calls_per_thread )
5862 test_runner = test_decorator (tested_method )
5963 test_runner (self )
6064
6165 self .assertEqual (calls_per_thread , tested_method .call_count )
6266
67+ def test_works_with_named_parameters (self ):
68+
69+ func = mock .MagicMock ()
70+ names = ["foo" , "bar" , "baz" ]
71+ params = [1 , 2 , 3 ]
72+ calls_per_thread = 2
73+ num_threads = 4
74+
75+ class FakeTest (parameterized .TestCase ):
76+
77+ @parameterized .named_parameters (zip (names , params ))
78+ @decorators .run_threaded (calls_per_thread = calls_per_thread ,
79+ num_threads = num_threads )
80+ def test_method (self , param ):
81+ func (param )
82+
83+ suite = absltest .TestLoader ().loadTestsFromTestCase (FakeTest )
84+ suite .debug () # Run tests without collecting the output.
85+
86+ expected_call_count = len (params ) * calls_per_thread * num_threads
87+
88+ self .assertEqual (func .call_count , expected_call_count )
89+ actual_params = {call [0 ][0 ] for call in func .call_args_list }
90+ self .assertSetEqual (set (params ), actual_params )
91+
6392
6493if __name__ == "__main__" :
6594 absltest .main ()
0 commit comments