Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Refactor automark and its test
  • Loading branch information
youknowone committed Feb 1, 2026
commit 3acbf0c6125c73264d4e1ca73aef0c8fa08c7a1d
70 changes: 33 additions & 37 deletions scripts/update_lib/cmd_auto_mark.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,19 +277,10 @@ def _expand_stripped_to_children(
class_bases, class_methods = _build_inheritance_info(tree)

for parent_cls, method_name in unmatched:
# parent must actually define this method
if method_name not in class_methods.get(parent_cls, set()):
continue
for cls in class_bases:
if cls == parent_cls:
continue
if method_name in class_methods.get(cls, set()):
continue
if (
_find_method_definition(cls, method_name, class_bases, class_methods)
== parent_cls
and (cls, method_name) in all_failing_tests
):
for cls in _find_all_inheritors(parent_cls, method_name, class_bases, class_methods):
if (cls, method_name) in all_failing_tests:
result.add((cls, method_name))

return result
Expand Down Expand Up @@ -331,16 +322,7 @@ def _consolidate_to_parent(
new_error_messages = dict(error_messages) if error_messages else {}

for (parent, method_name), failing_children in groups.items():
# Find ALL classes that inherit this method from parent
all_inheritors: set[str] = set()
for cls_name in class_bases:
if cls_name == parent:
continue
# Skip if this class defines the method itself
if method_name in class_methods.get(cls_name, set()):
continue
if _find_method_definition(cls_name, method_name, class_bases, class_methods) == parent:
all_inheritors.add(cls_name)
all_inheritors = _find_all_inheritors(parent, method_name, class_bases, class_methods)

if all_inheritors and failing_children >= all_inheritors:
# All inheritors fail → mark on parent instead
Expand Down Expand Up @@ -400,6 +382,20 @@ def _is_super_call_only(func_node: ast.FunctionDef | ast.AsyncFunctionDef) -> bo
return True


def _method_removal_range(
func_node: ast.FunctionDef | ast.AsyncFunctionDef, lines: list[str]
) -> range:
"""Line range covering an entire method including decorators and a preceding COMMENT line."""
first = (
func_node.decorator_list[0].lineno - 1
if func_node.decorator_list
else func_node.lineno - 1
)
if first > 0 and lines[first - 1].strip().startswith("#") and COMMENT in lines[first - 1]:
first -= 1
return range(first, func_node.end_lineno)


def _build_inheritance_info(tree: ast.Module) -> tuple[dict, dict]:
"""
Build inheritance information from AST.
Expand Down Expand Up @@ -455,6 +451,20 @@ def _find_method_definition(
return None


def _find_all_inheritors(
parent: str, method_name: str, class_bases: dict, class_methods: dict
) -> set[str]:
"""Find all classes that inherit *method_name* from *parent* (not overriding it)."""
return {
cls
for cls in class_bases
if cls != parent
and method_name not in class_methods.get(cls, set())
and _find_method_definition(cls, method_name, class_bases, class_methods)
== parent
}


def remove_expected_failures(
contents: str, tests_to_remove: set[tuple[str, str]]
) -> str:
Expand Down Expand Up @@ -490,15 +500,7 @@ def remove_expected_failures(
remove_entire_method = _is_super_call_only(item)

if remove_entire_method:
first_line = item.lineno - 1
if item.decorator_list:
first_line = item.decorator_list[0].lineno - 1
if first_line > 0:
prev_line = lines[first_line - 1].strip()
if prev_line.startswith("#") and COMMENT in prev_line:
first_line -= 1
for i in range(first_line, item.end_lineno):
lines_to_remove.add(i)
lines_to_remove.update(_method_removal_range(item, lines))
else:
for dec in item.decorator_list:
dec_line = dec.lineno - 1
Expand Down Expand Up @@ -662,13 +664,7 @@ def strip_reasonless_expected_failures(
# exists only to apply the decorator; without it
# the override is pointless and blocks parent
# consolidation)
first_line = item.decorator_list[0].lineno - 1
if first_line > 0:
prev = lines[first_line - 1].strip()
if prev.startswith("#") and COMMENT in prev:
first_line -= 1
for i in range(first_line, item.end_lineno):
lines_to_remove.add(i)
lines_to_remove.update(_method_removal_range(item, lines))
else:
lines_to_remove.add(dec_line)

Expand Down
Loading
Loading