Skip to content

Commit 787a538

Browse files
rivudhkmihaibudiu
authored andcommitted
py: add remaining illegal argument tests
Signed-off-by: rivudhk <rivudhkr@gmail.com>
1 parent 9496f67 commit 787a538

2 files changed

Lines changed: 189 additions & 31 deletions

File tree

python/tests/runtime_aggtest/aggtst_base.py

Lines changed: 116 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,29 @@ def __init__(self):
195195
self.tables = []
196196
self.views = []
197197

198+
@staticmethod
199+
def get_object_type_name(obj: SqlObject) -> str:
200+
"""Get the type name of a SQL object (Table or View)"""
201+
return "View" if isinstance(obj, View) else "Table"
202+
203+
@staticmethod
204+
def has_expected_error(obj: SqlObject) -> bool:
205+
"""Check if a table or a view has a non-empty expected_error attribute"""
206+
expected_error = getattr(obj, "expected_error", None)
207+
return expected_error is not None and str(expected_error).strip() != ""
208+
209+
@staticmethod
210+
def filter_by_expected_error(
211+
objects: list[SqlObject], should_fail: bool
212+
) -> list[SqlObject]:
213+
"""Filter tables and views based on whether they are expected to fail or not"""
214+
if should_fail:
215+
return [obj for obj in objects if TstAccumulator.has_expected_error(obj)]
216+
else:
217+
return [
218+
obj for obj in objects if not TstAccumulator.has_expected_error(obj)
219+
]
220+
198221
def add_table(self, table: Table):
199222
"""Add a new table to the program"""
200223
if DEBUG:
@@ -222,8 +245,17 @@ def generate_sql(tables: list[Table], views: list[View]) -> str:
222245
print("Generated sql\n" + result)
223246
return result
224247

225-
def run_pipeline(self, pipeline_name_prefix: str, sql: str, views: list[View]):
248+
def run_pipeline(
249+
self,
250+
pipeline_name_prefix: str,
251+
sql: str,
252+
views: list[View],
253+
tables: list[Table] = None,
254+
):
226255
"""Run pipeline with the given SQL, load tables, validate views, and shutdown"""
256+
if tables is None:
257+
tables = self.tables
258+
227259
pipeline = None
228260
sql_id = sql_hash(sql)
229261
pipeline_name = unique_pipeline_name(f"{pipeline_name_prefix}_{sql_id}")
@@ -246,7 +278,7 @@ def run_pipeline(self, pipeline_name_prefix: str, sql: str, views: list[View]):
246278

247279
pipeline.start()
248280

249-
for table in self.tables:
281+
for table in tables:
250282
if table.get_data() != []:
251283
pipeline.input_json(
252284
table.name, table.get_data(), update_format="insert_delete"
@@ -289,77 +321,131 @@ def run_pipeline(self, pipeline_name_prefix: str, sql: str, views: list[View]):
289321
pipeline.stop(force=True)
290322
pipeline.delete(True)
291323

292-
def assert_expected_error(self, view: View, actual_exception: Exception):
324+
def assert_expected_error(self, obj: SqlObject, actual_exception: Exception):
293325
"""Validate the error produced by the failing pipeline with the expected error type"""
294326
expected_substring = (
295-
str(getattr(view, "expected_error", "") or "").strip().lower()
327+
str(getattr(obj, "expected_error", "") or "").strip().lower()
296328
)
297329
actual_message = str(actual_exception).strip().lower()
298330

331+
obj_type = self.get_object_type_name(obj)
332+
299333
if DEBUG:
300334
print(
301-
f"[DEBUG] View `{view.name}` expected error substring: '{expected_substring}'"
335+
f"[DEBUG] {obj_type} `{obj.name}` expected error substring: '{expected_substring}'"
302336
)
303337
print(
304-
f"[DEBUG] View `{view.name}` received error message:\n{actual_message}"
338+
f"[DEBUG] {obj_type} `{obj.name}` received error message:\n{actual_message}"
305339
)
306340

307341
if expected_substring not in actual_message:
308342
raise AssertionError(
309-
f"\n[FAIL] failed view: {view.name} did not produce expected error substring.\n"
343+
f"\n[FAIL] failed {obj_type.lower()}: {obj.name} did not produce expected error substring.\n"
310344
f"Expected to find: '{expected_substring}'\n"
311345
f"Received error message:\n{actual_message}"
312-
) # Validate based on: does the error received contain the expected substring?
346+
)
313347

