Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions packages/bigframes/bigframes/bigquery/_operations/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def st_area(
bigframes.pandas.Series:
Series of float representing the areas.
"""
series = series._apply_unary_op(ops.geo_area_op)
series = series._apply_nary_op(ops.googlesql.ST_AREA, [])
series.name = None
return series

Expand Down Expand Up @@ -223,7 +223,7 @@ def st_centroid(
bigframes.pandas.Series:
A series of geography objects representing the centroids.
"""
series = series._apply_unary_op(ops.geo_st_centroid_op)
series = series._apply_nary_op(ops.googlesql.ST_CENTROID, [])
series.name = None
return series

Expand Down Expand Up @@ -753,6 +753,4 @@ def st_simplify(
Returns:
a Series containing the simplified GEOGRAPHY data.
"""
return geography._apply_unary_op(
ops.GeoStSimplifyOp(tolerance_meters=tolerance_meters)
)
return geography._apply_nary_op(ops.googlesql.ST_SIMPLIFY, [tolerance_meters])
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import bigframes.core.expression
from bigframes import dtypes
from bigframes import operations as ops
from bigframes.operations import googlesql


def rand() -> bigframes.core.col.Expression:
Expand Down Expand Up @@ -47,12 +48,9 @@ def rand() -> bigframes.core.col.Expression:
:func:`~bigframes.pandas.DataFrame.assign` and other methods. See
:func:`bigframes.pandas.col`.
"""
op = ops.SqlScalarOp(
_output_type=dtypes.FLOAT_DTYPE,
sql_template="RAND()",
is_deterministic=False,
return bigframes.core.col.Expression(
bigframes.core.expression.OpExpression(googlesql.RAND, ())
)
return bigframes.core.col.Expression(bigframes.core.expression.OpExpression(op, ()))


def hparam_range(min: float, max: float) -> bigframes.core.col.Expression:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,14 @@ def _(
]
return self.compile_row_op(expression.op, inputs)

@compile_expression.register
def _(
self,
expression: ex.Omitted,
bindings: typing.Dict[str, ibis_types.Value],
) -> ibis_types.Value:
return bigframes_vendored.ibis.omitted()

def compile_row_op(
self, op: ops.RowOp, inputs: typing.Sequence[ibis_types.Value]
) -> ibis_types.Value:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from bigframes.core.compile.ibis_compiler.scalar_op_compiler import (
scalar_op_compiler, # TODO(tswast): avoid import of variables
)
from bigframes.operations.googlesql import CallingConvention

_ZERO = typing.cast(ibis_types.NumericValue, ibis_types.literal(0))
_NAN = typing.cast(ibis_types.NumericValue, ibis_types.literal(np.nan))
Expand Down Expand Up @@ -1890,6 +1891,67 @@ def case_when_op(*cases_and_outputs: ibis_types.Value) -> ibis_types.Value:
return case_val.end() # type: ignore


@scalar_op_compiler.register_nary_op(ops.GoogleSqlScalarOp, pass_op=True)
def googlesql_scalar_op_impl(*operands: ibis_types.Value, op: ops.GoogleSqlScalarOp):
final_operands = []
arg_templates = []
if op.calling_convention == CallingConvention.FUNCTION:
for i, operand in enumerate(operands):
if i < len(op.args):
arg_spec = op.args[i]
else:
assert op.args[-1].is_vararg, (
f"Too many arguments, for {op.sql_name}, expected {len(op.args)}"
)
arg_spec = op.args[-1]
if operand.op().omitted:
assert arg_spec.optional, f"Argument omitted, but not optional"
continue

target_idx = len(final_operands)
final_operands.append(operand)
if arg_spec.arg_name:
arg_templates.append(f"{arg_spec.arg_name} => {{{target_idx}}}")
else:
arg_templates.append(f"{{{target_idx}}}")
args_template = ", ".join(arg_templates)
sql_template = f"{op.sql_name}({args_template})"
return ibis_generic.SqlScalar(
sql_template,
values=tuple(
typing.cast(ibis_generic.Value, expr.op()) for expr in final_operands
),
output_type=bigframes.core.compile.ibis_types.bigframes_dtype_to_ibis_dtype(
op.output_type()
),
).to_expr()
elif op.calling_convention == CallingConvention.PREFIX:
assert len(operands) == 1, "prefix op expects exactly 1 arg"
return ibis_generic.SqlScalar(
f"{op.sql_name} {{0}}",
values=tuple(
typing.cast(ibis_generic.Value, expr.op()) for expr in operands
),
output_type=bigframes.core.compile.ibis_types.bigframes_dtype_to_ibis_dtype(
op.output_type()
),
).to_expr()
elif op.calling_convention == CallingConvention.INFIX:
assert len(operands) == 2, "infix op expects exactly 2 args"
return ibis_generic.SqlScalar(
f"{{0}} {op.sql_name} {{1}}",
values=tuple(
typing.cast(ibis_generic.Value, expr.op()) for expr in operands
),
output_type=bigframes.core.compile.ibis_types.bigframes_dtype_to_ibis_dtype(
op.output_type()
),
).to_expr()
raise NotImplementedError(
f"Calling convention {op.calling_convention} not supported for {op}"
)


@scalar_op_compiler.register_nary_op(ops.SqlScalarOp, pass_op=True)
def sql_scalar_op_impl(*operands: ibis_types.Value, op: ops.SqlScalarOp):
return ibis_generic.SqlScalar(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,10 @@ def _(self, expr: agg_exprs.WindowExpression) -> sge.Expression:

@compile_expression.register
def _(self, expr: ex.OpExpression) -> sge.Expression:
# Non-recursively compiles the children scalar expressions.
inputs = tuple(
TypedExpr(self.compile_expression(sub_expr), sub_expr.output_type)
if not isinstance(sub_expr, ex.Omitted)
else TypedExpr(sge.Null, None, is_omitted=True)
for sub_expr in expr.inputs
)
return self.compile_row_op(expr.op, inputs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import bigframes.core.compile.sqlglot.expression_compiler as expression_compiler
from bigframes import dtypes
from bigframes import operations as ops
from bigframes.operations.googlesql import CallingConvention
from bigframes.core.compile.sqlglot import sql, sqlglot_types
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr

Expand Down Expand Up @@ -82,6 +83,35 @@ def _(expr: TypedExpr) -> sge.Expression:
return sge.BitwiseNot(this=sge.paren(expr.expr))


@register_nary_op(ops.GoogleSqlScalarOp, pass_op=True)
def _(*operands: TypedExpr, op: ops.GoogleSqlScalarOp) -> sge.Expression:
arg_templates = []
if op.calling_convention == CallingConvention.FUNCTION:
for i, operand in enumerate(operands):
if i < len(op.args):
arg_spec = op.args[i]
else:
assert op.args[-1].is_vararg, f"Too many arguments, for {op.sql_name}, expected {len(op.args)}"
arg_spec = op.args[-1]
if operand.is_omitted:
assert arg_spec.optional, f"Argument omitted, but not optional"
continue
elif arg_spec.arg_name:
arg_templates.append(f"{arg_spec.arg_name} => {operand.expr.sql(dialect='bigquery')}")
else:
arg_templates.append(operand.expr.sql(dialect='bigquery'))
args_template = ", ".join(arg_templates)
return sg.parse_one(f"{op.sql_name}({args_template})", dialect="bigquery")
elif op.calling_convention == CallingConvention.PREFIX:
assert len(operands) == 1, "prefix op expects exactly 1 arg"
return sg.parse_one(f"{op.sql_name} {operands[0].expr.sql(dialect='bigquery')}", dialect="bigquery")
elif op.calling_convention == CallingConvention.INFIX:
assert len(operands) == 2, 'infix op expects exactly 2 args'
return sg.parse_one(f"{operands[0].expr.sql(dialect='bigquery')} {op.sql_name} {operands[1].expr.sql(dialect='bigquery')}", dialect="bigquery")

raise NotImplementedError(f"Calling convention {op.calling_convention} not supported for {op}")


@register_nary_op(ops.SqlScalarOp, pass_op=True)
def _(*operands: TypedExpr, op: ops.SqlScalarOp) -> sge.Expression:
return sg.parse_one(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@ class TypedExpr:

expr: sge.Expression
dtype: dtypes.ExpressionType

# kludge to support optional args in argument lists
is_omitted: bool = False
50 changes: 50 additions & 0 deletions packages/bigframes/bigframes/core/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,56 @@ def output_type(self) -> dtypes.ExpressionType:
return self.dtype


@dataclasses.dataclass(frozen=True)
class Omitted(Expression):
"""Represents an omitted optional arg used calling a function."""

@property
def free_variables(self) -> typing.Tuple[Hashable, ...]:
return ()

@property
def is_const(self) -> bool:
return True

@property
def column_references(self) -> typing.Tuple[ids.ColumnId, ...]:
return ()

@property
def is_resolved(self):
return True # vacuously

@property
def output_type(self) -> dtypes.ExpressionType:
return None

def bind_refs(
self,
bindings: Mapping[ids.ColumnId, Expression],
allow_partial_bindings: bool = False,
) -> UnboundVariableExpression:
return self

def bind_variables(
self,
bindings: Mapping[Hashable, Expression],
allow_partial_bindings: bool = False,
) -> Expression:
return self

@property
def is_bijective(self) -> bool:
return True

@property
def is_identity(self) -> bool:
return True

def transform_children(self, t: Callable[[Expression], Expression]) -> Expression:
return self


@dataclasses.dataclass(frozen=True)
class OpExpression(Expression):
"""An expression representing a scalar operation applied to 1 or more argument sub-expressions."""
Expand Down
3 changes: 3 additions & 0 deletions packages/bigframes/bigframes/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@
timestamp_add_op,
timestamp_sub_op,
)
from bigframes.operations.googlesql import GoogleSqlScalarOp

__all__ = [
# Base ops
Expand Down Expand Up @@ -446,4 +447,6 @@
"ToArrayOp",
"ArrayReduceOp",
"ArrayMapOp",
# GoogleSql
"GoogleSqlScalarOp",
]
21 changes: 11 additions & 10 deletions packages/bigframes/bigframes/operations/ai_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from __future__ import annotations

import dataclasses
from typing import ClassVar, Literal, Tuple
import typing
from typing import Literal, Tuple

import pandas as pd
import pyarrow as pa
Expand All @@ -26,7 +27,7 @@

@dataclasses.dataclass(frozen=True)
class AIGenerate(base_ops.NaryOp):
name: ClassVar[str] = "ai_generate"
name: typing.ClassVar[str] = "ai_generate"

prompt_context: Tuple[str | None, ...]
connection_id: str | None
Expand Down Expand Up @@ -54,7 +55,7 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT

@dataclasses.dataclass(frozen=True)
class AIGenerateBool(base_ops.NaryOp):
name: ClassVar[str] = "ai_generate_bool"
name: typing.ClassVar[str] = "ai_generate_bool"

prompt_context: Tuple[str | None, ...]
connection_id: str | None
Expand All @@ -76,7 +77,7 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT

@dataclasses.dataclass(frozen=True)
class AIGenerateInt(base_ops.NaryOp):
name: ClassVar[str] = "ai_generate_int"
name: typing.ClassVar[str] = "ai_generate_int"

prompt_context: Tuple[str | None, ...]
connection_id: str | None
Expand All @@ -98,7 +99,7 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT

@dataclasses.dataclass(frozen=True)
class AIGenerateDouble(base_ops.NaryOp):
name: ClassVar[str] = "ai_generate_double"
name: typing.ClassVar[str] = "ai_generate_double"

prompt_context: Tuple[str | None, ...]
connection_id: str | None
Expand All @@ -120,7 +121,7 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT

@dataclasses.dataclass(frozen=True)
class AIEmbed(base_ops.UnaryOp):
name: ClassVar[str] = "ai_embed"
name: typing.ClassVar[str] = "ai_embed"

endpoint: str | None
model: str | None
Expand All @@ -142,7 +143,7 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT

@dataclasses.dataclass(frozen=True)
class AIIf(base_ops.NaryOp):
name: ClassVar[str] = "ai_if"
name: typing.ClassVar[str] = "ai_if"

prompt_context: Tuple[str | None, ...]
connection_id: str | None
Expand All @@ -156,7 +157,7 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT

@dataclasses.dataclass(frozen=True)
class AIClassify(base_ops.NaryOp):
name: ClassVar[str] = "ai_classify"
name: typing.ClassVar[str] = "ai_classify"

prompt_context: Tuple[str | None, ...]
categories: tuple[str, ...]
Expand All @@ -172,7 +173,7 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT

@dataclasses.dataclass(frozen=True)
class AIScore(base_ops.NaryOp):
name: ClassVar[str] = "ai_score"
name: typing.ClassVar[str] = "ai_score"

prompt_context: Tuple[str | None, ...]
connection_id: str | None
Expand All @@ -185,7 +186,7 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT

@dataclasses.dataclass(frozen=True)
class AISimilarity(base_ops.BinaryOp):
name: ClassVar[str] = "ai_similarity"
name: typing.ClassVar[str] = "ai_similarity"

endpoint: str | None
model: str | None
Expand Down
Loading
Loading