Skip to content

Commit 922b644

Browse files
committed
[update_lib] fix async detection
1 parent 528d657 commit 922b644

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

scripts/update_lib/patch_spec.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,9 +249,19 @@ def _iter_patch_lines(
249249
cache = {}
250250
# Build per-class set of async method names (for Phase 2 to generate correct override)
251251
async_methods: dict[str, set[str]] = {}
252+
# Track class bases for inherited async method lookup
253+
class_bases: dict[str, list[str]] = {}
254+
all_classes = {
255+
node.name for node in tree.body if isinstance(node, ast.ClassDef)
256+
}
252257
for node in tree.body:
253258
if isinstance(node, ast.ClassDef):
254259
cache[node.name] = node.end_lineno
260+
class_bases[node.name] = [
261+
base.id
262+
for base in node.bases
263+
if isinstance(base, ast.Name) and base.id in all_classes
264+
]
255265
cls_async: set[str] = set()
256266
for item in node.body:
257267
if isinstance(item, ast.AsyncFunctionDef):
@@ -282,7 +292,19 @@ def _iter_patch_lines(
282292

283293
for test_name, specs in tests.items():
284294
decorators = "\n".join(spec.as_decorator() for spec in specs)
285-
is_async = test_name in async_methods.get(cls_name, set())
295+
# Check current class and ancestors for async method
296+
is_async = False
297+
queue = [cls_name]
298+
visited: set[str] = set()
299+
while queue:
300+
cur = queue.pop(0)
301+
if cur in visited:
302+
continue
303+
visited.add(cur)
304+
if test_name in async_methods.get(cur, set()):
305+
is_async = True
306+
break
307+
queue.extend(class_bases.get(cur, []))
286308
if is_async:
287309
patch_lines = f"""
288310
{decorators}

0 commit comments

Comments
 (0)