Skip to content

Commit d906f87

Browse files
committed
Update functools lib, tests and Add singledispatch
1 parent f51b130 commit d906f87

File tree

2 files changed

+439
-72
lines changed

2 files changed

+439
-72
lines changed

Lib/functools.py

Lines changed: 93 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
# See C source code for _functools credits/copyright
1111

1212
__all__ = ['update_wrapper', 'wraps', 'WRAPPER_ASSIGNMENTS', 'WRAPPER_UPDATES',
13-
'total_ordering', 'cmp_to_key', 'lru_cache', 'reduce', 'partial',
14-
'partialmethod', 'singledispatch', 'singledispatchmethod',
15-
"cached_property"]
13+
'total_ordering', 'cache', 'cmp_to_key', 'lru_cache', 'reduce',
14+
'partial', 'partialmethod', 'singledispatch', 'singledispatchmethod',
15+
'cached_property']
1616

1717
from abc import get_cache_token
1818
from collections import namedtuple
@@ -86,82 +86,86 @@ def wraps(wrapped,
8686
# infinite recursion that could occur when the operator dispatch logic
8787
# detects a NotImplemented result and then calls a reflected method.
8888

89-
def _gt_from_lt(self, other, NotImplemented=NotImplemented):
89+
def _gt_from_lt(self, other):
9090
'Return a > b. Computed by @total_ordering from (not a < b) and (a != b).'
91-
op_result = self.__lt__(other)
91+
op_result = type(self).__lt__(self, other)
9292
if op_result is NotImplemented:
9393
return op_result
9494
return not op_result and self != other
9595

96-
def _le_from_lt(self, other, NotImplemented=NotImplemented):
96+
def _le_from_lt(self, other):
9797
'Return a <= b. Computed by @total_ordering from (a < b) or (a == b).'
98-
op_result = self.__lt__(other)
98+
op_result = type(self).__lt__(self, other)
99+
if op_result is NotImplemented:
100+
return op_result
99101
return op_result or self == other
100102

101-
def _ge_from_lt(self, other, NotImplemented=NotImplemented):
103+
def _ge_from_lt(self, other):
102104
'Return a >= b. Computed by @total_ordering from (not a < b).'
103-
op_result = self.__lt__(other)
105+
op_result = type(self).__lt__(self, other)
104106
if op_result is NotImplemented:
105107
return op_result
106108
return not op_result
107109

108-
def _ge_from_le(self, other, NotImplemented=NotImplemented):
110+
def _ge_from_le(self, other):
109111
'Return a >= b. Computed by @total_ordering from (not a <= b) or (a == b).'
110-
op_result = self.__le__(other)
112+
op_result = type(self).__le__(self, other)
111113
if op_result is NotImplemented:
112114
return op_result
113115
return not op_result or self == other
114116

115-
def _lt_from_le(self, other, NotImplemented=NotImplemented):
117+
def _lt_from_le(self, other):
116118
'Return a < b. Computed by @total_ordering from (a <= b) and (a != b).'
117-
op_result = self.__le__(other)
119+
op_result = type(self).__le__(self, other)
118120
if op_result is NotImplemented:
119121
return op_result
120122
return op_result and self != other
121123

122-
def _gt_from_le(self, other, NotImplemented=NotImplemented):
124+
def _gt_from_le(self, other):
123125
'Return a > b. Computed by @total_ordering from (not a <= b).'
124-
op_result = self.__le__(other)
126+
op_result = type(self).__le__(self, other)
125127
if op_result is NotImplemented:
126128
return op_result
127129
return not op_result
128130

129-
def _lt_from_gt(self, other, NotImplemented=NotImplemented):
131+
def _lt_from_gt(self, other):
130132
'Return a < b. Computed by @total_ordering from (not a > b) and (a != b).'
131-
op_result = self.__gt__(other)
133+
op_result = type(self).__gt__(self, other)
132134
if op_result is NotImplemented:
133135
return op_result
134136
return not op_result and self != other
135137

136-
def _ge_from_gt(self, other, NotImplemented=NotImplemented):
138+
def _ge_from_gt(self, other):
137139
'Return a >= b. Computed by @total_ordering from (a > b) or (a == b).'
138-
op_result = self.__gt__(other)
140+
op_result = type(self).__gt__(self, other)
141+
if op_result is NotImplemented:
142+
return op_result
139143
return op_result or self == other
140144

141-
def _le_from_gt(self, other, NotImplemented=NotImplemented):
145+
def _le_from_gt(self, other):
142146
'Return a <= b. Computed by @total_ordering from (not a > b).'
143-
op_result = self.__gt__(other)
147+
op_result = type(self).__gt__(self, other)
144148
if op_result is NotImplemented:
145149
return op_result
146150
return not op_result
147151

148-
def _le_from_ge(self, other, NotImplemented=NotImplemented):
152+
def _le_from_ge(self, other):
149153
'Return a <= b. Computed by @total_ordering from (not a >= b) or (a == b).'
150-
op_result = self.__ge__(other)
154+
op_result = type(self).__ge__(self, other)
151155
if op_result is NotImplemented:
152156
return op_result
153157
return not op_result or self == other
154158

155-
def _gt_from_ge(self, other, NotImplemented=NotImplemented):
159+
def _gt_from_ge(self, other):
156160
'Return a > b. Computed by @total_ordering from (a >= b) and (a != b).'
157-
op_result = self.__ge__(other)
161+
op_result = type(self).__ge__(self, other)
158162
if op_result is NotImplemented:
159163
return op_result
160164
return op_result and self != other
161165

162-
def _lt_from_ge(self, other, NotImplemented=NotImplemented):
166+
def _lt_from_ge(self, other):
163167
'Return a < b. Computed by @total_ordering from (not a >= b).'
164-
op_result = self.__ge__(other)
168+
op_result = type(self).__ge__(self, other)
165169
if op_result is NotImplemented:
166170
return op_result
167171
return not op_result
@@ -232,14 +236,14 @@ def __ge__(self, other):
232236

233237
def reduce(function, sequence, initial=_initial_missing):
234238
"""
235-
reduce(function, sequence[, initial]) -> value
239+
reduce(function, iterable[, initial]) -> value
236240
237-
Apply a function of two arguments cumulatively to the items of a sequence,
238-
from left to right, so as to reduce the sequence to a single value.
239-
For example, reduce(lambda x, y: x+y, [1, 2, 3, 4, 5]) calculates
241+
Apply a function of two arguments cumulatively to the items of a sequence
242+
or iterable, from left to right, so as to reduce the iterable to a single
243+
value. For example, reduce(lambda x, y: x+y, [1, 2, 3, 4, 5]) calculates
240244
((((1+2)+3)+4)+5). If initial is present, it is placed before the items
241-
of the sequence in the calculation, and serves as a default when the
242-
sequence is empty.
245+
of the iterable in the calculation, and serves as a default when the
246+
iterable is empty.
243247
"""
244248

245249
it = iter(sequence)
@@ -248,7 +252,8 @@ def reduce(function, sequence, initial=_initial_missing):
248252
try:
249253
value = next(it)
250254
except StopIteration:
251-
raise TypeError("reduce() of empty sequence with no initial value") from None
255+
raise TypeError(
256+
"reduce() of empty iterable with no initial value") from None
252257
else:
253258
value = initial
254259

@@ -347,23 +352,7 @@ class partialmethod(object):
347352
callables as instance methods.
348353
"""
349354

350-
def __init__(*args, **keywords):
351-
if len(args) >= 2:
352-
self, func, *args = args
353-
elif not args:
354-
raise TypeError("descriptor '__init__' of partialmethod "
355-
"needs an argument")
356-
elif 'func' in keywords:
357-
func = keywords.pop('func')
358-
self, *args = args
359-
import warnings
360-
warnings.warn("Passing 'func' as keyword argument is deprecated",
361-
DeprecationWarning, stacklevel=2)
362-
else:
363-
raise TypeError("type 'partialmethod' takes at least one argument, "
364-
"got %d" % (len(args)-1))
365-
args = tuple(args)
366-
355+
def __init__(self, func, /, *args, **keywords):
367356
if not callable(func) and not hasattr(func, "__get__"):
368357
raise TypeError("{!r} is not callable or a descriptor"
369358
.format(func))
@@ -381,7 +370,6 @@ def __init__(*args, **keywords):
381370
self.func = func
382371
self.args = args
383372
self.keywords = keywords
384-
__init__.__text_signature__ = '($self, func, /, *args, **keywords)'
385373

386374
def __repr__(self):
387375
args = ", ".join(map(repr, self.args))
@@ -427,6 +415,7 @@ def __isabstractmethod__(self):
427415

428416
__class_getitem__ = classmethod(GenericAlias)
429417

418+
430419
# Helper functions
431420

432421
def _unwrap_partial(func):
@@ -503,7 +492,7 @@ def lru_cache(maxsize=128, typed=False):
503492
with f.cache_info(). Clear the cache and statistics with f.cache_clear().
504493
Access the underlying function with f.__wrapped__.
505494
506-
See: http://en.wikipedia.org/wiki/Cache_replacement_policies#Least_recently_used_(LRU)
495+
See: https://en.wikipedia.org/wiki/Cache_replacement_policies#Least_recently_used_(LRU)
507496
508497
"""
509498

@@ -520,13 +509,15 @@ def lru_cache(maxsize=128, typed=False):
520509
# The user_function was passed in directly via the maxsize argument
521510
user_function, maxsize = maxsize, 128
522511
wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo)
512+
wrapper.cache_parameters = lambda : {'maxsize': maxsize, 'typed': typed}
523513
return update_wrapper(wrapper, user_function)
524514
elif maxsize is not None:
525515
raise TypeError(
526516
'Expected first argument to be an integer, a callable, or None')
527517

