Skip to content

Commit 51a7678

Browse files
Polandia94zzzeek
andcommitted
Type mysql dialect
Closes: #12164 Pull-request: #12164 Pull-request-sha: 545e2c3 Co-authored-by: Mike Bayer <mike_mp@zzzcomputing.com> Change-Id: I37bd98049ff1a64d58e9490b0e5e2ea764dd1f73
1 parent 10ff201 commit 51a7678

29 files changed

Lines changed: 1446 additions & 672 deletions

lib/sqlalchemy/connectors/asyncio.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,17 @@
2020
from typing import Optional
2121
from typing import Protocol
2222
from typing import Sequence
23+
from typing import TYPE_CHECKING
2324

2425
from ..engine import AdaptedConnection
25-
from ..engine.interfaces import _DBAPICursorDescription
26-
from ..engine.interfaces import _DBAPIMultiExecuteParams
27-
from ..engine.interfaces import _DBAPISingleExecuteParams
2826
from ..util.concurrency import await_
29-
from ..util.typing import Self
27+
28+
if TYPE_CHECKING:
29+
from ..engine.interfaces import _DBAPICursorDescription
30+
from ..engine.interfaces import _DBAPIMultiExecuteParams
31+
from ..engine.interfaces import _DBAPISingleExecuteParams
32+
from ..engine.interfaces import DBAPIModule
33+
from ..util.typing import Self
3034

3135

3236
class AsyncIODBAPIConnection(Protocol):
@@ -36,14 +40,19 @@ class AsyncIODBAPIConnection(Protocol):
3640
3741
"""
3842

39-
async def close(self) -> None: ...
43+
# note that async DBAPIs dont agree if close() should be awaitable,
44+
# so it is omitted here and picked up by the __getattr__ hook below
4045

4146
async def commit(self) -> None: ...
4247

4348
def cursor(self, *args: Any, **kwargs: Any) -> AsyncIODBAPICursor: ...
4449

4550
async def rollback(self) -> None: ...
4651

52+
def __getattr__(self, key: str) -> Any: ...
53+
54+
def __setattr__(self, key: str, value: Any) -> None: ...
55+
4756

4857
class AsyncIODBAPICursor(Protocol):
4958
"""protocol representing an async adapted version
@@ -101,6 +110,16 @@ async def nextset(self) -> Optional[bool]: ...
101110
def __aiter__(self) -> AsyncIterator[Any]: ...
102111

103112

113+
class AsyncAdapt_dbapi_module:
114+
if TYPE_CHECKING:
115+
Error = DBAPIModule.Error
116+
OperationalError = DBAPIModule.OperationalError
117+
InterfaceError = DBAPIModule.InterfaceError
118+
IntegrityError = DBAPIModule.IntegrityError
119+
120+
def __getattr__(self, key: str) -> Any: ...
121+
122+
104123
class AsyncAdapt_dbapi_cursor:
105124
server_side = False
106125
__slots__ = (

lib/sqlalchemy/connectors/pyodbc.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from __future__ import annotations
99

1010
import re
11-
from types import ModuleType
1211
import typing
1312
from typing import Any
1413
from typing import Dict
@@ -28,6 +27,7 @@
2827
from ..sql.type_api import TypeEngine
2928

3029
if typing.TYPE_CHECKING:
30+
from ..engine.interfaces import DBAPIModule
3131
from ..engine.interfaces import IsolationLevel
3232

3333

@@ -47,15 +47,13 @@ class PyODBCConnector(Connector):
4747
# hold the desired driver name
4848
pyodbc_driver_name: Optional[str] = None
4949

50-
dbapi: ModuleType
51-
5250
def __init__(self, use_setinputsizes: bool = False, **kw: Any):
5351
super().__init__(**kw)
5452
if use_setinputsizes:
5553
self.bind_typing = interfaces.BindTyping.SETINPUTSIZES
5654

5755
@classmethod
58-
def import_dbapi(cls) -> ModuleType:
56+
def import_dbapi(cls) -> DBAPIModule:
5957
return __import__("pyodbc")
6058

6159
def create_connect_args(self, url: URL) -> ConnectArgsType:
@@ -150,7 +148,7 @@ def is_disconnect(
150148
],
151149
cursor: Optional[interfaces.DBAPICursor],
152150
) -> bool:
153-
if isinstance(e, self.dbapi.ProgrammingError):
151+
if isinstance(e, self.loaded_dbapi.ProgrammingError):
154152
return "The cursor's connection has been closed." in str(
155153
e
156154
) or "Attempt to use a closed connection." in str(e)

