Skip to content

Commit 96db220

Browse files
committed
Fix run_threaded decorator so that it works with named_parameters
PiperOrigin-RevId: 182068715
1 parent 65a17c6 commit 96db220

2 files changed

Lines changed: 32 additions & 1 deletion

File tree

dm_control/mujoco/testing/decorators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from __future__ import division
2020
from __future__ import print_function
2121

22+
import functools
2223
import sys
2324
import threading
2425

@@ -42,6 +43,7 @@ def run_threaded(num_threads=4, calls_per_thread=10):
4243
"""
4344
def decorator(test_method):
4445
"""Decorator around the test method."""
46+
@functools.wraps(test_method) # Needed for `named_parameters` to work.
4547
def decorated_method(self, *args, **kwargs):
4648
"""Actual method this factory will return."""
4749
exceptions = []

dm_control/mujoco/testing/decorators_test.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
# Internal dependencies.
2323

2424
from absl.testing import absltest
25+
from absl.testing import parameterized
2526

2627
from dm_control.mujoco.testing import decorators
2728
import mock
@@ -42,7 +43,9 @@ def test_number_of_threads(self, mock_threading):
4243
mock_threading.Thread = mock.MagicMock(side_effect=mock_threads)
4344

4445
test_decorator = decorators.run_threaded(num_threads=num_threads)
45-
test_runner = test_decorator(mock.MagicMock())
46+
tested_method = mock.MagicMock()
47+
tested_method.__name__ = "foo"
48+
test_runner = test_decorator(tested_method)
4649
test_runner(self)
4750

4851
for thread in mock_threads:
@@ -53,13 +56,39 @@ def test_number_of_iterations(self):
5356
calls_per_thread = 5
5457

5558
tested_method = mock.MagicMock()
59+
tested_method.__name__ = "foo"
5660
test_decorator = decorators.run_threaded(
5761
num_threads=1, calls_per_thread=calls_per_thread)
5862
test_runner = test_decorator(tested_method)
5963
test_runner(self)
6064

6165
self.assertEqual(calls_per_thread, tested_method.call_count)
6266

67+
def test_works_with_named_parameters(self):
68+
69+
func = mock.MagicMock()
70+
names = ["foo", "bar", "baz"]
71+
params = [1, 2, 3]
72+
calls_per_thread = 2
73+
num_threads = 4
74+
75+
class FakeTest(parameterized.TestCase):
76+
77+
@parameterized.named_parameters(zip(names, params))
78+
@decorators.run_threaded(calls_per_thread=calls_per_thread,
79+
num_threads=num_threads)
80+
def test_method(self, param):
81+
func(param)
82+
83+
suite = absltest.TestLoader().loadTestsFromTestCase(FakeTest)
84+
suite.debug() # Run tests without collecting the output.
85+
86+
expected_call_count = len(params) * calls_per_thread * num_threads
87+
88+
self.assertEqual(func.call_count, expected_call_count)
89+
actual_params = {call[0][0] for call in func.call_args_list}
90+
self.assertSetEqual(set(params), actual_params)
91+
6392

6493
if __name__ == "__main__":
6594
absltest.main()

0 commit comments

Comments
 (0)