Skip to content

Commit 5a02150

Browse files
authored
Fix required and optional keys inheritance for TypedDict (#700)
(For a complete description of the issue see python/typing#700.)
1 parent 4fe38bc commit 5a02150

2 files changed

Lines changed: 54 additions & 13 deletions

File tree

src_py3/test_typing_extensions.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,16 @@ class LabelPoint2D(Point2D, Label): ...
434434
class Options(TypedDict, total=False):
435435
log_level: int
436436
log_path: str
437+
438+
class BaseAnimal(TypedDict):
439+
name: str
440+
441+
class Animal(BaseAnimal, total=False):
442+
voice: str
443+
tail: bool
444+
445+
class Cat(Animal):
446+
fur_color: str
437447
"""
438448

439449
if PY36:
@@ -444,6 +454,7 @@ class Options(TypedDict, total=False):
444454
A = B = CSub = G = CoolEmployee = CoolEmployeeWithDefault = object
445455
XMeth = XRepr = HasCallProtocol = NoneAndForward = Loop = object
446456
Point2D = Point2Dor3D = LabelPoint2D = Options = object
457+
BaseAnimal = Animal = Cat = object
447458

448459
gth = get_type_hints
449460

@@ -1549,6 +1560,29 @@ def test_optional_keys(self):
15491560
assert Point2Dor3D.__required_keys__ == frozenset(['x', 'y'])
15501561
assert Point2Dor3D.__optional_keys__ == frozenset(['z'])
15511562

1563+
@skipUnless(PY36, 'Python 3.6 required')
1564+
def test_keys_inheritance(self):
1565+
assert BaseAnimal.__required_keys__ == frozenset(['name'])
1566+
assert BaseAnimal.__optional_keys__ == frozenset([])
1567+
assert BaseAnimal.__annotations__ == {'name': str}
1568+
1569+
assert Animal.__required_keys__ == frozenset(['name'])
1570+
assert Animal.__optional_keys__ == frozenset(['tail', 'voice'])
1571+
assert Animal.__annotations__ == {
1572+
'name': str,
1573+
'tail': bool,
1574+
'voice': str,
1575+
}
1576+
1577+
assert Cat.__required_keys__ == frozenset(['name', 'fur_color'])
1578+
assert Cat.__optional_keys__ == frozenset(['tail', 'voice'])
1579+
assert Cat.__annotations__ == {
1580+
'fur_color': str,
1581+
'name': str,
1582+
'tail': bool,
1583+
'voice': str,
1584+
}
1585+
15521586

15531587
@skipUnless(TYPING_3_5_3, "Python >= 3.5.3 required")
15541588
class AnnotatedTests(BaseTestCase):

src_py3/typing_extensions.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1651,23 +1651,30 @@ def __new__(cls, name, bases, ns, total=True):
16511651
ns['__new__'] = _typeddict_new if name == 'TypedDict' else _dict_new
16521652
tp_dict = super(_TypedDictMeta, cls).__new__(cls, name, (dict,), ns)
16531653

1654-
anns = ns.get('__annotations__', {})
1654+
annotations = {}
1655+
own_annotations = ns.get('__annotations__', {})
1656+
own_annotation_keys = set(own_annotations.keys())
16551657
msg = "TypedDict('Name', {f0: t0, f1: t1, ...}); each t must be a type"
1656-
anns = {n: typing._type_check(tp, msg) for n, tp in anns.items()}
1657-
required = set(anns if total else ())
1658-
optional = set(() if total else anns)
1658+
own_annotations = {
1659+
n: typing._type_check(tp, msg) for n, tp in own_annotations.items()
1660+
}
1661+
required_keys = set()
1662+
optional_keys = set()
16591663

16601664
for base in bases:
1661-
base_anns = base.__dict__.get('__annotations__', {})
1662-
anns.update(base_anns)
1663-
if getattr(base, '__total__', True):
1664-
required.update(base_anns)
1665-
else:
1666-
optional.update(base_anns)
1665+
annotations.update(base.__dict__.get('__annotations__', {}))
1666+
required_keys.update(base.__dict__.get('__required_keys__', ()))
1667+
optional_keys.update(base.__dict__.get('__optional_keys__', ()))
1668+
1669+
annotations.update(own_annotations)
1670+
if total:
1671+
required_keys.update(own_annotation_keys)
1672+
else:
1673+
optional_keys.update(own_annotation_keys)
16671674

1668-
tp_dict.__annotations__ = anns
1669-
tp_dict.__required_keys__ = frozenset(required)
1670-
tp_dict.__optional_keys__ = frozenset(optional)
1675+
tp_dict.__annotations__ = annotations
1676+
tp_dict.__required_keys__ = frozenset(required_keys)
1677+
tp_dict.__optional_keys__ = frozenset(optional_keys)
16711678
if not hasattr(tp_dict, '__total__'):
16721679
tp_dict.__total__ = total
16731680
return tp_dict

0 commit comments

Comments
 (0)