diff --git a/packages/bigframes/bigframes/core/compile/sqlglot/sql/base.py b/packages/bigframes/bigframes/core/compile/sqlglot/sql/base.py index f77dcbee4d93..7ec660db9c80 100644 --- a/packages/bigframes/bigframes/core/compile/sqlglot/sql/base.py +++ b/packages/bigframes/bigframes/core/compile/sqlglot/sql/base.py @@ -49,6 +49,13 @@ def to_sql(expr: sge.Expression) -> str: """Generate SQL string from the given expression.""" + + def _flatten_null_casts(node: sge.Expression) -> sge.Expression: + if isinstance(node, (sge.Cast, sge.TryCast)) and str(node.to).upper() == "NULL": + return sge.Null() + return node + + expr = expr.transform(_flatten_null_casts) return expr.sql(dialect=DIALECT, pretty=PRETTY) @@ -121,8 +128,10 @@ def literal(value: typing.Any, dtype: dtypes.Dtype | None = None) -> sge.Express return sge.convert(value) -def cast(arg: typing.Any, to: str, safe: bool = False) -> sge.Cast | sge.TryCast: +def cast(arg: typing.Any, to: str | sge.DataType, safe: bool = False) -> sge.Expression: """Return a SQL expression that casts the given argument to the specified type.""" + if str(to).upper() == "NULL": + return sge.Null() if safe: return sge.TryCast(this=arg, to=to) else: diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/sql/test_base.py b/packages/bigframes/tests/unit/core/compile/sqlglot/sql/test_base.py index 617f3636d403..edbb77b32c62 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/sql/test_base.py +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/sql/test_base.py @@ -15,7 +15,9 @@ import datetime import decimal import re +import unittest.mock as mock +import bigframes_vendored.sqlglot.expressions as sge import numpy as np import pandas as pd import pyarrow as pa @@ -162,8 +164,6 @@ def test_literal_for_list(value: list, expected: str): def test_literal_null_type(): - import unittest.mock as mock - mock_dtype = mock.Mock() with mock.patch( "bigframes.core.compile.sqlglot.sql.base.sgt.from_bigframes_dtype", @@ -171,3 +171,21 @@ def test_literal_null_type(): ): got = sql.to_sql(sql.literal(None, dtype=mock_dtype)) assert got == "NULL" + + +@pytest.mark.parametrize( + ("arg", "safe"), + ( + pytest.param("abc", False, id="string"), + pytest.param(None, False, id="none"), + pytest.param("abc", True, id="safe_cast"), + ), +) +def test_cast_to_null_type_returns_flat_null(arg, safe): + assert sql.to_sql(sql.cast(arg, "NULL", safe=safe)) == "NULL" + + +def test_nested_cast_to_null_type_is_flattened(): + nested = sge.Cast(this=sge.Cast(this=sge.Null(), to="NULL"), to="INT64") + + assert sql.to_sql(nested) == "CAST(NULL AS INT64)"