From a967dd9b0799ba430aef743728b7fc0cf4015d50 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Tue, 16 Jun 2026 01:18:53 +0000 Subject: [PATCH] refactor(bigframes): Extract json conversions to distinct ops --- .../ibis_compiler/scalar_op_registry.py | 68 ++++++++++--------- .../bigframes/core/compile/polars/compiler.py | 5 ++ .../bigframes/core/compile/polars/lowering.py | 3 - .../sqlglot/expressions/generic_ops.py | 35 ---------- .../compile/sqlglot/expressions/json_ops.py | 35 +++++++++- packages/bigframes/bigframes/dataframe.py | 11 ++- .../bigframes/operations/__init__.py | 1 + .../bigframes/operations/generic_ops.py | 33 --------- .../bigframes/operations/json_ops.py | 2 + packages/bigframes/bigframes/series.py | 16 ++++- .../system/small/engines/test_generic_ops.py | 16 ++--- 11 files changed, 102 insertions(+), 123 deletions(-) diff --git a/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py b/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py index 5172d1e7c602..9386e6f25228 100644 --- a/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py +++ b/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py @@ -922,35 +922,6 @@ def astype_op_impl(x: ibis_types.Value, op: ops.AsTypeOp): elif to_type == ibis_dtypes.time: return x_converted.time() - if to_type == ibis_dtypes.json: - if x.type() == ibis_dtypes.string: - return parse_json_in_safe(x) if op.safe else parse_json(x) - if x.type() == ibis_dtypes.bool: - x_bool = typing.cast( - ibis_types.StringValue, - bigframes.core.compile.ibis_types.cast_ibis_value( - x, ibis_dtypes.string, safe=op.safe - ), - ).lower() - return parse_json_in_safe(x_bool) if op.safe else parse_json(x_bool) - if x.type() in (ibis_dtypes.int64, ibis_dtypes.float64): - x_str = bigframes.core.compile.ibis_types.cast_ibis_value( - x, ibis_dtypes.string, safe=op.safe - ) - return parse_json_in_safe(x_str) if op.safe else parse_json(x_str) - - if x.type() == ibis_dtypes.json: - if to_type == ibis_dtypes.int64: - return cast_json_to_int64_in_safe(x) if op.safe else cast_json_to_int64(x) - if to_type == ibis_dtypes.float64: - return ( - cast_json_to_float64_in_safe(x) if op.safe else cast_json_to_float64(x) - ) - if to_type == ibis_dtypes.bool: - return cast_json_to_bool_in_safe(x) if op.safe else cast_json_to_bool(x) - if to_type == ibis_dtypes.string: - return cast_json_to_string_in_safe(x) if op.safe else cast_json_to_string(x) - # TODO: either inline this function, or push rest of this op into the function return bigframes.core.compile.ibis_types.cast_ibis_value(x, to_type, safe=op.safe) @@ -1193,9 +1164,42 @@ def parse_json_op_impl(x: ibis_types.Value, op: ops.ParseJSON): return parse_json(json_str=x) -@scalar_op_compiler.register_unary_op(ops.ToJSON) -def to_json_op_impl(json_obj: ibis_types.Value): - return to_json(json_obj=json_obj) +@scalar_op_compiler.register_unary_op(ops.ToJSON, pass_op=True) +def to_json_op_impl(x: ibis_types.Value, op: ops.ToJSON): + if x.type() == ibis_dtypes.string: + return parse_json_in_safe(x) if op.safe else parse_json(x) + if x.type() == ibis_dtypes.bool: + x_bool = typing.cast( + ibis_types.StringValue, + bigframes.core.compile.ibis_types.cast_ibis_value( + x, ibis_dtypes.string, safe=op.safe + ), + ).lower() + return parse_json_in_safe(x_bool) if op.safe else parse_json(x_bool) + if x.type() in (ibis_dtypes.int64, ibis_dtypes.float64): + x_str = bigframes.core.compile.ibis_types.cast_ibis_value( + x, ibis_dtypes.string, safe=op.safe + ) + return parse_json_in_safe(x_str) if op.safe else parse_json(x_str) + raise TypeError(f"Cannot cast to JSON from type {x.type()}") + + +@scalar_op_compiler.register_unary_op(ops.JSONDecode, pass_op=True) +def json_decode_op_impl(x: ibis_types.Value, op: ops.JSONDecode): + to_type = bigframes.core.compile.ibis_types.bigframes_dtype_to_ibis_dtype( + op.to_type + ) + if to_type == ibis_dtypes.int64: + return cast_json_to_int64_in_safe(x) if op.safe else cast_json_to_int64(x) + if to_type == ibis_dtypes.float64: + return ( + cast_json_to_float64_in_safe(x) if op.safe else cast_json_to_float64(x) + ) + if to_type == ibis_dtypes.bool: + return cast_json_to_bool_in_safe(x) if op.safe else cast_json_to_bool(x) + if to_type == ibis_dtypes.string: + return cast_json_to_string_in_safe(x) if op.safe else cast_json_to_string(x) + raise TypeError(f"Cannot cast from JSON to type {to_type}") @scalar_op_compiler.register_unary_op(ops.ToJSONString) diff --git a/packages/bigframes/bigframes/core/compile/polars/compiler.py b/packages/bigframes/bigframes/core/compile/polars/compiler.py index 6f24929eeb4e..b93b1403c56a 100644 --- a/packages/bigframes/bigframes/core/compile/polars/compiler.py +++ b/packages/bigframes/bigframes/core/compile/polars/compiler.py @@ -482,6 +482,11 @@ def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: assert isinstance(op, json_ops.JSONDecode) return input.str.json_decode(_DTYPE_MAPPING[op.to_type]) + @compile_op.register(json_ops.ToJSON) + def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: + # Polars represents JSON as string, so to_json is cast to String + return input.cast(pl.String()) + @compile_op.register(arr_ops.ToArrayOp) def _(self, op: ops.ToArrayOp, *inputs: pl.Expr) -> pl.Expr: return pl.concat_list(*inputs) diff --git a/packages/bigframes/bigframes/core/compile/polars/lowering.py b/packages/bigframes/bigframes/core/compile/polars/lowering.py index 7416ebc963b4..a56c44a49071 100644 --- a/packages/bigframes/bigframes/core/compile/polars/lowering.py +++ b/packages/bigframes/bigframes/core/compile/polars/lowering.py @@ -412,9 +412,6 @@ def _coerce_comparables( def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression): if arg.output_type == cast_op.to_type: return arg - - if arg.output_type == dtypes.JSON_DTYPE: - return json_ops.JSONDecode(cast_op.to_type).as_expr(arg) if ( arg.output_type == dtypes.STRING_DTYPE and cast_op.to_type == dtypes.DATETIME_DTYPE diff --git a/packages/bigframes/bigframes/core/compile/sqlglot/expressions/generic_ops.py b/packages/bigframes/bigframes/core/compile/sqlglot/expressions/generic_ops.py index 22dcd8bf51ac..2cc27cb8e5a2 100644 --- a/packages/bigframes/bigframes/core/compile/sqlglot/expressions/generic_ops.py +++ b/packages/bigframes/bigframes/core/compile/sqlglot/expressions/generic_ops.py @@ -36,12 +36,6 @@ def _(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression: sg_to_type = sqlglot_types.from_bigframes_dtype(to_type) sg_expr = expr.expr - if to_type == dtypes.JSON_DTYPE: - return _cast_to_json(expr, op) - - if from_type == dtypes.JSON_DTYPE: - return _cast_from_json(expr, op) - if to_type == dtypes.INT_DTYPE: result = _cast_to_int(expr, op) if result is not None: @@ -251,35 +245,6 @@ def _(*values: TypedExpr) -> sge.Expression: # Helper functions -def _cast_to_json(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression: - from_type = expr.dtype - sg_expr = expr.expr - - if from_type == dtypes.STRING_DTYPE: - func_name = "SAFE.PARSE_JSON" if op.safe else "PARSE_JSON" - return sge.func(func_name, sg_expr) - if from_type in (dtypes.INT_DTYPE, dtypes.BOOL_DTYPE, dtypes.FLOAT_DTYPE): - sg_expr = sge.Cast(this=sg_expr, to="STRING") - return sge.func("PARSE_JSON", sg_expr) - raise TypeError(f"Cannot cast from {from_type} to {dtypes.JSON_DTYPE}") - - -def _cast_from_json(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression: - to_type = op.to_type - sg_expr = expr.expr - func_name = "" - if to_type == dtypes.INT_DTYPE: - func_name = "INT64" - elif to_type == dtypes.FLOAT_DTYPE: - func_name = "FLOAT64" - elif to_type == dtypes.BOOL_DTYPE: - func_name = "BOOL" - elif to_type == dtypes.STRING_DTYPE: - func_name = "STRING" - if func_name: - func_name = "SAFE." + func_name if op.safe else func_name - return sge.func(func_name, sg_expr) - raise TypeError(f"Cannot cast from {dtypes.JSON_DTYPE} to {to_type}") def _cast_to_int(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression | None: diff --git a/packages/bigframes/bigframes/core/compile/sqlglot/expressions/json_ops.py b/packages/bigframes/bigframes/core/compile/sqlglot/expressions/json_ops.py index f27b1f138d70..6ef0940306b9 100644 --- a/packages/bigframes/bigframes/core/compile/sqlglot/expressions/json_ops.py +++ b/packages/bigframes/bigframes/core/compile/sqlglot/expressions/json_ops.py @@ -17,6 +17,7 @@ import bigframes_vendored.sqlglot.expressions as sge import bigframes.core.compile.sqlglot.expression_compiler as expression_compiler +from bigframes import dtypes from bigframes import operations as ops from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr @@ -69,9 +70,37 @@ def _(expr: TypedExpr) -> sge.Expression: return sge.func("PARSE_JSON", expr.expr) -@register_unary_op(ops.ToJSON) -def _(expr: TypedExpr) -> sge.Expression: - return sge.func("TO_JSON", expr.expr) +@register_unary_op(ops.ToJSON, pass_op=True) +def _(expr: TypedExpr, op: ops.ToJSON) -> sge.Expression: + from_type = expr.dtype + sg_expr = expr.expr + + if from_type == dtypes.STRING_DTYPE: + func_name = "SAFE.PARSE_JSON" if op.safe else "PARSE_JSON" + return sge.func(func_name, sg_expr) + if from_type in (dtypes.INT_DTYPE, dtypes.BOOL_DTYPE, dtypes.FLOAT_DTYPE): + sg_expr = sge.Cast(this=sg_expr, to="STRING") + return sge.func("PARSE_JSON", sg_expr) + raise TypeError(f"Cannot cast from {from_type} to {dtypes.JSON_DTYPE}") + + +@register_unary_op(ops.JSONDecode, pass_op=True) +def _(expr: TypedExpr, op: ops.JSONDecode) -> sge.Expression: + to_type = op.to_type + sg_expr = expr.expr + func_name = "" + if to_type == dtypes.INT_DTYPE: + func_name = "INT64" + elif to_type == dtypes.FLOAT_DTYPE: + func_name = "FLOAT64" + elif to_type == dtypes.BOOL_DTYPE: + func_name = "BOOL" + elif to_type == dtypes.STRING_DTYPE: + func_name = "STRING" + if func_name: + func_name = "SAFE." + func_name if op.safe else func_name + return sge.func(func_name, sg_expr) + raise TypeError(f"Cannot cast from {dtypes.JSON_DTYPE} to {to_type}") @register_unary_op(ops.ToJSONString) diff --git a/packages/bigframes/bigframes/dataframe.py b/packages/bigframes/bigframes/dataframe.py index 6b7922fe9753..e071151baabe 100644 --- a/packages/bigframes/bigframes/dataframe.py +++ b/packages/bigframes/bigframes/dataframe.py @@ -442,17 +442,16 @@ def astype( if errors not in ["raise", "null"]: raise ValueError("Arg 'error' must be one of 'raise' or 'null'") - safe_cast = errors == "null" - if isinstance(dtype, dict): result = self.copy() for col, to_type in dtype.items(): - result[col] = result[col].astype(to_type) + result[col] = result[col].astype(to_type, errors=errors) return result - dtype = bigframes.dtypes.bigframes_type(dtype) - - return self._apply_unary_op(ops.AsTypeOp(dtype, safe_cast)) + result = self.copy() + for col in result.columns: + result[col] = result[col].astype(dtype, errors=errors) + return result def _should_sql_have_index(self) -> bool: """Should the SQL we pass to BQML and other I/O include the index?""" diff --git a/packages/bigframes/bigframes/operations/__init__.py b/packages/bigframes/bigframes/operations/__init__.py index b8d860029a0f..d6a6193c7579 100644 --- a/packages/bigframes/bigframes/operations/__init__.py +++ b/packages/bigframes/bigframes/operations/__init__.py @@ -128,6 +128,7 @@ ) from bigframes.operations.googlesql import GoogleSqlScalarOp from bigframes.operations.json_ops import ( + JSONDecode, JSONExtract, JSONExtractArray, JSONExtractStringArray, diff --git a/packages/bigframes/bigframes/operations/generic_ops.py b/packages/bigframes/bigframes/operations/generic_ops.py index 9a58f4b8ef33..99cda5fc095f 100644 --- a/packages/bigframes/bigframes/operations/generic_ops.py +++ b/packages/bigframes/bigframes/operations/generic_ops.py @@ -93,10 +93,6 @@ dtypes.STRING_DTYPE, dtypes.INT_DTYPE, ), - ( - dtypes.JSON_DTYPE, - dtypes.INT_DTYPE, - ), # Float casts ( dtypes.BOOL_DTYPE, @@ -118,10 +114,6 @@ dtypes.STRING_DTYPE, dtypes.FLOAT_DTYPE, ), - ( - dtypes.JSON_DTYPE, - dtypes.FLOAT_DTYPE, - ), # Bool casts ( dtypes.INT_DTYPE, @@ -131,10 +123,6 @@ dtypes.FLOAT_DTYPE, dtypes.BOOL_DTYPE, ), - ( - dtypes.JSON_DTYPE, - dtypes.BOOL_DTYPE, - ), # String casts ( dtypes.BYTES_DTYPE, @@ -168,10 +156,6 @@ dtypes.DATE_DTYPE, dtypes.STRING_DTYPE, ), - ( - dtypes.JSON_DTYPE, - dtypes.STRING_DTYPE, - ), # bytes casts ( dtypes.STRING_DTYPE, @@ -276,23 +260,6 @@ dtypes.INT_DTYPE, dtypes.TIMEDELTA_DTYPE, ), - # json casts - ( - dtypes.BOOL_DTYPE, - dtypes.JSON_DTYPE, - ), - ( - dtypes.FLOAT_DTYPE, - dtypes.JSON_DTYPE, - ), - ( - dtypes.STRING_DTYPE, - dtypes.JSON_DTYPE, - ), - ( - dtypes.INT_DTYPE, - dtypes.JSON_DTYPE, - ), ) ) diff --git a/packages/bigframes/bigframes/operations/json_ops.py b/packages/bigframes/bigframes/operations/json_ops.py index 7260a7922305..4aaaa43eac9b 100644 --- a/packages/bigframes/bigframes/operations/json_ops.py +++ b/packages/bigframes/bigframes/operations/json_ops.py @@ -105,6 +105,7 @@ def output_type(self, *input_types): @dataclasses.dataclass(frozen=True) class ToJSON(base_ops.UnaryOp): name: typing.ClassVar[str] = "to_json" + safe: bool = True def output_type(self, *input_types): input_type = input_types[0] @@ -220,6 +221,7 @@ def output_type(self, *input_types): class JSONDecode(base_ops.UnaryOp): name: typing.ClassVar[str] = "json_decode" to_type: dtypes.Dtype + safe: bool = True def output_type(self, *input_types): input_type = input_types[0] diff --git a/packages/bigframes/bigframes/series.py b/packages/bigframes/bigframes/series.py index 181bc4f63b2f..99268350566a 100644 --- a/packages/bigframes/bigframes/series.py +++ b/packages/bigframes/bigframes/series.py @@ -646,9 +646,19 @@ def astype( if errors not in ["raise", "null"]: raise ValueError("Argument 'errors' must be one of 'raise' or 'null'") dtype = bigframes.dtypes.bigframes_type(dtype) - return self._apply_unary_op( - bigframes.operations.AsTypeOp(to_type=dtype, safe=(errors == "null")) - ) + safe = errors == "null" + if dtype == bigframes.dtypes.JSON_DTYPE: + return self._apply_unary_op( + bigframes.operations.ToJSON(safe=safe) + ) + elif self.dtype == bigframes.dtypes.JSON_DTYPE: + return self._apply_unary_op( + bigframes.operations.JSONDecode(to_type=dtype, safe=safe) + ) + else: + return self._apply_unary_op( + bigframes.operations.AsTypeOp(to_type=dtype, safe=safe) + ) def to_pandas( self, diff --git a/packages/bigframes/tests/system/small/engines/test_generic_ops.py b/packages/bigframes/tests/system/small/engines/test_generic_ops.py index 22ad1bfefa4e..6237caf14bad 100644 --- a/packages/bigframes/tests/system/small/engines/test_generic_ops.py +++ b/packages/bigframes/tests/system/small/engines/test_generic_ops.py @@ -263,16 +263,16 @@ def test_engines_astype_time(scalars_array_value: array_value.ArrayValue, engine @pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_astype_from_json(scalars_array_value: array_value.ArrayValue, engine): exprs = [ - ops.AsTypeOp(to_type=bigframes.dtypes.INT_DTYPE).as_expr( + ops.JSONDecode(to_type=bigframes.dtypes.INT_DTYPE).as_expr( expression.const("5", bigframes.dtypes.JSON_DTYPE) ), - ops.AsTypeOp(to_type=bigframes.dtypes.FLOAT_DTYPE).as_expr( + ops.JSONDecode(to_type=bigframes.dtypes.FLOAT_DTYPE).as_expr( expression.const("5", bigframes.dtypes.JSON_DTYPE) ), - ops.AsTypeOp(to_type=bigframes.dtypes.BOOL_DTYPE).as_expr( + ops.JSONDecode(to_type=bigframes.dtypes.BOOL_DTYPE).as_expr( expression.const("true", bigframes.dtypes.JSON_DTYPE) ), - ops.AsTypeOp(to_type=bigframes.dtypes.STRING_DTYPE).as_expr( + ops.JSONDecode(to_type=bigframes.dtypes.STRING_DTYPE).as_expr( expression.const('"hello world"', bigframes.dtypes.JSON_DTYPE) ), ] @@ -284,17 +284,17 @@ def test_engines_astype_from_json(scalars_array_value: array_value.ArrayValue, e @pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_astype_to_json(scalars_array_value: array_value.ArrayValue, engine): exprs = [ - ops.AsTypeOp(to_type=bigframes.dtypes.JSON_DTYPE).as_expr( + ops.ToJSON().as_expr( expression.deref("int64_col") ), - ops.AsTypeOp(to_type=bigframes.dtypes.JSON_DTYPE).as_expr( + ops.ToJSON().as_expr( # Use a const since float to json has precision issues expression.const(5.2, bigframes.dtypes.FLOAT_DTYPE) ), - ops.AsTypeOp(to_type=bigframes.dtypes.JSON_DTYPE).as_expr( + ops.ToJSON().as_expr( expression.deref("bool_col") ), - ops.AsTypeOp(to_type=bigframes.dtypes.JSON_DTYPE).as_expr( + ops.ToJSON().as_expr( # Use a const since "str_col" has special chars. expression.const('"hello world"', bigframes.dtypes.STRING_DTYPE) ),