314348
if DEBUG:
315-
print(f"[PASS] View `{view.name}` failed as expected.")
349+
print(f"[PASS] {obj_type} `{obj.name}` failed as expected.")
350+
351+
def run_failing_object_test(
352+
self,
353+
obj: SqlObject,
354+
pipeline_name_prefix: str,
355+
sql: str,
356+
views: list[View],
357+
tables: list[Table],
358+
):
359+
"""Run a test for a single object(view, table) expected to fail and verify it produces the expected error"""
360+
obj_type = self.get_object_type_name(obj)
361+
if DEBUG:
362+
print(f"Testing failing {obj_type.lower()}: {obj.name}...")
363+
364+
try:
365+
self.run_pipeline(pipeline_name_prefix, sql, views=views, tables=tables)
366+
raise AssertionError(
367+
f"{obj_type}: `{obj.name}` was expected to fail, but it passed."
368+
)
369+
except AssertionError:
370+
raise
371+
except Exception as e:
372+
self.assert_expected_error(obj, e)
373+
374+
def run_table_tests(self, pipeline_name_prefix: str):
375+
"""Test passing tables together in a single pipeline, failing tables separately in individual pipelines"""
376+
# Separate tables by whether they have a non-empty expected_error attribute
377+
failing_tables = self.filter_by_expected_error(self.tables, should_fail=True)
378+
passing_tables = self.filter_by_expected_error(self.tables, should_fail=False)
379+
380+
# Test all passing tables together
381+
if passing_tables:
382+
if DEBUG:
383+
print(f"Testing {len(passing_tables)} passing tables together...")
384+
sql = TstAccumulator.generate_sql(
385+
passing_tables, []
386+
) # Contains SQL for all passing tables
387+
self.run_pipeline(
388+
pipeline_name_prefix, sql, views=[], tables=passing_tables
389+
)
390+
391+
# Test each failing table individually
392+
for table in failing_tables:
393+
sql = table.get_sql() # Contains SQL for the failing tables
394+
self.run_failing_object_test(
395+
table, pipeline_name_prefix, sql, views=[], tables=[table]
396+
)
316397

317398
def run_expected_failures(self, pipeline_name_prefix: str):
318-
"""Run each view that is expected to fail in a separate pipeline"""
319-
# List of views that contain the attribute: expected error type
320-
failing_views = [v for v in self.views if v.expected_error]
399+
"""Loop through each view that is expected to fail in a separate pipeline"""
400+
# Only use passing tables when testing views
401+
passing_tables = self.filter_by_expected_error(self.tables, should_fail=False)
402+
failing_views = self.filter_by_expected_error(self.views, should_fail=True)
403+
321404
for view in failing_views:
322-
if DEBUG:
323-
print(f"Running failing view: {view.name}...")
324405
sql = TstAccumulator.generate_sql(
325-
self.tables, [view]
326-
) # Contains SQL for the failing view and its related tables only
327-
try:
328-
self.run_pipeline(pipeline_name_prefix, sql, views=[view])
329-
raise AssertionError(
330-
f"View: `{view.name}` was expected to fail, but it passed."
331-
)
332-
except AssertionError:
333-
raise # Re-raise assertion errors about unexpected success
334-
except Exception as e:
335-
self.assert_expected_error(view, e)
406+
passing_tables, [view]
407+
) # Contains SQL for the failing views and tables
408+
self.run_failing_object_test(
409+
view, pipeline_name_prefix, sql, views=[view], tables=passing_tables
410+
)
336411

337412
def run_expected_successes(self, pipeline_name_prefix: str):
338413
"""Run all views that are expected to pass in a single pipeline"""
339-
# List of views that don't contain the attribute: expected error type
340-
passing_views = [v for v in self.views if not v.expected_error]
414+
# Use only passing tables when testing views
415+
passing_tables = self.filter_by_expected_error(self.tables, should_fail=False)
416+
passing_views = self.filter_by_expected_error(self.views, should_fail=False)
417+
341418
if not passing_views:
342419
return
343420
sql = TstAccumulator.generate_sql(
344-
self.tables, passing_views
421+
passing_tables, passing_views
345422
) # Contains SQL for all passing views and their related tables
346-
self.run_pipeline(pipeline_name_prefix, sql, views=passing_views)
423+
self.run_pipeline(
424+
pipeline_name_prefix, sql, views=passing_views, tables=passing_tables
425+
)
347426

