-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsqlvalidator.py
More file actions
129 lines (106 loc) · 5.34 KB
/
Copy pathsqlvalidator.py
File metadata and controls
129 lines (106 loc) · 5.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import json
from sqlglot import parse, exp
class SQLValidator:
"""
A class to validate SQL queries against a predefined schema.
The validation process includes:
1. Safety Check: Ensures the query is read-only (no DELETE, UPDATE, etc.).
2. Syntax Check: Verifies the SQL is syntactically correct.
3. Schema Check: Confirms all tables and columns exist in the schema.
"""
UNSAFE_COMMANDS = {"DELETE", "UPDATE", "INSERT", "DROP", "ALTER", "TRUNCATE", "GRANT", "REVOKE"}
def __init__(self, schema_json: dict):
"""
Initializes the validator with a database schema.
Args:
schema_json_path (str): The file path to the JSON schema.
"""
# Pre-process the schema for efficient lookups
self.tables = {table['table_name'] for table in schema_json['tables']}
self.columns = {
table['table_name']: set(table['columns'].keys())
for table in schema_json['tables']
}
print("Validator initialized. Known tables:", self.tables)
def validate(self, sql_query: str) -> dict:
"""
Validates a single SQL query through a multi-step process.
Args:
sql_query (str): The SQL query string to validate.
Returns:
dict: A dictionary containing a boolean 'is_valid' and a list of 'errors'.
"""
errors = []
try:
# sqlglot.parse returns a list of parsed expressions.
parsed_expressions = parse(sql_query)
if len(parsed_expressions) > 1:
errors.append("Multiple SQL statements are not allowed.")
return {"is_valid": False, "errors": errors}
parsed = parsed_expressions[0]
# 1. Safety Check: We only allow SELECT statements.
if not isinstance(parsed, exp.Select):
command_name = type(parsed).__name__.upper()
if command_name in self.UNSAFE_COMMANDS:
errors.append(f"Unsafe command '{command_name}' is not allowed.")
else:
errors.append(f"Only SELECT statements are allowed. Found '{command_name}'.")
return {"is_valid": False, "errors": errors}
# 2. Schema Check using the parsed expression
# Build a map of all aliases to their real table names for context
table_context = self._get_table_context(parsed)
# 2a. Validate tables
for table in parsed.find_all(exp.Table):
table_name = table.this.name
if table_name not in self.tables:
errors.append(f"Table '{table_name}' does not exist.")
if errors: # Don't check columns if tables are invalid
return {"is_valid": False, "errors": errors}
# 2b. Validate columns
for column in parsed.find_all(exp.Column):
if column.this.name == '*': # Ignore '*' wildcard
continue
col_name = column.this.name
table_alias = column.table # The alias or table name used, e.g., 'u' in 'u.name'
if table_alias:
# Column is qualified (e.g., users.name)
real_table = table_context.get(table_alias)
if not real_table:
# This case is rare as sqlglot would likely fail parsing if the alias doesn't exist
errors.append(f"Table alias or name '{table_alias}' not found in query context.")
continue
if col_name not in self.columns.get(real_table, set()):
errors.append(f"Column '{col_name}' does not exist in table '{real_table}'.")
else:
# Column is unqualified (e.g., name). Check if it exists in any table in the query.
found = False
for table_name in table_context.values():
if col_name in self.columns.get(table_name, set()):
found = True
break
if not found:
errors.append(f"Unqualified column '{col_name}' could not be found in any of the query's tables.")
except Exception as e:
# Catches syntax errors from sqlglot during parsing
errors.append(f"Invalid SQL syntax: {e}")
return {"is_valid": False, "errors": errors}
if errors:
return {"is_valid": False, "errors": errors}
return {"is_valid": True, "errors": []}
def _get_table_context(self, parsed_query: exp.Expression) -> dict:
"""Helper to create a mapping from table aliases to real table names."""
context = {}
from_clause = parsed_query.args.get('from')
if from_clause:
for table in from_clause.find_all(exp.Table):
table_name = table.this.name
alias = table.alias or table_name
context[alias] = table_name
# Get tables from JOIN clauses
for join in parsed_query.args.get('joins', []):
# The table being joined is in join.this
table = join.this
table_name = table.this.name
alias = table.alias or table_name
context[alias] = table_name
return context