11import asyncio
2+ import contextvars
23import inspect
4+ import warnings
35
46from .case import TestCase
57
@@ -32,8 +34,8 @@ class IsolatedAsyncioTestCase(TestCase):
3234
3335 def __init__ (self , methodName = 'runTest' ):
3436 super ().__init__ (methodName )
35- self ._asyncioTestLoop = None
36- self ._asyncioCallsQueue = None
37+ self ._asyncioRunner = None
38+ self ._asyncioTestContext = contextvars . copy_context ()
3739
3840 async def asyncSetUp (self ):
3941 pass
@@ -56,115 +58,85 @@ def addAsyncCleanup(self, func, /, *args, **kwargs):
5658 # 3. Regular "def func()" that returns awaitable object
5759 self .addCleanup (* (func , * args ), ** kwargs )
5860
61+ async def enterAsyncContext (self , cm ):
62+ """Enters the supplied asynchronous context manager.
63+
64+ If successful, also adds its __aexit__ method as a cleanup
65+ function and returns the result of the __aenter__ method.
66+ """
67+ # We look up the special methods on the type to match the with
68+ # statement.
69+ cls = type (cm )
70+ try :
71+ enter = cls .__aenter__
72+ exit = cls .__aexit__
73+ except AttributeError :
74+ raise TypeError (f"'{ cls .__module__ } .{ cls .__qualname__ } ' object does "
75+ f"not support the asynchronous context manager protocol"
76+ ) from None
77+ result = await enter (cm )
78+ self .addAsyncCleanup (exit , cm , None , None , None )
79+ return result
80+
5981 def _callSetUp (self ):
60- self .setUp ()
82+ # Force loop to be initialized and set as the current loop
83+ # so that setUp functions can use get_event_loop() and get the
84+ # correct loop instance.
85+ self ._asyncioRunner .get_loop ()
86+ self ._asyncioTestContext .run (self .setUp )
6187 self ._callAsync (self .asyncSetUp )
6288
6389 def _callTestMethod (self , method ):
64- self ._callMaybeAsync (method )
90+ if self ._callMaybeAsync (method ) is not None :
91+ warnings .warn (f'It is deprecated to return a value that is not None from a '
92+ f'test case ({ method } )' , DeprecationWarning , stacklevel = 4 )
6593
6694 def _callTearDown (self ):
6795 self ._callAsync (self .asyncTearDown )
68- self .tearDown ( )
96+ self ._asyncioTestContext . run ( self . tearDown )
6997
7098 def _callCleanup (self , function , * args , ** kwargs ):
7199 self ._callMaybeAsync (function , * args , ** kwargs )
72100
73101 def _callAsync (self , func , / , * args , ** kwargs ):
74- assert self ._asyncioTestLoop is not None , 'asyncio test loop is not initialized'
75- ret = func ( * args , ** kwargs )
76- assert inspect . isawaitable ( ret ), f' { func !r } returned non-awaitable'
77- fut = self . _asyncioTestLoop . create_future ()
78- self ._asyncioCallsQueue . put_nowait (( fut , ret ))
79- return self . _asyncioTestLoop . run_until_complete ( fut )
102+ assert self ._asyncioRunner is not None , 'asyncio runner is not initialized'
103+ assert inspect . iscoroutinefunction ( func ), f' { func !r } is not an async function'
104+ return self . _asyncioRunner . run (
105+ func ( * args , ** kwargs ),
106+ context = self ._asyncioTestContext
107+ )
80108
81109 def _callMaybeAsync (self , func , / , * args , ** kwargs ):
82- assert self ._asyncioTestLoop is not None , 'asyncio test loop is not initialized'
83- ret = func ( * args , ** kwargs )
84- if inspect . isawaitable ( ret ):
85- fut = self . _asyncioTestLoop . create_future ()
86- self ._asyncioCallsQueue . put_nowait (( fut , ret ))
87- return self . _asyncioTestLoop . run_until_complete ( fut )
110+ assert self ._asyncioRunner is not None , 'asyncio runner is not initialized'
111+ if inspect . iscoroutinefunction ( func ):
112+ return self . _asyncioRunner . run (
113+ func ( * args , ** kwargs ),
114+ context = self ._asyncioTestContext ,
115+ )
88116 else :
89- return ret
90-
91- async def _asyncioLoopRunner (self , fut ):
92- self ._asyncioCallsQueue = queue = asyncio .Queue ()
93- fut .set_result (None )
94- while True :
95- query = await queue .get ()
96- queue .task_done ()
97- if query is None :
98- return
99- fut , awaitable = query
100- try :
101- ret = await awaitable
102- if not fut .cancelled ():
103- fut .set_result (ret )
104- except (SystemExit , KeyboardInterrupt ):
105- raise
106- except (BaseException , asyncio .CancelledError ) as ex :
107- if not fut .cancelled ():
108- fut .set_exception (ex )
109-
110- def _setupAsyncioLoop (self ):
111- assert self ._asyncioTestLoop is None , 'asyncio test loop already initialized'
112- loop = asyncio .new_event_loop ()
113- asyncio .set_event_loop (loop )
114- loop .set_debug (True )
115- self ._asyncioTestLoop = loop
116- fut = loop .create_future ()
117- self ._asyncioCallsTask = loop .create_task (self ._asyncioLoopRunner (fut ))
118- loop .run_until_complete (fut )
119-
120- def _tearDownAsyncioLoop (self ):
121- assert self ._asyncioTestLoop is not None , 'asyncio test loop is not initialized'
122- loop = self ._asyncioTestLoop
123- self ._asyncioTestLoop = None
124- self ._asyncioCallsQueue .put_nowait (None )
125- loop .run_until_complete (self ._asyncioCallsQueue .join ())
117+ return self ._asyncioTestContext .run (func , * args , ** kwargs )
126118
127- try :
128- # cancel all tasks
129- to_cancel = asyncio .all_tasks (loop )
130- if not to_cancel :
131- return
132-
133- for task in to_cancel :
134- task .cancel ()
135-
136- loop .run_until_complete (
137- asyncio .gather (* to_cancel , return_exceptions = True ))
138-
139- for task in to_cancel :
140- if task .cancelled ():
141- continue
142- if task .exception () is not None :
143- loop .call_exception_handler ({
144- 'message' : 'unhandled exception during test shutdown' ,
145- 'exception' : task .exception (),
146- 'task' : task ,
147- })
148- # shutdown asyncgens
149- loop .run_until_complete (loop .shutdown_asyncgens ())
150- finally :
151- # Prevent our executor environment from leaking to future tests.
152- loop .run_until_complete (loop .shutdown_default_executor ())
153- asyncio .set_event_loop (None )
154- loop .close ()
119+ def _setupAsyncioRunner (self ):
120+ assert self ._asyncioRunner is None , 'asyncio runner is already initialized'
121+ runner = asyncio .Runner (debug = True )
122+ self ._asyncioRunner = runner
123+
124+ def _tearDownAsyncioRunner (self ):
125+ runner = self ._asyncioRunner
126+ runner .close ()
155127
156128 def run (self , result = None ):
157- self ._setupAsyncioLoop ()
129+ self ._setupAsyncioRunner ()
158130 try :
159131 return super ().run (result )
160132 finally :
161- self ._tearDownAsyncioLoop ()
133+ self ._tearDownAsyncioRunner ()
162134
163135 def debug (self ):
164- self ._setupAsyncioLoop ()
136+ self ._setupAsyncioRunner ()
165137 super ().debug ()
166- self ._tearDownAsyncioLoop ()
138+ self ._tearDownAsyncioRunner ()
167139
168140 def __del__ (self ):
169- if self ._asyncioTestLoop is not None :
170- self ._tearDownAsyncioLoop ()
141+ if self ._asyncioRunner is not None :
142+ self ._tearDownAsyncioRunner ()
0 commit comments