@@ -250,6 +250,99 @@ def path_to_test_parts(path: str) -> list[str]:
250250 return parts [- 2 :]
251251
252252
253+ def _expand_stripped_to_children (
254+ contents : str ,
255+ stripped_tests : set [tuple [str , str ]],
256+ all_failing_tests : set [tuple [str , str ]],
257+ ) -> set [tuple [str , str ]]:
258+ """Find child-class failures that correspond to stripped parent-class markers.
259+
260+ When ``strip_reasonless_expected_failures`` removes a marker from a parent
261+ (mixin) class, test failures are reported against the concrete subclasses,
262+ not the parent itself. This function maps those child failures back so
263+ they get re-marked (and later consolidated to the parent by
264+ ``_consolidate_to_parent``).
265+
266+ Returns the set of ``(class, method)`` pairs from *all_failing_tests* that
267+ should be re-marked.
268+ """
269+ # Direct matches (stripped test itself is a concrete TestCase)
270+ result = stripped_tests & all_failing_tests
271+
272+ unmatched = stripped_tests - all_failing_tests
273+ if not unmatched :
274+ return result
275+
276+ tree = ast .parse (contents )
277+ class_bases , class_methods = _build_inheritance_info (tree )
278+
279+ for parent_cls , method_name in unmatched :
280+ if method_name not in class_methods .get (parent_cls , set ()):
281+ continue
282+ for cls in _find_all_inheritors (
283+ parent_cls , method_name , class_bases , class_methods
284+ ):
285+ if (cls , method_name ) in all_failing_tests :
286+ result .add ((cls , method_name ))
287+
288+ return result
289+
290+
291+ def _consolidate_to_parent (
292+ contents : str ,
293+ failing_tests : set [tuple [str , str ]],
294+ error_messages : dict [tuple [str , str ], str ] | None = None ,
295+ ) -> tuple [set [tuple [str , str ]], dict [tuple [str , str ], str ] | None ]:
296+ """Move failures to the parent class when ALL inheritors fail.
297+
298+ If every concrete subclass that inherits a method from a parent class
299+ appears in *failing_tests*, replace those per-subclass entries with a
300+ single entry on the parent. This avoids creating redundant super-call
301+ overrides in every child.
302+
303+ Returns:
304+ (consolidated_failing_tests, consolidated_error_messages)
305+ """
306+ tree = ast .parse (contents )
307+ class_bases , class_methods = _build_inheritance_info (tree )
308+
309+ # Group by (defining_parent, method) → set of failing children
310+ from collections import defaultdict
311+
312+ groups : dict [tuple [str , str ], set [str ]] = defaultdict (set )
313+ for class_name , method_name in failing_tests :
314+ defining = _find_method_definition (
315+ class_name , method_name , class_bases , class_methods
316+ )
317+ if defining and defining != class_name :
318+ groups [(defining , method_name )].add (class_name )
319+
320+ if not groups :
321+ return failing_tests , error_messages
322+
323+ result = set (failing_tests )
324+ new_error_messages = dict (error_messages ) if error_messages else {}
325+
326+ for (parent , method_name ), failing_children in groups .items ():
327+ all_inheritors = _find_all_inheritors (
328+ parent , method_name , class_bases , class_methods
329+ )
330+
331+ if all_inheritors and failing_children >= all_inheritors :
332+ # All inheritors fail → mark on parent instead
333+ children_keys = {(child , method_name ) for child in failing_children }
334+ result -= children_keys
335+ result .add ((parent , method_name ))
336+ # Pick any child's error message for the parent
337+ if new_error_messages :
338+ for child in failing_children :
339+ msg = new_error_messages .pop ((child , method_name ), "" )
340+ if msg :
341+ new_error_messages [(parent , method_name )] = msg
342+
343+ return result , new_error_messages or error_messages
344+
345+
253346def build_patches (
254347 test_parts_set : set [tuple [str , str ]],
255348 error_messages : dict [tuple [str , str ], str ] | None = None ,
@@ -293,6 +386,24 @@ def _is_super_call_only(func_node: ast.FunctionDef | ast.AsyncFunctionDef) -> bo
293386 return True
294387
295388
389+ def _method_removal_range (
390+ func_node : ast .FunctionDef | ast .AsyncFunctionDef , lines : list [str ]
391+ ) -> range :
392+ """Line range covering an entire method including decorators and a preceding COMMENT line."""
393+ first = (
394+ func_node .decorator_list [0 ].lineno - 1
395+ if func_node .decorator_list
396+ else func_node .lineno - 1
397+ )
398+ if (
399+ first > 0
400+ and lines [first - 1 ].strip ().startswith ("#" )
401+ and COMMENT in lines [first - 1 ]
402+ ):
403+ first -= 1
404+ return range (first , func_node .end_lineno )
405+
406+
296407def _build_inheritance_info (tree : ast .Module ) -> tuple [dict , dict ]:
297408 """
298409 Build inheritance information from AST.
@@ -348,6 +459,20 @@ def _find_method_definition(
348459 return None
349460
350461
462+ def _find_all_inheritors (
463+ parent : str , method_name : str , class_bases : dict , class_methods : dict
464+ ) -> set [str ]:
465+ """Find all classes that inherit *method_name* from *parent* (not overriding it)."""
466+ return {
467+ cls
468+ for cls in class_bases
469+ if cls != parent
470+ and method_name not in class_methods .get (cls , set ())
471+ and _find_method_definition (cls , method_name , class_bases , class_methods )
472+ == parent
473+ }
474+
475+
351476def remove_expected_failures (
352477 contents : str , tests_to_remove : set [tuple [str , str ]]
353478) -> str :
@@ -383,15 +508,7 @@ def remove_expected_failures(
383508 remove_entire_method = _is_super_call_only (item )
384509
385510 if remove_entire_method :
386- first_line = item .lineno - 1
387- if item .decorator_list :
388- first_line = item .decorator_list [0 ].lineno - 1
389- if first_line > 0 :
390- prev_line = lines [first_line - 1 ].strip ()
391- if prev_line .startswith ("#" ) and COMMENT in prev_line :
392- first_line -= 1
393- for i in range (first_line , item .end_lineno ):
394- lines_to_remove .add (i )
511+ lines_to_remove .update (_method_removal_range (item , lines ))
395512 else :
396513 for dec in item .decorator_list :
397514 dec_line = dec .lineno - 1
@@ -406,11 +523,18 @@ def remove_expected_failures(
406523 and lines [dec_line - 1 ].strip ().startswith ("#" )
407524 and COMMENT in lines [dec_line - 1 ]
408525 )
526+ has_comment_after = (
527+ dec_line + 1 < len (lines )
528+ and lines [dec_line + 1 ].strip ().startswith ("#" )
529+ and COMMENT not in lines [dec_line + 1 ]
530+ )
409531
410532 if has_comment_on_line or has_comment_before :
411533 lines_to_remove .add (dec_line )
412534 if has_comment_before :
413535 lines_to_remove .add (dec_line - 1 )
536+ if has_comment_after and has_comment_on_line :
537+ lines_to_remove .add (dec_line + 1 )
414538
415539 for line_idx in sorted (lines_to_remove , reverse = True ):
416540 del lines [line_idx ]
@@ -481,12 +605,98 @@ def apply_test_changes(
481605 contents = remove_expected_failures (contents , unexpected_successes )
482606
483607 if failing_tests :
608+ failing_tests , error_messages = _consolidate_to_parent (
609+ contents , failing_tests , error_messages
610+ )
484611 patches = build_patches (failing_tests , error_messages )
485612 contents = apply_patches (contents , patches )
486613
487614 return contents
488615
489616
617+ def strip_reasonless_expected_failures (
618+ contents : str ,
619+ ) -> tuple [str , set [tuple [str , str ]]]:
620+ """Strip @expectedFailure decorators that have no failure reason.
621+
622+ Markers like ``@unittest.expectedFailure # TODO: RUSTPYTHON`` (without a
623+ reason after the semicolon) are removed so the tests fail normally during
624+ the next test run and error messages can be captured.
625+
626+ Returns:
627+ (modified_contents, stripped_tests) where stripped_tests is a set of
628+ (class_name, method_name) tuples whose markers were removed.
629+ """
630+ tree = ast .parse (contents )
631+ lines = contents .splitlines ()
632+ stripped_tests : set [tuple [str , str ]] = set ()
633+ lines_to_remove : set [int ] = set ()
634+
635+ for node in ast .walk (tree ):
636+ if not isinstance (node , ast .ClassDef ):
637+ continue
638+ for item in node .body :
639+ if not isinstance (item , (ast .FunctionDef , ast .AsyncFunctionDef )):
640+ continue
641+ for dec in item .decorator_list :
642+ dec_line = dec .lineno - 1
643+ line_content = lines [dec_line ]
644+
645+ if "expectedFailure" not in line_content :
646+ continue
647+
648+ has_comment_on_line = COMMENT in line_content
649+ has_comment_before = (
650+ dec_line > 0
651+ and lines [dec_line - 1 ].strip ().startswith ("#" )
652+ and COMMENT in lines [dec_line - 1 ]
653+ )
654+
655+ if not has_comment_on_line and not has_comment_before :
656+ continue # not our marker
657+
658+ # Check if there's a reason (on either the decorator or before)
659+ for check_line in (
660+ line_content ,
661+ lines [dec_line - 1 ] if has_comment_before else "" ,
662+ ):
663+ match = re .search (rf"{ COMMENT } (.*)" , check_line )
664+ if match and match .group (1 ).strip (";:, " ):
665+ break # has a reason, keep it
666+ else :
667+ # No reason found — strip this decorator
668+ stripped_tests .add ((node .name , item .name ))
669+
670+ if _is_super_call_only (item ):
671+ # Remove entire super-call override (the method
672+ # exists only to apply the decorator; without it
673+ # the override is pointless and blocks parent
674+ # consolidation)
675+ lines_to_remove .update (_method_removal_range (item , lines ))
676+ else :
677+ lines_to_remove .add (dec_line )
678+
679+ if has_comment_before :
680+ lines_to_remove .add (dec_line - 1 )
681+
682+ # Also remove a reason-comment on the line after (old format)
683+ if (
684+ has_comment_on_line
685+ and dec_line + 1 < len (lines )
686+ and lines [dec_line + 1 ].strip ().startswith ("#" )
687+ and COMMENT not in lines [dec_line + 1 ]
688+ ):
689+ lines_to_remove .add (dec_line + 1 )
690+
691+ if not lines_to_remove :
692+ return contents , stripped_tests
693+
694+ for idx in sorted (lines_to_remove , reverse = True ):
695+ del lines [idx ]
696+
697+ return "\n " .join (lines ) + "\n " if lines else "" , stripped_tests
698+
699+
490700def extract_test_methods (contents : str ) -> set [tuple [str , str ]]:
491701 """
492702 Extract all test method names from file contents.
@@ -529,6 +739,13 @@ def auto_mark_file(
529739 if not test_path .exists ():
530740 raise FileNotFoundError (f"File not found: { test_path } " )
531741
742+ # Strip reason-less markers so those tests fail normally and we capture
743+ # their error messages during the test run.
744+ contents = test_path .read_text (encoding = "utf-8" )
745+ contents , stripped_tests = strip_reasonless_expected_failures (contents )
746+ if stripped_tests :
747+ test_path .write_text (contents , encoding = "utf-8" )
748+
532749 test_name = get_test_module_name (test_path )
533750 if verbose :
534751 print (f"Running test: { test_name } " )
@@ -559,6 +776,13 @@ def auto_mark_file(
559776 else :
560777 failing_tests = set ()
561778
779+ # Re-mark stripped tests that still fail (to restore markers with reasons).
780+ # Uses inheritance expansion: if a parent marker was stripped, child
781+ # failures are included so _consolidate_to_parent can re-mark the parent.
782+ failing_tests |= _expand_stripped_to_children (
783+ contents , stripped_tests , all_failing_tests
784+ )
785+
562786 regressions = all_failing_tests - failing_tests
563787
564788 if verbose :
@@ -626,6 +850,19 @@ def auto_mark_directory(
626850 if not test_dir .is_dir ():
627851 raise ValueError (f"Not a directory: { test_dir } " )
628852
853+ # Get all .py files in directory
854+ test_files = sorted (test_dir .glob ("**/*.py" ))
855+
856+ # Strip reason-less markers from ALL files before running tests so those
857+ # tests fail normally and we capture their error messages.
858+ stripped_per_file : dict [pathlib .Path , set [tuple [str , str ]]] = {}
859+ for test_file in test_files :
860+ contents = test_file .read_text (encoding = "utf-8" )
861+ contents , stripped = strip_reasonless_expected_failures (contents )
862+ if stripped :
863+ test_file .write_text (contents , encoding = "utf-8" )
864+ stripped_per_file [test_file ] = stripped
865+
629866 test_name = get_test_module_name (test_dir )
630867 if verbose :
631868 print (f"Running test: { test_name } " )
@@ -644,9 +881,6 @@ def auto_mark_directory(
644881 total_regressions = 0
645882 all_regressions : list [tuple [str , str , str , str ]] = []
646883
647- # Get all .py files in directory
648- test_files = sorted (test_dir .glob ("**/*.py" ))
649-
650884 for test_file in test_files :
651885 # Get module prefix for this file (e.g., "test_inspect.test_inspect")
652886 module_prefix = get_test_module_name (test_file )
@@ -671,6 +905,15 @@ def auto_mark_directory(
671905 else :
672906 failing_tests = set ()
673907
908+ # Re-mark stripped tests that still fail (restore markers with reasons).
909+ # Uses inheritance expansion for parent→child mapping.
910+ stripped = stripped_per_file .get (test_file , set ())
911+ if stripped :
912+ file_contents = test_file .read_text (encoding = "utf-8" )
913+ failing_tests |= _expand_stripped_to_children (
914+ file_contents , stripped , all_failing_tests
915+ )
916+
674917 regressions = all_failing_tests - failing_tests
675918
676919 if failing_tests or unexpected_successes :
0 commit comments