Skip to content

Commit 2ce2ecb

Browse files
committed
Add support for UPSERT
Closes #57
1 parent 4c9be69 commit 2ce2ecb

File tree

3 files changed

+276
-9
lines changed

3 files changed

+276
-9
lines changed

CHANGELOG

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
* Support UPSERT
12
* Remove default escape char on LIKE and ILIKE
23
* Add GROUPING SETS, CUBE, and ROLLUP
34

sql/__init__.py

Lines changed: 181 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99

1010
__version__ = '1.4.4'
1111
__all__ = [
12-
'Flavor', 'Table', 'Values', 'Literal', 'Column', 'Grouping', 'Rollup',
13-
'Cube', 'Join', 'Asc', 'Desc', 'NullsFirst', 'NullsLast', 'format2numeric']
12+
'Flavor', 'Table', 'Values', 'Literal', 'Column', 'Grouping', 'Conflict',
13+
'Rollup', 'Cube', 'Excluded', 'Join', 'Asc', 'Desc', 'NullsFirst',
14+
'NullsLast', 'format2numeric']
1415

1516

1617
def _escape_identifier(name):
@@ -664,17 +665,20 @@ def params(self):
664665

665666

666667
class Insert(WithQuery):
667-
__slots__ = ('_table', '_columns', '_values', '_returning')
668+
__slots__ = ('_table', '_columns', '_values', '_on_conflict', '_returning')
668669

669-
def __init__(self, table, columns=None, values=None, returning=None,
670-
**kwargs):
670+
def __init__(
671+
self, table, columns=None, values=None, returning=None,
672+
on_conflict=None, **kwargs):
671673
self._table = None
672674
self._columns = None
673675
self._values = None
676+
self._on_conflict = None
674677
self._returning = None
675678
self.table = table
676679
self.columns = columns
677680
self.values = values
681+
self.on_conflict = on_conflict
678682
self.returning = returning
679683
super(Insert, self).__init__(**kwargs)
680684

@@ -710,6 +714,17 @@ def values(self, value):
710714
value = Values(value)
711715
self._values = value
712716

717+
@property
718+
def on_conflict(self):
719+
return self._on_conflict
720+
721+
@on_conflict.setter
722+
def on_conflict(self, value):
723+
if value is not None:
724+
assert isinstance(value, Conflict)
725+
assert value.table == self.table
726+
self._on_conflict = value
727+
713728
@property
714729
def returning(self):
715730
return self._returning
@@ -744,26 +759,166 @@ def __str__(self):
744759
# TODO manage DEFAULT
745760
elif self.values is None:
746761
values = ' DEFAULT VALUES'
762+
on_conflict = ''
763+
if self.on_conflict:
764+
on_conflict = ' %s' % self.on_conflict
747765
returning = ''
748766
if self.returning:
749767
returning = ' RETURNING ' + ', '.join(
750768
map(self._format, self.returning))
751769
return (self._with_str()
752770
+ 'INSERT INTO %s AS "%s"' % (self.table, self.table.alias)
753-
+ columns + values + returning)
771+
+ columns + values + on_conflict + returning)
754772

755773
@property
756774
def params(self):
757775
p = []
758776
p.extend(self._with_params())
759777
if isinstance(self.values, Query):
760778
p.extend(self.values.params)
779+
if self.on_conflict:
780+
p.extend(self.on_conflict.params)
761781
if self.returning:
762782
for exp in self.returning:
763783
p.extend(exp.params)
764784
return tuple(p)
765785

766786