528518
def decorating_function(user_function):
529519
wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo)
520+
wrapper.cache_parameters = lambda : {'maxsize': maxsize, 'typed': typed}
530521
return update_wrapper(wrapper, user_function)
531522

532523
return decorating_function
@@ -653,14 +644,23 @@ def cache_clear():
653644
pass
654645

655646

647+
################################################################################
648+
### cache -- simplified access to the infinity cache
649+
################################################################################
650+
651+
def cache(user_function, /):
652+
'Simple lightweight unbounded cache. Sometimes called "memoize".'
653+
return lru_cache(maxsize=None)(user_function)
654+
655+
656656
################################################################################
657657
### singledispatch() - single-dispatch generic function decorator
658658
################################################################################
659659

660660
def _c3_merge(sequences):
661661
"""Merges MROs in *sequences* to a single MRO using the C3 algorithm.
662662
663-
Adapted from http://www.python.org/download/releases/2.3/mro/.
663+
Adapted from https://www.python.org/download/releases/2.3/mro/.
664664
665665
"""
666666
result = []
@@ -740,6 +740,7 @@ def _compose_mro(cls, types):
740740
# Remove entries which are already present in the __mro__ or unrelated.
741741
def is_related(typ):
742742
return (typ not in bases and hasattr(typ, '__mro__')
743+
and not isinstance(typ, GenericAlias)
743744
and issubclass(cls, typ))
744745
types = [n for n in types if is_related(n)]
745746
# Remove entries which are strict bases of other entries (they will end up
@@ -837,16 +838,33 @@ def dispatch(cls):
837838
dispatch_cache[cls] = impl
838839
return impl
839840

