99
1010__version__ = '1.4.4'
1111__all__ = [
12- 'Flavor' , 'Table' , 'Values' , 'Literal' , 'Column' , 'Grouping' , 'Rollup' ,
13- 'Cube' , 'Join' , 'Asc' , 'Desc' , 'NullsFirst' , 'NullsLast' , 'format2numeric' ]
12+ 'Flavor' , 'Table' , 'Values' , 'Literal' , 'Column' , 'Grouping' , 'Conflict' ,
13+ 'Rollup' , 'Cube' , 'Excluded' , 'Join' , 'Asc' , 'Desc' , 'NullsFirst' ,
14+ 'NullsLast' , 'format2numeric' ]
1415
1516
1617def _escape_identifier (name ):
@@ -664,17 +665,20 @@ def params(self):
664665
665666
666667class Insert (WithQuery ):
667- __slots__ = ('_table' , '_columns' , '_values' , '_returning' )
668+ __slots__ = ('_table' , '_columns' , '_values' , '_on_conflict' , ' _returning' )
668669
669- def __init__ (self , table , columns = None , values = None , returning = None ,
670- ** kwargs ):
670+ def __init__ (
671+ self , table , columns = None , values = None , returning = None ,
672+ on_conflict = None , ** kwargs ):
671673 self ._table = None
672674 self ._columns = None
673675 self ._values = None
676+ self ._on_conflict = None
674677 self ._returning = None
675678 self .table = table
676679 self .columns = columns
677680 self .values = values
681+ self .on_conflict = on_conflict
678682 self .returning = returning
679683 super (Insert , self ).__init__ (** kwargs )
680684
@@ -710,6 +714,17 @@ def values(self, value):
710714 value = Values (value )
711715 self ._values = value
712716
717+ @property
718+ def on_conflict (self ):
719+ return self ._on_conflict
720+
721+ @on_conflict .setter
722+ def on_conflict (self , value ):
723+ if value is not None :
724+ assert isinstance (value , Conflict )
725+ assert value .table == self .table
726+ self ._on_conflict = value
727+
713728 @property
714729 def returning (self ):
715730 return self ._returning
@@ -744,26 +759,166 @@ def __str__(self):
744759 # TODO manage DEFAULT
745760 elif self .values is None :
746761 values = ' DEFAULT VALUES'
762+ on_conflict = ''
763+ if self .on_conflict :
764+ on_conflict = ' %s' % self .on_conflict
747765 returning = ''
748766 if self .returning :
749767 returning = ' RETURNING ' + ', ' .join (
750768 map (self ._format , self .returning ))
751769 return (self ._with_str ()
752770 + 'INSERT INTO %s AS "%s"' % (self .table , self .table .alias )
753- + columns + values + returning )
771+ + columns + values + on_conflict + returning )
754772
755773 @property
756774 def params (self ):
757775 p = []
758776 p .extend (self ._with_params ())
759777 if isinstance (self .values , Query ):
760778 p .extend (self .values .params )
779+ if self .on_conflict :
780+ p .extend (self .on_conflict .params )
761781 if self .returning :
762782 for exp in self .returning :
763783 p .extend (exp .params )
764784 return tuple (p )
765785
766786
787+ class Conflict (object ):
788+ __slots__ = (
789+ '_table' , '_indexed_columns' , '_index_where' , '_columns' , '_values' ,
790+ '_where' )
791+
792+ def __init__ (
793+ self , table , indexed_columns = None , index_where = None ,
794+ columns = None , values = None , where = None ):
795+ self ._table = None
796+ self ._indexed_columns = None
797+ self ._index_where = None
798+ self ._columns = None
799+ self ._values = None
800+ self ._where = None
801+ self .table = table
802+ self .indexed_columns = indexed_columns
803+ self .index_where = index_where
804+ self .columns = columns
805+ self .values = values
806+ self .where = where
807+
808+ @property
809+ def table (self ):
810+ return self ._table
811+
812+ @table .setter
813+ def table (self , value ):
814+ assert isinstance (value , Table )
815+ self ._table = value
816+
817+ @property
818+ def indexed_columns (self ):
819+ return self ._indexed_columns
820+
821+ @indexed_columns .setter
822+ def indexed_columns (self , value ):
823+ if value is not None :
824+ assert all (isinstance (col , Column ) for col in value )
825+ assert all (col .table == self .table for col in value )
826+ self ._indexed_columns = value
827+
828+ @property
829+ def index_where (self ):
830+ return self ._index_where
831+
832+ @index_where .setter
833+ def index_where (self , value ):
834+ from sql .operators import And , Or
835+ if value is not None :
836+ assert isinstance (value , (Expression , And , Or ))
837+ self ._index_where = value
838+
839+ @property
840+ def columns (self ):
841+ return self ._columns
842+
843+ @columns .setter
844+ def columns (self , value ):
845+ if value is not None :
846+ assert all (isinstance (col , Column ) for col in value )
847+ assert all (col .table == self .table for col in value )
848+ self ._columns = value
849+
850+ @property
851+ def values (self ):
852+ return self ._values
853+
854+ @values .setter
855+ def values (self , value ):
856+ if value is not None :
857+ assert isinstance (value , (list , Select ))
858+ if isinstance (value , list ):
859+ value = Values ([value ])
860+ self ._values = value
861+
862+ @property
863+ def where (self ):
864+ return self ._where
865+
866+ @where .setter
867+ def where (self , value ):
868+ from sql .operators import And , Or
869+ if value is not None :
870+ assert isinstance (value , (Expression , And , Or ))
871+ self ._where = value
872+
873+ def __str__ (self ):
874+ indexed_columns = ''
875+ if self .indexed_columns :
876+ assert all (c .table == self .table for c in self .indexed_columns )
877+ # Get columns without alias
878+ indexed_columns = ', ' .join (
879+ c .column_name for c in self .indexed_columns )
880+ indexed_columns = ' (' + indexed_columns + ')'
881+ if self .index_where :
882+ indexed_columns += ' WHERE ' + str (self .index_where )
883+ else :
884+ assert not self .index_where
885+ do = ''
886+ if not self .columns :
887+ assert not self .values
888+ assert not self .where
889+ do = 'NOTHING'
890+ else :
891+ assert all (c .table == self .table for c in self .columns )
892+ # Get columns without alias
893+ do = ', ' .join (c .column_name for c in self .columns )
894+ # TODO manage DEFAULT
895+ values = str (self .values )
896+ if values .startswith ('VALUES' ):
897+ values = values [len ('VALUES' ):]
898+ else :
899+ values = ' (' + values + ')'
900+ if len (self .columns ) == 1 :
901+ # PostgreSQL would require ROW expression
902+ # with single column with parenthesis
903+ do = 'UPDATE SET ' + do + ' =' + values
904+ else :
905+ do = 'UPDATE SET (' + do + ') =' + values
906+ if self .where :
907+ do += ' WHERE %s' % self .where
908+ return 'ON CONFLICT' + indexed_columns + ' DO ' + do
909+
910+ @property
911+ def params (self ):
912+ p = []
913+ if self .index_where :
914+ p .extend (self .index_where .params )
915+ if self .values :
916+ p .extend (self .values .params )
917+ if self .where :
918+ p .extend (self .where .params )
919+ return p
920+
921+
767922class Update (Insert ):
768923 __slots__ = ('_where' , '_values' , 'from_' )
769924
@@ -990,9 +1145,11 @@ def __str__(self):
9901145 def params (self ):
9911146 return ()
9921147
993- def insert (self , columns = None , values = None , returning = None , with_ = None ):
1148+ def insert (
1149+ self , columns = None , values = None , returning = None , with_ = None ,
1150+ on_conflict = None ):
9941151 return Insert (self , columns = columns , values = values ,
995- returning = returning , with_ = with_ )
1152+ on_conflict = on_conflict , returning = returning , with_ = with_ )
9961153
9971154 def update (self , columns , values , from_ = None , where = None , returning = None ,
9981155 with_ = None ):
@@ -1005,6 +1162,22 @@ def delete(self, only=False, using=None, where=None, returning=None,
10051162 returning = returning , with_ = with_ )
10061163
10071164
1165+ class _Excluded (Table ):
1166+ def __init__ (self ):
1167+ super ().__init__ ('EXCLUDED' )
1168+
1169+ @property
1170+ def alias (self ):
1171+ return 'EXCLUDED'
1172+
1173+ @property
1174+ def has_alias (self ):
1175+ return False
1176+
1177+
1178+ Excluded = _Excluded ()
1179+
1180+
10081181class Join (FromItem ):
10091182 __slots__ = ('_left' , '_right' , '_condition' , '_type_' )
10101183
0 commit comments