Skip to content

Commit 9fe3c3c

Browse files
g-80zzzeek
authored andcommitted
Factor out constraints into separate methods
Fixed issue where PostgreSQL dialect options such as ``postgresql_include`` on :class:`.PrimaryKeyConstraint` and :class:`.UniqueConstraint` were rendered in the wrong position when combined with constraint deferrability options like ``deferrable=True``. Pull request courtesy G Allajmi. Fixes: #12867 Closes: #13003 Pull-request: #13003 Pull-request-sha: 1a92160 Change-Id: I8c55d8faae25d56ff63c9126d569c01d8ee6c7dd
1 parent 6785a09 commit 9fe3c3c

5 files changed

Lines changed: 430 additions & 45 deletions

File tree

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
.. change::
2+
:tags: bug, postgresql
3+
:tickets: 12867
4+
5+
Fixed issue where PostgreSQL dialect options such as ``postgresql_include``
6+
on :class:`.PrimaryKeyConstraint` and :class:`.UniqueConstraint` were
7+
rendered in the wrong position when combined with constraint deferrability
8+
options like ``deferrable=True``. Pull request courtesy G Allajmi.

lib/sqlalchemy/dialects/postgresql/base.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2580,13 +2580,19 @@ def visit_foreign_key_constraint(self, constraint, **kw):
25802580
return text
25812581

25822582
def visit_primary_key_constraint(self, constraint, **kw):
2583-
text = super().visit_primary_key_constraint(constraint)
2583+
text = self.define_constraint_preamble(constraint, **kw)
2584+
text += self.define_primary_key_body(constraint, **kw)
25842585
text += self._define_include(constraint)
2586+
text += self.define_constraint_deferrability(constraint)
25852587
return text
25862588

25872589
def visit_unique_constraint(self, constraint, **kw):
2588-
text = super().visit_unique_constraint(constraint)
2590+
if len(constraint) == 0:
2591+
return ""
2592+
text = self.define_constraint_preamble(constraint, **kw)
2593+
text += self.define_unique_body(constraint, **kw)
25892594
text += self._define_include(constraint)
2595+
text += self.define_constraint_deferrability(constraint)
25902596
return text
25912597

25922598
@util.memoized_property

lib/sqlalchemy/sql/compiler.py

Lines changed: 57 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@
117117
from .elements import Null
118118
from .elements import True_
119119
from .functions import Function
120+
from .schema import CheckConstraint
120121
from .schema import Column
121122
from .schema import Constraint
122123
from .schema import ForeignKeyConstraint
@@ -7366,26 +7367,14 @@ def visit_table_or_column_check_constraint(self, constraint, **kw):
73667367
return self.visit_check_constraint(constraint)
73677368

73687369
def visit_check_constraint(self, constraint, **kw):
7369-
text = ""
7370-
if constraint.name is not None:
7371-
formatted_name = self.preparer.format_constraint(constraint)
7372-
if formatted_name is not None:
7373-
text += "CONSTRAINT %s " % formatted_name
7374-
text += "CHECK (%s)" % self.sql_compiler.process(
7375-
constraint.sqltext, include_table=False, literal_binds=True
7376-
)
7370+
text = self.define_constraint_preamble(constraint, **kw)
7371+
text += self.define_check_body(constraint, **kw)
73777372
text += self.define_constraint_deferrability(constraint)
73787373
return text
73797374

73807375
def visit_column_check_constraint(self, constraint, **kw):
7381-
text = ""
7382-
if constraint.name is not None:
7383-
formatted_name = self.preparer.format_constraint(constraint)
7384-
if formatted_name is not None:
7385-
text += "CONSTRAINT %s " % formatted_name
7386-
text += "CHECK (%s)" % self.sql_compiler.process(
7387-
constraint.sqltext, include_table=False, literal_binds=True
7388-
)
7376+
text = self.define_constraint_preamble(constraint, **kw)
7377+
text += self.define_check_body(constraint, **kw)
73897378
text += self.define_constraint_deferrability(constraint)
73907379
return text
73917380

