Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -922,35 +922,6 @@ def astype_op_impl(x: ibis_types.Value, op: ops.AsTypeOp):
elif to_type == ibis_dtypes.time:
return x_converted.time()

if to_type == ibis_dtypes.json:
if x.type() == ibis_dtypes.string:
return parse_json_in_safe(x) if op.safe else parse_json(x)
if x.type() == ibis_dtypes.bool:
x_bool = typing.cast(
ibis_types.StringValue,
bigframes.core.compile.ibis_types.cast_ibis_value(
x, ibis_dtypes.string, safe=op.safe
),
).lower()
return parse_json_in_safe(x_bool) if op.safe else parse_json(x_bool)
if x.type() in (ibis_dtypes.int64, ibis_dtypes.float64):
x_str = bigframes.core.compile.ibis_types.cast_ibis_value(
x, ibis_dtypes.string, safe=op.safe
)
return parse_json_in_safe(x_str) if op.safe else parse_json(x_str)

if x.type() == ibis_dtypes.json:
if to_type == ibis_dtypes.int64:
return cast_json_to_int64_in_safe(x) if op.safe else cast_json_to_int64(x)
if to_type == ibis_dtypes.float64:
return (
cast_json_to_float64_in_safe(x) if op.safe else cast_json_to_float64(x)
)
if to_type == ibis_dtypes.bool:
return cast_json_to_bool_in_safe(x) if op.safe else cast_json_to_bool(x)
if to_type == ibis_dtypes.string:
return cast_json_to_string_in_safe(x) if op.safe else cast_json_to_string(x)

# TODO: either inline this function, or push rest of this op into the function
return bigframes.core.compile.ibis_types.cast_ibis_value(x, to_type, safe=op.safe)

Expand Down Expand Up @@ -1193,9 +1164,42 @@ def parse_json_op_impl(x: ibis_types.Value, op: ops.ParseJSON):
return parse_json(json_str=x)


@scalar_op_compiler.register_unary_op(ops.ToJSON)
def to_json_op_impl(json_obj: ibis_types.Value):
return to_json(json_obj=json_obj)
@scalar_op_compiler.register_unary_op(ops.ToJSON, pass_op=True)
def to_json_op_impl(x: ibis_types.Value, op: ops.ToJSON):
if x.type() == ibis_dtypes.string:
return parse_json_in_safe(x) if op.safe else parse_json(x)
if x.type() == ibis_dtypes.bool:
x_bool = typing.cast(
ibis_types.StringValue,
bigframes.core.compile.ibis_types.cast_ibis_value(
x, ibis_dtypes.string, safe=op.safe
),
).lower()
return parse_json_in_safe(x_bool) if op.safe else parse_json(x_bool)
if x.type() in (ibis_dtypes.int64, ibis_dtypes.float64):
x_str = bigframes.core.compile.ibis_types.cast_ibis_value(
x, ibis_dtypes.string, safe=op.safe
)
return parse_json_in_safe(x_str) if op.safe else parse_json(x_str)
raise TypeError(f"Cannot cast to JSON from type {x.type()}")


@scalar_op_compiler.register_unary_op(ops.JSONDecode, pass_op=True)
def json_decode_op_impl(x: ibis_types.Value, op: ops.JSONDecode):
to_type = bigframes.core.compile.ibis_types.bigframes_dtype_to_ibis_dtype(
op.to_type
)
if to_type == ibis_dtypes.int64:
return cast_json_to_int64_in_safe(x) if op.safe else cast_json_to_int64(x)
if to_type == ibis_dtypes.float64:
return (
cast_json_to_float64_in_safe(x) if op.safe else cast_json_to_float64(x)
)
if to_type == ibis_dtypes.bool:
return cast_json_to_bool_in_safe(x) if op.safe else cast_json_to_bool(x)
if to_type == ibis_dtypes.string:
return cast_json_to_string_in_safe(x) if op.safe else cast_json_to_string(x)
raise TypeError(f"Cannot cast from JSON to type {to_type}")


@scalar_op_compiler.register_unary_op(ops.ToJSONString)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,11 @@ def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
assert isinstance(op, json_ops.JSONDecode)
return input.str.json_decode(_DTYPE_MAPPING[op.to_type])

@compile_op.register(json_ops.ToJSON)
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency and better type safety, annotate the op parameter with the specific json_ops.ToJSON type instead of the generic ops.ScalarOp, matching the pattern used in other registered operations (such as ops.ToArrayOp and json_ops.JSONDecode).

