diff --git a/Lib/test/test_warnings/__init__.py b/Lib/test/test_warnings/__init__.py index 4b1b4e193cb165b..ef7e332a323c5fe 100644 --- a/Lib/test/test_warnings/__init__.py +++ b/Lib/test/test_warnings/__init__.py @@ -545,6 +545,29 @@ class NonWarningSubclass: self.module.warn('good warning category', MyWarningClass) self.assertIsInstance(cm.warning, Warning) + def test_simplefilter_invalid_category(self): + class MyWarningClass(Warning): + pass + + class NonWarningSubclass: + pass + + msg_regex = 'category must be a Warning subclass, not (.*)' + + with self.assertRaisesRegex(TypeError, msg_regex): + self.module.simplefilter('always', '') + + with self.assertRaisesRegex(TypeError, msg_regex): + self.module.simplefilter('always', NonWarningSubclass) + + with self.assertRaisesRegex(TypeError, msg_regex): + self.module.simplefilter('always', MyWarningClass()) + + with original_warnings.catch_warnings(module=self.module, record=True) as w: + self.module.simplefilter('always', MyWarningClass) + self.assertEqual(len(w), 0) + + class CWarnTests(WarnTests, unittest.TestCase): module = c_warnings diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py index 0d550204a7687fa..19664087566dc2e 100644 --- a/Lib/unittest/case.py +++ b/Lib/unittest/case.py @@ -274,7 +274,11 @@ def __enter__(self): v.__warningregistry__ = {} self.warnings_manager = warnings.catch_warnings(record=True) self.warnings = self.warnings_manager.__enter__() - warnings.simplefilter("always", self.expected) + if isinstance(self.expected, tuple): + for expected in self.expected: + warnings.simplefilter("always", expected) + else: + warnings.simplefilter("always", self.expected) return self def __exit__(self, exc_type, exc_value, tb): diff --git a/Lib/warnings.py b/Lib/warnings.py index 691ccddfa450ad2..ad46d5e349439d8 100644 --- a/Lib/warnings.py +++ b/Lib/warnings.py @@ -176,6 +176,9 @@ def simplefilter(action, category=Warning, lineno=0, append=False): "once"), "invalid action: %r" % (action,) assert isinstance(lineno, int) and lineno >= 0, \ "lineno must be an int >= 0" + if not (isinstance(category, type) and issubclass(category, Warning)): + raise TypeError('category must be a Warning subclass, ' + 'not {!r}'.format(category)) _add_filter(action, None, category, None, lineno, append=append) def _add_filter(*item, append): diff --git a/Misc/NEWS.d/next/Library/2022-03-18-20-48-24.bpo-16845.gWkP5A.rst b/Misc/NEWS.d/next/Library/2022-03-18-20-48-24.bpo-16845.gWkP5A.rst new file mode 100644 index 000000000000000..9d8b12044d576ae --- /dev/null +++ b/Misc/NEWS.d/next/Library/2022-03-18-20-48-24.bpo-16845.gWkP5A.rst @@ -0,0 +1 @@ +Validate the category of warnings.simplefilter sooner