Skip to content

Commit 153b0da

Browse files
authored
automark add reasons to reason-less expectedFailures and handles parent better (#6936)
1 parent dd09f41 commit 153b0da

File tree

3 files changed

+692
-510
lines changed

3 files changed

+692
-510
lines changed

crates/codegen/src/ir.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1219,7 +1219,10 @@ pub(crate) fn label_exception_targets(blocks: &mut [Block]) {
12191219
preserve_lasti,
12201220
});
12211221
} else if is_pop {
1222-
debug_assert!(!stack.is_empty(), "POP_BLOCK with empty except stack at block {bi} instruction {i}");
1222+
debug_assert!(
1223+
!stack.is_empty(),
1224+
"POP_BLOCK with empty except stack at block {bi} instruction {i}"
1225+
);
12231226
stack.pop();
12241227
// POP_BLOCK → NOP
12251228
blocks[bi].instructions[i].instr = Instruction::Nop.into();

scripts/update_lib/cmd_auto_mark.py

Lines changed: 255 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,99 @@ def path_to_test_parts(path: str) -> list[str]:
250250
return parts[-2:]
251251

252252

253+
def _expand_stripped_to_children(
254+
contents: str,
255+
stripped_tests: set[tuple[str, str]],
256+
all_failing_tests: set[tuple[str, str]],
257+
) -> set[tuple[str, str]]:
258+
"""Find child-class failures that correspond to stripped parent-class markers.
259+
260+
When ``strip_reasonless_expected_failures`` removes a marker from a parent
261+
(mixin) class, test failures are reported against the concrete subclasses,
262+
not the parent itself. This function maps those child failures back so
263+
they get re-marked (and later consolidated to the parent by
264+
``_consolidate_to_parent``).
265+
266+
Returns the set of ``(class, method)`` pairs from *all_failing_tests* that
267+
should be re-marked.
268+
"""
269+
# Direct matches (stripped test itself is a concrete TestCase)
270+
result = stripped_tests & all_failing_tests
271+
272+
unmatched = stripped_tests - all_failing_tests
273+
if not unmatched:
274+
return result
275+
276+
tree = ast.parse(contents)
277+
class_bases, class_methods = _build_inheritance_info(tree)
278+
279+
for parent_cls, method_name in unmatched:
280+
if method_name not in class_methods.get(parent_cls, set()):
281+
continue
282+
for cls in _find_all_inheritors(
283+
parent_cls, method_name, class_bases, class_methods
284+
):
285+
if (cls, method_name) in all_failing_tests:
286+
result.add((cls, method_name))
287+
288+
return result
289+
290+
291+
def _consolidate_to_parent(
292+
contents: str,
293+
failing_tests: set[tuple[str, str]],
294+
error_messages: dict[tuple[str, str], str] | None = None,
295+
) -> tuple[set[tuple[str, str]], dict[tuple[str, str], str] | None]:
296+
"""Move failures to the parent class when ALL inheritors fail.
297+
298+
If every concrete subclass that inherits a method from a parent class
299+
appears in *failing_tests*, replace those per-subclass entries with a
300+
single entry on the parent. This avoids creating redundant super-call
301+
overrides in every child.
302+
303+
Returns:
304+
(consolidated_failing_tests, consolidated_error_messages)
305+
"""
306+
tree = ast.parse(contents)
307+
class_bases, class_methods = _build_inheritance_info(tree)
308+
309+
# Group by (defining_parent, method) → set of failing children
310+
from collections import defaultdict
311+
312+
groups: dict[tuple[str, str], set[str]] = defaultdict(set)
313+
for class_name, method_name in failing_tests:
314+
defining = _find_method_definition(
315+
class_name, method_name, class_bases, class_methods
316+
)
317+
if defining and defining != class_name:
318+
groups[(defining, method_name)].add(class_name)
319+
320+
if not groups:
321+
return failing_tests, error_messages
322+
323+
result = set(failing_tests)
324+
new_error_messages = dict(error_messages) if error_messages else {}
325+
326+
for (parent, method_name), failing_children in groups.items():
327+
all_inheritors = _find_all_inheritors(
328+
parent, method_name, class_bases, class_methods
329+
)
330+
331+
if all_inheritors and failing_children >= all_inheritors:
332+
# All inheritors fail → mark on parent instead
333+
children_keys = {(child, method_name) for child in failing_children}
334+
result -= children_keys
335+
result.add((parent, method_name))
336+
# Pick any child's error message for the parent
337+
if new_error_messages:
338+
for child in failing_children:
339+
msg = new_error_messages.pop((child, method_name), "")
340+
if msg:
341+
new_error_messages[(parent, method_name)] = msg
342+
343+
return result, new_error_messages or error_messages
344+
345+
253346
def build_patches(
254347
test_parts_set: set[tuple[str, str]],
255348
error_messages: dict[tuple[str, str], str] | None = None,
@@ -293,6 +386,24 @@ def _is_super_call_only(func_node: ast.FunctionDef | ast.AsyncFunctionDef) -> bo
293386
return True
294387

295388

389+
def _method_removal_range(
390+
func_node: ast.FunctionDef | ast.AsyncFunctionDef, lines: list[str]
391+
) -> range:
392+
"""Line range covering an entire method including decorators and a preceding COMMENT line."""
393+
first = (
394+
func_node.decorator_list[0].lineno - 1
395+
if func_node.decorator_list
396+
else func_node.lineno - 1
397+
)
398+
if (
399+
first > 0
400+
and lines[first - 1].strip().startswith("#")
401+
and COMMENT in lines[first - 1]
402+
):
403+
first -= 1
404+
return range(first, func_node.end_lineno)
405+
406+
296407
def _build_inheritance_info(tree: ast.Module) -> tuple[dict, dict]:
297408
"""
298409
Build inheritance information from AST.
@@ -348,6 +459,20 @@ def _find_method_definition(
348459
return None
349460

350461

462+
def _find_all_inheritors(
463+
parent: str, method_name: str, class_bases: dict, class_methods: dict
464+
) -> set[str]:
465+
"""Find all classes that inherit *method_name* from *parent* (not overriding it)."""
466+
return {
467+
cls
468+
for cls in class_bases
469+
if cls != parent
470+
and method_name not in class_methods.get(cls, set())
471+
and _find_method_definition(cls, method_name, class_bases, class_methods)
472+
== parent
473+
}
474+
475+
351476
def remove_expected_failures(
352477
contents: str, tests_to_remove: set[tuple[str, str]]
353478
) -> str:
@@ -383,15 +508,7 @@ def remove_expected_failures(
383508
remove_entire_method = _is_super_call_only(item)
384509

385510
if remove_entire_method:
386-
first_line = item.lineno - 1
387-
if item.decorator_list:
388-
first_line = item.decorator_list[0].lineno - 1
389-
if first_line > 0:
390-
prev_line = lines[first_line - 1].strip()
391-
if prev_line.startswith("#") and COMMENT in prev_line:
392-
first_line -= 1
393-
for i in range(first_line, item.end_lineno):
394-
lines_to_remove.add(i)
511+
lines_to_remove.update(_method_removal_range(item, lines))
395512
else:
396513
for dec in item.decorator_list:
397514
dec_line = dec.lineno - 1
@@ -406,11 +523,18 @@ def remove_expected_failures(
406523
and lines[dec_line - 1].strip().startswith("#")
407524
and COMMENT in lines[dec_line - 1]
408525
)
526+
has_comment_after = (
527+
dec_line + 1 < len(lines)
528+
and lines[dec_line + 1].strip().startswith("#")
529+
and COMMENT not in lines[dec_line + 1]
530+
)
409531

410532
if has_comment_on_line or has_comment_before:
411533
lines_to_remove.add(dec_line)
412534
if has_comment_before:
413535
lines_to_remove.add(dec_line - 1)
536+
if has_comment_after and has_comment_on_line:
537+
lines_to_remove.add(dec_line + 1)
414538

415539
for line_idx in sorted(lines_to_remove, reverse=True):
416540
del lines[line_idx]
@@ -481,12 +605,98 @@ def apply_test_changes(
481605
contents = remove_expected_failures(contents, unexpected_successes)
482606

483607
if failing_tests:
608+
failing_tests, error_messages = _consolidate_to_parent(
609+
contents, failing_tests, error_messages
610+
)
484611
patches = build_patches(failing_tests, error_messages)
485612
contents = apply_patches(contents, patches)
486613

487614
return contents
488615

489616

617+
def strip_reasonless_expected_failures(
618+
contents: str,
619+
) -> tuple[str, set[tuple[str, str]]]:
620+
"""Strip @expectedFailure decorators that have no failure reason.
621+
622+
Markers like ``@unittest.expectedFailure # TODO: RUSTPYTHON`` (without a
623+
reason after the semicolon) are removed so the tests fail normally during
624+
the next test run and error messages can be captured.
625+
626+
Returns:
627+
(modified_contents, stripped_tests) where stripped_tests is a set of
628+
(class_name, method_name) tuples whose markers were removed.
629+
"""
630+
tree = ast.parse(contents)
631+
lines = contents.splitlines()
632+
stripped_tests: set[tuple[str, str]] = set()
633+
lines_to_remove: set[int] = set()
634+
635+
for node in ast.walk(tree):
636+
if not isinstance(node, ast.ClassDef):
637+
continue
638+
for item in node.body:
639+
if not isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
640+
continue
641+
for dec in item.decorator_list:
642+
dec_line = dec.lineno - 1
643+
line_content = lines[dec_line]
644+
645+
if "expectedFailure" not in line_content:
646+
continue
647+
648+
has_comment_on_line = COMMENT in line_content
649+
has_comment_before = (
650+
dec_line > 0
651+
and lines[dec_line - 1].strip().startswith("#")
652+
and COMMENT in lines[dec_line - 1]
653+
)
654+
655+
if not has_comment_on_line and not has_comment_before:
656+
continue # not our marker
657+
658+
# Check if there's a reason (on either the decorator or before)
659+
for check_line in (
660+
line_content,
661+
lines[dec_line - 1] if has_comment_before else "",
662+
):
663+
match = re.search(rf"{COMMENT}(.*)", check_line)
664+
if match and match.group(1).strip(";:, "):
665+
break # has a reason, keep it
666+
else:
667+
# No reason found — strip this decorator
668+
stripped_tests.add((node.name, item.name))
669+
670+
if _is_super_call_only(item):
671+
# Remove entire super-call override (the method
672+
# exists only to apply the decorator; without it
673+
# the override is pointless and blocks parent
674+
# consolidation)
675+
lines_to_remove.update(_method_removal_range(item, lines))
676+
else:
677+
lines_to_remove.add(dec_line)
678+
679+
if has_comment_before:
680+
lines_to_remove.add(dec_line - 1)
681+
682+
# Also remove a reason-comment on the line after (old format)
683+
if (
684+
has_comment_on_line
685+
and dec_line + 1 < len(lines)
686+
and lines[dec_line + 1].strip().startswith("#")
687+
and COMMENT not in lines[dec_line + 1]
688+
):
689+
lines_to_remove.add(dec_line + 1)
690+
691+
if not lines_to_remove:
692+
return contents, stripped_tests
693+
694+
for idx in sorted(lines_to_remove, reverse=True):
695+
del lines[idx]
696+
697+
return "\n".join(lines) + "\n" if lines else "", stripped_tests
698+
699+
490700
def extract_test_methods(contents: str) -> set[tuple[str, str]]:
491701
"""
492702
Extract all test method names from file contents.
@@ -529,6 +739,13 @@ def auto_mark_file(
529739
if not test_path.exists():
530740
raise FileNotFoundError(f"File not found: {test_path}")
531741

742+
# Strip reason-less markers so those tests fail normally and we capture
743+
# their error messages during the test run.
744+
contents = test_path.read_text(encoding="utf-8")
745+
contents, stripped_tests = strip_reasonless_expected_failures(contents)
746+
if stripped_tests:
747+
test_path.write_text(contents, encoding="utf-8")
748+
532749
test_name = get_test_module_name(test_path)
533750
if verbose:
534751
print(f"Running test: {test_name}")
@@ -559,6 +776,13 @@ def auto_mark_file(
559776
else:
560777
failing_tests = set()
561778

779+
# Re-mark stripped tests that still fail (to restore markers with reasons).
780+
# Uses inheritance expansion: if a parent marker was stripped, child
781+
# failures are included so _consolidate_to_parent can re-mark the parent.
782+
failing_tests |= _expand_stripped_to_children(
783+
contents, stripped_tests, all_failing_tests
784+
)
785+
562786
regressions = all_failing_tests - failing_tests
563787

564788
if verbose:
@@ -626,6 +850,19 @@ def auto_mark_directory(
626850
if not test_dir.is_dir():
627851
raise ValueError(f"Not a directory: {test_dir}")
628852

853+
# Get all .py files in directory
854+
test_files = sorted(test_dir.glob("**/*.py"))
855+
856+
# Strip reason-less markers from ALL files before running tests so those
857+
# tests fail normally and we capture their error messages.
858+
stripped_per_file: dict[pathlib.Path, set[tuple[str, str]]] = {}
859+
for test_file in test_files:
860+
contents = test_file.read_text(encoding="utf-8")
861+
contents, stripped = strip_reasonless_expected_failures(contents)
862+
if stripped:
863+
test_file.write_text(contents, encoding="utf-8")
864+
stripped_per_file[test_file] = stripped
865+
629866
test_name = get_test_module_name(test_dir)
630867
if verbose:
631868
print(f"Running test: {test_name}")
@@ -644,9 +881,6 @@ def auto_mark_directory(
644881
total_regressions = 0
645882
all_regressions: list[tuple[str, str, str, str]] = []
646883

647-
# Get all .py files in directory
648-
test_files = sorted(test_dir.glob("**/*.py"))
649-
650884
for test_file in test_files:
651885
# Get module prefix for this file (e.g., "test_inspect.test_inspect")
652886
module_prefix = get_test_module_name(test_file)
@@ -671,6 +905,15 @@ def auto_mark_directory(
671905
else:
672906
failing_tests = set()
673907

908+
# Re-mark stripped tests that still fail (restore markers with reasons).
909+
# Uses inheritance expansion for parent→child mapping.
910+
stripped = stripped_per_file.get(test_file, set())
911+
if stripped:
912+
file_contents = test_file.read_text(encoding="utf-8")
913+
failing_tests |= _expand_stripped_to_children(
914+
file_contents, stripped, all_failing_tests
915+
)
916+
674917
regressions = all_failing_tests - failing_tests
675918

676919
if failing_tests or unexpected_successes:

0 commit comments

Comments
 (0)