@@ -249,9 +249,19 @@ def _iter_patch_lines(
249249 cache = {}
250250 # Build per-class set of async method names (for Phase 2 to generate correct override)
251251 async_methods : dict [str , set [str ]] = {}
252+ # Track class bases for inherited async method lookup
253+ class_bases : dict [str , list [str ]] = {}
254+ all_classes = {
255+ node .name for node in tree .body if isinstance (node , ast .ClassDef )
256+ }
252257 for node in tree .body :
253258 if isinstance (node , ast .ClassDef ):
254259 cache [node .name ] = node .end_lineno
260+ class_bases [node .name ] = [
261+ base .id
262+ for base in node .bases
263+ if isinstance (base , ast .Name ) and base .id in all_classes
264+ ]
255265 cls_async : set [str ] = set ()
256266 for item in node .body :
257267 if isinstance (item , ast .AsyncFunctionDef ):
@@ -282,7 +292,19 @@ def _iter_patch_lines(
282292
283293 for test_name , specs in tests .items ():
284294 decorators = "\n " .join (spec .as_decorator () for spec in specs )
285- is_async = test_name in async_methods .get (cls_name , set ())
295+ # Check current class and ancestors for async method
296+ is_async = False
297+ queue = [cls_name ]
298+ visited : set [str ] = set ()
299+ while queue :
300+ cur = queue .pop (0 )
301+ if cur in visited :
302+ continue
303+ visited .add (cur )
304+ if test_name in async_methods .get (cur , set ()):
305+ is_async = True
306+ break
307+ queue .extend (class_bases .get (cur , []))
286308 if is_async :
287309 patch_lines = f"""
288310{ decorators }
0 commit comments