Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion Lib/asyncio/base_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,12 @@ def create_task(self, coro, *, name=None, context=None):

tasks._set_task_name(task, name)

return task
try:
return task
finally:
# gh-128552: prevent a refcycle of
# task.exception().__traceback__->BaseEventLoop.create_task->task
del task

def set_task_factory(self, factory):
"""Set a task factory that will be used by loop.create_task().
Expand Down
7 changes: 6 additions & 1 deletion Lib/asyncio/taskgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,12 @@ def create_task(self, coro, *, name=None, context=None):
else:
self._tasks.add(task)
task.add_done_callback(self._on_task_done)
return task
try:
return task
finally:
# gh-128552: prevent a refcycle of
# task.exception().__traceback__->TaskGroup.create_task->task
del task

# Since Python 3.8 Tasks propagate all exceptions correctly,
# except for KeyboardInterrupt and SystemExit which are
Expand Down
150 changes: 147 additions & 3 deletions Lib/test/test_asyncio/test_taskgroups.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Adapted with permission from the EdgeDB project;
# license: PSFL.

import weakref
import sys
import gc
import asyncio
import contextvars
Expand All @@ -27,7 +29,25 @@ def get_error_types(eg):
return {type(exc) for exc in eg.exceptions}


class TestTaskGroup(unittest.IsolatedAsyncioTestCase):
def set_gc_state(enabled):
was_enabled = gc.isenabled()
if enabled:
gc.enable()
else:
gc.disable()
return was_enabled


@contextlib.contextmanager
def disable_gc():
was_enabled = set_gc_state(enabled=False)
try:
yield
finally:
set_gc_state(enabled=was_enabled)


class BaseTestTaskGroup:

async def test_taskgroup_01(self):

Expand Down Expand Up @@ -820,8 +840,82 @@ async def test_taskgroup_without_parent_task(self):
coro = asyncio.sleep(0)
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
tg.create_task(coro)
# We still have to await coro to avoid a warning
await coro

Comment thread
graingert marked this conversation as resolved.
Outdated
async def test_coro_closed_when_tg_closed(self):
async def run_coro_after_tg_closes():
async with taskgroups.TaskGroup() as tg:
pass
coro = asyncio.sleep(0)
with self.assertRaisesRegex(RuntimeError, "is finished"):
tg.create_task(coro)

await run_coro_after_tg_closes()

async def test_cancelling_level_preserved(self):
async def raise_after(t, e):
await asyncio.sleep(t)
raise e()

try:
async with asyncio.TaskGroup() as tg:
tg.create_task(raise_after(0.0, RuntimeError))
except* RuntimeError:
pass
self.assertEqual(asyncio.current_task().cancelling(), 0)

async def test_nested_groups_both_cancelled(self):
async def raise_after(t, e):
await asyncio.sleep(t)
raise e()

try:
async with asyncio.TaskGroup() as outer_tg:
try:
async with asyncio.TaskGroup() as inner_tg:
inner_tg.create_task(raise_after(0, RuntimeError))
outer_tg.create_task(raise_after(0, ValueError))
except* RuntimeError:
pass
else:
self.fail("RuntimeError not raised")
self.assertEqual(asyncio.current_task().cancelling(), 1)
except* ValueError:
pass
else:
self.fail("ValueError not raised")
self.assertEqual(asyncio.current_task().cancelling(), 0)

async def test_error_and_cancel(self):
event = asyncio.Event()

async def raise_error():
event.set()
await asyncio.sleep(0)
raise RuntimeError()

async def inner():
try:
async with taskgroups.TaskGroup() as tg:
tg.create_task(raise_error())
await asyncio.sleep(1)
self.fail("Sleep in group should have been cancelled")
except* RuntimeError:
self.assertEqual(asyncio.current_task().cancelling(), 1)
self.assertEqual(asyncio.current_task().cancelling(), 1)
await asyncio.sleep(1)
self.fail("Sleep after group should have been cancelled")

async def outer():
t = asyncio.create_task(inner())
await event.wait()
self.assertEqual(t.cancelling(), 0)
t.cancel()
self.assertEqual(t.cancelling(), 1)
with self.assertRaises(asyncio.CancelledError):
await t
self.assertTrue(t.cancelled())

await outer()

async def test_exception_refcycles_direct(self):
"""Test that TaskGroup doesn't keep a reference to the raised ExceptionGroup"""
Expand Down Expand Up @@ -880,6 +974,30 @@ async def coro_fn():
self.assertIsInstance(exc, _Done)
self.assertListEqual(gc.get_referrers(exc), [])


async def test_exception_refcycles_parent_task_wr(self):
"""Test that TaskGroup deletes self._parent_task and create_task() deletes task"""
tg = asyncio.TaskGroup()
exc = None

class _Done(Exception):
pass

async def coro_fn():
async with tg:
raise _Done

with disable_gc():
try:
async with asyncio.TaskGroup() as tg2:
task_wr = weakref.ref(tg2.create_task(coro_fn()))
except* _Done as excs:
Comment thread
graingert marked this conversation as resolved.
exc = excs.exceptions[0].exceptions[0]

self.assertIsNone(task_wr())
self.assertIsInstance(exc, _Done)
self.assertListEqual(gc.get_referrers(exc), [])

async def test_exception_refcycles_propagate_cancellation_error(self):
"""Test that TaskGroup deletes propagate_cancellation_error"""
tg = asyncio.TaskGroup()
Expand Down Expand Up @@ -912,6 +1030,32 @@ class MyKeyboardInterrupt(KeyboardInterrupt):
self.assertIsNotNone(exc)
self.assertListEqual(gc.get_referrers(exc), [])

if sys.platform == "win32":
EventLoop = asyncio.ProactorEventLoop
else:
EventLoop = asyncio.SelectorEventLoop


class IsolatedAsyncioTestCase(unittest.IsolatedAsyncioTestCase):
loop_factory = None

def _setupAsyncioRunner(self):
assert self._asyncioRunner is None, 'asyncio runner is already initialized'
runner = asyncio.Runner(debug=True, loop_factory=self.loop_factory)
self._asyncioRunner = runner


class TestTaskGroup(BaseTestTaskGroup, IsolatedAsyncioTestCase):
loop_factory = EventLoop


class TestEagerTaskTaskGroup(BaseTestTaskGroup, IsolatedAsyncioTestCase):
@staticmethod
def loop_factory():
loop = EventLoop()
loop.set_task_factory(asyncio.eager_task_factory)
return loop


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix cyclic garbage introduced by :meth:`asyncio.loop.create_task` and :meth:`asyncio.TaskGroup.create_task` holding a reference to the created task if it is eager.