Skip to content

Commit b4d7bf7

Browse files
jkrejchazzzeek
authored andcommitted
typing: pg: type NamedType create/drops (fixes #12557)
Type the `create` and `drop` functions for `NamedType`s Also partially type the SchemaType create/drop functions more generally One change to this is that the default parameter of `None` is removed. It doesn't work and will fail with a `AttributeError` at runtime since it immediately tries to access a property of `None` which doesn't exist. Fixes #12557 This pull request is: - [X] A documentation / typographical / small typing error fix - Good to go, no issue or tests are needed - [X] A short code fix - please include the issue number, and create an issue if none exists, which must include a complete example of the issue. one line code fixes without an issue and demonstration will not be accepted. - Please include: `Fixes: #<issue number>` in the commit message - please include tests. one line code fixes without tests will not be accepted. - [ ] A new feature implementation - please include the issue number, and create an issue if none exists, which must include a complete example of how the feature would look. - Please include: `Fixes: #<issue number>` in the commit message - please include tests. **Have a nice day!** Closes: #12558 Pull-request: #12558 Pull-request-sha: 75c8d81 Change-Id: I173771d365f34f54ab474b9661e1cdc70cc4de84
1 parent 912ddc0 commit b4d7bf7

10 files changed

Lines changed: 105 additions & 39 deletions

File tree

lib/sqlalchemy/dialects/postgresql/named_types.py

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
# mypy: ignore-errors
88
from __future__ import annotations
99

10+
from types import ModuleType
1011
from typing import Any
12+
from typing import Dict
1113
from typing import Optional
1214
from typing import Type
1315
from typing import TYPE_CHECKING
@@ -25,18 +27,21 @@
2527
from ...sql.ddl import InvokeDropDDLBase
2628

2729
if TYPE_CHECKING:
30+
from ...sql._typing import _CreateDropBind
2831
from ...sql._typing import _TypeEngineArgument
2932

3033

31-
class NamedType(sqltypes.TypeEngine):
34+
class NamedType(schema.SchemaVisitable, sqltypes.TypeEngine):
3235
"""Base for named types."""
3336

3437
__abstract__ = True
3538
DDLGenerator: Type[NamedTypeGenerator]
3639
DDLDropper: Type[NamedTypeDropper]
3740
create_type: bool
3841

39-
def create(self, bind, checkfirst=True, **kw):
42+
def create(
43+
self, bind: _CreateDropBind, checkfirst: bool = True, **kw: Any
44+
) -> None:
4045
"""Emit ``CREATE`` DDL for this type.
4146
4247
:param bind: a connectable :class:`_engine.Engine`,
@@ -50,7 +55,9 @@ def create(self, bind, checkfirst=True, **kw):
5055
"""
5156
bind._run_ddl_visitor(self.DDLGenerator, self, checkfirst=checkfirst)
5257

53-
def drop(self, bind, checkfirst=True, **kw):
58+
def drop(
59+
self, bind: _CreateDropBind, checkfirst: bool = True, **kw: Any
60+
) -> None:
5461
"""Emit ``DROP`` DDL for this type.
5562
5663
:param bind: a connectable :class:`_engine.Engine`,
@@ -63,7 +70,9 @@ def drop(self, bind, checkfirst=True, **kw):
6370
"""
6471
bind._run_ddl_visitor(self.DDLDropper, self, checkfirst=checkfirst)
6572

66-
def _check_for_name_in_memos(self, checkfirst, kw):
73+
def _check_for_name_in_memos(
74+
self, checkfirst: bool, kw: Dict[str, Any]
75+
) -> bool:
6776
"""Look in the 'ddl runner' for 'memos', then
6877
note our name in that collection.
6978
@@ -87,7 +96,13 @@ def _check_for_name_in_memos(self, checkfirst, kw):
8796
else:
8897
return False
8998

90-
def _on_table_create(self, target, bind, checkfirst=False, **kw):
99+
def _on_table_create(
100+
self,
101+
target: Any,
102+
bind: _CreateDropBind,
103+
checkfirst: bool = False,
104+
**kw: Any,
105+
) -> None:
91106
if (
92107
checkfirst
93108
or (
@@ -97,19 +112,37 @@ def _on_table_create(self, target, bind, checkfirst=False, **kw):
97112
) and not self._check_for_name_in_memos(checkfirst, kw):
98113
self.create(bind=bind, checkfirst=checkfirst)
99114

100-
def _on_table_drop(self, target, bind, checkfirst=False, **kw):
115+
def _on_table_drop(
116+
self,
117+
target: Any,
118+
bind: _CreateDropBind,
119+
checkfirst: bool = False,
120+
**kw: Any,
121+
) -> None:
101122
if (
102123
not self.metadata
103124
and not kw.get("_is_metadata_operation", False)
104125
and not self._check_for_name_in_memos(checkfirst, kw)
105126
):
106127
self.drop(bind=bind, checkfirst=checkfirst)
107128

108-
def _on_metadata_create(self, target, bind, checkfirst=False, **kw):
129+
def _on_metadata_create(
130+
self,
131+
target: Any,
132+
bind: _CreateDropBind,
133+
checkfirst: bool = False,
134+
**kw: Any,
135+
) -> None:
109136
if not self._check_for_name_in_memos(checkfirst, kw):
110137
self.create(bind=bind, checkfirst=checkfirst)
111138

112-
def _on_metadata_drop(self, target, bind, checkfirst=False, **kw):
139+
def _on_metadata_drop(
140+
self,
141+
target: Any,
142+
bind: _CreateDropBind,
143+
checkfirst: bool = False,
144+
**kw: Any,
145+
) -> None:
113146
if not self._check_for_name_in_memos(checkfirst, kw):
114147
self.drop(bind=bind, checkfirst=checkfirst)
115148

@@ -314,7 +347,7 @@ def adapt_emulated_to_native(cls, impl, **kw):
314347

315348
return cls(**kw)
316349

317-
def create(self, bind=None, checkfirst=True):
350+
def create(self, bind: _CreateDropBind, checkfirst: bool = True) -> None:
318351
"""Emit ``CREATE TYPE`` for this
319352
:class:`_postgresql.ENUM`.
320353
@@ -335,7 +368,7 @@ def create(self, bind=None, checkfirst=True):
335368

336369
super().create(bind, checkfirst=checkfirst)
337370

338-
def drop(self, bind=None, checkfirst=True):
371+
def drop(self, bind: _CreateDropBind, checkfirst: bool = True) -> None:
339372
"""Emit ``DROP TYPE`` for this
340373
:class:`_postgresql.ENUM`.
341374
@@ -355,7 +388,7 @@ def drop(self, bind=None, checkfirst=True):
355388

356389
super().drop(bind, checkfirst=checkfirst)
357390

358-
def get_dbapi_type(self, dbapi):
391+
def get_dbapi_type(self, dbapi: ModuleType) -> None:
359392
"""dont return dbapi.STRING for ENUM in PostgreSQL, since that's
360393
a different type"""
361394

lib/sqlalchemy/engine/base.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,11 @@
7373
from ..sql._typing import _InfoType
7474
from ..sql.compiler import Compiled
7575
from ..sql.ddl import ExecutableDDLElement
76-
from ..sql.ddl import SchemaDropper
77-
from ..sql.ddl import SchemaGenerator
76+
from ..sql.ddl import InvokeDDLBase
7877
from ..sql.functions import FunctionElement
7978
from ..sql.schema import DefaultGenerator
8079
from ..sql.schema import HasSchemaAttr
81-
from ..sql.schema import SchemaItem
80+
from ..sql.schema import SchemaVisitable
8281
from ..sql.selectable import TypedReturnsRows
8382

8483

@@ -2450,8 +2449,8 @@ def _handle_dbapi_exception_noconnection(
24502449

24512450
def _run_ddl_visitor(
24522451
self,
2453-
visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]],
2454-
element: SchemaItem,
2452+
visitorcallable: Type[InvokeDDLBase],
2453+
element: SchemaVisitable,
24552454
**kwargs: Any,
24562455
) -> None:
24572456
"""run a DDL visitor.
@@ -2460,7 +2459,9 @@ def _run_ddl_visitor(
24602459
options given to the visitor so that "checkfirst" is skipped.
24612460
24622461
"""
2463-
visitorcallable(self.dialect, self, **kwargs).traverse_single(element)
2462+
visitorcallable(
2463+
dialect=self.dialect, connection=self, **kwargs
2464+
).traverse_single(element)
24642465

24652466

24662467
class ExceptionContextImpl(ExceptionContext):
@@ -3246,8 +3247,8 @@ def begin(self) -> Iterator[Connection]:
32463247

32473248
def _run_ddl_visitor(
32483249
self,
3249-
visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]],
3250-
element: SchemaItem,
3250+
visitorcallable: Type[InvokeDDLBase],
3251+
element: SchemaVisitable,
32513252
**kwargs: Any,
32523253
) -> None:
32533254
with self.begin() as conn:

lib/sqlalchemy/engine/mock.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,9 @@
2727
from .interfaces import Dialect
2828
from .url import URL
2929
from ..sql.base import Executable
30-
from ..sql.ddl import SchemaDropper
31-
from ..sql.ddl import SchemaGenerator
30+
from ..sql.ddl import InvokeDDLBase
3231
from ..sql.schema import HasSchemaAttr
33-
from ..sql.schema import SchemaItem
32+
from ..sql.visitors import Visitable
3433

3534

3635
class MockConnection:
@@ -53,12 +52,14 @@ def execution_options(self, **kw: Any) -> MockConnection:
5352

5453
def _run_ddl_visitor(
5554
self,
56-
visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]],
57-
element: SchemaItem,
55+
visitorcallable: Type[InvokeDDLBase],
56+
element: Visitable,
5857
**kwargs: Any,
5958
) -> None:
6059
kwargs["checkfirst"] = False
61-
visitorcallable(self.dialect, self, **kwargs).traverse_single(element)
60+
visitorcallable(
61+
dialect=self.dialect, connection=self, **kwargs
62+
).traverse_single(element)
6263

6364
def execute(
6465
self,

lib/sqlalchemy/schema.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
from .sql.schema import PrimaryKeyConstraint as PrimaryKeyConstraint
6666
from .sql.schema import SchemaConst as SchemaConst
6767
from .sql.schema import SchemaItem as SchemaItem
68+
from .sql.schema import SchemaVisitable as SchemaVisitable
6869
from .sql.schema import Sequence as Sequence
6970
from .sql.schema import Table as Table
7071
from .sql.schema import UniqueConstraint as UniqueConstraint

lib/sqlalchemy/sql/_typing.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,10 @@
7272
from .sqltypes import TableValueType
7373
from .sqltypes import TupleType
7474
from .type_api import TypeEngine
75+
from ..engine import Connection
7576
from ..engine import Dialect
77+
from ..engine import Engine
78+
from ..engine.mock import MockConnection
7679
from ..util.typing import TypeGuard
7780

7881
_T = TypeVar("_T", bound=Any)
@@ -304,6 +307,8 @@ def dialect(self) -> Dialect: ...
304307

305308
_AutoIncrementType = Union[bool, Literal["auto", "ignore_fk"]]
306309

310+
_CreateDropBind = Union["Engine", "Connection", "MockConnection"]
311+
307312
if TYPE_CHECKING:
308313

309314
def is_sql_compiler(c: Compiled) -> TypeGuard[SQLCompiler]: ...

lib/sqlalchemy/sql/base.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1540,8 +1540,19 @@ def _set_parent_with_dispatch(
15401540
self.dispatch.after_parent_attach(self, parent)
15411541

15421542

1543+
class SchemaVisitable(SchemaEventTarget, visitors.Visitable):
1544+
"""Base class for elements that are targets of a :class:`.SchemaVisitor`.
1545+
1546+
.. versionadded:: 2.0.41
1547+
1548+
"""
1549+
1550+
15431551
class SchemaVisitor(ClauseVisitor):
1544-
"""Define the visiting for ``SchemaItem`` objects."""
1552+
"""Define the visiting for ``SchemaItem`` and more
1553+
generally ``SchemaVisitable`` objects.
1554+
1555+
"""
15451556

15461557
__traverse_options__ = {"schema_visitor": True}
15471558

lib/sqlalchemy/sql/ddl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -865,8 +865,9 @@ class DropConstraintComment(_CreateDropBase["Constraint"]):
865865

866866

867867
class InvokeDDLBase(SchemaVisitor):
868-
def __init__(self, connection):
868+
def __init__(self, connection, **kw):
869869
self.connection = connection
870+
assert not kw, f"Unexpected keywords: {kw.keys()}"
870871

871872
@contextlib.contextmanager
872873
def with_ddl_events(self, target, **kw):

lib/sqlalchemy/sql/schema.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
from .base import DialectKWArgs
7272
from .base import Executable
7373
from .base import SchemaEventTarget as SchemaEventTarget
74+
from .base import SchemaVisitable as SchemaVisitable
7475
from .coercions import _document_text_coercion
7576
from .elements import ClauseElement
7677
from .elements import ColumnClause
@@ -91,6 +92,7 @@
9192

9293
if typing.TYPE_CHECKING:
9394
from ._typing import _AutoIncrementType
95+
from ._typing import _CreateDropBind
9496
from ._typing import _DDLColumnArgument
9597
from ._typing import _DDLColumnReferenceArgument
9698
from ._typing import _InfoType
@@ -109,7 +111,6 @@
109111
from ..engine.interfaces import _CoreMultiExecuteParams
110112
from ..engine.interfaces import CoreExecuteOptionsParameter
111113
from ..engine.interfaces import ExecutionContext
112-
from ..engine.mock import MockConnection
113114
from ..engine.reflection import _ReflectionInfo
114115
from ..sql.selectable import FromClause
115116

@@ -118,8 +119,6 @@
118119
_TAB = TypeVar("_TAB", bound="Table")
119120

120121

121-
_CreateDropBind = Union["Engine", "Connection", "MockConnection"]
122-
123122
_ConstraintNameArgument = Optional[Union[str, _NoneName]]
124123

125124
_ServerDefaultArgument = Union[
@@ -213,7 +212,7 @@ def replace(
213212

214213

215214
@inspection._self_inspects
216-
class SchemaItem(SchemaEventTarget, visitors.Visitable):
215+
class SchemaItem(SchemaVisitable):
217216
"""Base class for items that define a database schema."""
218217

219218
__visit_name__ = "schema_item"

0 commit comments

Comments
 (0)