1010# See C source code for _functools credits/copyright
1111
1212__all__ = ['update_wrapper' , 'wraps' , 'WRAPPER_ASSIGNMENTS' , 'WRAPPER_UPDATES' ,
13- 'total_ordering' , 'cmp_to_key ' , 'lru_cache ' , 'reduce ' , 'partial ' ,
14- 'partialmethod' , 'singledispatch' , 'singledispatchmethod' ,
15- " cached_property" ]
13+ 'total_ordering' , 'cache ' , 'cmp_to_key ' , 'lru_cache ' , 'reduce ' ,
14+ 'partial' , ' partialmethod' , 'singledispatch' , 'singledispatchmethod' ,
15+ ' cached_property' ]
1616
1717from abc import get_cache_token
1818from collections import namedtuple
@@ -86,82 +86,86 @@ def wraps(wrapped,
8686# infinite recursion that could occur when the operator dispatch logic
8787# detects a NotImplemented result and then calls a reflected method.
8888
89- def _gt_from_lt (self , other , NotImplemented = NotImplemented ):
89+ def _gt_from_lt (self , other ):
9090 'Return a > b. Computed by @total_ordering from (not a < b) and (a != b).'
91- op_result = self .__lt__ (other )
91+ op_result = type ( self ) .__lt__ (self , other )
9292 if op_result is NotImplemented :
9393 return op_result
9494 return not op_result and self != other
9595
96- def _le_from_lt (self , other , NotImplemented = NotImplemented ):
96+ def _le_from_lt (self , other ):
9797 'Return a <= b. Computed by @total_ordering from (a < b) or (a == b).'
98- op_result = self .__lt__ (other )
98+ op_result = type (self ).__lt__ (self , other )
99+ if op_result is NotImplemented :
100+ return op_result
99101 return op_result or self == other
100102
101- def _ge_from_lt (self , other , NotImplemented = NotImplemented ):
103+ def _ge_from_lt (self , other ):
102104 'Return a >= b. Computed by @total_ordering from (not a < b).'
103- op_result = self .__lt__ (other )
105+ op_result = type ( self ) .__lt__ (self , other )
104106 if op_result is NotImplemented :
105107 return op_result
106108 return not op_result
107109
108- def _ge_from_le (self , other , NotImplemented = NotImplemented ):
110+ def _ge_from_le (self , other ):
109111 'Return a >= b. Computed by @total_ordering from (not a <= b) or (a == b).'
110- op_result = self .__le__ (other )
112+ op_result = type ( self ) .__le__ (self , other )
111113 if op_result is NotImplemented :
112114 return op_result
113115 return not op_result or self == other
114116
115- def _lt_from_le (self , other , NotImplemented = NotImplemented ):
117+ def _lt_from_le (self , other ):
116118 'Return a < b. Computed by @total_ordering from (a <= b) and (a != b).'
117- op_result = self .__le__ (other )
119+ op_result = type ( self ) .__le__ (self , other )
118120 if op_result is NotImplemented :
119121 return op_result
120122 return op_result and self != other
121123
122- def _gt_from_le (self , other , NotImplemented = NotImplemented ):
124+ def _gt_from_le (self , other ):
123125 'Return a > b. Computed by @total_ordering from (not a <= b).'
124- op_result = self .__le__ (other )
126+ op_result = type ( self ) .__le__ (self , other )
125127 if op_result is NotImplemented :
126128 return op_result
127129 return not op_result
128130
129- def _lt_from_gt (self , other , NotImplemented = NotImplemented ):
131+ def _lt_from_gt (self , other ):
130132 'Return a < b. Computed by @total_ordering from (not a > b) and (a != b).'
131- op_result = self .__gt__ (other )
133+ op_result = type ( self ) .__gt__ (self , other )
132134 if op_result is NotImplemented :
133135 return op_result
134136 return not op_result and self != other
135137
136- def _ge_from_gt (self , other , NotImplemented = NotImplemented ):
138+ def _ge_from_gt (self , other ):
137139 'Return a >= b. Computed by @total_ordering from (a > b) or (a == b).'
138- op_result = self .__gt__ (other )
140+ op_result = type (self ).__gt__ (self , other )
141+ if op_result is NotImplemented :
142+ return op_result
139143 return op_result or self == other
140144
141- def _le_from_gt (self , other , NotImplemented = NotImplemented ):
145+ def _le_from_gt (self , other ):
142146 'Return a <= b. Computed by @total_ordering from (not a > b).'
143- op_result = self .__gt__ (other )
147+ op_result = type ( self ) .__gt__ (self , other )
144148 if op_result is NotImplemented :
145149 return op_result
146150 return not op_result
147151
148- def _le_from_ge (self , other , NotImplemented = NotImplemented ):
152+ def _le_from_ge (self , other ):
149153 'Return a <= b. Computed by @total_ordering from (not a >= b) or (a == b).'
150- op_result = self .__ge__ (other )
154+ op_result = type ( self ) .__ge__ (self , other )
151155 if op_result is NotImplemented :
152156 return op_result
153157 return not op_result or self == other
154158
155- def _gt_from_ge (self , other , NotImplemented = NotImplemented ):
159+ def _gt_from_ge (self , other ):
156160 'Return a > b. Computed by @total_ordering from (a >= b) and (a != b).'
157- op_result = self .__ge__ (other )
161+ op_result = type ( self ) .__ge__ (self , other )
158162 if op_result is NotImplemented :
159163 return op_result
160164 return op_result and self != other
161165
162- def _lt_from_ge (self , other , NotImplemented = NotImplemented ):
166+ def _lt_from_ge (self , other ):
163167 'Return a < b. Computed by @total_ordering from (not a >= b).'
164- op_result = self .__ge__ (other )
168+ op_result = type ( self ) .__ge__ (self , other )
165169 if op_result is NotImplemented :
166170 return op_result
167171 return not op_result
@@ -232,14 +236,14 @@ def __ge__(self, other):
232236
233237def reduce (function , sequence , initial = _initial_missing ):
234238 """
235- reduce(function, sequence [, initial]) -> value
239+ reduce(function, iterable [, initial]) -> value
236240
237- Apply a function of two arguments cumulatively to the items of a sequence,
238- from left to right, so as to reduce the sequence to a single value.
239- For example, reduce(lambda x, y: x+y, [1, 2, 3, 4, 5]) calculates
241+ Apply a function of two arguments cumulatively to the items of a sequence
242+ or iterable, from left to right, so as to reduce the iterable to a single
243+ value. For example, reduce(lambda x, y: x+y, [1, 2, 3, 4, 5]) calculates
240244 ((((1+2)+3)+4)+5). If initial is present, it is placed before the items
241- of the sequence in the calculation, and serves as a default when the
242- sequence is empty.
245+ of the iterable in the calculation, and serves as a default when the
246+ iterable is empty.
243247 """
244248
245249 it = iter (sequence )
@@ -248,7 +252,8 @@ def reduce(function, sequence, initial=_initial_missing):
248252 try :
249253 value = next (it )
250254 except StopIteration :
251- raise TypeError ("reduce() of empty sequence with no initial value" ) from None
255+ raise TypeError (
256+ "reduce() of empty iterable with no initial value" ) from None
252257 else :
253258 value = initial
254259
@@ -347,23 +352,7 @@ class partialmethod(object):
347352 callables as instance methods.
348353 """
349354
350- def __init__ (* args , ** keywords ):
351- if len (args ) >= 2 :
352- self , func , * args = args
353- elif not args :
354- raise TypeError ("descriptor '__init__' of partialmethod "
355- "needs an argument" )
356- elif 'func' in keywords :
357- func = keywords .pop ('func' )
358- self , * args = args
359- import warnings
360- warnings .warn ("Passing 'func' as keyword argument is deprecated" ,
361- DeprecationWarning , stacklevel = 2 )
362- else :
363- raise TypeError ("type 'partialmethod' takes at least one argument, "
364- "got %d" % (len (args )- 1 ))
365- args = tuple (args )
366-
355+ def __init__ (self , func , / , * args , ** keywords ):
367356 if not callable (func ) and not hasattr (func , "__get__" ):
368357 raise TypeError ("{!r} is not callable or a descriptor"
369358 .format (func ))
@@ -381,7 +370,6 @@ def __init__(*args, **keywords):
381370 self .func = func
382371 self .args = args
383372 self .keywords = keywords
384- __init__ .__text_signature__ = '($self, func, /, *args, **keywords)'
385373
386374 def __repr__ (self ):
387375 args = ", " .join (map (repr , self .args ))
@@ -427,6 +415,7 @@ def __isabstractmethod__(self):
427415
428416 __class_getitem__ = classmethod (GenericAlias )
429417
418+
430419# Helper functions
431420
432421def _unwrap_partial (func ):
@@ -503,7 +492,7 @@ def lru_cache(maxsize=128, typed=False):
503492 with f.cache_info(). Clear the cache and statistics with f.cache_clear().
504493 Access the underlying function with f.__wrapped__.
505494
506- See: http ://en.wikipedia.org/wiki/Cache_replacement_policies#Least_recently_used_(LRU)
495+ See: https ://en.wikipedia.org/wiki/Cache_replacement_policies#Least_recently_used_(LRU)
507496
508497 """
509498
@@ -520,13 +509,15 @@ def lru_cache(maxsize=128, typed=False):
520509 # The user_function was passed in directly via the maxsize argument
521510 user_function , maxsize = maxsize , 128
522511 wrapper = _lru_cache_wrapper (user_function , maxsize , typed , _CacheInfo )
512+ wrapper .cache_parameters = lambda : {'maxsize' : maxsize , 'typed' : typed }
523513 return update_wrapper (wrapper , user_function )
524514 elif maxsize is not None :
525515 raise TypeError (
526516 'Expected first argument to be an integer, a callable, or None' )
527517
528518 def decorating_function (user_function ):
529519 wrapper = _lru_cache_wrapper (user_function , maxsize , typed , _CacheInfo )
520+ wrapper .cache_parameters = lambda : {'maxsize' : maxsize , 'typed' : typed }
530521 return update_wrapper (wrapper , user_function )
531522
532523 return decorating_function
@@ -653,14 +644,23 @@ def cache_clear():
653644 pass
654645
655646
647+ ################################################################################
648+ ### cache -- simplified access to the infinity cache
649+ ################################################################################
650+
651+ def cache (user_function , / ):
652+ 'Simple lightweight unbounded cache. Sometimes called "memoize".'
653+ return lru_cache (maxsize = None )(user_function )
654+
655+
656656################################################################################
657657### singledispatch() - single-dispatch generic function decorator
658658################################################################################
659659
660660def _c3_merge (sequences ):
661661 """Merges MROs in *sequences* to a single MRO using the C3 algorithm.
662662
663- Adapted from http ://www.python.org/download/releases/2.3/mro/.
663+ Adapted from https ://www.python.org/download/releases/2.3/mro/.
664664
665665 """
666666 result = []
@@ -740,6 +740,7 @@ def _compose_mro(cls, types):
740740 # Remove entries which are already present in the __mro__ or unrelated.
741741 def is_related (typ ):
742742 return (typ not in bases and hasattr (typ , '__mro__' )
743+ and not isinstance (typ , GenericAlias )
743744 and issubclass (cls , typ ))
744745 types = [n for n in types if is_related (n )]
745746 # Remove entries which are strict bases of other entries (they will end up
@@ -837,16 +838,33 @@ def dispatch(cls):
837838 dispatch_cache [cls ] = impl
838839 return impl
839840
841+ def _is_union_type (cls ):
842+ from typing import get_origin , Union
843+ return get_origin (cls ) in {Union , types .UnionType }
844+
845+ def _is_valid_dispatch_type (cls ):
846+ if isinstance (cls , type ):
847+ return True
848+ from typing import get_args
849+ return (_is_union_type (cls ) and
850+ all (isinstance (arg , type ) for arg in get_args (cls )))
851+
840852 def register (cls , func = None ):
841853 """generic_func.register(cls, func) -> func
842854
843855 Registers a new implementation for the given *cls* on a *generic_func*.
844856
845857 """
846858 nonlocal cache_token
847- if func is None :
848- if isinstance ( cls , type ) :
859+ if _is_valid_dispatch_type ( cls ) :
860+ if func is None :
849861 return lambda f : register (cls , f )
862+ else :
863+ if func is not None :
864+ raise TypeError (
865+ f"Invalid first argument to `register()`. "
866+ f"{ cls !r} is not a class or union type."
867+ )
850868 ann = getattr (cls , '__annotations__' , {})
851869 if not ann :
852870 raise TypeError (
@@ -859,12 +877,25 @@ def register(cls, func=None):
859877 # only import typing if annotation parsing is necessary
860878 from typing import get_type_hints
861879 argname , cls = next (iter (get_type_hints (func ).items ()))
862- if not isinstance (cls , type ):
863- raise TypeError (
864- f"Invalid annotation for { argname !r} . "
865- f"{ cls !r} is not a class."
866- )
867- registry [cls ] = func
880+ if not _is_valid_dispatch_type (cls ):
881+ if _is_union_type (cls ):
882+ raise TypeError (
883+ f"Invalid annotation for { argname !r} . "
884+ f"{ cls !r} not all arguments are classes."
885+ )
886+ else :
887+ raise TypeError (
888+ f"Invalid annotation for { argname !r} . "
889+ f"{ cls !r} is not a class."
890+ )
891+
892+ if _is_union_type (cls ):
893+ from typing import get_args
894+
895+ for arg in get_args (cls ):
896+ registry [arg ] = func
897+ else :
898+ registry [cls ] = func
868899 if cache_token is None and hasattr (cls , '__abstractmethods__' ):
869900 cache_token = get_cache_token ()
870901 dispatch_cache .clear ()
0 commit comments