@@ -195,6 +195,29 @@ def __init__(self):
195195 self .tables = []
196196 self .views = []
197197
198+ @staticmethod
199+ def get_object_type_name (obj : SqlObject ) -> str :
200+ """Get the type name of a SQL object (Table or View)"""
201+ return "View" if isinstance (obj , View ) else "Table"
202+
203+ @staticmethod
204+ def has_expected_error (obj : SqlObject ) -> bool :
205+ """Check if a table or a view has a non-empty expected_error attribute"""
206+ expected_error = getattr (obj , "expected_error" , None )
207+ return expected_error is not None and str (expected_error ).strip () != ""
208+
209+ @staticmethod
210+ def filter_by_expected_error (
211+ objects : list [SqlObject ], should_fail : bool
212+ ) -> list [SqlObject ]:
213+ """Filter tables and views based on whether they are expected to fail or not"""
214+ if should_fail :
215+ return [obj for obj in objects if TstAccumulator .has_expected_error (obj )]
216+ else :
217+ return [
218+ obj for obj in objects if not TstAccumulator .has_expected_error (obj )
219+ ]
220+
198221 def add_table (self , table : Table ):
199222 """Add a new table to the program"""
200223 if DEBUG :
@@ -222,8 +245,17 @@ def generate_sql(tables: list[Table], views: list[View]) -> str:
222245 print ("Generated sql\n " + result )
223246 return result
224247
225- def run_pipeline (self , pipeline_name_prefix : str , sql : str , views : list [View ]):
248+ def run_pipeline (
249+ self ,
250+ pipeline_name_prefix : str ,
251+ sql : str ,
252+ views : list [View ],
253+ tables : list [Table ] = None ,
254+ ):
226255 """Run pipeline with the given SQL, load tables, validate views, and shutdown"""
256+ if tables is None :
257+ tables = self .tables
258+
227259 pipeline = None
228260 sql_id = sql_hash (sql )
229261 pipeline_name = unique_pipeline_name (f"{ pipeline_name_prefix } _{ sql_id } " )
@@ -246,7 +278,7 @@ def run_pipeline(self, pipeline_name_prefix: str, sql: str, views: list[View]):
246278
247279 pipeline .start ()
248280
249- for table in self . tables :
281+ for table in tables :
250282 if table .get_data () != []:
251283 pipeline .input_json (
252284 table .name , table .get_data (), update_format = "insert_delete"
@@ -289,77 +321,131 @@ def run_pipeline(self, pipeline_name_prefix: str, sql: str, views: list[View]):
289321 pipeline .stop (force = True )
290322 pipeline .delete (True )
291323
292- def assert_expected_error (self , view : View , actual_exception : Exception ):
324+ def assert_expected_error (self , obj : SqlObject , actual_exception : Exception ):
293325 """Validate the error produced by the failing pipeline with the expected error type"""
294326 expected_substring = (
295- str (getattr (view , "expected_error" , "" ) or "" ).strip ().lower ()
327+ str (getattr (obj , "expected_error" , "" ) or "" ).strip ().lower ()
296328 )
297329 actual_message = str (actual_exception ).strip ().lower ()
298330
331+ obj_type = self .get_object_type_name (obj )
332+
299333 if DEBUG :
300334 print (
301- f"[DEBUG] View `{ view .name } ` expected error substring: '{ expected_substring } '"
335+ f"[DEBUG] { obj_type } `{ obj .name } ` expected error substring: '{ expected_substring } '"
302336 )
303337 print (
304- f"[DEBUG] View `{ view .name } ` received error message:\n { actual_message } "
338+ f"[DEBUG] { obj_type } `{ obj .name } ` received error message:\n { actual_message } "
305339 )
306340
307341 if expected_substring not in actual_message :
308342 raise AssertionError (
309- f"\n [FAIL] failed view : { view .name } did not produce expected error substring.\n "
343+ f"\n [FAIL] failed { obj_type . lower () } : { obj .name } did not produce expected error substring.\n "
310344 f"Expected to find: '{ expected_substring } '\n "
311345 f"Received error message:\n { actual_message } "
312- ) # Validate based on: does the error received contain the expected substring?
346+ )
313347
314348 if DEBUG :
315- print (f"[PASS] View `{ view .name } ` failed as expected." )
349+ print (f"[PASS] { obj_type } `{ obj .name } ` failed as expected." )
350+
351+ def run_failing_object_test (
352+ self ,
353+ obj : SqlObject ,
354+ pipeline_name_prefix : str ,
355+ sql : str ,
356+ views : list [View ],
357+ tables : list [Table ],
358+ ):
359+ """Run a test for a single object(view, table) expected to fail and verify it produces the expected error"""
360+ obj_type = self .get_object_type_name (obj )
361+ if DEBUG :
362+ print (f"Testing failing { obj_type .lower ()} : { obj .name } ..." )
363+
364+ try :
365+ self .run_pipeline (pipeline_name_prefix , sql , views = views , tables = tables )
366+ raise AssertionError (
367+ f"{ obj_type } : `{ obj .name } ` was expected to fail, but it passed."
368+ )
369+ except AssertionError :
370+ raise
371+ except Exception as e :
372+ self .assert_expected_error (obj , e )
373+
374+ def run_table_tests (self , pipeline_name_prefix : str ):
375+ """Test passing tables together in a single pipeline, failing tables separately in individual pipelines"""
376+ # Separate tables by whether they have a non-empty expected_error attribute
377+ failing_tables = self .filter_by_expected_error (self .tables , should_fail = True )
378+ passing_tables = self .filter_by_expected_error (self .tables , should_fail = False )
379+
380+ # Test all passing tables together
381+ if passing_tables :
382+ if DEBUG :
383+ print (f"Testing { len (passing_tables )} passing tables together..." )
384+ sql = TstAccumulator .generate_sql (
385+ passing_tables , []
386+ ) # Contains SQL for all passing tables
387+ self .run_pipeline (
388+ pipeline_name_prefix , sql , views = [], tables = passing_tables
389+ )
390+
391+ # Test each failing table individually
392+ for table in failing_tables :
393+ sql = table .get_sql () # Contains SQL for the failing tables
394+ self .run_failing_object_test (
395+ table , pipeline_name_prefix , sql , views = [], tables = [table ]
396+ )
316397
317398 def run_expected_failures (self , pipeline_name_prefix : str ):
318- """Run each view that is expected to fail in a separate pipeline"""
319- # List of views that contain the attribute: expected error type
320- failing_views = [v for v in self .views if v .expected_error ]
399+ """Loop through each view that is expected to fail in a separate pipeline"""
400+ # Only use passing tables when testing views
401+ passing_tables = self .filter_by_expected_error (self .tables , should_fail = False )
402+ failing_views = self .filter_by_expected_error (self .views , should_fail = True )
403+
321404 for view in failing_views :
322- if DEBUG :
323- print (f"Running failing view: { view .name } ..." )
324405 sql = TstAccumulator .generate_sql (
325- self .tables , [view ]
326- ) # Contains SQL for the failing view and its related tables only
327- try :
328- self .run_pipeline (pipeline_name_prefix , sql , views = [view ])
329- raise AssertionError (
330- f"View: `{ view .name } ` was expected to fail, but it passed."
331- )
332- except AssertionError :
333- raise # Re-raise assertion errors about unexpected success
334- except Exception as e :
335- self .assert_expected_error (view , e )
406+ passing_tables , [view ]
407+ ) # Contains SQL for the failing views and tables
408+ self .run_failing_object_test (
409+ view , pipeline_name_prefix , sql , views = [view ], tables = passing_tables
410+ )
336411
337412 def run_expected_successes (self , pipeline_name_prefix : str ):
338413 """Run all views that are expected to pass in a single pipeline"""
339- # List of views that don't contain the attribute: expected error type
340- passing_views = [v for v in self .views if not v .expected_error ]
414+ # Use only passing tables when testing views
415+ passing_tables = self .filter_by_expected_error (self .tables , should_fail = False )
416+ passing_views = self .filter_by_expected_error (self .views , should_fail = False )
417+
341418 if not passing_views :
342419 return
343420 sql = TstAccumulator .generate_sql (
344- self . tables , passing_views
421+ passing_tables , passing_views
345422 ) # Contains SQL for all passing views and their related tables
346- self .run_pipeline (pipeline_name_prefix , sql , views = passing_views )
423+ self .run_pipeline (
424+ pipeline_name_prefix , sql , views = passing_views , tables = passing_tables
425+ )
347426
348427 def run_tests (self , pipeline_name_prefix : str ):
349428 """Run all tests registered"""
429+ # Test tables (passing tables together, failing tables individually)
430+ self .run_table_tests (pipeline_name_prefix )
431+ # Test views (failing views individually, passing views together)
350432 self .run_expected_failures (pipeline_name_prefix )
351433 self .run_expected_successes (pipeline_name_prefix )
352434
353435
354436class TstTable :
355437 """Base class for defining tables"""
356438
439+ expected_error = None
440+
357441 def __init__ (self ):
358442 self .sql = ""
359443 self .data = []
360444
361445 def register (self , ta : TstAccumulator ):
362- ta .add_table (Table (self .sql , self .data ))
446+ table = Table (self .sql , self .data )
447+ table .expected_error = getattr (self , "expected_error" , None )
448+ ta .add_table (table )
363449
364450
365451class TstView :
0 commit comments