787+
class Conflict(object):
788+
__slots__ = (
789+
'_table', '_indexed_columns', '_index_where', '_columns', '_values',
790+
'_where')
791+
792+
def __init__(
793+
self, table, indexed_columns=None, index_where=None,
794+
columns=None, values=None, where=None):
795+
self._table = None
796+
self._indexed_columns = None
797+
self._index_where = None
798+
self._columns = None
799+
self._values = None
800+
self._where = None
801+
self.table = table
802+
self.indexed_columns = indexed_columns
803+
self.index_where = index_where
804+
self.columns = columns
805+
self.values = values
806+
self.where = where
807+
808+
@property
809+
def table(self):
810+
return self._table
811+
812+
@table.setter
813+
def table(self, value):
814+
assert isinstance(value, Table)
815+
self._table = value
816+
817+
@property
818+
def indexed_columns(self):
819+
return self._indexed_columns
820+
821+
@indexed_columns.setter
822+
def indexed_columns(self, value):
823+
if value is not None:
824+
assert all(isinstance(col, Column) for col in value)
825+
assert all(col.table == self.table for col in value)
826+
self._indexed_columns = value
827+
828+
@property
829+
def index_where(self):
830+
return self._index_where
831+
832+
@index_where.setter
833+
def index_where(self, value):
834+
from sql.operators import And, Or
835+
if value is not None:
836+
assert isinstance(value, (Expression, And, Or))
837+
self._index_where = value
838+
839+
@property
840+
def columns(self):
841+
return self._columns
842+
843+
@columns.setter
844+
def columns(self, value):
845+
if value is not None:
846+
assert all(isinstance(col, Column) for col in value)
847+
assert all(col.table == self.table for col in value)
848+
self._columns = value
849+
850+
@property
851+
def values(self):
852+
return self._values
853+
854+
@values.setter
855+
def values(self, value):
856+
if value is not None:
857+
assert isinstance(value, (list, Select))
858+
if isinstance(value, list):
859+
value = Values([value])
860+
self._values = value
861+
862+
@property
863+
def where(self):
864+
return self._where
865+
866+
@where.setter
867+
def where(self, value):
868+
from sql.operators import And, Or
869+
if value is not None:
870+
assert isinstance(value, (Expression, And, Or))
871+
self._where = value
872+
873+
def __str__(self):
874+
indexed_columns = ''
875+
if self.indexed_columns:
876+
assert all(c.table == self.table for c in self.indexed_columns)
877+
# Get columns without alias
878+
indexed_columns = ', '.join(
879+
c.column_name for c in self.indexed_columns)
880+
indexed_columns = ' (' + indexed_columns + ')'
881+
if self.index_where:
882+
indexed_columns += ' WHERE ' + str(self.index_where)
883+
else:
884+
assert not self.index_where
885+
do = ''
886+
if not self.columns:
887+
assert not self.values
888+
assert not self.where
889+
do = 'NOTHING'
890+
else:
891+
assert all(c.table == self.table for c in self.columns)
892+
# Get columns without alias
893+
do = ', '.join(c.column_name for c in self.columns)
894+
# TODO manage DEFAULT
895+
values = str(self.values)
896+
if values.startswith('VALUES'):
897+
values = values[len('VALUES'):]
898+
else:
899+
values = ' (' + values + ')'
900+
if len(self.columns) == 1:
901+
# PostgreSQL would require ROW expression
902+
# with single column with parenthesis
903+
do = 'UPDATE SET ' + do + ' =' + values
904+
else:
905+
do = 'UPDATE SET (' + do + ') =' + values
906+
if self.where:
907+
do += ' WHERE %s' % self.where
908+
return 'ON CONFLICT' + indexed_columns + ' DO ' + do
909+
910+
@property
911+
def params(self):
912+
p = []
913+
if self.index_where:
914+
p.extend(self.index_where.params)
915+
if self.values:
916+
p.extend(self.values.params)
917+
if self.where:
918+
p.extend(self.where.params)
919+
return p
920+
921+
767922
class Update(Insert):
768923
__slots__ = ('_where', '_values', 'from_')
769924

@@ -990,9 +1145,11 @@ def __str__(self):
9901145
def params(self):
9911146
return ()
9921147

993-
def insert(self, columns=None, values=None, returning=None, with_=None):
1148+
def insert(
1149+
self, columns=None, values=None, returning=None, with_=None,
1150+
on_conflict=None):
9941151
return Insert(self, columns=columns, values=values,
995-
returning=returning, with_=with_)
1152+
on_conflict=on_conflict, returning=returning, with_=with_)
9961153

9971154
def update(self, columns, values, from_=None, where=None, returning=None,
9981155
with_=None):
@@ -1005,6 +1162,22 @@ def delete(self, only=False, using=None, where=None, returning=None,
10051162
returning=returning, with_=with_)
10061163

10071164

1165+
class _Excluded(Table):
1166+
def __init__(self):
1167+
super().__init__('EXCLUDED')
1168+
1169+
@property
1170+
def alias(self):
1171+
return 'EXCLUDED'
1172+
1173+
@property
1174+
def has_alias(self):
1175+
return False
1176+
1177+
1178+
Excluded = _Excluded()
1179+
1180+
10081181
class Join(FromItem):
10091182
__slots__ = ('_left', '_right', '_condition', '_type_')
10101183

sql/tests/test_insert.py

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# this repository contains the full copyright notices and license terms.
33
import unittest
44

5-
from sql import Table, With
5+
from sql import Conflict, Excluded, Table, With
66
from sql.functions import Abs
77

