Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 182 additions & 5 deletions packages/bigframes/bigframes/core/pyformat.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,160 @@ def _parse_fields(sql_template: str) -> list[str]:
]


def _is_escaped_open_brace(sql_template: str, idx: int, literal_char: str) -> bool:
"""Checks if the character at idx in sql_template is an escaped open brace '{{'."""
return sql_template[idx : idx + 2] == "{{" and literal_char == "{"


def _is_escaped_close_brace(sql_template: str, idx: int, literal_char: str) -> bool:
"""Checks if the character at idx in sql_template is an escaped close brace '}}'."""
return sql_template[idx : idx + 2] == "}}" and literal_char == "}"


def _consume_literal(sql_template: str, current_idx: int, literal_text: str) -> int:
"""Advances current_idx past literal_text in sql_template, accounting for escaped braces.

A **literal** (or literal text) is the static part of the template string that
does not contain formatting placeholders. The string.Formatter parser resolves
escaped braces ('{{' and '}}') into single braces ('{' and '}') in its output
literal_text.

This function aligns the resolved literal_text back to the original
sql_template by consuming 2 characters from sql_template ('{{' or '}}') for
every single escaped brace character in literal_text, and 1 character for
everything else.

Returns:
int: the advanced current_idx in sql_template.
"""
lit_idx = 0
while lit_idx < len(literal_text):
if _is_escaped_open_brace(sql_template, current_idx, literal_text[lit_idx]):
current_idx += 2
lit_idx += 1
elif _is_escaped_close_brace(sql_template, current_idx, literal_text[lit_idx]):
current_idx += 2
lit_idx += 1
elif (
current_idx < len(sql_template)
and sql_template[current_idx] == literal_text[lit_idx]
):
current_idx += 1
lit_idx += 1
else:
raise RuntimeError(
"Internal error: failed to align parsed SQL template with original query. "
f"Expected {literal_text[lit_idx]!r} at position {current_idx} in template, "
f"but found {sql_template[current_idx : current_idx + 2]!r}."
)
return current_idx


def _is_escaped_brace(sql_template: str, idx: int) -> bool:
"""Checks if the template has an escaped brace ('{{' or '}}') at the given index."""
return sql_template[idx : idx + 2] in ("{{", "}}")


def _advance_past_field(sql_template: str, current_idx: int) -> int:
"""Advances current_idx past the format field starting at current_idx.

A **field** (or replacement field) is a placeholder in the template enclosed
in braces (e.g., "{my_var}" or "{json_col: { "val": 1 } }").

This function assumes current_idx points to the opening '{' of a field.
It parses forward, tracking nested braces to find the matching closing '}'
that terminates the field, while ignoring escaped braces ('{{' and '}}')
which do not affect the nesting level.

Returns:
int: the index immediately after the closing '}' of the field.
"""
assert sql_template[current_idx] == "{"
brace_count = 1
current_idx += 1 # past '{'

while brace_count > 0 and current_idx < len(sql_template):
if _is_escaped_brace(sql_template, current_idx):
current_idx += 2
elif sql_template[current_idx] == "{":
brace_count += 1
current_idx += 1
elif sql_template[current_idx] == "}":
brace_count -= 1
current_idx += 1
else:
current_idx += 1

return current_idx


def _find_all_field_positions(sql_template: str) -> dict[tuple[str, int], int]:
"""Finds the character positions of all fields in the sql_template.

Returns:
dict: a dict mapping (field_name, occurrence_idx) to character index.
"""
formatter = string.Formatter()
current_idx = 0
seen_counts: dict[str, int] = {}
positions: dict[tuple[str, int], int] = {}

for literal_text, field_name, _, _ in formatter.parse(sql_template):
current_idx = _consume_literal(sql_template, current_idx, literal_text)

if field_name is not None:
occurrence_idx = seen_counts.get(field_name, 0)
seen_counts[field_name] = occurrence_idx + 1

positions[(field_name, occurrence_idx)] = current_idx

current_idx = _advance_past_field(sql_template, current_idx)

return positions


def get_error_context_at_pos(sql_template: str, pos: int) -> str:
"""Create a helpful 'pointer' to where the problematic position is
in the original SQL.

This should make the error message a lot friendlier, by providing more
context towards the problematic syntax.
"""
if pos == -1:
return ""

