11"""Post-process datamodel-codegen output to fix known issues and prune the TypedDict file.
22
33Applied to `_models.py`:
4- - Fix discriminator field names that use camelCase instead of snake_case (known issue with
5- discriminators on schemas referenced from array items).
4+ - Fix discriminator field names that use camelCase instead of snake_case (known issue with discriminators on schemas
5+ referenced from array items).
66- Rewrite every `class X(StrEnum)` as `X = Literal[...]` so downstream code can pass plain strings
77 (and reuse the named alias in resource-client signatures) instead of enum members.
8- - Move the resulting `X = Literal[...]` definitions into `_literals.py`, leaving
9- `_models.py` importing them — so consumers can depend on a dedicated literals module
10- without pulling in every Pydantic model.
8+ - Move the resulting `X = Literal[...]` definitions into `_literals.py`, leaving `_models.py` importing them — so
9+ consumers can depend on a dedicated literals module without pulling in every Pydantic model.
1110- Add `@docs_group('Models')` to every model class (plus the required import).
1211
1312Applied to `_typeddicts.py`:
14- - Keep only the TypedDicts actually used as resource-client method inputs (plus their transitive
15- dependencies). The file is generated in full by datamodel-codegen; the trimming happens here.
13+ - Keep only the TypedDicts actually used as resource-client method inputs (plus their transitive dependencies).
14+ The file is generated in full by datamodel-codegen; the trimming happens here.
1615- Rename every kept class to add a `Dict` suffix so it doesn't clash with the Pydantic model name
1716 (e.g. `WebhookCreate` -> `WebhookCreateDict`) and rewire references.
17+ - Generate a camelCase sibling for every kept TypedDict (`FooDict` -> `FooCamelDict`) so users can pass API-shaped
18+ dicts and still satisfy the type checker. Field identifiers are looked up in the Pydantic alias map extracted
19+ from `_models.py`; nested TypedDict refs are rewired to the camel variant.
1820- Add `@docs_group('Typed dicts')` to every kept class.
1921"""
2022
@@ -391,6 +393,194 @@ def rename_with_dict_suffix(content: str, names: set[str]) -> str:
391393 return content
392394
393395
396+ def _extract_alias_from_field_call (field_call : ast .Call ) -> str | None :
397+ """Return the `alias=` kwarg value from a `Field(...)` call, or None if not present."""
398+ for kw in field_call .keywords :
399+ if kw .arg == 'alias' and isinstance (kw .value , ast .Constant ) and isinstance (kw .value .value , str ):
400+ return kw .value .value
401+ return None
402+
403+
404+ def _extract_class_field_aliases (class_node : ast .ClassDef ) -> dict [str , str ]:
405+ """Return `{snake_field: api_field}` for every annotated field declared on `class_node`.
406+
407+ Fields without a `Field(alias=...)` map to themselves (their declared Python name matches the API name — typical
408+ for single-word fields like `url`, `id`).
409+ """
410+ aliases : dict [str , str ] = {}
411+ for stmt in class_node .body :
412+ if not isinstance (stmt , ast .AnnAssign ) or not isinstance (stmt .target , ast .Name ):
413+ continue
414+ field_name = stmt .target .id
415+ if field_name == 'model_config' :
416+ continue
417+ # Default: no alias means snake name == API name.
418+ api_name = field_name
419+ # Walk the annotation to find a nested `Field(alias='...')` call inside `Annotated[...]`.
420+ for sub in ast .walk (stmt .annotation ):
421+ if isinstance (sub , ast .Call ) and isinstance (sub .func , ast .Name ) and sub .func .id == 'Field' :
422+ found = _extract_alias_from_field_call (sub )
423+ if found is not None :
424+ api_name = found
425+ break
426+ aliases [field_name ] = api_name
427+ return aliases
428+
429+
430+ def build_alias_map (models_source : str ) -> dict [str , dict [str , str ]]:
431+ """Return `{ModelName: {snake_field: api_field}}` for every Pydantic model in `models_source`.
432+
433+ The map is the source of truth for camelCase field names: it captures both `Field(alias=...)` overrides
434+ and the bare-name case (single-word fields without an alias). Used when synthesizing camelCase TypedDict
435+ variants so the API spelling round-trips losslessly.
436+ """
437+ tree = ast .parse (models_source )
438+ return {node .name : _extract_class_field_aliases (node ) for node in tree .body if isinstance (node , ast .ClassDef )}
439+
440+
441+ def _camel_dict_name (snake_name : str ) -> str :
442+ """Insert `Camel` before the trailing `Dict` (e.g. `RequestDict` -> `RequestCamelDict`)."""
443+ if not snake_name .endswith ('Dict' ):
444+ raise ValueError (f"Expected name to end with 'Dict': { snake_name !r} " )
445+ return snake_name [: - len ('Dict' )] + 'CamelDict'
446+
447+
448+ def _is_dict_str_any (node : ast .expr ) -> bool :
449+ """Return True if `node` is a `dict[str, Any]` subscript (casing-agnostic open mapping)."""
450+ return isinstance (node , ast .Subscript ) and isinstance (node .value , ast .Name ) and node .value .id == 'dict'
451+
452+
453+ def _rename_fields_in_class_block (block : list [str ], field_aliases : dict [str , str ]) -> list [str ]:
454+ """Rewrite each field declaration line in `block` using `field_aliases`.
455+
456+ Matches lines of the form `<indent><snake_ident>:` and substitutes the identifier when an alias is present.
457+ Multi-line annotations and trailing default values are preserved verbatim because only the field name
458+ on the first line is replaced.
459+ """
460+ field_decl = re .compile (r'^(\s+)([a-z_][a-z0-9_]*)(\s*:)' )
461+ out : list [str ] = []
462+ for line in block :
463+ m = field_decl .match (line )
464+ if m is None :
465+ out .append (line )
466+ continue
467+ indent , name , colon = m .group (1 ), m .group (2 ), m .group (3 )
468+ api_name = field_aliases .get (name )
469+ if api_name is None or api_name == name :
470+ out .append (line )
471+ continue
472+ out .append (f'{ indent } { api_name } { colon } { line [m .end () :]} ' )
473+ return out
474+
475+
476+ def _rename_typeddict_refs_in_block (block : list [str ], rename_set : set [str ]) -> list [str ]:
477+ """Rewrite every whole-word occurrence of each name in `rename_set` to its camel form.
478+
479+ Operates on the block as a single string so refs spanning multiple lines (e.g. annotations wrapped across lines)
480+ are caught.
481+ """
482+ if not rename_set :
483+ return block
484+ text = '\n ' .join (block )
485+ # `\b` anchors already prevent partial-prefix matches; we just iterate the set in any stable
486+ # order. Sorting keeps the substitution deterministic across Python hash seeds.
487+ for snake in sorted (rename_set ):
488+ text = re .sub (rf'\b{ re .escape (snake )} \b' , _camel_dict_name (snake ), text )
489+ return text .split ('\n ' )
490+
491+
492+ def add_camel_case_typeddicts (content : str , alias_map : dict [str , dict [str , str ]]) -> str :
493+ """Insert a camelCase sibling for every TypedDict and TypeAlias in `content`.
494+
495+ For each class `<Name>Dict(TypedDict)` and each `<Name>Dict: TypeAlias = ...`, emit a sibling `<Name>CamelDict`
496+ directly after the original. Field identifiers are renamed using `alias_map[<Name>]`; nested TypedDict references
497+ in annotations are rewired to their camel variant via whole-word substitution.
498+
499+ `TaskInputDict: TypeAlias = dict[str, Any]` and similar casing-agnostic aliases get a trivial camel alias too,
500+ so refs from other camel TypedDicts (e.g. `RequestBaseCamelDict.user_data: NotRequired[RequestUserDataCamelDict]`)
501+ resolve cleanly.
502+
503+ Idempotent: blocks whose name already ends with `CamelDict` are skipped.
504+ """
505+ tree = ast .parse (content )
506+ lines = content .split ('\n ' )
507+
508+ # Pass 1: gather every snake-side symbol that needs a camel sibling.
509+ snake_classes : list [tuple [ast .ClassDef , int , int ]] = [] # node, block_start, block_end (exclusive)
510+ snake_aliases : list [tuple [int , int ]] = [] # block_start, block_end
511+ flat_aliases : list [tuple [int , str ]] = [] # block_end, alias_name
512+
513+ body_with_trailing_docstrings = _extract_top_level_symbols (tree )
514+ end_by_name : dict [str , int ] = {name : end for name , _ , end in body_with_trailing_docstrings }
515+ existing_symbols : set [str ] = {name for name , _ , _ in body_with_trailing_docstrings }
516+
517+ for node in tree .body :
518+ if isinstance (node , ast .ClassDef ):
519+ # Every class kept in `_typeddicts.py` is a TypedDict — either directly (base is `TypedDict`) or by
520+ # inheriting from a sibling TypedDict (e.g. `RequestDict(RequestBaseDict)`). The `Dict` suffix
521+ # is the load-bearing filter; the base check is informational only.
522+ if not node .name .endswith ('Dict' ) or node .name .endswith ('CamelDict' ):
523+ continue
524+ if _camel_dict_name (node .name ) in existing_symbols :
525+ continue
526+ start = node .lineno - 1
527+ if start > 0 and lines [start - 1 ].lstrip ().startswith ('@' ):
528+ start -= 1
529+ end = end_by_name .get (node .name , node .end_lineno or node .lineno )
530+ snake_classes .append ((node , start , end ))
531+ elif (
532+ isinstance (node , ast .AnnAssign )
533+ and isinstance (node .target , ast .Name )
534+ and isinstance (node .annotation , ast .Name )
535+ and node .annotation .id == 'TypeAlias'
536+ ):
537+ name = node .target .id
538+ if not name .endswith ('Dict' ) or name .endswith ('CamelDict' ):
539+ continue
540+ if _camel_dict_name (name ) in existing_symbols :
541+ continue
542+ if node .value is None :
543+ continue
544+ start = node .lineno - 1
545+ end = end_by_name .get (name , node .end_lineno or node .lineno )
546+ if _is_dict_str_any (node .value ):
547+ flat_aliases .append ((end , name ))
548+ else :
549+ snake_aliases .append ((start , end ))
550+
551+ # The rename set covers EVERY snake-side `*Dict` symbol in the file (not just the ones we need to clone)
552+ # so nested refs inside a cloned block still rewire correctly even on re-runs where most camel siblings
553+ # already exist.
554+ rename_set : set [str ] = {
555+ name for name in existing_symbols if name .endswith ('Dict' ) and not name .endswith ('CamelDict' )
556+ }
557+
558+ # Pass 2: build camel blocks.
559+ insertions : list [tuple [int , list [str ]]] = []
560+
561+ for class_node , start , end in snake_classes :
562+ block = lines [start :end ]
563+ renamed_refs = _rename_typeddict_refs_in_block (block , rename_set )
564+ field_aliases = alias_map .get (class_node .name [: - len ('Dict' )], {})
565+ camel_block = _rename_fields_in_class_block (renamed_refs , field_aliases )
566+ insertions .append ((end , ['' , * camel_block ]))
567+
568+ for start , end in snake_aliases :
569+ block = lines [start :end ]
570+ camel_block = _rename_typeddict_refs_in_block (block , rename_set )
571+ insertions .append ((end , ['' , * camel_block ]))
572+
573+ for end , name in flat_aliases :
574+ insertions .append ((end , ['' , f'{ _camel_dict_name (name )} : TypeAlias = dict[str, Any]' ]))
575+
576+ # Insert in reverse line order so earlier indices stay valid.
577+ new_lines = lines [:]
578+ for after , block in sorted (insertions , key = lambda i : i [0 ], reverse = True ):
579+ new_lines [after :after ] = block
580+
581+ return _collapse_blank_lines ('\n ' .join (new_lines ))
582+
583+
394584def postprocess_models (models_path : Path , literals_path : Path ) -> list [Path ]:
395585 """Apply `_models.py`-specific fixes and emit `_literals.py`.
396586
@@ -414,13 +604,14 @@ def postprocess_models(models_path: Path, literals_path: Path) -> list[Path]:
414604 return changed
415605
416606
417- def postprocess_typeddicts (path : Path ) -> bool :
607+ def postprocess_typeddicts (path : Path , alias_map : dict [ str , dict [ str , str ]] ) -> bool :
418608 """Apply `_typeddicts.py`-specific fixes. Returns True if the file changed."""
419609 original = path .read_text ()
420610 pruned , kept = prune_typeddicts (original , RESOURCE_INPUT_TYPEDDICTS )
421611 renamed = rename_with_dict_suffix (pruned , kept )
422612 flattened = flatten_empty_typeddicts (renamed )
423- final = add_docs_group_decorators (flattened , 'Typed dicts' )
613+ camelized = add_camel_case_typeddicts (flattened , alias_map )
614+ final = add_docs_group_decorators (camelized , 'Typed dicts' )
424615 if final == original :
425616 return False
426617 path .write_text (final )
@@ -442,9 +633,10 @@ def main() -> None:
442633 else :
443634 print ('No fixes needed for _models.py / _literals.py' )
444635
445- if postprocess_typeddicts (TYPEDDICTS_PATH ):
636+ alias_map = build_alias_map (MODELS_PATH .read_text ())
637+ if postprocess_typeddicts (TYPEDDICTS_PATH , alias_map ):
446638 changed .append (TYPEDDICTS_PATH )
447- print (f'Pruned and renamed TypedDicts in { TYPEDDICTS_PATH } ' )
639+ print (f'Pruned, renamed, and camelized TypedDicts in { TYPEDDICTS_PATH } ' )
448640 else :
449641 print ('No fixes needed for _typeddicts.py' )
450642
0 commit comments