88

@@ -103,3 +103,96 @@ def test_schema(self):
103103
self.assertEqual(str(query),
104104
'INSERT INTO "default"."t1" AS "a" ("c1") VALUES (%s)')
105105
self.assertEqual(tuple(query.params), ('foo',))
106+
107+
def test_upsert_nothing(self):
108+
query = self.table.insert(
109+
[self.table.c1], [['foo']],
110+
on_conflict=Conflict(self.table))
111+
112+
self.assertEqual(str(query),
113+
'INSERT INTO "t" AS "a" ("c1") VALUES (%s) '
114+
'ON CONFLICT DO NOTHING')
115+
self.assertEqual(tuple(query.params), ('foo',))
116+
117+
def test_upsert_indexed_column(self):
118+
query = self.table.insert(
119+
[self.table.c1], [['foo']],
120+
on_conflict=Conflict(
121+
self.table,
122+
indexed_columns=[self.table.c1, self.table.c2]))
123+
124+
self.assertEqual(str(query),
125+
'INSERT INTO "t" AS "a" ("c1") VALUES (%s) '
126+
'ON CONFLICT ("c1", "c2") DO NOTHING')
127+
self.assertEqual(tuple(query.params), ('foo',))
128+
129+
def test_upsert_indexed_column_index_where(self):
130+
query = self.table.insert(
131+
[self.table.c1], [['foo']],
132+
on_conflict=Conflict(
133+
self.table,
134+
indexed_columns=[self.table.c1],
135+
index_where=self.table.c2 == 'bar'))
136+
137+
self.assertEqual(str(query),
138+
'INSERT INTO "t" AS "a" ("c1") VALUES (%s) '
139+
'ON CONFLICT ("c1") WHERE ("a"."c2" = %s) DO NOTHING')
140+
self.assertEqual(tuple(query.params), ('foo', 'bar'))
141+
142+
def test_upsert_update(self):
143+
query = self.table.insert(
144+
[self.table.c1], [['baz']],
145+
on_conflict=Conflict(
146+
self.table,
147+
columns=[self.table.c1, self.table.c2],
148+
values=['foo', 'bar']))
149+
150+
self.assertEqual(str(query),
151+
'INSERT INTO "t" AS "a" ("c1") VALUES (%s) '
152+
'ON CONFLICT DO UPDATE SET ("c1", "c2") = (%s, %s)')
153+
self.assertEqual(tuple(query.params), ('baz', 'foo', 'bar'))
154+
155+
def test_upsert_update_where(self):
156+
query = self.table.insert(
157+
[self.table.c1], [['baz']],
158+
on_conflict=Conflict(
159+
self.table,
160+
columns=[self.table.c1],
161+
values=['foo'],
162+
where=self.table.c2 == 'bar'))
163+
164+
self.assertEqual(str(query),
165+
'INSERT INTO "t" AS "a" ("c1") VALUES (%s) '
166+
'ON CONFLICT DO UPDATE SET "c1" = (%s) '
167+
'WHERE ("a"."c2" = %s)')
168+
self.assertEqual(tuple(query.params), ('baz', 'foo', 'bar'))
169+
170+
def test_upsert_update_subquery(self):
171+
t1 = Table('t1')
172+
t2 = Table('t2')
173+
subquery = t2.select(t2.c1, t2.c2)
174+
query = t1.insert(
175+
[t1.c1], [['baz']],
176+
on_conflict=Conflict(
177+
t1,
178+
columns=[t1.c1, t1.c2],
179+
values=subquery))
180+
181+
self.assertEqual(str(query),
182+
'INSERT INTO "t1" AS "b" ("c1") VALUES (%s) '
183+
'ON CONFLICT DO UPDATE SET ("c1", "c2") = '
184+
'(SELECT "a"."c1", "a"."c2" FROM "t2" AS "a")')
185+
self.assertEqual(tuple(query.params), ('baz',))
186+
187+
def test_upsert_update_excluded(self):
188+
query = self.table.insert(
189+
[self.table.c1], [[1]],
190+
on_conflict=Conflict(
191+
self.table,
192+
columns=[self.table.c1],
193+
values=[Excluded.c1 + 2]))
194+
195+
self.assertEqual(str(query),
196+
'INSERT INTO "t" AS "a" ("c1") VALUES (%s) '
197+
'ON CONFLICT DO UPDATE SET "c1" = (("EXCLUDED"."c1" + %s))')
198+
self.assertEqual(tuple(query.params), (1, 2))

0 commit comments

Comments
 (0)