from __future__ import annotations import keyword import re from pathlib import Path from jinja2 import Environment, FileSystemLoader, StrictUndefined from .extensions import GeneratorExtensions from .model import ( AnyAnnotation, DictAnnotation, EnumDef, FieldDef, GeneratedArtifact, ListAnnotation, LiteralAnnotation, MappingAnnotation, NamedAnnotation, NormalizedSpec, OperationDef, TupleAnnotation, TypeAliasDef, TypeAnnotation, TypedDictDef, UnionAnnotation, ) def _render_annotation(annotation: TypeAnnotation) -> str: match annotation: case AnyAnnotation(): return "Any" case DictAnnotation(key, value): return f"dict[{_render_annotation(key)}, {_render_annotation(value)}]" case ListAnnotation(item): return f"list[{_render_annotation(item)}]" case MappingAnnotation(key, value): return f"Mapping[{_render_annotation(key)}, {_render_annotation(value)}]" case LiteralAnnotation(values): return f"Literal[{', '.join(repr(value) for value in values)}]" case NamedAnnotation(name): return name case TupleAnnotation(items): return f"tuple[{', '.join(_render_annotation(item) for item in items)}]" case UnionAnnotation(items): return " | ".join(_render_annotation(item) for item in items) or "Any" def _field_annotation(field: FieldDef) -> str: """ Jinja2 filter to format a FieldDef's annotation. """ annotation = repr(_render_field_annotation(field)) if not field.required: annotation = f"NotRequired[{annotation}]" return annotation def _class_field_annotation(field: FieldDef, total_optional: bool) -> str: annotation = _render_field_annotation(field) if not field.required and not total_optional: annotation = f"NotRequired[{annotation}]" return annotation def _render_field_annotation(field: FieldDef) -> str: annotation = _render_annotation(field.annotation) if field.description is None: return annotation return f"Annotated[{annotation}, _openapi_python_field({field.description!r})]" def _string_literal(value: str) -> str: return repr(value) def _comment(value: str, spaces: int = 0) -> str: prefix = " " * spaces return "\n".join( f"{prefix}# {line}" if line else f"{prefix}#" for line in value.splitlines() ) def _supports_typeddict_class_syntax(defn: TypedDictDef) -> bool: return all( field.name.isidentifier() and not keyword.iskeyword(field.name) and not field.name.startswith("__") for field in defn.fields ) _TEMPLATE_DIR = Path(__file__).with_name("templates") _JINJA_ENV = Environment( loader=FileSystemLoader(_TEMPLATE_DIR), trim_blocks=True, lstrip_blocks=True, undefined=StrictUndefined, ) _JINJA_ENV.filters["repr"] = repr _JINJA_ENV.filters["annotation"] = _render_annotation _JINJA_ENV.filters["field_annotation"] = _field_annotation _JINJA_ENV.filters["class_field_annotation"] = _class_field_annotation _JINJA_ENV.filters["comment"] = _comment _JINJA_ENV.filters["string_literal"] = _string_literal def _render_template(name: str, **context: object) -> str: return _JINJA_ENV.get_template(name).render(**context) def _indent(text: str, spaces: int = 4) -> str: prefix = " " * spaces return "\n".join((prefix + line) if line else "" for line in text.splitlines()) def _format_typeddict(defn: TypedDictDef) -> str: total_optional = bool(defn.fields) and all( not field.required for field in defn.fields ) return _render_template( "typeddict.py.j2", defn=defn, class_syntax=_supports_typeddict_class_syntax(defn), total_optional=total_optional, ) def _format_alias(alias: TypeAliasDef) -> str: return _render_template("alias.py.j2", alias=alias) def _enum_member_name(value: object) -> str: text = str(value).upper() text = re.sub(r"[^a-zA-Z0-9_]", "_", text).strip("_") or "VALUE" text = re.sub(r"_+", "_", text) if text[0].isdigit(): text = f"_{text}" if keyword.iskeyword(text.lower()): text = f"{text}_" return text def _enum_base(defn: EnumDef) -> str: if all(isinstance(value, str) for value in defn.values): return "str, Enum" if all( isinstance(value, int) and not isinstance(value, bool) for value in defn.values ): return "int, Enum" return "Enum" def _format_enum(defn: EnumDef) -> str: members: list[tuple[str, object]] = [] used: set[str] = set() for value in defn.values: name = _enum_member_name(value) base_name = name index = 2 while name in used: name = f"{base_name}_{index}" index += 1 used.add(name) members.append((name, value)) return _render_template( "enum.py.j2", defn=defn, enum_base=_enum_base(defn), members=members, ) def _annotation_dependencies(annotation: TypeAnnotation, names: set[str]) -> set[str]: match annotation: case AnyAnnotation() | LiteralAnnotation(): return set() case DictAnnotation(key, value): return _annotation_dependencies(key, names) | _annotation_dependencies( value, names ) case MappingAnnotation(key, value): return _annotation_dependencies(key, names) | _annotation_dependencies( value, names ) case ListAnnotation(item): return _annotation_dependencies(item, names) case NamedAnnotation(name): return {name} if name in names else set() case TupleAnnotation(items) | UnionAnnotation(items): dependencies: set[str] = set() for item in items: dependencies.update(_annotation_dependencies(item, names)) return dependencies def _typed_dict_dependencies(defn: TypedDictDef, names: set[str]) -> set[str]: dependencies: set[str] = set() for field in defn.fields: dependencies.update(_annotation_dependencies(field.annotation, names)) dependencies.discard(defn.name) return dependencies def _alias_dependencies(defn: TypeAliasDef, names: set[str]) -> set[str]: dependencies = _annotation_dependencies(defn.annotation, names) dependencies.discard(defn.name) return dependencies def _type_dependencies(defn: TypeAliasDef | TypedDictDef, names: set[str]) -> set[str]: match defn: case TypeAliasDef(): return _alias_dependencies(defn, names) case TypedDictDef(): return _typed_dict_dependencies(defn, names) def _order_type_definitions( aliases: tuple[TypeAliasDef, ...], typed_dicts: tuple[TypedDictDef, ...] ) -> list[TypeAliasDef | TypedDictDef]: by_name: dict[str, TypeAliasDef | TypedDictDef] = { item.name: item for item in (*aliases, *typed_dicts) } names = set(by_name) ordered: list[TypeAliasDef | TypedDictDef] = [] temporary: set[str] = set() permanent: set[str] = set() def visit(name: str) -> None: if name in permanent or name in temporary: return temporary.add(name) for dependency in sorted(_type_dependencies(by_name[name], names)): visit(dependency) temporary.remove(name) permanent.add(name) ordered.append(by_name[name]) for name in sorted(by_name): visit(name) return ordered def _format_type_definition(defn: TypeAliasDef | TypedDictDef) -> str: match defn: case TypeAliasDef(): return _format_alias(defn) case TypedDictDef(): return _format_typeddict(defn) def _has_field_descriptions(defns: tuple[TypedDictDef, ...]) -> bool: return any(field.description is not None for defn in defns for field in defn.fields) def _call_parameters(op: OperationDef, *, generate_requests: bool) -> dict[str, str]: if not generate_requests: return { "params": "params: dict[str, Any] | None = None", "query": "query: dict[str, Any] | None = None", "headers": "headers: dict[str, Any] | None = None", "body": "body: object | None = None", } params = "params: " + _render_annotation(op.params_type) if not op.params_required: params += " | None = None" query = "query: " + _render_annotation(op.query_type) if not op.query_required: query += " | None = None" headers = "headers: " + _render_annotation(op.headers_type) if not op.headers_required: headers += " | None = None" body = "body: object | None = None" if op.body_type is not None: body = f"body: {_render_annotation(op.body_type)}" if not op.body_required: body += " | None = None" return { "params": params, "query": query, "headers": headers, "body": body, } def _protocol_name(op: OperationDef, *, is_async: bool = False) -> str: return f"Async{op.protocol_name}" if is_async else op.protocol_name def _protocol_block( op: OperationDef, *, generate_requests: bool, generate_responses: bool, is_async: bool = False, ) -> str: return _render_template( "protocol.py.j2", op=op, is_async=is_async, protocol_name=_protocol_name(op, is_async=is_async), call_parameters=_call_parameters(op, generate_requests=generate_requests), response_type=( _render_annotation(op.response_type) if generate_responses else "Any" ), ) def _method_overload_line( op: OperationDef, *, return_type: str, is_async: bool = False ) -> str: return _render_template( "method_overload.py.j2", op=op, return_type=return_type, protocol_name=_protocol_name(op, is_async=is_async), ) def _fallback_method_block( method: str, overloads: list[str], *, is_async: bool = False ) -> str: return _render_template( "method_block.py.j2", method=method, overloads="\n".join(overloads), callable_return="Awaitable[Any]" if is_async else "object", call_return="Any" if is_async else "object", is_async=is_async, ) def _included_annotations( spec: NormalizedSpec, *, generate_requests: bool, generate_responses: bool ) -> tuple[TypeAnnotation, ...]: """ Returns the set of type annotations that be included in the rendered output. """ roots: list[TypeAnnotation] = [] for op in spec.operations: if generate_requests: roots.extend((op.params_type, op.query_type, op.headers_type)) if op.body_type is not None: roots.append(op.body_type) if generate_responses: roots.append(op.response_type) return tuple(roots) def _used_type_names( spec: NormalizedSpec, *, generate_requests: bool, generate_responses: bool ) -> set[str]: """ Returns the set of type definition names that are transitively referenced by the client protocols. """ by_name: dict[str, TypeAliasDef | TypedDictDef | EnumDef] = { item.name: item for item in (*spec.aliases, *spec.typed_dicts, *spec.enums) } all_names = set(by_name) used: set[str] = set() pending: list[str] = [] for annotation in _included_annotations( spec, generate_requests=generate_requests, generate_responses=generate_responses, ): pending.extend(_annotation_dependencies(annotation, all_names) - used) while pending: name = pending.pop() if name in used: continue used.add(name) item = by_name[name] match item: case TypeAliasDef() | TypedDictDef(): pending.extend(_type_dependencies(item, all_names) - used) case EnumDef(): pass return used def _render_types( spec: NormalizedSpec, *, generate_routes: bool, generate_requests: bool, generate_responses: bool, ) -> str: route_aliases = _route_aliases(spec, generate_routes=generate_routes) used_names = _used_type_names( spec, generate_requests=generate_requests, generate_responses=generate_responses, ) aliases = tuple(item for item in spec.aliases if item.name in used_names) typed_dicts = tuple(item for item in spec.typed_dicts if item.name in used_names) enums = tuple(item for item in spec.enums if item.name in used_names) type_definitions = _order_type_definitions((*route_aliases, *aliases), typed_dicts) blocks = [_format_enum(item) for item in enums] + [ _format_type_definition(item) for item in type_definitions ] return _render_template( "types.py.j2", has_field_descriptions=_has_field_descriptions(typed_dicts), type_blocks="\n".join(blocks).strip() + "\n", ) def rendered_type_definition_count( spec: NormalizedSpec, *, generate_routes: bool, generate_requests: bool, generate_responses: bool, ) -> int: used_names = _used_type_names( spec, generate_requests=generate_requests, generate_responses=generate_responses, ) return ( len(_route_aliases(spec, generate_routes=generate_routes)) + sum(1 for item in spec.aliases if item.name in used_names) + sum(1 for item in spec.typed_dicts if item.name in used_names) + sum(1 for item in spec.enums if item.name in used_names) ) def _literal_annotation(values: set[str]) -> LiteralAnnotation: return LiteralAnnotation(tuple(sorted(values))) def _route_aliases( spec: NormalizedSpec, *, generate_routes: bool ) -> tuple[TypeAliasDef, ...]: if not generate_routes: return (TypeAliasDef(name="RouteLiteral", annotation=NamedAnnotation("str")),) routes_by_method: dict[str, set[str]] = {} for op in spec.operations: routes_by_method.setdefault(op.method.upper(), set()).add(op.route_literal) aliases = [ TypeAliasDef( name=f"{method}_RouteLiteral", annotation=_literal_annotation(routes) ) for method, routes in sorted(routes_by_method.items()) ] if aliases: aliases.append( TypeAliasDef( name="RouteLiteral", annotation=UnionAnnotation( tuple(NamedAnnotation(alias.name) for alias in aliases) ), ) ) else: aliases.append( TypeAliasDef(name="RouteLiteral", annotation=NamedAnnotation("str")) ) return tuple(aliases) def _render_transport(spec: NormalizedSpec, *, protocol_only: bool) -> str: return _render_template( "transport.py.j2", typing_imports=("Protocol" if protocol_only else "TYPE_CHECKING, Protocol"), include_default_transport=not protocol_only, ) def _render_client( spec: NormalizedSpec, *, protocol_only: bool, generate_routes: bool, generate_requests: bool, generate_responses: bool, ) -> str: protocols: list[str] = [] async_protocols: list[str] = [] method_overloads: dict[str, list[str]] = {} async_method_overloads: dict[str, list[str]] = {} generate_operation_protocols = generate_routes and ( generate_requests or generate_responses ) for op in spec.operations: if generate_operation_protocols: protocols.append( _protocol_block( op, generate_requests=generate_requests, generate_responses=generate_responses, ) ) async_protocols.append( _protocol_block( op, generate_requests=generate_requests, generate_responses=generate_responses, is_async=True, ) ) method_overloads.setdefault(op.method, []).append( _method_overload_line( op, return_type=_protocol_name(op), is_async=False ) ) async_method_overloads.setdefault(op.method, []).append( _method_overload_line( op, return_type=_protocol_name(op, is_async=True), is_async=True, ) ) else: overloads = method_overloads.setdefault(op.method, []) async_overloads = async_method_overloads.setdefault(op.method, []) if generate_routes: overloads.append( _method_overload_line( op, return_type="Callable[..., object]", is_async=False ) ) async_overloads.append( _method_overload_line( op, return_type="Callable[..., Awaitable[Any]]", is_async=True ) ) method_blocks: list[str] = [] for method in sorted(method_overloads): method_blocks.append( _fallback_method_block( method, method_overloads[method], ) ) async_method_blocks: list[str] = [] for method in sorted(async_method_overloads): async_method_blocks.append( _fallback_method_block( method, async_method_overloads[method], is_async=True, ) ) if not protocol_only: transport_imports = ( "from .transport import AsyncTransport, DefaultAsyncTransport, " "DefaultTransport, Transport" ) sync_transport_type = "Transport | None = None" sync_transport_assignment = "transport or DefaultTransport()" async_transport_type = "AsyncTransport | None = None" async_transport_assignment = "transport or DefaultAsyncTransport()" else: transport_imports = "from .transport import AsyncTransport, Transport" sync_transport_type = "Transport" sync_transport_assignment = "transport" async_transport_type = "AsyncTransport" async_transport_assignment = "transport" return _render_template( "client.py.j2", transport_imports=transport_imports, sync_transport_type=sync_transport_type, sync_transport_assignment=sync_transport_assignment, async_transport_type=async_transport_type, async_transport_assignment=async_transport_assignment, protocol_blocks="\n".join(protocols).strip() + "\n", async_protocol_blocks="\n".join(async_protocols).strip() + "\n", method_blocks=_indent("\n".join(method_blocks).strip() + "\n"), async_method_blocks=_indent("\n".join(async_method_blocks).strip() + "\n"), ) def render_package( spec: NormalizedSpec, extensions: GeneratorExtensions | None = None, *, protocol_only: bool = False, generate_routes: bool = True, generate_requests: bool = True, generate_responses: bool = True, ) -> list[GeneratedArtifact]: context = { "types": _render_types( spec, generate_routes=generate_routes, generate_requests=generate_requests, generate_responses=generate_responses, ), "transport": _render_transport(spec, protocol_only=protocol_only), "client": _render_client( spec, protocol_only=protocol_only, generate_routes=generate_routes, generate_requests=generate_requests, generate_responses=generate_responses, ), } if extensions: for hook in extensions.render_context_hooks: context = hook(spec, context) init_content = _render_template( "init.py.j2", include_default_transport=not protocol_only ) return [ GeneratedArtifact( relative_path=f"{spec.package_name}/__init__.py", content=init_content ), GeneratedArtifact( relative_path=f"{spec.package_name}/types.py", content=context["types"] ), GeneratedArtifact( relative_path=f"{spec.package_name}/transport.py", content=context["transport"], ), GeneratedArtifact( relative_path=f"{spec.package_name}/client.py", content=context["client"] ), GeneratedArtifact(relative_path=f"{spec.package_name}/py.typed", content="\n"), ]