Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix tests
  • Loading branch information
kumaraditya303 committed Jan 9, 2025
commit 7984147606696f10da83722255fc18778d3043ca
53 changes: 11 additions & 42 deletions Lib/test/test_asyncio/test_free_threading.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

threading_helper.requires_working_threading(module=True)


class MyException(Exception):
pass

Expand Down Expand Up @@ -79,60 +80,28 @@ async def main():
loop.set_task_factory(self.factory)
r.run(main())

def test_run_coroutine_threadsafe_exception_caught(self) -> None:
def test_run_coroutine_threadsafe_exception(self) -> None:
exc = MyException("test")

async def coro():
await asyncio.sleep(0.1)
await asyncio.sleep(0)
raise exc

def in_thread(loop: asyncio.AbstractEventLoop):
fut = asyncio.run_coroutine_threadsafe(coro(), loop)
self.assertEqual(fut.exception(), exc)
return exc

async def main():
loop = asyncio.get_running_loop()
tasks = []
async with asyncio.TaskGroup() as tg:
for _ in range(10):
task = tg.create_task(asyncio.to_thread(in_thread, loop))
tasks.append(task)
for task in tasks:
self.assertEqual(await task, exc)

with asyncio.Runner() as r:
loop = r.get_loop()
loop.set_task_factory(self.factory)
r.run(main())

def test_run_coroutine_threadsafe_exception_uncaught(self) -> None:
async def coro():
await asyncio.sleep(1)
raise MyException("test")

def in_thread(loop: asyncio.AbstractEventLoop):
fut = asyncio.run_coroutine_threadsafe(coro(), loop)
return fut.result()

async def main():
loop = asyncio.get_running_loop()
tasks = []
try:
async with asyncio.TaskGroup() as tg:
for _ in range(10):
task = tg.create_task(asyncio.to_thread(in_thread, loop))
tasks.append(task)
except ExceptionGroup:
for task in tasks:
try:
await task
except (MyException, asyncio.CancelledError):
pass
else:
self.fail("Task should have raised an exception")
else:
self.fail("TaskGroup should have raised an exception")
for _ in range(10):
task = loop.create_task(asyncio.to_thread(in_thread, loop))
tasks.append(task)
results = await asyncio.gather(*tasks, return_exceptions=True)

self.assertEqual(len(results), 10)
for result in results:
self.assertIs(result, exc)

with asyncio.Runner() as r:
loop = r.get_loop()
Expand Down