lines = sql_template.splitlines(keepends=True)

char_count = 0
target_line_idx = -1
for i, line in enumerate(lines):
if char_count <= pos < char_count + len(line):
target_line_idx = i
break
char_count += len(line)

if target_line_idx == -1:
return ""

col_offset = pos - char_count

context_lines = []
start_line = max(0, target_line_idx - 2)
end_line = min(len(lines), target_line_idx + 3)

for i in range(start_line, end_line):
line_num = i + 1
line_content = lines[i].rstrip("\r\n")
if i == target_line_idx:
context_lines.append(f"{line_num:4d}: {line_content}")
indent = 6 + col_offset
context_lines.append(" " * indent + "^")
else:
context_lines.append(f"{line_num:4d}: {line_content}")

return "\n".join(context_lines)


def pyformat(
sql_template: str,
*,
Expand All @@ -185,13 +339,36 @@ def pyformat(

Raises:
TypeError: if a referenced variable is not of a supported type.
KeyError: if a referenced variable is not found.
ValueError:
if a referenced variable is not found (KeyError is caught and raised
as ValueError with context).
"""
fields = _parse_fields(sql_template)

format_kwargs = {}
try:
fields = _parse_fields(sql_template)
except ValueError as e:
raise ValueError(
"Failed to parse SQL template. "
"Did you mean to escape '{' and '}' by doubling them?\n"
f"Error details: {e}"
) from e

format_kwargs: dict[str, str] = {}
seen_counts: dict[str, int] = {}
for name in fields:
value = pyformat_args[name]
seen_counts[name] = seen_counts.get(name, 0) + 1
try:
value = pyformat_args[name]
except KeyError as e:
positions = _find_all_field_positions(sql_template)
occurrence_idx = seen_counts[name] - 1
pos = positions.get((name, occurrence_idx), -1)
context = get_error_context_at_pos(sql_template, pos)
raise ValueError(
f"Undetected variable {name!r} in SQL template. "
"Did you mean to escape '{' and '}' by doubling them?\n"
f"{context}"
) from e
Comment on lines +365 to +370
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If get_error_context_at_pos returns an empty string (e.g., when the position is -1 or not found), the error message will end with a trailing newline, which looks untidy. We should only append the context if it is non-empty.

            context = get_error_context_at_pos(sql_template, pos)
            msg = (
                f"Undetected variable {name!r} in SQL template. "
                "Did you mean to escape '{' and '}' by doubling them?"
            )
            if context:
                msg += f"\n{context}"
            raise ValueError(msg) from e

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we don't expect this case to happen, I would prefer not to add logic for it.


format_kwargs[name] = _field_to_template_value(
name, value, session=session, dry_run=dry_run
)
Expand Down
124 changes: 122 additions & 2 deletions packages/bigframes/tests/unit/core/test_pyformat.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,72 @@ def test_parse_fields(sql_template: str, expected: List[str]):
assert fields == expected


def test_get_error_context_at_pos_invalid_pos():
assert pyformat.get_error_context_at_pos("SELECT 1", -1) == ""
assert pyformat.get_error_context_at_pos("SELECT 1", 100) == ""


def test_get_error_context_at_pos_single_line():
sql = "SELECT {foo}"
# pos of '{' is 7
context = pyformat.get_error_context_at_pos(sql, 7)
expected = " 1: SELECT {foo}\n ^"
assert context == expected


def test_get_error_context_at_pos_multi_line():
sql = "SELECT 1\nFROM my_table\nWHERE col = {foo}\nAND active = True\nLIMIT 10"
# Lines:
# 1: SELECT 1 (len 9 including \n)
# 2: FROM my_table (len 14 including \n) -> total 23
# 3: WHERE col = {foo} -> '{' is at 23 + 12 = 35

context = pyformat.get_error_context_at_pos(sql, 35)
expected = (
" 1: SELECT 1\n"
" 2: FROM my_table\n"
" 3: WHERE col = {foo}\n"
" ^\n"
" 4: AND active = True\n"
" 5: LIMIT 10"
)
assert context == expected


def test_get_error_context_at_pos_multi_line_limits():
# Test that it only shows at most 2 lines before and 2 lines after
sql = (
"LINE 1\n"
"LINE 2\n"
"LINE 3\n"
"LINE 4\n"
"LINE 5\n"
"TARGET {foo}\n"
"LINE 7\n"
"LINE 8\n"
"LINE 9\n"
"LINE 10"
)
# Line lengths:
# LINE 1\n (7)
# LINE 2\n (7) -> 14
# LINE 3\n (7) -> 21
# LINE 4\n (7) -> 28
# LINE 5\n (7) -> 35
# TARGET {foo}\n -> '{' is at 35 + 7 = 42

context = pyformat.get_error_context_at_pos(sql, 42)
expected = (
" 4: LINE 4\n"
" 5: LINE 5\n"
" 6: TARGET {foo}\n"
" ^\n"
" 7: LINE 7\n"
" 8: LINE 8"
)
assert context == expected


def test_pyformat_with_unsupported_type_raises_typeerror(session):
pyformat_args = {"my_object": object()}
sql = "SELECT {my_object}"
Expand All @@ -70,13 +136,67 @@ def test_pyformat_with_unsupported_type_raises_typeerror(session):
pyformat.pyformat(sql, pyformat_args=pyformat_args, session=session)


def test_pyformat_with_missing_variable_raises_keyerror(session):
def test_pyformat_with_missing_variable_raises_valueerror(session):
pyformat_args: Dict[str, Any] = {}
sql = "SELECT {my_object}"

with pytest.raises(KeyError, match="my_object"):
with pytest.raises(ValueError) as exc_info:
pyformat.pyformat(sql, pyformat_args=pyformat_args, session=session)

err_msg = str(exc_info.value)
assert "Undetected variable 'my_object' in SQL template" in err_msg
assert "Did you mean to escape '{' and '}'" in err_msg
assert " 1: SELECT {my_object}" in err_msg
assert " ^" in err_msg


def test_pyformat_with_unescaped_braces_raises_valueerror_with_context(session):
pyformat_args = {"active": True}
sql = """SELECT * FROM my_table
WHERE json_col = { "generation_config": { "temperature": 0.9 } }
AND active = {active}
"""

with pytest.raises(ValueError) as exc_info:
pyformat.pyformat(sql, pyformat_args=pyformat_args, session=session)

err_msg = str(exc_info.value)
assert "Undetected variable ' \"generation_config\"' in SQL template" in err_msg
assert "Did you mean to escape '{' and '}'" in err_msg
# The triple quote string starts with SELECT immediately, so lines are:
# 1: SELECT * FROM my_table
# 2: WHERE json_col = { "generation_config": { "temperature": 0.9 } }
# 3: AND active = {active}
assert " 1: SELECT * FROM my_table" in err_msg
assert (
' 2: WHERE json_col = { "generation_config": { "temperature": 0.9 } }'
in err_msg
)
assert " ^" in err_msg
assert " 3: AND active = {active}" in err_msg


def test_pyformat_with_malformed_template_raises_valueerror(session):
pyformat_args: Dict[str, Any] = {}

# Case 1: Single '{' (unmatched)
sql_1 = "SELECT {foo"
with pytest.raises(ValueError) as exc_info:
pyformat.pyformat(sql_1, pyformat_args=pyformat_args, session=session)
err_msg_1 = str(exc_info.value)
assert "Failed to parse SQL template" in err_msg_1
assert "Did you mean to escape '{' and '}'" in err_msg_1
assert "expected '}' before end of string" in err_msg_1

# Case 2: Single '}' (unmatched)
sql_2 = "SELECT foo}"
with pytest.raises(ValueError) as exc_info:
pyformat.pyformat(sql_2, pyformat_args=pyformat_args, session=session)
err_msg_2 = str(exc_info.value)
assert "Failed to parse SQL template" in err_msg_2
assert "Did you mean to escape '{' and '}'" in err_msg_2
assert "Single '}' encountered in format string" in err_msg_2


def test_pyformat_with_no_variables(session):
pyformat_args: Dict[str, Any] = {}
Expand Down
Loading