Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix race in test (two threads need two ready markers!)
  • Loading branch information
ambv committed May 25, 2025
commit eb61f385c072bfc56f225395c42a53f10f72f5bf
9 changes: 5 additions & 4 deletions Lib/asyncio/tools.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
"""Tools to analyze tasks running in asyncio programs."""

from dataclasses import dataclass
from collections import defaultdict
from itertools import count
from enum import Enum
import sys
from _remote_debugging import RemoteUnwinder

def get_all_awaited_by(pid):
unwinder = RemoteUnwinder(pid)
return unwinder.get_all_awaited_by()

class NodeType(Enum):
COROUTINE = 1
Expand Down Expand Up @@ -121,6 +117,11 @@ def dfs(v):


# ─── PRINT TREE FUNCTION ───────────────────────────────────────
def get_all_awaited_by(pid):
unwinder = RemoteUnwinder(pid)
return unwinder.get_all_awaited_by()


def build_async_tree(result, task_emoji="(T)", cor_emoji=""):
"""
Build a list of strings for pretty-print an async call tree.
Expand Down
23 changes: 16 additions & 7 deletions Lib/test/test_external_inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,24 @@ def _make_test_script(script_dir, script_basename, source):
"Test only runs on Linux, Windows and MacOS",
)


def get_stack_trace(pid):
unwinder = RemoteUnwinder(pid, all_threads=True)
return unwinder.get_stack_trace()


def get_async_stack_trace(pid):
unwinder = RemoteUnwinder(pid)
return unwinder.get_async_stack_trace()


def get_all_awaited_by(pid):
unwinder = RemoteUnwinder(pid)
return unwinder.get_all_awaited_by()


class TestGetStackTrace(unittest.TestCase):
maxDiff = None

@skip_if_not_supported
@unittest.skipIf(
Expand All @@ -65,13 +70,16 @@ def bar():
for x in range(100):
if x == 50:
baz()

def baz():
foo()

def foo():
sock.sendall(b"ready"); time.sleep(10_000) # same line number
sock.sendall(b"ready:thread\\n"); time.sleep(10_000) # same line number

t = threading.Thread(target=bar); t.start(); t.join()
t = threading.Thread(target=bar)
t.start()
sock.sendall(b"ready:main\\n"); t.join() # same line number
"""
)
stack_trace = None
Expand All @@ -92,8 +100,9 @@ def foo():
p = subprocess.Popen([sys.executable, script_name])
client_socket, _ = server_socket.accept()
server_socket.close()
response = client_socket.recv(1024)
self.assertEqual(response, b"ready")
response = b""
while b"ready:main" not in response or b"ready:thread" not in response:
response += client_socket.recv(1024)
stack_trace = get_stack_trace(p.pid)
except PermissionError:
self.skipTest("Insufficient permissions to read the stack trace")
Expand All @@ -105,14 +114,14 @@ def foo():
p.wait(timeout=SHORT_TIMEOUT)

thread_expected_stack_trace = [
("foo", script_name, 14),
("baz", script_name, 11),
("foo", script_name, 15),
("baz", script_name, 12),
("bar", script_name, 9),
('Thread.run', threading.__file__, ANY)
]
main_thread_stack_trace = [
(ANY, threading.__file__, ANY),
("<module>", script_name, 16),
("<module>", script_name, 19),
]
self.assertEqual(stack_trace, [
(ANY, thread_expected_stack_trace),
Expand Down
Loading