@@ -318,6 +318,82 @@ async def recursive():
318318 self .assertEqual (ncols , 10 )
319319 self .assertEqual (depth , 0 )
320320
321+ @_async_test
322+ async def test_decorator (self ):
323+ entered = False
324+
325+ @asynccontextmanager
326+ async def context ():
327+ nonlocal entered
328+ entered = True
329+ yield
330+ entered = False
331+
332+ @context ()
333+ async def test ():
334+ self .assertTrue (entered )
335+
336+ self .assertFalse (entered )
337+ await test ()
338+ self .assertFalse (entered )
339+
340+ @_async_test
341+ async def test_decorator_with_exception (self ):
342+ entered = False
343+
344+ @asynccontextmanager
345+ async def context ():
346+ nonlocal entered
347+ try :
348+ entered = True
349+ yield
350+ finally :
351+ entered = False
352+
353+ @context ()
354+ async def test ():
355+ self .assertTrue (entered )
356+ raise NameError ('foo' )
357+
358+ self .assertFalse (entered )
359+ with self .assertRaisesRegex (NameError , 'foo' ):
360+ await test ()
361+ self .assertFalse (entered )
362+
363+ @_async_test
364+ async def test_decorating_method (self ):
365+
366+ @asynccontextmanager
367+ async def context ():
368+ yield
369+
370+
371+ class Test (object ):
372+
373+ @context ()
374+ async def method (self , a , b , c = None ):
375+ self .a = a
376+ self .b = b
377+ self .c = c
378+
379+ # these tests are for argument passing when used as a decorator
380+ test = Test ()
381+ await test .method (1 , 2 )
382+ self .assertEqual (test .a , 1 )
383+ self .assertEqual (test .b , 2 )
384+ self .assertEqual (test .c , None )
385+
386+ test = Test ()
387+ await test .method ('a' , 'b' , 'c' )
388+ self .assertEqual (test .a , 'a' )
389+ self .assertEqual (test .b , 'b' )
390+ self .assertEqual (test .c , 'c' )
391+
392+ test = Test ()
393+ await test .method (a = 1 , b = 2 )
394+ self .assertEqual (test .a , 1 )
395+ self .assertEqual (test .b , 2 )
396+
321397
322398class AclosingTestCase (unittest .TestCase ):
323399
0 commit comments