Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
FIx TypedDict required and optional keys inheritance
  • Loading branch information
vemel committed Feb 7, 2020
commit ebfe32148fc3b01110c75ec76e3f499d4ddc64ad
35 changes: 22 additions & 13 deletions typing_extensions/src_py3/typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1651,23 +1651,32 @@ def __new__(cls, name, bases, ns, total=True):
ns['__new__'] = _typeddict_new if name == 'TypedDict' else _dict_new
tp_dict = super(_TypedDictMeta, cls).__new__(cls, name, (dict,), ns)

anns = ns.get('__annotations__', {})
annotations = {}
own_annotations = ns.get('__annotations__', {})
own_annotation_keys = set(annotations.keys())
msg = "TypedDict('Name', {f0: t0, f1: t1, ...}); each t must be a type"
anns = {n: typing._type_check(tp, msg) for n, tp in anns.items()}
required = set(anns if total else ())
optional = set(() if total else anns)
own_annotations = {
n: typing._type_check(tp, msg) for n, tp in own_annotations.items()
}
required_keys = set()
optional_keys = set()

for base in bases:
base_anns = base.__dict__.get('__annotations__', {})
anns.update(base_anns)
if getattr(base, '__total__', True):
required.update(base_anns)
else:
optional.update(base_anns)
annotations.update(base.__dict__.get('__annotations__', {}))
required_keys.update(base.__dict__.get('__required_keys__', ()))
optional_keys.update(base.__dict__.get('__optional_keys__', ()))

annotations.update(own_annotations)
if total:
required_keys.update(own_annotation_keys)
optional_keys.difference_update(own_annotation_keys)
else:
optional_keys.update(own_annotation_keys)
required_keys.difference_update(own_annotation_keys)

tp_dict.__annotations__ = anns
tp_dict.__required_keys__ = frozenset(required)
tp_dict.__optional_keys__ = frozenset(optional)
tp_dict.__annotations__ = annotations
tp_dict.__required_keys__ = frozenset(required_keys)
tp_dict.__optional_keys__ = frozenset(optional_keys)
if not hasattr(tp_dict, '__total__'):
tp_dict.__total__ = total
return tp_dict
Expand Down