Skip to content

Commit a62ba8c

Browse files
[ty] Fix overloaded callable assignability for unary Callable targets (#23277)
## Summary Fixes astral-sh/ty#2546 Improves assignability checking from overloaded callables to a single `Callable[...]` target with an explicit union parameter domain. Previously, this was handled with per-overload `when_any` checks. This PR replaces that with an aggregate probe over the overload set that: - filters down to overlapping overload arms, - unions their parameter domains and return types, - checks parameter coverage and return compatibility against the target callable. The aggregate probe is accept-only ie. if it isn't definitively satisfied, we fall back to the existing `when_any` behavior. This change is intentionally scoped to unary targets with explicit union domains and excludes dynamic/typevar candidates. General `n > 1` overload-set assignability is a way larger problem left for later. ## Test Plan - Added/updated mdtests for explicit [ #2546 ](astral-sh/ty#2546) repros and negative cases (missing domain coverage, incompatible return union) in legacy generic-callables. - Added mdtests for overloaded generic-callable argument handling in legacy and PEP 695 callables, including `Callable[[T], T]` under union-constrained inference. - Added dataclass-transform converter coverage (`overloaded_converter`, `ConverterClass`) with an explicit TODO expectation for the still unhandled `converter=dict` case. - Added a reduced SymPy one-import MRE to lock the overload/protocol panic shape. - Updated Liskov tests for unannotated overrides of overloaded dunder methods. --------- Co-authored-by: Douglas Creager <dcreager@dcreager.net>
1 parent e5f2f36 commit a62ba8c

9 files changed

Lines changed: 562 additions & 10 deletions

File tree

crates/ty_python_semantic/resources/mdtest/annotations/callable.md

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,100 @@ def _(c: Callable[[int, str], int]):
187187
reveal_type(c) # revealed: (int, str, /) -> int
188188
```
189189

190+
## Overloaded callable assignability
191+
192+
An overloaded callable should be assignable to a non-overloaded callable type when the overload set
193+
as a whole is compatible with the target callable.
194+
195+
```py
196+
from typing import Callable, overload
197+
198+
@overload
199+
def foo(x: int) -> str: ...
200+
@overload
201+
def foo(x: str) -> str: ...
202+
def foo(x: int | str) -> str:
203+
return str(x)
204+
205+
def expects(c: Callable[[int | str], str]) -> None:
206+
pass
207+
208+
expects(foo)
209+
```
210+
211+
```py
212+
from typing import overload
213+
214+
@overload
215+
def foo(x: int) -> str: ...
216+
@overload
217+
def foo(x: str) -> str: ...
218+
def foo(x: int | str) -> str:
219+
return str(x)
220+
221+
def errors() -> None:
222+
for x in map(foo, range(1, 10)):
223+
print(x)
224+
```
225+
226+
```py
227+
from typing import Callable, overload
228+
229+
@overload
230+
def converter(x: int) -> str: ...
231+
@overload
232+
def converter(x: bytes) -> bytes: ...
233+
def converter(x: int | bytes) -> str | bytes:
234+
if isinstance(x, int):
235+
return str(x)
236+
return x
237+
238+
def expects_int_str(c: Callable[[int], str]) -> None:
239+
pass
240+
241+
expects_int_str(converter)
242+
```
243+
244+
The overload set must cover the full target parameter domain.
245+
246+
```py
247+
from typing import Callable, overload
248+
249+
@overload
250+
def partial_converter(x: int) -> str: ...
251+
@overload
252+
def partial_converter(x: bytes) -> str: ...
253+
def partial_converter(x: int | bytes) -> str:
254+
return str(x)
255+
256+
def expects_int_or_str(c: Callable[[int | str], str]) -> None:
257+
pass
258+
259+
# error: [invalid-argument-type]
260+
expects_int_or_str(partial_converter)
261+
```
262+
263+
Even when the parameter domain is covered, return compatibility must still hold.
264+
265+
```py
266+
from typing import Callable, overload
267+
268+
@overload
269+
def wide_return_converter(x: int) -> str: ...
270+
@overload
271+
def wide_return_converter(x: str) -> bytes: ...
272+
def wide_return_converter(x: int | str) -> str | bytes:
273+
if isinstance(x, int):
274+
return str(x)
275+
return x.encode()
276+
277+
def expects_str_return(c: Callable[[int | str], str]) -> None:
278+
pass
279+
280+
# error: [invalid-argument-type]
281+
expects_str_return(wide_return_converter)
282+
```
283+
190284
## Union
191285

192286
```py

crates/ty_python_semantic/resources/mdtest/dataclasses/dataclass_transform.md

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,46 @@ class Person:
894894
reveal_type(Person.__init__) # revealed: (self: Person, name: str, *, age: int | None) -> None
895895
```
896896

897+
### Converter field specifier with overloaded callables
898+
899+
```py
900+
from typing import Callable, TypeVar, overload
901+
from typing_extensions import dataclass_transform
902+
903+
T = TypeVar("T")
904+
S = TypeVar("S")
905+
906+
def model_field(*, converter: Callable[[S], T], default: S | None = None) -> T:
907+
raise NotImplementedError
908+
909+
@dataclass_transform(field_specifiers=(model_field,))
910+
class ModelBase: ...
911+
912+
@overload
913+
def overloaded_converter(s: str) -> int: ...
914+
@overload
915+
def overloaded_converter(s: list[str]) -> int: ...
916+
def overloaded_converter(s: str | list[str], *args: str) -> int | str:
917+
return 0
918+
919+
class ConverterClass:
920+
@overload
921+
def __init__(self, val: str) -> None: ...
922+
@overload
923+
def __init__(self, val: bytes) -> None: ...
924+
def __init__(self, val: str | bytes) -> None:
925+
pass
926+
927+
class Model(ModelBase):
928+
field3: ConverterClass = model_field(converter=ConverterClass)
929+
field4: int = model_field(converter=overloaded_converter)
930+
# TODO: This should be accepted once overloaded class callables with richer signatures are
931+
# modeled in callable assignability.
932+
# error: [invalid-assignment]
933+
# error: [invalid-argument-type]
934+
field5: dict[str, str] = model_field(converter=dict, default=())
935+
```
936+
897937
### Nested dataclass-transformers
898938

899939
Make sure that models are only affected by the field specifiers of their own transformer:

crates/ty_python_semantic/resources/mdtest/generics/legacy/callables.md

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,3 +245,155 @@ reveal_type(generic_context(outside_callable(int_identity)))
245245
# error: [invalid-argument-type]
246246
outside_callable(int_identity)("string")
247247
```
248+
249+
## Overloaded callable as generic `Callable` argument
250+
251+
An overloaded callable should be assignable to a non-overloaded callable type when the overload set
252+
as a whole is compatible with the target callable.
253+
254+
The type variable should be inferred from the first matching overload, rather than unioning
255+
parameter types across all overloads (which would create an unsatisfiable expected type for
256+
contravariant type variables).
257+
258+
```py
259+
from typing import Callable, TypeVar, overload
260+
261+
T = TypeVar("T")
262+
263+
def accepts_callable(converter: Callable[[T], None]) -> T:
264+
raise NotImplementedError
265+
266+
@overload
267+
def f(val: str) -> None: ...
268+
@overload
269+
def f(val: bytes) -> None: ...
270+
def f(val: str | bytes) -> None:
271+
pass
272+
273+
reveal_type(accepts_callable(f)) # revealed: str | bytes
274+
```
275+
276+
When `T` is constrained to a union by other arguments, the overloaded callable must still be treated
277+
as a whole to satisfy `Callable[[T], T]`.
278+
279+
```py
280+
from typing import Callable, TypeVar, overload
281+
282+
T = TypeVar("T")
283+
284+
def apply_twice(converter: Callable[[T], T], left: T, right: T) -> tuple[T, T]:
285+
return converter(left), converter(right)
286+
287+
@overload
288+
def f(val: int) -> int: ...
289+
@overload
290+
def f(val: str) -> str: ...
291+
def f(val: int | str) -> int | str:
292+
return val
293+
294+
x: int | str = 1
295+
y: int | str = "a"
296+
297+
result = apply_twice(f, x, y)
298+
# revealed: tuple[int | str, int | str]
299+
reveal_type(result)
300+
```
301+
302+
An overloaded callable returned from a generic callable factory should still be assignable to the
303+
declared generic callable return type.
304+
305+
```py
306+
from collections.abc import Callable, Coroutine
307+
from typing import Any, TypeVar, overload
308+
309+
S = TypeVar("S")
310+
T = TypeVar("T")
311+
U = TypeVar("U")
312+
313+
def singleton(flag: bool = False) -> Callable[[Callable[[int], S]], Callable[[int], S]]:
314+
@overload
315+
def wrapper(func: Callable[[int], Coroutine[Any, Any, T]]) -> Callable[[int], Coroutine[Any, Any, T]]: ...
316+
@overload
317+
def wrapper(func: Callable[[int], U]) -> Callable[[int], U]: ...
318+
def wrapper(func: Callable[[int], Coroutine[Any, Any, T] | U]) -> Callable[[int], Coroutine[Any, Any, T] | U]:
319+
return func
320+
321+
return wrapper
322+
```
323+
324+
## SymPy one-import MRE scaffold (multi-file)
325+
326+
Reduced regression lock for a SymPy overload/protocol shape that can panic in the
327+
overload-assignability path.
328+
329+
```py
330+
from __future__ import annotations
331+
332+
from sympy.polys.compatibility import Domain, IPolys
333+
from typing import Generic, TypeVar, overload
334+
335+
T = TypeVar("T")
336+
337+
class DefaultPrinting:
338+
pass
339+
340+
class PolyRing(DefaultPrinting, IPolys[T], Generic[T]):
341+
symbols: tuple[object, ...]
342+
domain: Domain[T]
343+
344+
def clone(
345+
self,
346+
symbols: object | None = None,
347+
domain: object | None = None,
348+
order: object | None = None,
349+
) -> PolyRing[T]:
350+
return self
351+
352+
@overload
353+
def __getitem__(self, key: int) -> PolyRing[T]: ...
354+
@overload
355+
def __getitem__(self, key: slice) -> PolyRing[T] | Domain[T]: ...
356+
def __getitem__(self, key: slice | int) -> PolyRing[T] | Domain[T]:
357+
symbols = self.symbols[key]
358+
if not symbols:
359+
return self.domain
360+
return self.clone(symbols=symbols)
361+
362+
def takes_ring(x: PolyRing[int]) -> None:
363+
reveal_type(x[0]) # revealed: PolyRing[int]
364+
reveal_type(x[:]) # revealed: PolyRing[int] | Domain[int]
365+
```
366+
367+
`sympy/polys/compatibility.pyi`:
368+
369+
```pyi
370+
from __future__ import annotations
371+
372+
from typing import Generic, Protocol, TypeVar, overload
373+
374+
T = TypeVar("T")
375+
S = TypeVar("S")
376+
377+
class Domain(Generic[T]): ...
378+
379+
class IPolys(Protocol[T]):
380+
@overload
381+
def clone(
382+
self,
383+
symbols: object | None = None,
384+
domain: None = None,
385+
order: None = None,
386+
) -> IPolys[T]: ...
387+
@overload
388+
def clone(
389+
self,
390+
symbols: object | None = None,
391+
*,
392+
domain: Domain[S],
393+
order: None = None,
394+
) -> IPolys[S]: ...
395+
@overload
396+
def __getitem__(self, key: int) -> IPolys[T]: ...
397+
@overload
398+
def __getitem__(self, key: slice) -> IPolys[T] | Domain[T]: ...
399+
```

0 commit comments

Comments
 (0)