lib/sqlalchemy/dialects/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from __future__ import annotations
99

10+
from typing import Any
1011
from typing import Callable
1112
from typing import Optional
1213
from typing import Type
@@ -39,7 +40,7 @@ def _auto_fn(name: str) -> Optional[Callable[[], Type[Dialect]]]:
3940
# hardcoded. if mysql / mariadb etc were third party dialects
4041
# they would just publish all the entrypoints, which would actually
4142
# look much nicer.
42-
module = __import__(
43+
module: Any = __import__(
4344
"sqlalchemy.dialects.mysql.mariadb"
4445
).dialects.mysql.mariadb
4546
return module.loader(driver) # type: ignore

lib/sqlalchemy/dialects/mysql/aiomysql.py

Lines changed: 67 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#
55
# This module is part of SQLAlchemy and is released under
66
# the MIT License: https://www.opensource.org/licenses/mit-license.php
7-
# mypy: ignore-errors
87

98
r"""
109
.. dialect:: mysql+aiomysql
@@ -29,17 +28,39 @@
2928
)
3029
3130
""" # noqa
31+
from __future__ import annotations
32+
33+
from types import ModuleType
34+
from typing import Any
35+
from typing import Optional
36+
from typing import TYPE_CHECKING
37+
from typing import Union
38+
3239
from .pymysql import MySQLDialect_pymysql
3340
from ...connectors.asyncio import AsyncAdapt_dbapi_connection
3441
from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
42+
from ...connectors.asyncio import AsyncAdapt_dbapi_module
3543
from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor
3644
from ...util.concurrency import await_
3745

46+
if TYPE_CHECKING:
47+
48+
from ...connectors.asyncio import AsyncIODBAPIConnection
49+
from ...connectors.asyncio import AsyncIODBAPICursor
50+
from ...engine.interfaces import ConnectArgsType
51+
from ...engine.interfaces import DBAPIConnection
52+
from ...engine.interfaces import DBAPICursor
53+
from ...engine.interfaces import DBAPIModule
54+
from ...engine.interfaces import PoolProxiedConnection
55+
from ...engine.url import URL
56+
3857

3958
class AsyncAdapt_aiomysql_cursor(AsyncAdapt_dbapi_cursor):
4059
__slots__ = ()
4160

42-
def _make_new_cursor(self, connection):
61+
def _make_new_cursor(
62+
self, connection: AsyncIODBAPIConnection
63+
) -> AsyncIODBAPICursor:
4364
return connection.cursor(self._adapt_connection.dbapi.Cursor)
4465

4566

@@ -48,7 +69,9 @@ class AsyncAdapt_aiomysql_ss_cursor(
4869
):
4970
__slots__ = ()
5071

51-
def _make_new_cursor(self, connection):
72+
def _make_new_cursor(
73+
self, connection: AsyncIODBAPIConnection
74+
) -> AsyncIODBAPICursor:
5275
return connection.cursor(
5376
self._adapt_connection.dbapi.aiomysql.cursors.SSCursor
5477
)
@@ -60,33 +83,33 @@ class AsyncAdapt_aiomysql_connection(AsyncAdapt_dbapi_connection):
6083
_cursor_cls = AsyncAdapt_aiomysql_cursor
6184
_ss_cursor_cls = AsyncAdapt_aiomysql_ss_cursor
6285

63-
def ping(self, reconnect):
86+
def ping(self, reconnect: bool) -> None:
6487
assert not reconnect
65-
return await_(self._connection.ping(reconnect))
88+
await_(self._connection.ping(reconnect))
6689

67-
def character_set_name(self):
68-
return self._connection.character_set_name()
90+
def character_set_name(self) -> Optional[str]:
91+
return self._connection.character_set_name() # type: ignore[no-any-return] # noqa: E501
6992

70-
def autocommit(self, value):
93+
def autocommit(self, value: Any) -> None:
7194
await_(self._connection.autocommit(value))
7295

73-
def terminate(self):
96+
def terminate(self) -> None:
7497
# it's not awaitable.
7598
self._connection.close()
7699

77100
def close(self) -> None:
78101
await_(self._connection.ensure_closed())
79102

80103

81-
class AsyncAdapt_aiomysql_dbapi:
82-
def __init__(self, aiomysql, pymysql):
104+
class AsyncAdapt_aiomysql_dbapi(AsyncAdapt_dbapi_module):
105+
def __init__(self, aiomysql: ModuleType, pymysql: ModuleType):
83106
self.aiomysql = aiomysql
84107
self.pymysql = pymysql
85108
self.paramstyle = "format"
86109
self._init_dbapi_attributes()
87110
self.Cursor, self.SSCursor = self._init_cursors_subclasses()
88111

89-
def _init_dbapi_attributes(self):
112+
def _init_dbapi_attributes(self) -> None:
90113
for name in (
91114
"Warning",
92115
"Error",
@@ -112,65 +135,80 @@ def _init_dbapi_attributes(self):
112135
):
113136
setattr(self, name, getattr(self.pymysql, name))
114137

115-
def connect(self, *arg, **kw):
138+
def connect(self, *arg: Any, **kw: Any) -> AsyncAdapt_aiomysql_connection:
116139
creator_fn = kw.pop("async_creator_fn", self.aiomysql.connect)
117140

118141
return AsyncAdapt_aiomysql_connection(
119142
self,
120143
await_(creator_fn(*arg, **kw)),
121144
)
122145

123-
def _init_cursors_subclasses(self):
146+
def _init_cursors_subclasses(
147+
self,
148+
) -> tuple[AsyncIODBAPICursor, AsyncIODBAPICursor]:
124149
# suppress unconditional warning emitted by aiomysql
125-
class Cursor(self.aiomysql.Cursor):
126-
async def _show_warnings(self, conn):
150+
class Cursor(self.aiomysql.Cursor): # type: ignore[misc, name-defined]
151+
async def _show_warnings(
152+
self, conn: AsyncIODBAPIConnection
153+
) -> None:
127154
pass
128155

129-
class SSCursor(self.aiomysql.SSCursor):
130-
async def _show_warnings(self, conn):
156+
class SSCursor(self.aiomysql.SSCursor): # type: ignore[misc, name-defined] # noqa: E501
157+
async def _show_warnings(
158+
self, conn: AsyncIODBAPIConnection
159+
) -> None:
131160
pass
132161

133-
return Cursor, SSCursor
162+
return Cursor, SSCursor # type: ignore[return-value]
134163

135164

136165
class MySQLDialect_aiomysql(MySQLDialect_pymysql):
137166
driver = "aiomysql"
138167
supports_statement_cache = True
139168

140-
supports_server_side_cursors = True
169+
supports_server_side_cursors = True # type: ignore[assignment]
141170
_sscursor = AsyncAdapt_aiomysql_ss_cursor
142171

143172
is_async = True
144173
has_terminate = True
145174

146175
@classmethod
147-
def import_dbapi(cls):
176+
def import_dbapi(cls) -> AsyncAdapt_aiomysql_dbapi:
148177
return AsyncAdapt_aiomysql_dbapi(
149178
__import__("aiomysql"), __import__("pymysql")
150179
)
151180

152-
def do_terminate(self, dbapi_connection) -> None:
181+
def do_terminate(self, dbapi_connection: DBAPIConnection) -> None:
153182
dbapi_connection.terminate()
154183

155-
def create_connect_args(self, url):
184+
def create_connect_args(
185+
self, url: URL, _translate_args: Optional[dict[str, Any]] = None
186+
) -> ConnectArgsType:
156187
return super().create_connect_args(
157188
url, _translate_args=dict(username="user", database="db")
158189
)
159190

160-
def is_disconnect(self, e, connection, cursor):
191+
def is_disconnect(
192+
self,
193+
e: DBAPIModule.Error,
194+
connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
195+
cursor: Optional[DBAPICursor],
196+
) -> bool:
161197
if super().is_disconnect(e, connection, cursor):
162198
return True
163199
else:
164200
str_e = str(e).lower()
165201
return "not connected" in str_e
166202

167-
def _found_rows_client_flag(self):
168-
from pymysql.constants import CLIENT
203+
def _found_rows_client_flag(self) -> int:
204+
from pymysql.constants import CLIENT # type: ignore
169205

170-
return CLIENT.FOUND_ROWS
206+
return CLIENT.FOUND_ROWS # type: ignore[no-any-return]
171207

172-
def get_driver_connection(self, connection):
173-
return connection._connection
208+
def get_driver_connection(
209+
self, connection: DBAPIConnection
210+
) -> AsyncIODBAPIConnection:
211+
return connection._connection # type: ignore[no-any-return]
174212

175213

176214
dialect = MySQLDialect_aiomysql

0 commit comments

Comments
 (0)