@@ -7394,11 +7383,50 @@ def visit_primary_key_constraint(
73947383
) -> str:
73957384
if len(constraint) == 0:
73967385
return ""
7386+
text = self.define_constraint_preamble(constraint, **kw)
7387+
text += self.define_primary_key_body(constraint, **kw)
7388+
text += self.define_constraint_deferrability(constraint)
7389+
return text
7390+
7391+
def visit_foreign_key_constraint(
7392+
self, constraint: ForeignKeyConstraint, **kw: Any
7393+
) -> str:
7394+
text = self.define_constraint_preamble(constraint, **kw)
7395+
text += self.define_foreign_key_body(constraint, **kw)
7396+
text += self.define_constraint_match(constraint)
7397+
text += self.define_constraint_cascades(constraint)
7398+
text += self.define_constraint_deferrability(constraint)
7399+
return text
7400+
7401+
def define_constraint_remote_table(self, constraint, table, preparer):
7402+
"""Format the remote table clause of a CREATE CONSTRAINT clause."""
7403+
7404+
return preparer.format_table(table)
7405+
7406+
def visit_unique_constraint(
7407+
self, constraint: UniqueConstraint, **kw: Any
7408+
) -> str:
7409+
if len(constraint) == 0:
7410+
return ""
7411+
text = self.define_constraint_preamble(constraint, **kw)
7412+
text += self.define_unique_body(constraint, **kw)
7413+
text += self.define_constraint_deferrability(constraint)
7414+
return text
7415+
7416+
def define_constraint_preamble(
7417+
self, constraint: Constraint, **kw: Any
7418+
) -> str:
73977419
text = ""
73987420
if constraint.name is not None:
73997421
formatted_name = self.preparer.format_constraint(constraint)
74007422
if formatted_name is not None:
74017423
text += "CONSTRAINT %s " % formatted_name
7424+
return text
7425+
7426+
def define_primary_key_body(
7427+
self, constraint: PrimaryKeyConstraint, **kw: Any
7428+
) -> str:
7429+
text = ""
74027430
text += "PRIMARY KEY "
74037431
text += "(%s)" % ", ".join(
74047432
self.preparer.quote(c.name)
@@ -7408,18 +7436,14 @@ def visit_primary_key_constraint(
74087436
else constraint.columns
74097437
)
74107438
)
7411-
text += self.define_constraint_deferrability(constraint)
74127439
return text
74137440

7414-
def visit_foreign_key_constraint(self, constraint, **kw):
7441+
def define_foreign_key_body(
7442+
self, constraint: ForeignKeyConstraint, **kw: Any
7443+
) -> str:
74157444
preparer = self.preparer
7416-
text = ""
7417-
if constraint.name is not None:
7418-
formatted_name = self.preparer.format_constraint(constraint)
7419-
if formatted_name is not None:
7420-
text += "CONSTRAINT %s " % formatted_name
74217445
remote_table = list(constraint.elements)[0].column.table
7422-
text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % (
7446+
text = "FOREIGN KEY(%s) REFERENCES %s (%s)" % (
74237447
", ".join(
74247448
preparer.quote(f.parent.name) for f in constraint.elements
74257449
),
@@ -7430,31 +7454,21 @@ def visit_foreign_key_constraint(self, constraint, **kw):
74307454
preparer.quote(f.column.name) for f in constraint.elements
74317455
),
74327456
)
7433-
text += self.define_constraint_match(constraint)
7434-
text += self.define_constraint_cascades(constraint)
7435-
text += self.define_constraint_deferrability(constraint)
74367457
return text
74377458

7438-
def define_constraint_remote_table(self, constraint, table, preparer):
7439-
"""Format the remote table clause of a CREATE CONSTRAINT clause."""
7440-
7441-
return preparer.format_table(table)
7442-
7443-
def visit_unique_constraint(
7459+
def define_unique_body(
74447460
self, constraint: UniqueConstraint, **kw: Any
74457461
) -> str:
7446-
if len(constraint) == 0:
7447-
return ""
7448-
text = ""
7449-
if constraint.name is not None:
7450-
formatted_name = self.preparer.format_constraint(constraint)
7451-
if formatted_name is not None:
7452-
text += "CONSTRAINT %s " % formatted_name
7453-
text += "UNIQUE %s(%s)" % (
7462+
text = "UNIQUE %s(%s)" % (
74547463
self.define_unique_constraint_distinct(constraint, **kw),
74557464
", ".join(self.preparer.quote(c.name) for c in constraint),
74567465
)
7457-
text += self.define_constraint_deferrability(constraint)
7466+
return text
7467+
7468+
def define_check_body(self, constraint: CheckConstraint, **kw: Any) -> str:
7469+
text = "CHECK (%s)" % self.sql_compiler.process(
7470+
constraint.sqltext, include_table=False, literal_binds=True
7471+
)
74587472
return text
74597473

74607474
def define_unique_constraint_distinct(
@@ -7500,7 +7514,7 @@ def define_constraint_deferrability(self, constraint: Constraint) -> str:
75007514
)
75017515
return text
75027516

7503-
def define_constraint_match(self, constraint):
7517+
def define_constraint_match(self, constraint: ForeignKeyConstraint) -> str:
75047518
text = ""
75057519
if constraint.match is not None:
75067520
text += " MATCH %s" % constraint.match

0 commit comments

Comments
 (0)