Skip to content

Commit 084eac6

Browse files
CaselITGerrit Code Review
authored andcommitted
Merge "improve overloads applied to generic functions" into main
2 parents 0c1824c + 5cc6a65 commit 084eac6

3 files changed

Lines changed: 68 additions & 57 deletions

File tree

lib/sqlalchemy/sql/functions.py

Lines changed: 57 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# This module is part of SQLAlchemy and is released under
66
# the MIT License: https://www.opensource.org/licenses/mit-license.php
77

8-
98
"""SQL function API, factories, and built-in functions."""
109

1110
from __future__ import annotations
@@ -153,7 +152,9 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
153152

154153
clause_expr: Grouping[Any]
155154

156-
def __init__(self, *clauses: _ColumnExpressionOrLiteralArgument[Any]):
155+
def __init__(
156+
self, *clauses: _ColumnExpressionOrLiteralArgument[Any]
157+
) -> None:
157158
r"""Construct a :class:`.FunctionElement`.
158159
159160
:param \*clauses: list of column expressions that form the arguments
@@ -777,7 +778,7 @@ def _gen_cache_key(self, anon_map: Any, bindparams: Any) -> Any:
777778

778779
def __init__(
779780
self, fn: FunctionElement[Any], left_index: int, right_index: int
780-
):
781+
) -> None:
781782
self.sql_function = fn
782783
self.left_index = left_index
783784
self.right_index = right_index
@@ -829,7 +830,7 @@ def __init__(
829830
fn: FunctionElement[_T],
830831
name: str,
831832
type_: Optional[_TypeEngineArgument[_T]] = None,
832-
):
833+
) -> None:
833834
self.fn = fn
834835
self.name = name
835836

@@ -928,7 +929,7 @@ class _FunctionGenerator:
928929
929930
""" # noqa
930931

931-
def __init__(self, **opts: Any):
932+
def __init__(self, **opts: Any) -> None:
932933
self.__names: List[str] = []
933934
self.opts = opts
934935

@@ -988,10 +989,10 @@ def aggregate_strings(self) -> Type[aggregate_strings]: ...
988989
@property
989990
def ansifunction(self) -> Type[AnsiFunction[Any]]: ...
990991

991-
# set ColumnElement[_T] as a separate overload, to appease mypy
992-
# which seems to not want to accept _T from _ColumnExpressionArgument.
993-
# this is even if all non-generic types are removed from it, so
994-
# reasons remain unclear for why this does not work
992+
# set ColumnElement[_T] as a separate overload, to appease
993+
# mypy which seems to not want to accept _T from
994+
# _ColumnExpressionArgument. Seems somewhat related to the covariant
995+
# _HasClauseElement as of mypy 1.15
995996

