1+ import sys
2+ import collections .abc
13import concurrent .futures
24import contextvars
35import functools
46import gc
57import random
6- import sys
78import time
89import unittest
910import weakref
@@ -26,8 +27,7 @@ def wrapper(*args, **kwargs):
2627
2728
2829class ContextTest (unittest .TestCase ):
29- # TODO: RUSTPYTHON
30- @unittest .expectedFailure
30+ @unittest .expectedFailure # TODO: RUSTPYTHON
3131 def test_context_var_new_1 (self ):
3232 with self .assertRaisesRegex (TypeError , 'takes exactly 1' ):
3333 contextvars .ContextVar ()
@@ -63,6 +63,14 @@ def test_context_var_repr_1(self):
6363 c .reset (t )
6464 self .assertIn (' used ' , repr (t ))
6565
66+ @isolated_context
67+ def test_token_repr_1 (self ):
68+ c = contextvars .ContextVar ('a' )
69+ tok = c .set (1 )
70+ self .assertRegex (repr (tok ),
71+ r"^<Token var=<ContextVar name='a' "
72+ r"at 0x[0-9a-fA-F]+> at 0x[0-9a-fA-F]+>$" )
73+
6674 def test_context_subclassing_1 (self ):
6775 with self .assertRaisesRegex (TypeError , 'not an acceptable base type' ):
6876 class MyContextVar (contextvars .ContextVar ):
@@ -77,8 +85,7 @@ class MyContext(contextvars.Context):
7785 class MyToken (contextvars .Token ):
7886 pass
7987
80- # TODO: RUSTPYTHON
81- @unittest .expectedFailure
88+ @unittest .expectedFailure # TODO: RUSTPYTHON
8289 def test_context_new_1 (self ):
8390 with self .assertRaisesRegex (TypeError , 'any arguments' ):
8491 contextvars .Context (1 )
@@ -88,8 +95,17 @@ def test_context_new_1(self):
8895 contextvars .Context (a = 1 )
8996 contextvars .Context (** {})
9097
91- # TODO: RUSTPYTHON
92- @unittest .expectedFailure
98+ @unittest .expectedFailure # TODO: RUSTPYTHON; AssertionError: TypeError not raised
99+ def test_context_new_unhashable_str_subclass (self ):
100+ # gh-132002: it used to crash on unhashable str subtypes.
101+ class weird_str (str ):
102+ def __eq__ (self , other ):
103+ pass
104+
105+ with self .assertRaisesRegex (TypeError , 'unhashable type' ):
106+ contextvars .ContextVar (weird_str ())
107+
108+ @unittest .expectedFailure # TODO: RUSTPYTHON
93109 def test_context_typerrors_1 (self ):
94110 ctx = contextvars .Context ()
95111
@@ -104,8 +120,7 @@ def test_context_get_context_1(self):
104120 ctx = contextvars .copy_context ()
105121 self .assertIsInstance (ctx , contextvars .Context )
106122
107- # TODO: RUSTPYTHON
108- @unittest .expectedFailure
123+ @unittest .expectedFailure # TODO: RUSTPYTHON
109124 def test_context_run_1 (self ):
110125 ctx = contextvars .Context ()
111126
@@ -153,8 +168,7 @@ def func(*args, **kwargs):
153168 with self .assertRaises (ZeroDivisionError ):
154169 ctx .run (func , 1 , 2 , a = 123 )
155170
156- # TODO: RUSTPYTHON
157- @unittest .expectedFailure
171+ @unittest .expectedFailure # TODO: RUSTPYTHON
158172 @isolated_context
159173 def test_context_run_4 (self ):
160174 ctx1 = contextvars .Context ()
@@ -353,9 +367,22 @@ def ctx2_fun():
353367
354368 ctx1 .run (ctx1_fun )
355369
370+ def test_context_isinstance (self ):
371+ ctx = contextvars .Context ()
372+ self .assertIsInstance (ctx , collections .abc .Mapping )
373+ self .assertTrue (issubclass (contextvars .Context , collections .abc .Mapping ))
374+
375+ mapping_methods = (
376+ '__contains__' , '__eq__' , '__getitem__' , '__iter__' , '__len__' ,
377+ '__ne__' , 'get' , 'items' , 'keys' , 'values' ,
378+ )
379+ for name in mapping_methods :
380+ with self .subTest (name = name ):
381+ self .assertTrue (callable (getattr (ctx , name )))
382+
383+ @unittest .skipIf (sys .platform == "darwin" , "TODO: RUSTPYTHON; Flaky on Mac, self.assertEqual(cvar.get(), num + i) AssertionError: 8 != 12" )
356384 @isolated_context
357385 @threading_helper .requires_working_threading ()
358- @unittest .skipIf (sys .platform == 'darwin' , 'TODO: RUSTPYTHON; Flaky on Mac, self.assertEqual(cvar.get(), num + i) AssertionError: 8 != 12' )
359386 def test_context_threads_1 (self ):
360387 cvar = contextvars .ContextVar ('cvar' )
361388
@@ -373,6 +400,199 @@ def sub(num):
373400 tp .shutdown ()
374401 self .assertEqual (results , list (range (10 )))
375402
403+ @isolated_context
404+ @threading_helper .requires_working_threading ()
405+ def test_context_thread_inherit (self ):
406+ import threading
407+
408+ cvar = contextvars .ContextVar ('cvar' )
409+
410+ def run_context_none ():
411+ if sys .flags .thread_inherit_context :
412+ expected = 1
413+ else :
414+ expected = None
415+ self .assertEqual (cvar .get (None ), expected )
416+
417+ # By default, context is inherited based on the
418+ # sys.flags.thread_inherit_context option.
419+ cvar .set (1 )
420+ thread = threading .Thread (target = run_context_none )
421+ thread .start ()
422+ thread .join ()
423+
424+ # Passing 'None' explicitly should have same behaviour as not
425+ # passing parameter.
426+ thread = threading .Thread (target = run_context_none , context = None )
427+ thread .start ()
428+ thread .join ()
429+
430+ # An explicit Context value can also be passed
431+ custom_ctx = contextvars .Context ()
432+ custom_var = None
433+
434+ def setup_context ():
435+ nonlocal custom_var
436+ custom_var = contextvars .ContextVar ('custom' )
437+ custom_var .set (2 )
438+
439+ custom_ctx .run (setup_context )
440+
441+ def run_custom ():
442+ self .assertEqual (custom_var .get (), 2 )
443+
444+ thread = threading .Thread (target = run_custom , context = custom_ctx )
445+ thread .start ()
446+ thread .join ()
447+
448+ # You can also pass a new Context() object to start with an empty context
449+ def run_empty ():
450+ with self .assertRaises (LookupError ):
451+ cvar .get ()
452+
453+ thread = threading .Thread (target = run_empty , context = contextvars .Context ())
454+ thread .start ()
455+ thread .join ()
456+
457+ def test_token_contextmanager_with_default (self ):
458+ ctx = contextvars .Context ()
459+ c = contextvars .ContextVar ('c' , default = 42 )
460+
461+ def fun ():
462+ with c .set (36 ):
463+ self .assertEqual (c .get (), 36 )
464+
465+ self .assertEqual (c .get (), 42 )
466+
467+ ctx .run (fun )
468+
469+ def test_token_contextmanager_without_default (self ):
470+ ctx = contextvars .Context ()
471+ c = contextvars .ContextVar ('c' )
472+
473+ def fun ():
474+ with c .set (36 ):
475+ self .assertEqual (c .get (), 36 )
476+
477+ with self .assertRaisesRegex (LookupError , "<ContextVar name='c'" ):
478+ c .get ()
479+
480+ ctx .run (fun )
481+
482+ def test_token_contextmanager_on_exception (self ):
483+ ctx = contextvars .Context ()
484+ c = contextvars .ContextVar ('c' , default = 42 )
485+
486+ def fun ():
487+ with c .set (36 ):
488+ self .assertEqual (c .get (), 36 )
489+ raise ValueError ("custom exception" )
490+
491+ self .assertEqual (c .get (), 42 )
492+
493+ with self .assertRaisesRegex (ValueError , "custom exception" ):
494+ ctx .run (fun )
495+
496+ def test_token_contextmanager_reentrant (self ):
497+ ctx = contextvars .Context ()
498+ c = contextvars .ContextVar ('c' , default = 42 )
499+
500+ def fun ():
501+ token = c .set (36 )
502+ with self .assertRaisesRegex (
503+ RuntimeError ,
504+ "<Token .+ has already been used once"
505+ ):
506+ with token :
507+ with token :
508+ self .assertEqual (c .get (), 36 )
509+
510+ self .assertEqual (c .get (), 42 )
511+
512+ ctx .run (fun )
513+
514+ def test_token_contextmanager_multiple_c_set (self ):
515+ ctx = contextvars .Context ()
516+ c = contextvars .ContextVar ('c' , default = 42 )
517+
518+ def fun ():
519+ with c .set (36 ):
520+ self .assertEqual (c .get (), 36 )
521+ c .set (24 )
522+ self .assertEqual (c .get (), 24 )
523+ c .set (12 )
524+ self .assertEqual (c .get (), 12 )
525+
526+ self .assertEqual (c .get (), 42 )
527+
528+ ctx .run (fun )
529+
530+ def test_token_contextmanager_with_explicit_reset_the_same_token (self ):
531+ ctx = contextvars .Context ()
532+ c = contextvars .ContextVar ('c' , default = 42 )
533+
534+ def fun ():
535+ with self .assertRaisesRegex (
536+ RuntimeError ,
537+ "<Token .+ has already been used once"
538+ ):
539+ with c .set (36 ) as token :
540+ self .assertEqual (c .get (), 36 )
541+ c .reset (token )
542+
543+ self .assertEqual (c .get (), 42 )
544+
545+ self .assertEqual (c .get (), 42 )
546+
547+ ctx .run (fun )
548+
549+ def test_token_contextmanager_with_explicit_reset_another_token (self ):
550+ ctx = contextvars .Context ()
551+ c = contextvars .ContextVar ('c' , default = 42 )
552+
553+ def fun ():
554+ with c .set (36 ):
555+ self .assertEqual (c .get (), 36 )
556+
557+ token = c .set (24 )
558+ self .assertEqual (c .get (), 24 )
559+ c .reset (token )
560+ self .assertEqual (c .get (), 36 )
561+
562+ self .assertEqual (c .get (), 42 )
563+
564+ ctx .run (fun )
565+
566+ def test_context_eq_reentrant_contextvar_set (self ):
567+ var = contextvars .ContextVar ("v" )
568+ ctx1 = contextvars .Context ()
569+ ctx2 = contextvars .Context ()
570+
571+ class ReentrantEq :
572+ def __eq__ (self , other ):
573+ ctx1 .run (lambda : var .set (object ()))
574+ return True
575+
576+ ctx1 .run (var .set , ReentrantEq ())
577+ ctx2 .run (var .set , object ())
578+ ctx1 == ctx2
579+
580+ def test_context_eq_reentrant_contextvar_set_in_hash (self ):
581+ var = contextvars .ContextVar ("v" )
582+ ctx1 = contextvars .Context ()
583+ ctx2 = contextvars .Context ()
584+
585+ class ReentrantHash :
586+ def __hash__ (self ):
587+ ctx1 .run (lambda : var .set (object ()))
588+ return 0
589+ def __eq__ (self , other ):
590+ return isinstance (other , ReentrantHash )
591+
592+ ctx1 .run (var .set , ReentrantHash ())
593+ ctx2 .run (var .set , ReentrantHash ())
594+ ctx1 == ctx2
595+
376596
377597# HAMT Tests
378598
0 commit comments