Suggested change
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
def _(self, op: json_ops.ToJSON, input: pl.Expr) -> pl.Expr:

# Polars represents JSON as string, so to_json is cast to String
return input.cast(pl.String())

@compile_op.register(arr_ops.ToArrayOp)
def _(self, op: ops.ToArrayOp, *inputs: pl.Expr) -> pl.Expr:
return pl.concat_list(*inputs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -412,9 +412,6 @@ def _coerce_comparables(
def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression):
if arg.output_type == cast_op.to_type:
return arg

if arg.output_type == dtypes.JSON_DTYPE:
return json_ops.JSONDecode(cast_op.to_type).as_expr(arg)
if (
arg.output_type == dtypes.STRING_DTYPE
and cast_op.to_type == dtypes.DATETIME_DTYPE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,6 @@ def _(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression:
sg_to_type = sqlglot_types.from_bigframes_dtype(to_type)
sg_expr = expr.expr

if to_type == dtypes.JSON_DTYPE:
return _cast_to_json(expr, op)

if from_type == dtypes.JSON_DTYPE:
return _cast_from_json(expr, op)

if to_type == dtypes.INT_DTYPE:
result = _cast_to_int(expr, op)
if result is not None:
Expand Down Expand Up @@ -251,35 +245,6 @@ def _(*values: TypedExpr) -> sge.Expression:


# Helper functions
def _cast_to_json(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression:
from_type = expr.dtype
sg_expr = expr.expr

if from_type == dtypes.STRING_DTYPE:
func_name = "SAFE.PARSE_JSON" if op.safe else "PARSE_JSON"
return sge.func(func_name, sg_expr)
if from_type in (dtypes.INT_DTYPE, dtypes.BOOL_DTYPE, dtypes.FLOAT_DTYPE):
sg_expr = sge.Cast(this=sg_expr, to="STRING")
return sge.func("PARSE_JSON", sg_expr)
raise TypeError(f"Cannot cast from {from_type} to {dtypes.JSON_DTYPE}")


def _cast_from_json(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression:
to_type = op.to_type
sg_expr = expr.expr
func_name = ""
if to_type == dtypes.INT_DTYPE:
func_name = "INT64"
elif to_type == dtypes.FLOAT_DTYPE:
func_name = "FLOAT64"
elif to_type == dtypes.BOOL_DTYPE:
func_name = "BOOL"
elif to_type == dtypes.STRING_DTYPE:
func_name = "STRING"
if func_name:
func_name = "SAFE." + func_name if op.safe else func_name
return sge.func(func_name, sg_expr)
raise TypeError(f"Cannot cast from {dtypes.JSON_DTYPE} to {to_type}")


def _cast_to_int(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression | None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import bigframes_vendored.sqlglot.expressions as sge

import bigframes.core.compile.sqlglot.expression_compiler as expression_compiler
from bigframes import dtypes
from bigframes import operations as ops
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr

Expand Down Expand Up @@ -69,9 +70,37 @@ def _(expr: TypedExpr) -> sge.Expression:
return sge.func("PARSE_JSON", expr.expr)


@register_unary_op(ops.ToJSON)
def _(expr: TypedExpr) -> sge.Expression:
return sge.func("TO_JSON", expr.expr)
@register_unary_op(ops.ToJSON, pass_op=True)
def _(expr: TypedExpr, op: ops.ToJSON) -> sge.Expression:
from_type = expr.dtype
sg_expr = expr.expr

if from_type == dtypes.STRING_DTYPE:
func_name = "SAFE.PARSE_JSON" if op.safe else "PARSE_JSON"
return sge.func(func_name, sg_expr)
if from_type in (dtypes.INT_DTYPE, dtypes.BOOL_DTYPE, dtypes.FLOAT_DTYPE):
sg_expr = sge.Cast(this=sg_expr, to="STRING")
return sge.func("PARSE_JSON", sg_expr)
raise TypeError(f"Cannot cast from {from_type} to {dtypes.JSON_DTYPE}")


@register_unary_op(ops.JSONDecode, pass_op=True)
def _(expr: TypedExpr, op: ops.JSONDecode) -> sge.Expression:
to_type = op.to_type
sg_expr = expr.expr
func_name = ""
if to_type == dtypes.INT_DTYPE:
func_name = "INT64"
elif to_type == dtypes.FLOAT_DTYPE:
func_name = "FLOAT64"
elif to_type == dtypes.BOOL_DTYPE:
func_name = "BOOL"
elif to_type == dtypes.STRING_DTYPE:
func_name = "STRING"
if func_name:
func_name = "SAFE." + func_name if op.safe else func_name
return sge.func(func_name, sg_expr)
raise TypeError(f"Cannot cast from {dtypes.JSON_DTYPE} to {to_type}")


@register_unary_op(ops.ToJSONString)
Expand Down
11 changes: 5 additions & 6 deletions packages/bigframes/bigframes/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,17 +442,16 @@ def astype(
if errors not in ["raise", "null"]:
raise ValueError("Arg 'error' must be one of 'raise' or 'null'")

safe_cast = errors == "null"

if isinstance(dtype, dict):
result = self.copy()
for col, to_type in dtype.items():
result[col] = result[col].astype(to_type)
result[col] = result[col].astype(to_type, errors=errors)
return result

dtype = bigframes.dtypes.bigframes_type(dtype)

return self._apply_unary_op(ops.AsTypeOp(dtype, safe_cast))
result = self.copy()
for col in result.columns:
result[col] = result[col].astype(dtype, errors=errors)
return result
Comment on lines +451 to +454

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Looping over columns and performing individual series assignments (result[col] = result[col].astype(...)) introduces a significant performance regression for non-JSON data types. In bigframes, column-by-column assignment creates multiple intermediate DataFrame states and nested expression trees, which can lead to slow compilation times and complex SQL queries.

We can avoid this overhead by fast-pathing the common cases where no JSON columns are involved, or where all columns are being cast to/from JSON, applying the operation to all columns at once using _apply_unary_op.

        safe = errors == "null"
        dtype = bigframes.dtypes.bigframes_type(dtype)

        if dtype == bigframes.dtypes.JSON_DTYPE:
            return self._apply_unary_op(ops.ToJSON(safe=safe))

        if all(t == bigframes.dtypes.JSON_DTYPE for t in self.dtypes):
            return self._apply_unary_op(ops.JSONDecode(to_type=dtype, safe=safe))

        if not any(t == bigframes.dtypes.JSON_DTYPE for t in self.dtypes):
            return self._apply_unary_op(ops.AsTypeOp(to_type=dtype, safe=safe))

        result = self.copy()
        for col in result.columns:
            result[col] = result[col].astype(dtype, errors=errors)
        return result


def _should_sql_have_index(self) -> bool:
"""Should the SQL we pass to BQML and other I/O include the index?"""
Expand Down
1 change: 1 addition & 0 deletions packages/bigframes/bigframes/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@
)
from bigframes.operations.googlesql import GoogleSqlScalarOp
from bigframes.operations.json_ops import (
JSONDecode,
JSONExtract,
JSONExtractArray,
JSONExtractStringArray,
Expand Down
33 changes: 0 additions & 33 deletions packages/bigframes/bigframes/operations/generic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,6 @@
dtypes.STRING_DTYPE,
dtypes.INT_DTYPE,
),
(
dtypes.JSON_DTYPE,
dtypes.INT_DTYPE,
),
# Float casts
(
dtypes.BOOL_DTYPE,
Expand All @@ -118,10 +114,6 @@
dtypes.STRING_DTYPE,
dtypes.FLOAT_DTYPE,
),
(
dtypes.JSON_DTYPE,
dtypes.FLOAT_DTYPE,
),
# Bool casts
(
dtypes.INT_DTYPE,
Expand All @@ -131,10 +123,6 @@
dtypes.FLOAT_DTYPE,
dtypes.BOOL_DTYPE,
),
(
dtypes.JSON_DTYPE,
dtypes.BOOL_DTYPE,
),
# String casts
(
dtypes.BYTES_DTYPE,
Expand Down Expand Up @@ -168,10 +156,6 @@
dtypes.DATE_DTYPE,
dtypes.STRING_DTYPE,
),
(
dtypes.JSON_DTYPE,
dtypes.STRING_DTYPE,
),
# bytes casts
(
dtypes.STRING_DTYPE,
Expand Down Expand Up @@ -276,23 +260,6 @@
dtypes.INT_DTYPE,
dtypes.TIMEDELTA_DTYPE,
),
# json casts
(
dtypes.BOOL_DTYPE,
dtypes.JSON_DTYPE,
),
(
dtypes.FLOAT_DTYPE,
dtypes.JSON_DTYPE,
),
(
dtypes.STRING_DTYPE,
dtypes.JSON_DTYPE,
),
(
dtypes.INT_DTYPE,
dtypes.JSON_DTYPE,
),
)
)

Expand Down
2 changes: 2 additions & 0 deletions packages/bigframes/bigframes/operations/json_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def output_type(self, *input_types):
@dataclasses.dataclass(frozen=True)
class ToJSON(base_ops.UnaryOp):
name: typing.ClassVar[str] = "to_json"
safe: bool = True

def output_type(self, *input_types):
input_type = input_types[0]
Expand Down Expand Up @@ -220,6 +221,7 @@ def output_type(self, *input_types):
class JSONDecode(base_ops.UnaryOp):
name: typing.ClassVar[str] = "json_decode"
to_type: dtypes.Dtype
safe: bool = True

def output_type(self, *input_types):
input_type = input_types[0]
Expand Down
16 changes: 13 additions & 3 deletions packages/bigframes/bigframes/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,9 +646,19 @@ def astype(
if errors not in ["raise", "null"]:
raise ValueError("Argument 'errors' must be one of 'raise' or 'null'")
dtype = bigframes.dtypes.bigframes_type(dtype)
return self._apply_unary_op(
bigframes.operations.AsTypeOp(to_type=dtype, safe=(errors == "null"))
)
safe = errors == "null"
if dtype == bigframes.dtypes.JSON_DTYPE:
return self._apply_unary_op(
bigframes.operations.ToJSON(safe=safe)
)
elif self.dtype == bigframes.dtypes.JSON_DTYPE:
return self._apply_unary_op(
bigframes.operations.JSONDecode(to_type=dtype, safe=safe)
)
else:
return self._apply_unary_op(
bigframes.operations.AsTypeOp(to_type=dtype, safe=safe)
)

def to_pandas(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,16 +263,16 @@ def test_engines_astype_time(scalars_array_value: array_value.ArrayValue, engine
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
def test_engines_astype_from_json(scalars_array_value: array_value.ArrayValue, engine):
exprs = [
ops.AsTypeOp(to_type=bigframes.dtypes.INT_DTYPE).as_expr(
ops.JSONDecode(to_type=bigframes.dtypes.INT_DTYPE).as_expr(
expression.const("5", bigframes.dtypes.JSON_DTYPE)
),
ops.AsTypeOp(to_type=bigframes.dtypes.FLOAT_DTYPE).as_expr(
ops.JSONDecode(to_type=bigframes.dtypes.FLOAT_DTYPE).as_expr(
expression.const("5", bigframes.dtypes.JSON_DTYPE)
),
ops.AsTypeOp(to_type=bigframes.dtypes.BOOL_DTYPE).as_expr(
ops.JSONDecode(to_type=bigframes.dtypes.BOOL_DTYPE).as_expr(
expression.const("true", bigframes.dtypes.JSON_DTYPE)
),
ops.AsTypeOp(to_type=bigframes.dtypes.STRING_DTYPE).as_expr(
ops.JSONDecode(to_type=bigframes.dtypes.STRING_DTYPE).as_expr(
expression.const('"hello world"', bigframes.dtypes.JSON_DTYPE)
),
]
Expand All @@ -284,17 +284,17 @@ def test_engines_astype_from_json(scalars_array_value: array_value.ArrayValue, e
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
def test_engines_astype_to_json(scalars_array_value: array_value.ArrayValue, engine):
exprs = [
ops.AsTypeOp(to_type=bigframes.dtypes.JSON_DTYPE).as_expr(
ops.ToJSON().as_expr(
expression.deref("int64_col")
),
ops.AsTypeOp(to_type=bigframes.dtypes.JSON_DTYPE).as_expr(
ops.ToJSON().as_expr(
# Use a const since float to json has precision issues
expression.const(5.2, bigframes.dtypes.FLOAT_DTYPE)
),
ops.AsTypeOp(to_type=bigframes.dtypes.JSON_DTYPE).as_expr(
ops.ToJSON().as_expr(
expression.deref("bool_col")
),
ops.AsTypeOp(to_type=bigframes.dtypes.JSON_DTYPE).as_expr(
ops.ToJSON().as_expr(
# Use a const since "str_col" has special chars.
expression.const('"hello world"', bigframes.dtypes.STRING_DTYPE)
),
Expand Down
Loading