This repository was archived by the owner on May 17, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 304
Expand file tree
/
Copy pathtest_query.py
More file actions
349 lines (253 loc) · 11.5 KB
/
test_query.py
File metadata and controls
349 lines (253 loc) · 11.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
from datetime import datetime
from typing import List, Optional
import unittest
from data_diff.abcs.database_types import FractionalType, TemporalType
from data_diff.databases.base import Database, BaseDialect
from data_diff.utils import CaseInsensitiveDict, CaseSensitiveDict
from data_diff.databases.base import Compiler, CompileError
from data_diff.queries.api import outerjoin, cte, when, coalesce
from data_diff.queries.ast_classes import Random
from data_diff.queries.api import code, this, table
def normalize_spaces(s: str):
return " ".join(s.split())
class MockDialect(BaseDialect):
name = "MockDialect"
PLACEHOLDER_TABLE = None
ROUNDS_ON_PREC_LOSS = False
def quote(self, s: str) -> str:
return s
def concat(self, l: List[str]) -> str:
s = ", ".join(l)
return f"concat({s})"
def to_comparable(self, s: str) -> str:
return s
def to_string(self, s: str) -> str:
return f"cast({s} as varchar)"
def is_distinct_from(self, a: str, b: str) -> str:
return f"{a} is distinct from {b}"
def random(self) -> str:
return "random()"
def current_timestamp(self) -> str:
return "now()"
def current_database(self) -> str:
return "current_database()"
def current_schema(self) -> str:
return "current_schema()"
def limit_select(
self,
select_query: str,
offset: Optional[int] = None,
limit: Optional[int] = None,
has_order_by: Optional[bool] = None,
) -> str:
x = offset and f"OFFSET {offset}", limit and f"LIMIT {limit}"
result = " ".join(filter(None, x))
return f"SELECT * FROM ({select_query}) AS LIMITED_SELECT {result}"
def explain_as_text(self, query: str) -> str:
return f"explain {query}"
def timestamp_value(self, t: datetime) -> str:
return f"timestamp '{t}'"
def set_timezone_to_utc(self) -> str:
return "set timezone 'UTC'"
def optimizer_hints(self, s: str):
return f"/*+ {s} */ "
def md5_as_int(self, s: str) -> str:
raise NotImplementedError
def md5_as_hex(self, s: str) -> str:
raise NotImplementedError
def normalize_number(self, value: str, coltype: FractionalType) -> str:
raise NotImplementedError
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
raise NotImplementedError
parse_type = NotImplemented
class MockDatabase(Database):
CONNECT_URI_HELP = "mock://"
CONNECT_URI_PARAMS = []
dialect = MockDialect()
_query = NotImplemented
query_table_schema = NotImplemented
select_table_schema = NotImplemented
_process_table_schema = NotImplemented
parse_table_name = NotImplemented
close = NotImplemented
_normalize_table_path = NotImplemented
is_autocommit = NotImplemented
class TestQuery(unittest.TestCase):
def setUp(self):
pass
def test_basic(self):
c = Compiler(MockDatabase())
t = table("point")
t2 = t.select(x=this.x + 1, y=t["y"] + this.x)
assert c.compile(t2) == "SELECT (x + 1) AS x, (y + x) AS y FROM point"
t = table("point").where(this.x == 1, this.y == 2)
assert c.compile(t) == "SELECT * FROM point WHERE (x = 1) AND (y = 2)"
t = table("person").where(this.name == "Albert")
self.assertEqual(c.compile(t), "SELECT * FROM person WHERE (name = 'Albert')")
def test_outerjoin(self):
c = Compiler(MockDatabase())
a = table("a")
b = table("b")
keys = ["x", "y"]
cols = ["u", "v"]
j = outerjoin(a, b).on(a[k] == b[k] for k in keys)
self.assertEqual(
c.compile(j), "SELECT * FROM a tmp1 FULL OUTER JOIN b tmp2 ON (tmp1.x = tmp2.x) AND (tmp1.y = tmp2.y)"
)
def test_schema(self):
c = Compiler(MockDatabase())
schema = dict(id="int", comment="varchar")
# test table
t = table("a", schema=CaseInsensitiveDict(schema))
q = t.select(this.Id, t["COMMENT"])
assert c.compile(q) == "SELECT id, comment FROM a"
t = table("a", schema=CaseSensitiveDict(schema))
self.assertRaises(KeyError, t.__getitem__, "Id")
self.assertRaises(KeyError, t.select, this.Id)
# test select
q = t.select(this.id)
self.assertRaises(KeyError, q.__getitem__, "comment")
# test join
s = CaseInsensitiveDict({"x": int, "y": int})
a = table("a", schema=s)
b = table("b", schema=s)
keys = ["x", "y"]
j = outerjoin(a, b).on(a[k] == b[k] for k in keys).select(a["x"], b["y"], xsum=a["x"] + b["x"])
j["x"], j["y"], j["xsum"]
self.assertRaises(KeyError, j.__getitem__, "ysum")
def test_commutable_select(self):
# c = Compiler(MockDatabase())
t = table("a")
q1 = t.select("a").where("b")
q2 = t.where("b").select("a")
assert q1 == q2, (q1, q2)
def test_cte(self):
c = Compiler(MockDatabase())
t = table("a")
# single cte
t2 = cte(t.select(this.x))
t3 = t2.select(this.x)
expected = "WITH tmp1 AS (SELECT x FROM a) SELECT x FROM tmp1"
assert normalize_spaces(c.compile(t3)) == expected
# nested cte
c = Compiler(MockDatabase())
t4 = cte(t3).select(this.x)
expected = "WITH tmp1 AS (SELECT x FROM a), tmp2 AS (SELECT x FROM tmp1) SELECT x FROM tmp2"
assert normalize_spaces(c.compile(t4)) == expected
# parameterized cte
c = Compiler(MockDatabase())
t2 = cte(t.select(this.x), params=["y"])
t3 = t2.select(this.y)
expected = "WITH tmp1(y) AS (SELECT x FROM a) SELECT y FROM tmp1"
assert normalize_spaces(c.compile(t3)) == expected
def test_funcs(self):
c = Compiler(MockDatabase())
t = table("a")
q = c.compile(t.order_by(Random()).limit(10))
self.assertEqual(q, "SELECT * FROM (SELECT * FROM a ORDER BY random()) AS LIMITED_SELECT LIMIT 10")
q = c.compile(t.select(coalesce(this.a, this.b)))
self.assertEqual(q, "SELECT COALESCE(a, b) FROM a")
def test_select_distinct(self):
c = Compiler(MockDatabase())
t = table("a")
q = c.compile(t.select(this.b, distinct=True))
assert q == "SELECT DISTINCT b FROM a"
# selects merge
q = c.compile(t.where(this.b > 10).select(this.b, distinct=True))
self.assertEqual(q, "SELECT DISTINCT b FROM a WHERE (b > 10)")
# selects stay apart
q = c.compile(t.limit(10).select(this.b, distinct=True))
self.assertEqual(q, "SELECT DISTINCT b FROM (SELECT * FROM (SELECT * FROM a) AS LIMITED_SELECT LIMIT 10) tmp1")
q = c.compile(t.select(this.b, distinct=True).select(distinct=False))
self.assertEqual(q, "SELECT * FROM (SELECT DISTINCT b FROM a) tmp2")
def test_select_with_optimizer_hints(self):
c = Compiler(MockDatabase())
t = table("a")
q = c.compile(t.select(this.b, optimizer_hints="PARALLEL(a 16)"))
assert q == "SELECT /*+ PARALLEL(a 16) */ b FROM a"
q = c.compile(t.where(this.b > 10).select(this.b, optimizer_hints="PARALLEL(a 16)"))
self.assertEqual(q, "SELECT /*+ PARALLEL(a 16) */ b FROM a WHERE (b > 10)")
q = c.compile(t.limit(10).select(this.b, optimizer_hints="PARALLEL(a 16)"))
self.assertEqual(
q, "SELECT /*+ PARALLEL(a 16) */ b FROM (SELECT * FROM (SELECT * FROM a) AS LIMITED_SELECT LIMIT 10) tmp1"
)
q = c.compile(t.select(this.a).group_by(this.b).agg(this.c).select(optimizer_hints="PARALLEL(a 16)"))
self.assertEqual(
q, "SELECT /*+ PARALLEL(a 16) */ * FROM (SELECT b, c FROM (SELECT a FROM a) tmp2 GROUP BY 1) tmp3"
)
def test_table_ops(self):
c = Compiler(MockDatabase())
a = table("a").select(this.x)
b = table("b").select(this.y)
q = c.compile(a.union(b))
assert q == "SELECT x FROM a UNION SELECT y FROM b"
q = c.compile(a.union_all(b))
assert q == "SELECT x FROM a UNION ALL SELECT y FROM b"
q = c.compile(a.minus(b))
assert q == "SELECT x FROM a EXCEPT SELECT y FROM b"
q = c.compile(a.intersect(b))
assert q == "SELECT x FROM a INTERSECT SELECT y FROM b"
def test_ops(self):
c = Compiler(MockDatabase())
t = table("a")
q = c.compile(t.select(this.b + this.c))
self.assertEqual(q, "SELECT (b + c) FROM a")
q = c.compile(t.select(this.b.like(this.c)))
self.assertEqual(q, "SELECT (b LIKE c) FROM a")
q = c.compile(t.select(-this.b.sum()))
self.assertEqual(q, "SELECT (-SUM(b)) FROM a")
def test_group_by(self):
c = Compiler(MockDatabase())
t = table("a")
q = c.compile(t.group_by(this.b).agg(this.c))
self.assertEqual(q, "SELECT b, c FROM a GROUP BY 1")
q = c.compile(t.where(this.b > 1).group_by(this.b).agg(this.c))
self.assertEqual(q, "SELECT b, c FROM a WHERE (b > 1) GROUP BY 1")
self.assertRaises(CompileError, c.compile, t.select(this.b).group_by(this.b))
q = c.compile(t.select(this.b).group_by(this.b).agg())
self.assertEqual(q, "SELECT b FROM (SELECT b FROM a) tmp1 GROUP BY 1")
q = c.compile(t.group_by(this.b, this.c).agg(this.d, this.e))
self.assertEqual(q, "SELECT b, c, d, e FROM a GROUP BY 1, 2")
# Having
q = c.compile(t.group_by(this.b).agg(this.c).having(this.b > 1))
self.assertEqual(q, "SELECT b, c FROM a GROUP BY 1 HAVING (b > 1)")
q = c.compile(t.group_by(this.b).having(this.b > 1).agg(this.c))
self.assertEqual(q, "SELECT b, c FROM a GROUP BY 1 HAVING (b > 1)")
q = c.compile(t.select(this.b).group_by(this.b).agg().having(this.b > 1))
self.assertEqual(q, "SELECT b FROM (SELECT b FROM a) tmp2 GROUP BY 1 HAVING (b > 1)")
# Having sum
q = c.compile(t.group_by(this.b).agg(this.c, this.d).having(this.b.sum() > 1))
self.assertEqual(q, "SELECT b, c, d FROM a GROUP BY 1 HAVING (SUM(b) > 1)")
# Select interaction
q = c.compile(t.select(this.a).group_by(this.b).agg(this.c).select(this.c + 1))
self.assertEqual(q, "SELECT (c + 1) FROM (SELECT b, c FROM (SELECT a FROM a) tmp3 GROUP BY 1) tmp4")
def test_case_when(self):
c = Compiler(MockDatabase())
t = table("a")
q = c.compile(t.select(when(this.b).then(this.c)))
self.assertEqual(q, "SELECT CASE WHEN b THEN c END FROM a")
q = c.compile(t.select(when(this.b).then(this.c).else_(this.d)))
self.assertEqual(q, "SELECT CASE WHEN b THEN c ELSE d END FROM a")
q = c.compile(
t.select(
when(this.type == "text")
.then(this.text)
.when(this.type == "number")
.then(this.number)
.else_("unknown type")
)
)
self.assertEqual(
q,
"SELECT CASE WHEN (type = 'text') THEN text WHEN (type = 'number') THEN number ELSE 'unknown type' END FROM a",
)
def test_code(self):
c = Compiler(MockDatabase())
t = table("a")
q = c.compile(t.select(this.b, code("<x>")).where(code("<y>")))
self.assertEqual(q, "SELECT b, <x> FROM a WHERE <y>")
def tablesample(t, size):
return code("{t} TABLESAMPLE BERNOULLI ({size})", t=t, size=size)
nonzero = table("points").where(this.x > 0, this.y > 0)
q = c.compile(tablesample(nonzero, 10))
self.assertEqual(q, "SELECT * FROM points WHERE (x > 0) AND (y > 0) TABLESAMPLE BERNOULLI (10)")