Skip to content

Commit b2eff4a

Browse files
committed
add missing __class_getitem__ from cpython 3.10
Signed-off-by: snowapril <sinjihng@gmail.com>
1 parent 32ba09c commit b2eff4a

File tree

15 files changed

+203
-0
lines changed

15 files changed

+203
-0
lines changed

Lib/_collections_abc.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99
from abc import ABCMeta, abstractmethod
1010
import sys
1111

12+
GenericAlias = type(list[int])
13+
EllipsisType = type(...)
14+
def _f(): pass
15+
FunctionType = type(_f)
16+
del _f
17+
1218
__all__ = ["Awaitable", "Coroutine",
1319
"AsyncIterable", "AsyncIterator", "AsyncGenerator",
1420
"Hashable", "Iterable", "Iterator", "Generator", "Reversible",
@@ -110,6 +116,8 @@ def __subclasshook__(cls, C):
110116
return _check_methods(C, "__await__")
111117
return NotImplemented
112118

119+
__class_getitem__ = classmethod(GenericAlias)
120+
113121

114122
class Coroutine(Awaitable):
115123

@@ -169,6 +177,8 @@ def __subclasshook__(cls, C):
169177
return _check_methods(C, "__aiter__")
170178
return NotImplemented
171179

180+
__class_getitem__ = classmethod(GenericAlias)
181+
172182

173183
class AsyncIterator(AsyncIterable):
174184

@@ -255,6 +265,8 @@ def __subclasshook__(cls, C):
255265
return _check_methods(C, "__iter__")
256266
return NotImplemented
257267

268+
__class_getitem__ = classmethod(GenericAlias)
269+
258270

259271
class Iterator(Iterable):
260272

@@ -384,6 +396,8 @@ def __subclasshook__(cls, C):
384396
return _check_methods(C, "__contains__")
385397
return NotImplemented
386398

399+
__class_getitem__ = classmethod(GenericAlias)
400+
387401
class Collection(Sized, Iterable, Container):
388402

389403
__slots__ = ()
@@ -394,6 +408,141 @@ def __subclasshook__(cls, C):
394408
return _check_methods(C, "__len__", "__iter__", "__contains__")
395409
return NotImplemented
396410

411+
412+
class _CallableGenericAlias(GenericAlias):
413+
""" Represent `Callable[argtypes, resulttype]`.
414+
415+
This sets ``__args__`` to a tuple containing the flattened ``argtypes``
416+
followed by ``resulttype``.
417+
418+
Example: ``Callable[[int, str], float]`` sets ``__args__`` to
419+
``(int, str, float)``.
420+
"""
421+
422+
__slots__ = ()
423+
424+
def __new__(cls, origin, args):
425+
if not (isinstance(args, tuple) and len(args) == 2):
426+
raise TypeError(
427+
"Callable must be used as Callable[[arg, ...], result].")
428+
t_args, t_result = args
429+
if isinstance(t_args, list):
430+
args = (*t_args, t_result)
431+
elif not _is_param_expr(t_args):
432+
raise TypeError(f"Expected a list of types, an ellipsis, "
433+
f"ParamSpec, or Concatenate. Got {t_args}")
434+
return super().__new__(cls, origin, args)
435+
436+
@property
437+
def __parameters__(self):
438+
params = []
439+
for arg in self.__args__:
440+
# Looks like a genericalias
441+
if hasattr(arg, "__parameters__") and isinstance(arg.__parameters__, tuple):
442+
params.extend(arg.__parameters__)
443+
else:
444+
if _is_typevarlike(arg):
445+
params.append(arg)
446+
return tuple(dict.fromkeys(params))
447+
448+
def __repr__(self):
449+
if len(self.__args__) == 2 and _is_param_expr(self.__args__[0]):
450+
return super().__repr__()
451+
return (f'collections.abc.Callable'
452+
f'[[{", ".join([_type_repr(a) for a in self.__args__[:-1]])}], '
453+
f'{_type_repr(self.__args__[-1])}]')
454+
455+
def __reduce__(self):
456+
args = self.__args__
457+
if not (len(args) == 2 and _is_param_expr(args[0])):
458+
args = list(args[:-1]), args[-1]
459+
return _CallableGenericAlias, (Callable, args)
460+
461+
def __getitem__(self, item):
462+
# Called during TypeVar substitution, returns the custom subclass
463+
# rather than the default types.GenericAlias object. Most of the
464+
# code is copied from typing's _GenericAlias and the builtin
465+
# types.GenericAlias.
466+
467+
# A special case in PEP 612 where if X = Callable[P, int],
468+
# then X[int, str] == X[[int, str]].
469+
param_len = len(self.__parameters__)
470+
if param_len == 0:
471+
raise TypeError(f'{self} is not a generic class')
472+
if not isinstance(item, tuple):
473+
item = (item,)
474+
if (param_len == 1 and _is_param_expr(self.__parameters__[0])
475+
and item and not _is_param_expr(item[0])):
476+
item = (list(item),)
477+
item_len = len(item)
478+
if item_len != param_len:
479+
raise TypeError(f'Too {"many" if item_len > param_len else "few"}'
480+
f' arguments for {self};'
481+
f' actual {item_len}, expected {param_len}')
482+
subst = dict(zip(self.__parameters__, item))
483+
new_args = []
484+
for arg in self.__args__:
485+
if _is_typevarlike(arg):
486+
if _is_param_expr(arg):
487+
arg = subst[arg]
488+
if not _is_param_expr(arg):
489+
raise TypeError(f"Expected a list of types, an ellipsis, "
490+
f"ParamSpec, or Concatenate. Got {arg}")
491+
else:
492+
arg = subst[arg]
493+
# Looks like a GenericAlias
494+
elif hasattr(arg, '__parameters__') and isinstance(arg.__parameters__, tuple):
495+
subparams = arg.__parameters__
496+
if subparams:
497+
subargs = tuple(subst[x] for x in subparams)
498+
arg = arg[subargs]
499+
new_args.append(arg)
500+
501+
# args[0] occurs due to things like Z[[int, str, bool]] from PEP 612
502+
if not isinstance(new_args[0], list):
503+
t_result = new_args[-1]
504+
t_args = new_args[:-1]
505+
new_args = (t_args, t_result)
506+
return _CallableGenericAlias(Callable, tuple(new_args))
507+
508+
509+
def _is_typevarlike(arg):
510+
obj = type(arg)
511+
# looks like a TypeVar/ParamSpec
512+
return (obj.__module__ == 'typing'
513+
and obj.__name__ in {'ParamSpec', 'TypeVar'})
514+
515+
def _is_param_expr(obj):
516+
"""Checks if obj matches either a list of types, ``...``, ``ParamSpec`` or
517+
``_ConcatenateGenericAlias`` from typing.py
518+
"""
519+
if obj is Ellipsis:
520+
return True
521+
if isinstance(obj, list):
522+
return True
523+
obj = type(obj)
524+
names = ('ParamSpec', '_ConcatenateGenericAlias')
525+
return obj.__module__ == 'typing' and any(obj.__name__ == name for name in names)
526+
527+
def _type_repr(obj):
528+
"""Return the repr() of an object, special-casing types (internal helper).
529+
530+
Copied from :mod:`typing` since collections.abc
531+
shouldn't depend on that module.
532+
"""
533+
if isinstance(obj, GenericAlias):
534+
return repr(obj)
535+
if isinstance(obj, type):
536+
if obj.__module__ == 'builtins':
537+
return obj.__qualname__
538+
return f'{obj.__module__}.{obj.__qualname__}'
539+
if obj is Ellipsis:
540+
return '...'
541+
if isinstance(obj, FunctionType):
542+
return obj.__name__
543+
return repr(obj)
544+
545+
397546
class Callable(metaclass=ABCMeta):
398547

399548
__slots__ = ()
@@ -408,6 +557,8 @@ def __subclasshook__(cls, C):
408557
return _check_methods(C, "__call__")
409558
return NotImplemented
410559

560+
__class_getitem__ = classmethod(_CallableGenericAlias)
561+
411562

412563
### SETS ###
413564

@@ -703,6 +854,8 @@ def __len__(self):
703854
def __repr__(self):
704855
return '{0.__class__.__name__}({0._mapping!r})'.format(self)
705856

857+
__class_getitem__ = classmethod(GenericAlias)
858+
706859

707860
class KeysView(MappingView, Set):
708861

Lib/_weakrefset.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# by abc.py to load everything else at startup.
44

55
from _weakref import ref
6+
from types import GenericAlias
67

78
__all__ = ['WeakSet']
89

@@ -197,3 +198,5 @@ def isdisjoint(self, other):
197198

198199
def __repr__(self):
199200
return repr(self.data)
201+
202+
__class_getitem__ = classmethod(GenericAlias)

Lib/asyncio/futures.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,9 @@ def __del__(self):
184184
context['source_traceback'] = self._source_traceback
185185
self._loop.call_exception_handler(context)
186186

187+
def __class_getitem__(cls, type):
188+
return cls
189+
187190
def cancel(self):
188191
"""Cancel the future and schedule callbacks.
189192

Lib/asyncio/queues.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ def __repr__(self):
8181
def __str__(self):
8282
return '<{} {}>'.format(type(self).__name__, self._format())
8383

84+
def __class_getitem__(cls, type):
85+
return cls
86+
8487
def _format(self):
8588
result = 'maxsize={!r}'.format(self._maxsize)
8689
if getattr(self, '_queue', None):

Lib/concurrent/futures/_base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import logging
88
import threading
99
import time
10+
import types
1011

1112
FIRST_COMPLETED = 'FIRST_COMPLETED'
1213
FIRST_EXCEPTION = 'FIRST_EXCEPTION'
@@ -544,6 +545,8 @@ def set_exception(self, exception):
544545
self._condition.notify_all()
545546
self._invoke_callbacks()
546547

548+
__class_getitem__ = classmethod(types.GenericAlias)
549+
547550
class Executor(object):
548551
"""This is an abstract base class for concrete asynchronous executors."""
549552

Lib/concurrent/futures/thread.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import itertools
1111
import queue
1212
import threading
13+
import types
1314
import weakref
1415
import os
1516

@@ -62,6 +63,8 @@ def run(self):
6263
else:
6364
self.future.set_result(result)
6465

66+
__class_getitem__ = classmethod(types.GenericAlias)
67+
6568

6669
def _worker(executor_reference, work_queue, initializer, initargs):
6770
if initializer is not None:

Lib/dataclasses.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import builtins
88
import functools
99
import _thread
10+
from types import GenericAlias
1011

1112

1213
__all__ = ['dataclass',
@@ -217,6 +218,8 @@ def __repr__(self):
217218
type_name = repr(self.type)
218219
return f'dataclasses.InitVar[{type_name}]'
219220

221+
def __class_getitem__(cls, type):
222+
return InitVar(type)
220223

221224
# Instances of Field are only ever created from within this module,
222225
# and only from the field() function, although Field instances are
@@ -285,6 +288,8 @@ def __set_name__(self, owner, name):
285288
# it.
286289
func(self.default, owner, name)
287290

291+
__class_getitem__ = classmethod(GenericAlias)
292+
288293

289294
class _DataclassParams:
290295
__slots__ = ('init',

Lib/difflib.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
from heapq import nlargest as _nlargest
3434
from collections import namedtuple as _namedtuple
35+
from types import GenericAlias
3536

3637
Match = _namedtuple('Match', 'a b size')
3738

@@ -685,6 +686,8 @@ def real_quick_ratio(self):
685686
# shorter sequence
686687
return _calculate_ratio(min(la, lb), la + lb)
687688

689+
__class_getitem__ = classmethod(GenericAlias)
690+
688691
def get_close_matches(word, possibilities, n=3, cutoff=0.6):
689692
"""Use SequenceMatcher to return list of the best "good enough" matches.
690693

Lib/functools.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from _thread import RLock
2323
except ModuleNotFoundError:
2424
from _dummy_thread import RLock
25+
from types import GenericAlias
2526

2627

2728
################################################################################
@@ -427,6 +428,8 @@ def __get__(self, obj, cls=None):
427428
def __isabstractmethod__(self):
428429
return getattr(self.func, "__isabstractmethod__", False)
429430

431+
__class_getitem__ = classmethod(GenericAlias)
432+
430433
# Helper functions
431434

432435
def _unwrap_partial(func):
@@ -977,3 +980,5 @@ def __get__(self, instance, owner=None):
977980
)
978981
raise TypeError(msg) from None
979982
return val
983+
984+
__class_getitem__ = classmethod(GenericAlias)

Lib/multiprocessing/managers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import threading
1919
import array
2020
import queue
21+
import types
2122

2223
from time import time as _time
2324
from traceback import format_exc
@@ -1078,6 +1079,8 @@ def set(self, value):
10781079
return self._callmethod('set', (value,))
10791080
value = property(get, set)
10801081

1082+
__class_getitem__ = classmethod(types.GenericAlias)
1083+
10811084

10821085
BaseListProxy = MakeProxyType('BaseListProxy', (
10831086
'__add__', '__contains__', '__delitem__', '__getitem__', '__len__',

0 commit comments

Comments
 (0)