996997
@overload
997998
def array_agg(
@@ -1012,7 +1013,7 @@ def array_agg(
10121013
@overload
10131014
def array_agg(
10141015
self,
1015-
col: _ColumnExpressionOrLiteralArgument[_T],
1016+
col: _T,
10161017
*args: _ColumnExpressionOrLiteralArgument[Any],
10171018
**kwargs: Any,
10181019
) -> array_agg[_T]: ...
@@ -1030,10 +1031,10 @@ def cast(self) -> Type[Cast[Any]]: ...
10301031
@property
10311032
def char_length(self) -> Type[char_length]: ...
10321033

1033-
# set ColumnElement[_T] as a separate overload, to appease mypy
1034-
# which seems to not want to accept _T from _ColumnExpressionArgument.
1035-
# this is even if all non-generic types are removed from it, so
1036-
# reasons remain unclear for why this does not work
1034+
# set ColumnElement[_T] as a separate overload, to appease
1035+
# mypy which seems to not want to accept _T from
1036+
# _ColumnExpressionArgument. Seems somewhat related to the covariant
1037+
# _HasClauseElement as of mypy 1.15
10371038

10381039
@overload
10391040
def coalesce(
@@ -1054,7 +1055,7 @@ def coalesce(
10541055
@overload
10551056
def coalesce(
10561057
self,
1057-
col: _ColumnExpressionOrLiteralArgument[_T],
1058+
col: _T,
10581059
*args: _ColumnExpressionOrLiteralArgument[Any],
10591060
**kwargs: Any,
10601061
) -> coalesce[_T]: ...
@@ -1105,10 +1106,10 @@ def localtime(self) -> Type[localtime]: ...
11051106
@property
11061107
def localtimestamp(self) -> Type[localtimestamp]: ...
11071108

1108-
# set ColumnElement[_T] as a separate overload, to appease mypy
1109-
# which seems to not want to accept _T from _ColumnExpressionArgument.
1110-
# this is even if all non-generic types are removed from it, so
1111-
# reasons remain unclear for why this does not work
1109+
# set ColumnElement[_T] as a separate overload, to appease
1110+
# mypy which seems to not want to accept _T from
1111+
# _ColumnExpressionArgument. Seems somewhat related to the covariant
1112+
# _HasClauseElement as of mypy 1.15
11121113

11131114
@overload
11141115
def max( # noqa: A001
@@ -1129,7 +1130,7 @@ def max( # noqa: A001
11291130
@overload
11301131
def max( # noqa: A001
11311132
self,
1132-
col: _ColumnExpressionOrLiteralArgument[_T],
1133+
col: _T,
11331134
*args: _ColumnExpressionOrLiteralArgument[Any],
11341135
**kwargs: Any,
11351136
) -> max[_T]: ...
@@ -1141,10 +1142,10 @@ def max( # noqa: A001
11411142
**kwargs: Any,
11421143
) -> max[_T]: ...
11431144

1144-
# set ColumnElement[_T] as a separate overload, to appease mypy
1145-
# which seems to not want to accept _T from _ColumnExpressionArgument.
1146-
# this is even if all non-generic types are removed from it, so
1147-
# reasons remain unclear for why this does not work
1145+
# set ColumnElement[_T] as a separate overload, to appease
1146+
# mypy which seems to not want to accept _T from
1147+
# _ColumnExpressionArgument. Seems somewhat related to the covariant
1148+
# _HasClauseElement as of mypy 1.15
11481149

11491150
@overload
11501151
def min( # noqa: A001
@@ -1165,7 +1166,7 @@ def min( # noqa: A001
11651166
@overload
11661167
def min( # noqa: A001
11671168
self,
1168-
col: _ColumnExpressionOrLiteralArgument[_T],
1169+
col: _T,
11691170
*args: _ColumnExpressionOrLiteralArgument[Any],
11701171
**kwargs: Any,
11711172
) -> min[_T]: ...
@@ -1210,10 +1211,10 @@ def rollup(self) -> Type[rollup[Any]]: ...
12101211
@property
12111212
def session_user(self) -> Type[session_user]: ...
12121213

1213-
# set ColumnElement[_T] as a separate overload, to appease mypy
1214-
# which seems to not want to accept _T from _ColumnExpressionArgument.
1215-
# this is even if all non-generic types are removed from it, so
1216-
# reasons remain unclear for why this does not work
1214+
# set ColumnElement[_T] as a separate overload, to appease
1215+
# mypy which seems to not want to accept _T from
1216+
# _ColumnExpressionArgument. Seems somewhat related to the covariant
1217+
# _HasClauseElement as of mypy 1.15
12171218

12181219
@overload
12191220
def sum( # noqa: A001
@@ -1234,7 +1235,7 @@ def sum( # noqa: A001
12341235
@overload
12351236
def sum( # noqa: A001
12361237
self,
1237-
col: _ColumnExpressionOrLiteralArgument[_T],
1238+
col: _T,
12381239
*args: _ColumnExpressionOrLiteralArgument[Any],
12391240
**kwargs: Any,
12401241
) -> sum[_T]: ...
@@ -1330,7 +1331,7 @@ def __init__(
13301331
*clauses: _ColumnExpressionOrLiteralArgument[_T],
13311332
type_: None = ...,
13321333
packagenames: Optional[Tuple[str, ...]] = ...,
1333-
): ...
1334+
) -> None: ...
13341335

13351336
@overload
13361337
def __init__(
@@ -1339,15 +1340,15 @@ def __init__(
13391340
*clauses: _ColumnExpressionOrLiteralArgument[Any],
13401341
type_: _TypeEngineArgument[_T] = ...,
13411342
packagenames: Optional[Tuple[str, ...]] = ...,
1342-
): ...
1343+
) -> None: ...
13431344

13441345
def __init__(
13451346
self,
13461347
name: str,
13471348
*clauses: _ColumnExpressionOrLiteralArgument[Any],
13481349
type_: Optional[_TypeEngineArgument[_T]] = None,
13491350
packagenames: Optional[Tuple[str, ...]] = None,
1350-
):
1351+
) -> None:
13511352
"""Construct a :class:`.Function`.
13521353
13531354
The :data:`.func` construct is normally used to construct
@@ -1523,7 +1524,7 @@ def _register_generic_function(
15231524

15241525
def __init__(
15251526
self, *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any
1526-
):
1527+
) -> None:
15271528
parsed_args = kwargs.pop("_parsed_args", None)
15281529
if parsed_args is None:
15291530
parsed_args = [
@@ -1570,7 +1571,7 @@ class next_value(GenericFunction[int]):
15701571
("sequence", InternalTraversal.dp_named_ddl_element)
15711572
]
15721573

1573-
def __init__(self, seq: schema.Sequence, **kw: Any):
1574+
def __init__(self, seq: schema.Sequence, **kw: Any) -> None:
15741575
assert isinstance(
15751576
seq, schema.Sequence
15761577
), "next_value() accepts a Sequence object as input."
@@ -1595,7 +1596,9 @@ class AnsiFunction(GenericFunction[_T]):
15951596

15961597
inherit_cache = True
15971598

1598-
def __init__(self, *args: _ColumnExpressionArgument[Any], **kwargs: Any):
1599+
def __init__(
1600+
self, *args: _ColumnExpressionArgument[Any], **kwargs: Any
1601+
) -> None:
15991602
GenericFunction.__init__(self, *args, **kwargs)
16001603

16011604

@@ -1606,38 +1609,38 @@ class ReturnTypeFromArgs(GenericFunction[_T]):
16061609

16071610
inherit_cache = True
16081611

1609-
# set ColumnElement[_T] as a separate overload, to appease mypy which seems
1610-
# to not want to accept _T from _ColumnExpressionArgument. this is even if
1611-
# all non-generic types are removed from it, so reasons remain unclear for
1612-
# why this does not work
1612+
# set ColumnElement[_T] as a separate overload, to appease
1613+
# mypy which seems to not want to accept _T from
1614+
# _ColumnExpressionArgument. Seems somewhat related to the covariant
1615+
# _HasClauseElement as of mypy 1.15
16131616

16141617
@overload
16151618
def __init__(
16161619
self,
16171620
col: ColumnElement[_T],
16181621
*args: _ColumnExpressionOrLiteralArgument[Any],
16191622
**kwargs: Any,
1620-
): ...
1623+
) -> None: ...
16211624

16221625
@overload
16231626
def __init__(
16241627
self,
16251628
col: _ColumnExpressionArgument[_T],
16261629
*args: _ColumnExpressionOrLiteralArgument[Any],
16271630
**kwargs: Any,
1628-
): ...
1631+
) -> None: ...
16291632

16301633
@overload
16311634
def __init__(
16321635
self,
1633-
col: _ColumnExpressionOrLiteralArgument[_T],
1636+
col: _T,
16341637
*args: _ColumnExpressionOrLiteralArgument[Any],
16351638
**kwargs: Any,
1636-
): ...
1639+
) -> None: ...
16371640

16381641
def __init__(
1639-
self, *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any
1640-
):
1642+
self, *args: _ColumnExpressionOrLiteralArgument[_T], **kwargs: Any
1643+
) -> None:
16411644
fn_args: Sequence[ColumnElement[Any]] = [
16421645
coercions.expect(
16431646
roles.ExpressionElementRole,
@@ -1719,7 +1722,7 @@ class char_length(GenericFunction[int]):
17191722
type = sqltypes.Integer()
17201723
inherit_cache = True
17211724

1722-
def __init__(self, arg: _ColumnExpressionArgument[str], **kw: Any):
1725+
def __init__(self, arg: _ColumnExpressionArgument[str], **kw: Any) -> None:
17231726
# slight hack to limit to just one positional argument
17241727
# not sure why this one function has this special treatment
17251728
super().__init__(arg, **kw)
@@ -1765,7 +1768,7 @@ def __init__(
17651768
_ColumnExpressionArgument[Any], _StarOrOne, None
17661769
] = None,
17671770
**kwargs: Any,
1768-
):
1771+
) -> None:
17691772
if expression is None:
17701773
expression = literal_column("*")
17711774
super().__init__(expression, **kwargs)
@@ -1854,7 +1857,9 @@ class array_agg(ReturnTypeFromArgs[Sequence[_T]]):
18541857

18551858
inherit_cache = True
18561859

1857-
def __init__(self, *args: _ColumnExpressionArgument[Any], **kwargs: Any):
1860+
def __init__(
1861+
self, *args: _ColumnExpressionArgument[Any], **kwargs: Any
1862+
) -> None:
18581863
fn_args: Sequence[ColumnElement[Any]] = [
18591864
coercions.expect(
18601865
roles.ExpressionElementRole, c, apply_propagate_attrs=self
@@ -2081,5 +2086,7 @@ class aggregate_strings(GenericFunction[str]):
20812086
_has_args = True
20822087
inherit_cache = True
20832088

2084-
def __init__(self, clause: _ColumnExpressionArgument[Any], separator: str):
2089+
def __init__(
2090+
self, clause: _ColumnExpressionArgument[Any], separator: str
2091+
) -> None:
20852092
super().__init__(clause, separator)

test/typing/plain_files/sql/functions_again.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
from sqlalchemy import column
12
from sqlalchemy import func
3+
from sqlalchemy import Integer
24
from sqlalchemy import select
35
from sqlalchemy.orm import DeclarativeBase
46
from sqlalchemy.orm import Mapped
@@ -53,6 +55,10 @@ class Foo(Base):
5355
# test #10818
5456
# EXPECTED_TYPE: coalesce[str]
5557
reveal_type(func.coalesce(Foo.c, "a", "b"))
58+
# EXPECTED_TYPE: coalesce[str]
59+
reveal_type(func.coalesce("a", "b"))
60+
# EXPECTED_TYPE: coalesce[int]
61+
reveal_type(func.coalesce(column("x", Integer), 3))
5662

5763

5864
stmt2 = select(Foo.a, func.coalesce(Foo.c, "a", "b")).group_by(Foo.a)

tools/generate_sql_functions.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,10 @@ def process_functions(filename: str, cmd: code_writer_cmd) -> str:
6767
textwrap.indent(
6868
f"""
6969
70-
# set ColumnElement[_T] as a separate overload, to appease mypy
71-
# which seems to not want to accept _T from _ColumnExpressionArgument.
72-
# this is even if all non-generic types are removed from it, so
73-
# reasons remain unclear for why this does not work
70+
# set ColumnElement[_T] as a separate overload, to appease
71+
# mypy which seems to not want to accept _T from
72+
# _ColumnExpressionArgument. Seems somewhat related to the covariant
73+
# _HasClauseElement as of mypy 1.15
7474
7575
@overload
7676
def {key}( {' # noqa: A001' if is_reserved_word else ''}
@@ -90,17 +90,15 @@ def {key}( {' # noqa: A001' if is_reserved_word else ''}
9090
) -> {fn_class.__name__}[_T]:
9191
...
9292
93-
9493
@overload
9594
def {key}( {' # noqa: A001' if is_reserved_word else ''}
9695
self,
97-
col: _ColumnExpressionOrLiteralArgument[_T],
96+
col: _T,
9897
*args: _ColumnExpressionOrLiteralArgument[Any],
9998
**kwargs: Any,
10099
) -> {fn_class.__name__}[_T]:
101100
...
102101
103-
104102
def {key}( {' # noqa: A001' if is_reserved_word else ''}
105103
self,
106104
col: _ColumnExpressionOrLiteralArgument[_T],

0 commit comments

Comments
 (0)