841+
def _is_union_type(cls):
842+
from typing import get_origin, Union
843+
return get_origin(cls) in {Union, types.UnionType}
844+
845+
def _is_valid_dispatch_type(cls):
846+
if isinstance(cls, type):
847+
return True
848+
from typing import get_args
849+
return (_is_union_type(cls) and
850+
all(isinstance(arg, type) for arg in get_args(cls)))
851+
840852
def register(cls, func=None):
841853
"""generic_func.register(cls, func) -> func
842854
843855
Registers a new implementation for the given *cls* on a *generic_func*.
844856
845857
"""
846858
nonlocal cache_token
847-
if func is None:
848-
if isinstance(cls, type):
859+
if _is_valid_dispatch_type(cls):
860+
if func is None:
849861
return lambda f: register(cls, f)
862+
else:
863+
if func is not None:
864+
raise TypeError(
865+
f"Invalid first argument to `register()`. "
866+
f"{cls!r} is not a class or union type."
867+
)
850868
ann = getattr(cls, '__annotations__', {})
851869
if not ann:
852870
raise TypeError(
@@ -859,12 +877,25 @@ def register(cls, func=None):
859877
# only import typing if annotation parsing is necessary
860878
from typing import get_type_hints
861879
argname, cls = next(iter(get_type_hints(func).items()))
862-
if not isinstance(cls, type):
863-
raise TypeError(
864-
f"Invalid annotation for {argname!r}. "
865-
f"{cls!r} is not a class."
866-
)
867-
registry[cls] = func
880+
if not _is_valid_dispatch_type(cls):
881+
if _is_union_type(cls):
882+
raise TypeError(
883+
f"Invalid annotation for {argname!r}. "
884+
f"{cls!r} not all arguments are classes."
885+
)
886+
else:
887+
raise TypeError(
888+
f"Invalid annotation for {argname!r}. "
889+
f"{cls!r} is not a class."
890+
)
891+
892+
if _is_union_type(cls):
893+
from typing import get_args
894+
895+
for arg in get_args(cls):
896+
registry[arg] = func
897+
else:
898+
registry[cls] = func
868899
if cache_token is None and hasattr(cls, '__abstractmethods__'):
869900
cache_token = get_cache_token()
870901
dispatch_cache.clear()

0 commit comments

Comments
 (0)