348427
def run_tests(self, pipeline_name_prefix: str):
349428
"""Run all tests registered"""
429+
# Test tables (passing tables together, failing tables individually)
430+
self.run_table_tests(pipeline_name_prefix)
431+
# Test views (failing views individually, passing views together)
350432
self.run_expected_failures(pipeline_name_prefix)
351433
self.run_expected_successes(pipeline_name_prefix)
352434

353435

354436
class TstTable:
355437
"""Base class for defining tables"""
356438

439+
expected_error = None
440+
357441
def __init__(self):
358442
self.sql = ""
359443
self.data = []
360444

361445
def register(self, ta: TstAccumulator):
362-
ta.add_table(Table(self.sql, self.data))
446+
table = Table(self.sql, self.data)
447+
table.expected_error = getattr(self, "expected_error", None)
448+
ta.add_table(table)
363449

364450

365451
class TstView:

python/tests/runtime_aggtest/illarg_tests/test_grammar_tbl_fn.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,76 @@
1-
from tests.runtime_aggtest.aggtst_base import TstView
1+
from tests.runtime_aggtest.aggtst_base import TstTable, TstView
2+
3+
4+
# CONNECTOR_METADATA
5+
class illarg_connector_metadata_legal_tbl(TstTable):
6+
"""Define the table used by connector_metadata passing test"""
7+
8+
def __init__(self):
9+
self.sql = """CREATE TABLE connector_metadata_legal_tbl(
10+
id INT,
11+
meta_val VARCHAR
12+
DEFAULT CAST(CONNECTOR_METADATA()['check'] AS VARCHAR)
13+
) WITH (
14+
'connectors' = '[{
15+
"name": "check",
16+
"transport": {
17+
"name": "datagen",
18+
"config": {
19+
"plan": [{"limit": 2}]
20+
}
21+
}
22+
}]'
23+
);"""
24+
self.data = []
25+
26+
27+
class illarg_connector_metadata_legal(TstView):
28+
def __init__(self):
29+
self.data = [{"id": 0, "meta_val": "0"}, {"id": 1, "meta_val": "1"}]
30+
self.sql = """CREATE MATERIALIZED VIEW illarg_connector_metadata_legal AS SELECT
31+
*
32+
FROM connector_metadata_legal_tbl"""
33+
34+
35+
# Negative Test
36+
class illarg_connector_metadata_illegal_tbl(TstTable):
37+
"""Define the table used by connector_metadata failing test"""
38+
39+
def __init__(self):
40+
self.sql = """CREATE TABLE connector_metadata_illegal_tbl(
41+
id INT,
42+
meta_val VARCHAR
43+
DEFAULT CAST(CONNECTOR_METADATA('check')['check'] AS VARCHAR)
44+
) WITH (
45+
'connectors' = '[{
46+
"name": "check",
47+
"transport": {
48+
"name": "datagen",
49+
"config": {
50+
"plan": [{"limit": 2}]
51+
}
52+
}
53+
}]'
54+
);"""
55+
self.data = []
56+
self.expected_error = "Invalid number of arguments to function 'connector_metadata'. Was expecting 0 arguments"
57+
58+
59+
# DEFAULT
60+
class illarg_default_legal(TstTable):
61+
def __init__(self):
62+
self.sql = """CREATE TABLE default_legal (
63+
x INT DEFAULT CAST(42 AS INTEGER))"""
64+
self.data = []
65+
66+
67+
# Negative Test
68+
class illarg_default_illegal(TstTable):
69+
def __init__(self):
70+
self.sql = """CREATE TABLE default_illegal(
71+
intt INT DEFAULT CAST(intt AS INTEGER))"""
72+
self.data = []
73+
self.expected_error = "column 'intt' not found in any table"
274

375

476
# AS

0 commit comments

Comments
 (0)