117117 from .elements import Null
118118 from .elements import True_
119119 from .functions import Function
120+ from .schema import CheckConstraint
120121 from .schema import Column
121122 from .schema import Constraint
122123 from .schema import ForeignKeyConstraint
@@ -7366,26 +7367,14 @@ def visit_table_or_column_check_constraint(self, constraint, **kw):
73667367 return self .visit_check_constraint (constraint )
73677368
73687369 def visit_check_constraint (self , constraint , ** kw ):
7369- text = ""
7370- if constraint .name is not None :
7371- formatted_name = self .preparer .format_constraint (constraint )
7372- if formatted_name is not None :
7373- text += "CONSTRAINT %s " % formatted_name
7374- text += "CHECK (%s)" % self .sql_compiler .process (
7375- constraint .sqltext , include_table = False , literal_binds = True
7376- )
7370+ text = self .define_constraint_preamble (constraint , ** kw )
7371+ text += self .define_check_body (constraint , ** kw )
73777372 text += self .define_constraint_deferrability (constraint )
73787373 return text
73797374
73807375 def visit_column_check_constraint (self , constraint , ** kw ):
7381- text = ""
7382- if constraint .name is not None :
7383- formatted_name = self .preparer .format_constraint (constraint )
7384- if formatted_name is not None :
7385- text += "CONSTRAINT %s " % formatted_name
7386- text += "CHECK (%s)" % self .sql_compiler .process (
7387- constraint .sqltext , include_table = False , literal_binds = True
7388- )
7376+ text = self .define_constraint_preamble (constraint , ** kw )
7377+ text += self .define_check_body (constraint , ** kw )
73897378 text += self .define_constraint_deferrability (constraint )
73907379 return text
73917380
@@ -7394,11 +7383,50 @@ def visit_primary_key_constraint(
73947383 ) -> str :
73957384 if len (constraint ) == 0 :
73967385 return ""
7386+ text = self .define_constraint_preamble (constraint , ** kw )
7387+ text += self .define_primary_key_body (constraint , ** kw )
7388+ text += self .define_constraint_deferrability (constraint )
7389+ return text
7390+
7391+ def visit_foreign_key_constraint (
7392+ self , constraint : ForeignKeyConstraint , ** kw : Any
7393+ ) -> str :
7394+ text = self .define_constraint_preamble (constraint , ** kw )
7395+ text += self .define_foreign_key_body (constraint , ** kw )
7396+ text += self .define_constraint_match (constraint )
7397+ text += self .define_constraint_cascades (constraint )
7398+ text += self .define_constraint_deferrability (constraint )
7399+ return text
7400+
7401+ def define_constraint_remote_table (self , constraint , table , preparer ):
7402+ """Format the remote table clause of a CREATE CONSTRAINT clause."""
7403+
7404+ return preparer .format_table (table )
7405+
7406+ def visit_unique_constraint (
7407+ self , constraint : UniqueConstraint , ** kw : Any
7408+ ) -> str :
7409+ if len (constraint ) == 0 :
7410+ return ""
7411+ text = self .define_constraint_preamble (constraint , ** kw )
7412+ text += self .define_unique_body (constraint , ** kw )
7413+ text += self .define_constraint_deferrability (constraint )
7414+ return text
7415+
7416+ def define_constraint_preamble (
7417+ self , constraint : Constraint , ** kw : Any
7418+ ) -> str :
73977419 text = ""
73987420 if constraint .name is not None :
73997421 formatted_name = self .preparer .format_constraint (constraint )
74007422 if formatted_name is not None :
74017423 text += "CONSTRAINT %s " % formatted_name
7424+ return text
7425+
7426+ def define_primary_key_body (
7427+ self , constraint : PrimaryKeyConstraint , ** kw : Any
7428+ ) -> str :
7429+ text = ""
74027430 text += "PRIMARY KEY "
74037431 text += "(%s)" % ", " .join (
74047432 self .preparer .quote (c .name )
@@ -7408,18 +7436,14 @@ def visit_primary_key_constraint(
74087436 else constraint .columns
74097437 )
74107438 )
7411- text += self .define_constraint_deferrability (constraint )
74127439 return text
74137440
7414- def visit_foreign_key_constraint (self , constraint , ** kw ):
7441+ def define_foreign_key_body (
7442+ self , constraint : ForeignKeyConstraint , ** kw : Any
7443+ ) -> str :
74157444 preparer = self .preparer
7416- text = ""
7417- if constraint .name is not None :
7418- formatted_name = self .preparer .format_constraint (constraint )
7419- if formatted_name is not None :
7420- text += "CONSTRAINT %s " % formatted_name
74217445 remote_table = list (constraint .elements )[0 ].column .table
7422- text + = "FOREIGN KEY(%s) REFERENCES %s (%s)" % (
7446+ text = "FOREIGN KEY(%s) REFERENCES %s (%s)" % (
74237447 ", " .join (
74247448 preparer .quote (f .parent .name ) for f in constraint .elements
74257449 ),
@@ -7430,31 +7454,21 @@ def visit_foreign_key_constraint(self, constraint, **kw):
74307454 preparer .quote (f .column .name ) for f in constraint .elements
74317455 ),
74327456 )
7433- text += self .define_constraint_match (constraint )
7434- text += self .define_constraint_cascades (constraint )
7435- text += self .define_constraint_deferrability (constraint )
74367457 return text
74377458
7438- def define_constraint_remote_table (self , constraint , table , preparer ):
7439- """Format the remote table clause of a CREATE CONSTRAINT clause."""
7440-
7441- return preparer .format_table (table )
7442-
7443- def visit_unique_constraint (
7459+ def define_unique_body (
74447460 self , constraint : UniqueConstraint , ** kw : Any
74457461 ) -> str :
7446- if len (constraint ) == 0 :
7447- return ""
7448- text = ""
7449- if constraint .name is not None :
7450- formatted_name = self .preparer .format_constraint (constraint )
7451- if formatted_name is not None :
7452- text += "CONSTRAINT %s " % formatted_name
7453- text += "UNIQUE %s(%s)" % (
7462+ text = "UNIQUE %s(%s)" % (
74547463 self .define_unique_constraint_distinct (constraint , ** kw ),
74557464 ", " .join (self .preparer .quote (c .name ) for c in constraint ),
74567465 )
7457- text += self .define_constraint_deferrability (constraint )
7466+ return text
7467+
7468+ def define_check_body (self , constraint : CheckConstraint , ** kw : Any ) -> str :
7469+ text = "CHECK (%s)" % self .sql_compiler .process (
7470+ constraint .sqltext , include_table = False , literal_binds = True
7471+ )
74587472 return text
74597473
74607474 def define_unique_constraint_distinct (
@@ -7500,7 +7514,7 @@ def define_constraint_deferrability(self, constraint: Constraint) -> str:
75007514 )
75017515 return text
75027516
7503- def define_constraint_match (self , constraint ) :
7517+ def define_constraint_match (self , constraint : ForeignKeyConstraint ) -> str :
75047518 text = ""
75057519 if constraint .match is not None :
75067520 text += " MATCH %s" % constraint .match
0 commit comments