Skip to content

Commit f88ceee

Browse files
committed
Add MERGE
Closes #57
1 parent 2ce2ecb commit f88ceee

File tree

3 files changed

+318
-0
lines changed

3 files changed

+318
-0
lines changed

CHANGELOG

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
* Add MERGE
12
* Support UPSERT
23
* Remove default escape char on LIKE and ILIKE
34
* Add GROUPING SETS, CUBE, and ROLLUP

sql/__init__.py

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
__version__ = '1.4.4'
1111
__all__ = [
1212
'Flavor', 'Table', 'Values', 'Literal', 'Column', 'Grouping', 'Conflict',
13+
'Matched', 'MatchedUpdate', 'MatchedDelete',
14+
'NotMatched', 'NotMatchedInsert',
1315
'Rollup', 'Cube', 'Excluded', 'Join', 'Asc', 'Desc', 'NullsFirst',
1416
'NullsLast', 'format2numeric']
1517

@@ -1075,6 +1077,207 @@ def params(self):
10751077
return tuple(p)
10761078

10771079

1080+
class Merge(WithQuery):
1081+
__slots__ = ('_target', '_source', '_condition', '_whens')
1082+
1083+
def __init__(self, target, source, condition, *whens, **kwargs):
1084+
self._target = None
1085+
self._source = None
1086+
self._condition = None
1087+
self._whens = None
1088+
self.target = target
1089+
self.source = source
1090+
self.condition = condition
1091+
self.whens = whens
1092+
super().__init__(**kwargs)
1093+
1094+
@property
1095+
def target(self):
1096+
return self._target
1097+
1098+
@target.setter
1099+
def target(self, value):
1100+
assert isinstance(value, Table)
1101+
self._target = value
1102+
1103+
@property
1104+
def source(self):
1105+
return self._source
1106+
1107+
@source.setter
1108+
def source(self, value):
1109+
assert isinstance(value, (Table, SelectQuery, Values))
1110+
self._source = value
1111+
1112+
@property
1113+
def condition(self):
1114+
return self._condition
1115+
1116+
@condition.setter
1117+
def condition(self, value):
1118+
assert isinstance(value, Expression)
1119+
self._condition = value
1120+
1121+
@property
1122+
def whens(self):
1123+
return self._whens
1124+
1125+
@whens.setter
1126+
def whens(self, value):
1127+
assert all(isinstance(w, Matched) for w in value)
1128+
self._whens = tuple(value)
1129+
1130+
def __str__(self):
1131+
with AliasManager():
1132+
if isinstance(self.source, (Select, Values)):
1133+
source = '(%s)' % self.source
1134+
else:
1135+
source = self.source
1136+
if self.condition:
1137+
condition = 'ON %s' % self.condition
1138+
else:
1139+
condition = ''
1140+
return (self._with_str()
1141+
+ 'MERGE INTO %s AS "%s" ' % (self.target, self.target.alias)
1142+
+ 'USING %s AS "%s" ' % (source, self.source.alias)
1143+
+ condition + ' ' + ' '.join(map(str, self.whens)))
1144+
1145+
@property
1146+
def params(self):
1147+
p = []
1148+
p.extend(self._with_params())
1149+
if isinstance(self.source, (SelectQuery, Values)):
1150+
p.extend(self.source.params)
1151+
if self.condition:
1152+
p.extend(self.condition.params)
1153+
for match in self.whens:
1154+
p.extend(match.params)
1155+
return tuple(p)
1156+
1157+
1158+
class Matched(object):
1159+
__slots__ = ('_condition',)
1160+
_when = 'MATCHED'
1161+
1162+
def __init__(self, condition=None):
1163+
self._condition = None
1164+
self.condition = condition
1165+
1166+
@property
1167+
def condition(self):
1168+
return self._condition
1169+
1170+
@condition.setter
1171+
def condition(self, value):
1172+
if value is not None:
1173+
assert isinstance(value, Expression)
1174+
self._condition = value
1175+
1176+
def _then_str(self):
1177+
return 'DO NOTHING'
1178+
1179+
def __str__(self):
1180+
if self.condition is not None:
1181+
condition = ' AND ' + str(self.condition)
1182+
else:
1183+
condition = ''
1184+
return 'WHEN ' + self._when + condition + ' THEN ' + self._then_str()
1185+
1186+
@property
1187+
def params(self):
1188+
p = []
1189+
if self.condition:
1190+
p.extend(self.condition.params)
1191+
return tuple(p)
1192+
1193+
1194+
class _MatchedValues(Matched):
1195+
__slots__ = ('_columns', '_values')
1196+
1197+
def __init__(self, columns, values, **kwargs):
1198+
self._columns = columns
1199+
self._values = values
1200+
self.columns = columns
1201+
self.values = values
1202+
super().__init__(**kwargs)
1203+
1204+
@property
1205+
def columns(self):
1206+
return self._columns
1207+
1208+
@columns.setter
1209+
def columns(self, value):
1210+
assert all(isinstance(col, Column) for col in value)
1211+
self._columns = value
1212+
1213+
1214+
class MatchedUpdate(_MatchedValues, Matched):
1215+
__slots__ = ()
1216+
1217+
@property
1218+
def values(self):
1219+
return self._values
1220+
1221+
@values.setter
1222+
def values(self, value):
1223+
self._values = value
1224+
1225+
def _then_str(self):
1226+
columns = [c.column_name for c in self.columns]
1227+
return 'UPDATE SET ' + ', '.join(
1228+
'%s = %s' % (c, Update._format(v))
1229+
for c, v in zip(columns, self.values))
1230+
1231+
@property
1232+
def params(self):
1233+
p = list(super().params)
1234+
for value in self.values:
1235+
if isinstance(value, (Expression, Select)):
1236+
p.extend(value.params)
1237+
else:
1238+
p.append(value)
1239+
return tuple(p)
1240+
1241+
1242+
class MatchedDelete(Matched):
1243+
__slots__ = ()
1244+
1245+
def _then_str(self):
1246+
return 'DELETE'
1247+
1248+
1249+
class NotMatched(Matched):
1250+
__slots__ = ()
1251+
_when = 'NOT MATCHED'
1252+
1253+
1254+
class NotMatchedInsert(_MatchedValues, NotMatched):
1255+
__slots__ = ()
1256+
1257+
@property
1258+
def values(self):
1259+
return self._values
1260+
1261+
@values.setter
1262+
def values(self, value):
1263+
self._values = Values([value])
1264+
1265+
def _then_str(self):
1266+
columns = ', '.join(c.column_name for c in self.columns)
1267+
columns = '(' + columns + ')'
1268+
if self.values is None:
1269+
values = ' DEFAULT VALUES '
1270+
else:
1271+
values = ' ' + str(self.values)
1272+
return 'INSERT ' + columns + values
1273+
1274+
@property
1275+
def params(self):
1276+
p = list(super().params)
1277+
p.extend(self.values.params)
1278+
return tuple(p)
1279+
1280+
10781281
class CombiningQuery(FromItem, SelectQuery):
10791282
__slots__ = ('queries', 'all_')
10801283
_operator = ''
@@ -1161,6 +1364,9 @@ def delete(self, only=False, using=None, where=None, returning=None,
11611364
return Delete(self, only=only, using=using, where=where,
11621365
returning=returning, with_=with_)
11631366

1367+
def merge(self, source, condition, *whens, with_=None):
1368+
return Merge(self, source, condition, *whens, with_=with_)
1369+
11641370

11651371
class _Excluded(Table):
11661372
def __init__(self):

sql/tests/test_merge.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# This file is part of python-sql. The COPYRIGHT file at the top level of
2+
# this repository contains the full copyright notices and license terms.
3+
4+
import unittest
5+
6+
from sql import (
7+
Matched, MatchedDelete, MatchedUpdate, NotMatched, NotMatchedInsert, Table,
8+
With)
9+
10+
11+
class TestMerge(unittest.TestCase):
12+
target = Table('t')
13+
source = Table('s')
14+
15+
def test_merge(self):
16+
query = self.target.merge(
17+
self.source, self.target.c1 == self.source.c2, Matched())
18+
self.assertEqual(
19+
str(query),
20+
'MERGE INTO "t" AS "a" USING "s" AS "b" '
21+
'ON ("a"."c1" = "b"."c2") '
22+
'WHEN MATCHED THEN DO NOTHING')
23+
self.assertEqual(query.params, ())
24+
25+
def test_condition(self):
26+
query = self.target.merge(
27+
self.source,
28+
(self.target.c1 == self.source.c2) & (self.target.c3 == 42),
29+
Matched())
30+
self.assertEqual(
31+
str(query),
32+
'MERGE INTO "t" AS "a" USING "s" AS "b" '
33+
'ON (("a"."c1" = "b"."c2") AND ("a"."c3" = %s)) '
34+
'WHEN MATCHED THEN DO NOTHING')
35+
self.assertEqual(query.params, (42,))
36+
37+
def test_matched(self):
38+
query = self.target.merge(
39+
self.source, self.target.c1 == self.source.c2,
40+
Matched((self.source.c3 == 42)
41+
& (self.target.c4 == self.source.c5)))
42+
self.assertEqual(
43+
str(query),
44+
'MERGE INTO "t" AS "a" USING "s" AS "b" '
45+
'ON ("a"."c1" = "b"."c2") '
46+
'WHEN MATCHED '
47+
'AND (("b"."c3" = %s) AND ("a"."c4" = "b"."c5")) '
48+
'THEN DO NOTHING')
49+
self.assertEqual(query.params, (42,))
50+
51+
def test_matched_update(self):
52+
query = self.target.merge(
53+
self.source, self.target.c1 == self.source.c2,
54+
MatchedUpdate([self.target.c1], [self.target.c1 + self.source.c2]))
55+
self.assertEqual(
56+
str(query),
57+
'MERGE INTO "t" AS "a" USING "s" AS "b" '
58+
'ON ("a"."c1" = "b"."c2") '
59+
'WHEN MATCHED THEN '
60+
'UPDATE SET "c1" = ("a"."c1" + "b"."c2")')
61+
self.assertEqual(query.params, ())
62+
63+
def test_matched_delete(self):
64+
query = self.target.merge(
65+
self.source, self.target.c1 == self.source.c2, MatchedDelete())
66+
self.assertEqual(
67+
str(query),
68+
'MERGE INTO "t" AS "a" USING "s" AS "b" '
69+
'ON ("a"."c1" = "b"."c2") '
70+
'WHEN MATCHED THEN DELETE')
71+
self.assertEqual(query.params, ())
72+
73+
def test_not_matched(self):
74+
query = self.target.merge(
75+
self.source, self.target.c1 == self.source.c2, NotMatched())
76+
self.assertEqual(
77+
str(query),
78+
'MERGE INTO "t" AS "a" USING "s" AS "b" '
79+
'ON ("a"."c1" = "b"."c2") '
80+
'WHEN NOT MATCHED THEN DO NOTHING')
81+
self.assertEqual(query.params, ())
82+
83+
def test_not_matched_insert(self):
84+
query = self.target.merge(
85+
self.source, self.target.c1 == self.source.c2,
86+
NotMatchedInsert(
87+
[self.target.c1, self.target.c2],
88+
[self.source.c3, self.source.c4]))
89+
self.assertEqual(
90+
str(query),
91+
'MERGE INTO "t" AS "a" USING "s" AS "b" '
92+
'ON ("a"."c1" = "b"."c2") '
93+
'WHEN NOT MATCHED THEN '
94+
'INSERT ("c1", "c2") VALUES ("b"."c3", "b"."c4")')
95+
self.assertEqual(query.params, ())
96+
97+
def test_with(self):
98+
t1 = Table('t1')
99+
w = With(query=t1.select(where=t1.c2 == 42))
100+
source = w.select()
101+
102+
query = self.target.merge(
103+
source, self.target.c1 == source.c2, Matched(), with_=[w])
104+
self.assertEqual(
105+
str(query),
106+
'WITH "a" AS (SELECT * FROM "t1" AS "d" WHERE ("d"."c2" = %s)) '
107+
'MERGE INTO "t" AS "b" '
108+
'USING (SELECT * FROM "a" AS "a") AS "c" '
109+
'ON ("b"."c1" = "c"."c2") '
110+
'WHEN MATCHED THEN DO NOTHING')
111+
self.assertEqual(query.params, (42,))

0 commit comments

Comments
 (0)