diff --git a/bigframes/core/compile/sqlglot/compiler.py b/bigframes/core/compile/sqlglot/compiler.py index 3b64596d8f..55dd93b600 100644 --- a/bigframes/core/compile/sqlglot/compiler.py +++ b/bigframes/core/compile/sqlglot/compiler.py @@ -41,8 +41,6 @@ def compile_sql(request: configs.CompileRequest) -> configs.CompileResult: """Compiles a BigFrameNode according to the request into SQL using SQLGlot.""" - # Generator for unique identifiers. - uid_gen = guid.SequentialUIDGenerator() output_names = tuple((expression.DerefOp(id), id.sql) for id in request.node.ids) result_node = nodes.ResultNode( request.node, @@ -61,11 +59,7 @@ def compile_sql(request: configs.CompileRequest) -> configs.CompileResult: ) if request.sort_rows: result_node = typing.cast(nodes.ResultNode, rewrite.column_pruning(result_node)) - result_node = _remap_variables(result_node, uid_gen) - result_node = typing.cast( - nodes.ResultNode, rewrite.defer_selection(result_node) - ) - sql = _compile_result_node(result_node, uid_gen) + sql = _compile_result_node(result_node) return configs.CompileResult( sql, result_node.schema.to_bigquery(), result_node.order_by ) @@ -73,10 +67,8 @@ def compile_sql(request: configs.CompileRequest) -> configs.CompileResult: ordering: typing.Optional[bf_ordering.RowOrdering] = result_node.order_by result_node = dataclasses.replace(result_node, order_by=None) result_node = typing.cast(nodes.ResultNode, rewrite.column_pruning(result_node)) + sql = _compile_result_node(result_node) - result_node = _remap_variables(result_node, uid_gen) - result_node = typing.cast(nodes.ResultNode, rewrite.defer_selection(result_node)) - sql = _compile_result_node(result_node, uid_gen) # Return the ordering iff no extra columns are needed to define the row order if ordering is not None: output_order = ( @@ -97,16 +89,22 @@ def _remap_variables( return typing.cast(nodes.ResultNode, result_node) -def _compile_result_node( - root: nodes.ResultNode, uid_gen: guid.SequentialUIDGenerator -) -> str: +def _compile_result_node(root: nodes.ResultNode) -> str: + # Create UIDs to standardize variable names and ensure consistent compilation + # of nodes using the same generator. + uid_gen = guid.SequentialUIDGenerator() + root = _remap_variables(root, uid_gen) + root = typing.cast(nodes.ResultNode, rewrite.defer_selection(root)) + # Have to bind schema as the final step before compilation. root = typing.cast(nodes.ResultNode, schema_binding.bind_schema_to_tree(root)) + selected_cols: tuple[tuple[str, sge.Expression], ...] = tuple( (name, scalar_compiler.scalar_op_compiler.compile_expression(ref)) for ref, name in root.output_cols ) - sqlglot_ir = compile_node(root.child, uid_gen).select(selected_cols) + sqlglot_ir = compile_node(root.child, uid_gen) + sqlglot_ir = sqlglot_ir.select(selected_cols, sqlglot_ir.uid_gen) if root.order_by is not None: ordering_cols = tuple( @@ -119,10 +117,10 @@ def _compile_result_node( ) for ordering in root.order_by.all_ordering_columns ) - sqlglot_ir = sqlglot_ir.order_by(ordering_cols) + sqlglot_ir = sqlglot_ir.order_by(ordering_cols, sqlglot_ir.uid_gen) if root.limit is not None: - sqlglot_ir = sqlglot_ir.limit(root.limit) + sqlglot_ir = sqlglot_ir.limit(root.limit, sqlglot_ir.uid_gen) return sqlglot_ir.sql @@ -135,15 +133,14 @@ def compile_node( bf_to_sqlglot: dict[nodes.BigFrameNode, ir.SQLGlotIR] = {} child_results: tuple[ir.SQLGlotIR, ...] = () for current_node in list(node.iter_nodes_topo()): - if current_node.child_nodes == (): - # For leaf node, generates a dumpy child to pass the UID generator. - child_results = tuple([ir.SQLGlotIR(uid_gen=uid_gen)]) - else: - # Child nodes should have been compiled in the reverse topological order. - child_results = tuple( - bf_to_sqlglot[child] for child in current_node.child_nodes - ) - result = _compile_node(current_node, *child_results) + # Child nodes should have been compiled in the reverse topological order. + child_results = tuple( + bf_to_sqlglot[child] for child in current_node.child_nodes + ) + result = _compile_node(current_node, uid_gen, *child_results) + + # Update the uid_gen to be used in the next nodes compilation. + uid_gen = result.uid_gen bf_to_sqlglot[current_node] = result return bf_to_sqlglot[node] @@ -151,14 +148,20 @@ def compile_node( @functools.singledispatch def _compile_node( - node: nodes.BigFrameNode, *compiled_children: ir.SQLGlotIR + node: nodes.BigFrameNode, + uid_gen: guid.SequentialUIDGenerator, + *compiled_children: ir.SQLGlotIR, ) -> ir.SQLGlotIR: """Defines transformation but isn't cached, always use compile_node instead""" raise ValueError(f"Can't compile unrecognized node: {node}") @_compile_node.register -def compile_readlocal(node: nodes.ReadLocalNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR: +def compile_readlocal( + node: nodes.ReadLocalNode, + uid_gen: guid.SequentialUIDGenerator, + *child: ir.SQLGlotIR, +) -> ir.SQLGlotIR: pa_table = node.local_data_source.data pa_table = pa_table.select([item.source_id for item in node.scan_list.items]) pa_table = pa_table.rename_columns([item.id.sql for item in node.scan_list.items]) @@ -167,11 +170,15 @@ def compile_readlocal(node: nodes.ReadLocalNode, child: ir.SQLGlotIR) -> ir.SQLG if offsets: pa_table = pyarrow_utils.append_offsets(pa_table, offsets) - return ir.SQLGlotIR.from_pyarrow(pa_table, node.schema, uid_gen=child.uid_gen) + return ir.SQLGlotIR.from_pyarrow(pa_table, node.schema, uid_gen=uid_gen) @_compile_node.register -def compile_readtable(node: nodes.ReadTableNode, child: ir.SQLGlotIR): +def compile_readtable( + node: nodes.ReadTableNode, + uid_gen: guid.SequentialUIDGenerator, + *child: ir.SQLGlotIR, +): table = node.source.table return ir.SQLGlotIR.from_table( table.project_id, @@ -179,39 +186,50 @@ def compile_readtable(node: nodes.ReadTableNode, child: ir.SQLGlotIR): table.table_id, col_names=[col.source_id for col in node.scan_list.items], alias_names=[col.id.sql for col in node.scan_list.items], - uid_gen=child.uid_gen, + uid_gen=uid_gen, sql_predicate=node.source.sql_predicate, system_time=node.source.at_time, ) @_compile_node.register -def compile_selection(node: nodes.SelectionNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR: +def compile_selection( + node: nodes.SelectionNode, uid_gen: guid.SequentialUIDGenerator, child: ir.SQLGlotIR +) -> ir.SQLGlotIR: selected_cols: tuple[tuple[str, sge.Expression], ...] = tuple( (id.sql, scalar_compiler.scalar_op_compiler.compile_expression(expr)) for expr, id in node.input_output_pairs ) - return child.select(selected_cols) + return child.select(selected_cols, uid_gen=uid_gen) @_compile_node.register -def compile_projection(node: nodes.ProjectionNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR: +def compile_projection( + node: nodes.ProjectionNode, + uid_gen: guid.SequentialUIDGenerator, + child: ir.SQLGlotIR, +) -> ir.SQLGlotIR: projected_cols: tuple[tuple[str, sge.Expression], ...] = tuple( (id.sql, scalar_compiler.scalar_op_compiler.compile_expression(expr)) for expr, id in node.assignments ) - return child.project(projected_cols) + return child.project(projected_cols, uid_gen=uid_gen) @_compile_node.register -def compile_filter(node: nodes.FilterNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR: +def compile_filter( + node: nodes.FilterNode, uid_gen: guid.SequentialUIDGenerator, child: ir.SQLGlotIR +) -> ir.SQLGlotIR: condition = scalar_compiler.scalar_op_compiler.compile_expression(node.predicate) - return child.filter(tuple([condition])) + return child.filter(tuple([condition]), uid_gen=uid_gen) @_compile_node.register def compile_join( - node: nodes.JoinNode, left: ir.SQLGlotIR, right: ir.SQLGlotIR + node: nodes.JoinNode, + uid_gen: guid.SequentialUIDGenerator, + left: ir.SQLGlotIR, + right: ir.SQLGlotIR, ) -> ir.SQLGlotIR: conditions = tuple( ( @@ -232,12 +250,16 @@ def compile_join( join_type=node.type, conditions=conditions, joins_nulls=node.joins_nulls, + uid_gen=uid_gen, ) @_compile_node.register def compile_isin_join( - node: nodes.InNode, left: ir.SQLGlotIR, right: ir.SQLGlotIR + node: nodes.InNode, + uid_gen: guid.SequentialUIDGenerator, + left: ir.SQLGlotIR, + right: ir.SQLGlotIR, ) -> ir.SQLGlotIR: right_field = node.right_child.fields[0] conditions = ( @@ -255,6 +277,7 @@ def compile_isin_join( return left.isin_join( right, + uid_gen=uid_gen, indicator_col=node.indicator_col.sql, conditions=conditions, joins_nulls=node.joins_nulls, @@ -262,9 +285,15 @@ def compile_isin_join( @_compile_node.register -def compile_concat(node: nodes.ConcatNode, *children: ir.SQLGlotIR) -> ir.SQLGlotIR: +def compile_concat( + node: nodes.ConcatNode, + uid_gen: guid.SequentialUIDGenerator, + *children: ir.SQLGlotIR, +) -> ir.SQLGlotIR: assert len(children) >= 1 - uid_gen = children[0].uid_gen + + if len(children) == 1: + return children[0] # BigQuery `UNION` query takes the column names from the first `SELECT` clause. default_output_ids = [field.id.sql for field in node.child_nodes[0].fields] @@ -281,21 +310,27 @@ def compile_concat(node: nodes.ConcatNode, *children: ir.SQLGlotIR) -> ir.SQLGlo @_compile_node.register -def compile_explode(node: nodes.ExplodeNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR: +def compile_explode( + node: nodes.ExplodeNode, uid_gen: guid.SequentialUIDGenerator, child: ir.SQLGlotIR +) -> ir.SQLGlotIR: offsets_col = node.offsets_col.sql if (node.offsets_col is not None) else None columns = tuple(ref.id.sql for ref in node.column_ids) - return child.explode(columns, offsets_col) + return child.explode(columns, offsets_col, uid_gen=uid_gen) @_compile_node.register def compile_random_sample( - node: nodes.RandomSampleNode, child: ir.SQLGlotIR + node: nodes.RandomSampleNode, + uid_gen: guid.SequentialUIDGenerator, + child: ir.SQLGlotIR, ) -> ir.SQLGlotIR: - return child.sample(node.fraction) + return child.sample(node.fraction, uid_gen=uid_gen) @_compile_node.register -def compile_aggregate(node: nodes.AggregateNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR: +def compile_aggregate( + node: nodes.AggregateNode, uid_gen: guid.SequentialUIDGenerator, child: ir.SQLGlotIR +) -> ir.SQLGlotIR: # The BigQuery ordered aggregation cannot support for NULL FIRST/LAST, # so we need to add extra expressions to enforce the null ordering. ordering_cols = windows.get_window_order_by(node.order_by, override_null_order=True) @@ -319,11 +354,13 @@ def compile_aggregate(node: nodes.AggregateNode, child: ir.SQLGlotIR) -> ir.SQLG if node.child.field_by_id[key.id].nullable: dropna_cols.append(by_col) - return child.aggregate(aggregations, by_cols, tuple(dropna_cols)) + return child.aggregate(aggregations, by_cols, tuple(dropna_cols), uid_gen=uid_gen) @_compile_node.register -def compile_window(node: nodes.WindowOpNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR: +def compile_window( + node: nodes.WindowOpNode, uid_gen: guid.SequentialUIDGenerator, child: ir.SQLGlotIR +) -> ir.SQLGlotIR: window_spec = node.window_spec result = child for cdef in node.agg_exprs: @@ -391,6 +428,7 @@ def compile_window(node: nodes.WindowOpNode, child: ir.SQLGlotIR) -> ir.SQLGlotI result = result.window( window_op=window_op, output_column_id=cdef.id.sql, + uid_gen=uid_gen, ) return result diff --git a/bigframes/core/compile/sqlglot/sqlglot_ir.py b/bigframes/core/compile/sqlglot/sqlglot_ir.py index 544d46b832..b844debb46 100644 --- a/bigframes/core/compile/sqlglot/sqlglot_ir.py +++ b/bigframes/core/compile/sqlglot/sqlglot_ir.py @@ -234,6 +234,7 @@ def from_union( def select( self, selected_cols: tuple[tuple[str, sge.Expression], ...], + uid_gen: guid.SequentialUIDGenerator, ) -> SQLGlotIR: """Replaces new selected columns of the current SELECT clause.""" selections = [ @@ -249,15 +250,16 @@ def select( new_expr = _select_to_cte( self.expr, sge.to_identifier( - next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted + next(uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted ), ) new_expr = new_expr.select(*selections, append=False) - return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) + return SQLGlotIR(expr=new_expr, uid_gen=uid_gen) def project( self, projected_cols: tuple[tuple[str, sge.Expression], ...], + uid_gen: guid.SequentialUIDGenerator, ) -> SQLGlotIR: """Adds new columns to the SELECT clause.""" projected_cols_expr = [ @@ -270,66 +272,67 @@ def project( new_expr = _select_to_cte( self.expr, sge.to_identifier( - next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted + next(uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted ), ) new_expr = new_expr.select(*projected_cols_expr, append=True) - return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) + return SQLGlotIR(expr=new_expr, uid_gen=uid_gen) def order_by( self, ordering: tuple[sge.Ordered, ...], + uid_gen: guid.SequentialUIDGenerator, ) -> SQLGlotIR: """Adds an ORDER BY clause to the query.""" if len(ordering) == 0: - return SQLGlotIR(expr=self.expr.copy(), uid_gen=self.uid_gen) + return SQLGlotIR(expr=self.expr.copy(), uid_gen=uid_gen) new_expr = self.expr.order_by(*ordering) - return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) + return SQLGlotIR(expr=new_expr, uid_gen=uid_gen) def limit( self, limit: int | None, + uid_gen: guid.SequentialUIDGenerator, ) -> SQLGlotIR: """Adds a LIMIT clause to the query.""" if limit is not None: new_expr = self.expr.limit(limit) else: new_expr = self.expr.copy() - return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) + return SQLGlotIR(expr=new_expr, uid_gen=uid_gen) def filter( self, conditions: tuple[sge.Expression, ...], + uid_gen: guid.SequentialUIDGenerator, ) -> SQLGlotIR: """Filters the query by adding a WHERE clause.""" condition = _and(conditions) if condition is None: - return SQLGlotIR(expr=self.expr.copy(), uid_gen=self.uid_gen) + return SQLGlotIR(expr=self.expr.copy(), uid_gen=uid_gen) new_expr = _select_to_cte( self.expr, sge.to_identifier( - next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted + next(uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted ), ) - return SQLGlotIR( - expr=new_expr.where(condition, append=False), uid_gen=self.uid_gen - ) + return SQLGlotIR(expr=new_expr.where(condition, append=False), uid_gen=uid_gen) def join( self, right: SQLGlotIR, join_type: typing.Literal["inner", "outer", "left", "right", "cross"], conditions: tuple[tuple[typed_expr.TypedExpr, typed_expr.TypedExpr], ...], - *, - joins_nulls: bool = True, + joins_nulls: bool, + uid_gen: guid.SequentialUIDGenerator, ) -> SQLGlotIR: """Joins the current query with another SQLGlotIR instance.""" left_cte_name = sge.to_identifier( - next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted + next(uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted ) right_cte_name = sge.to_identifier( - next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted + next(uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted ) left_select = _select_to_cte(self.expr, left_cte_name) @@ -354,18 +357,19 @@ def join( ) new_expr = _set_query_ctes(new_expr, merged_ctes) - return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) + return SQLGlotIR(expr=new_expr, uid_gen=uid_gen) def isin_join( self, right: SQLGlotIR, indicator_col: str, conditions: tuple[typed_expr.TypedExpr, typed_expr.TypedExpr], - joins_nulls: bool = True, + joins_nulls: bool, + uid_gen: guid.SequentialUIDGenerator, ) -> SQLGlotIR: """Joins the current query with another SQLGlotIR instance.""" left_cte_name = sge.to_identifier( - next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted + next(uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted ) left_select = _select_to_cte(self.expr, left_cte_name) @@ -384,7 +388,7 @@ def isin_join( new_column: sge.Expression if joins_nulls: right_table_name = sge.to_identifier( - next(self.uid_gen.get_uid_stream("bft_")), quoted=self.quoted + next(uid_gen.get_uid_stream("bft_")), quoted=self.quoted ) right_condition = typed_expr.TypedExpr( sge.Column(this=conditions[1].expr, table=right_table_name), @@ -416,25 +420,32 @@ def isin_join( ) new_expr = _set_query_ctes(new_expr, merged_ctes) - return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) + return SQLGlotIR(expr=new_expr, uid_gen=uid_gen) def explode( self, column_names: tuple[str, ...], offsets_col: typing.Optional[str], + uid_gen: guid.SequentialUIDGenerator, ) -> SQLGlotIR: """Unnests one or more array columns.""" num_columns = len(list(column_names)) assert num_columns > 0, "At least one column must be provided for explode." if num_columns == 1: - return self._explode_single_column(column_names[0], offsets_col) + return self._explode_single_column( + column_names[0], offsets_col, uid_gen=uid_gen + ) else: - return self._explode_multiple_columns(column_names, offsets_col) + return self._explode_multiple_columns( + column_names, offsets_col, uid_gen=uid_gen + ) - def sample(self, fraction: float) -> SQLGlotIR: + def sample( + self, fraction: float, uid_gen: guid.SequentialUIDGenerator + ) -> SQLGlotIR: """Uniform samples a fraction of the rows.""" uuid_col = sge.to_identifier( - next(self.uid_gen.get_uid_stream("bfcol_")), quoted=self.quoted + next(uid_gen.get_uid_stream("bfcol_")), quoted=self.quoted ) uuid_expr = sge.Alias(this=sge.func("RAND"), alias=uuid_col) condition = sge.LT( @@ -443,18 +454,19 @@ def sample(self, fraction: float) -> SQLGlotIR: ) new_cte_name = sge.to_identifier( - next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted + next(uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted ) new_expr = _select_to_cte( self.expr.select(uuid_expr, append=True), new_cte_name ).where(condition, append=False) - return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) + return SQLGlotIR(expr=new_expr, uid_gen=uid_gen) def aggregate( self, aggregations: tuple[tuple[str, sge.Expression], ...], by_cols: tuple[sge.Expression, ...], dropna_cols: tuple[sge.Expression, ...], + uid_gen: guid.SequentialUIDGenerator, ) -> SQLGlotIR: """Applies the aggregation expressions. @@ -474,7 +486,7 @@ def aggregate( new_expr = _select_to_cte( self.expr, sge.to_identifier( - next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted + next(uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted ), ) new_expr = new_expr.group_by(*by_cols).select( @@ -489,14 +501,15 @@ def aggregate( ) if condition is not None: new_expr = new_expr.where(condition, append=False) - return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) + return SQLGlotIR(expr=new_expr, uid_gen=uid_gen) def window( self, window_op: sge.Expression, output_column_id: str, + uid_gen: guid.SequentialUIDGenerator, ) -> SQLGlotIR: - return self.project(((output_column_id, window_op),)) + return self.project(((output_column_id, window_op),), uid_gen=uid_gen) def insert( self, @@ -533,7 +546,10 @@ def replace( return f"{merge_str}\n{whens_str}" def _explode_single_column( - self, column_name: str, offsets_col: typing.Optional[str] + self, + column_name: str, + offsets_col: typing.Optional[str], + uid_gen: guid.SequentialUIDGenerator, ) -> SQLGlotIR: """Helper method to handle the case of exploding a single column.""" offset = ( @@ -541,7 +557,7 @@ def _explode_single_column( ) column = sge.to_identifier(column_name, quoted=self.quoted) unnested_column_alias = sge.to_identifier( - next(self.uid_gen.get_uid_stream("bfcol_")), quoted=self.quoted + next(uid_gen.get_uid_stream("bfcol_")), quoted=self.quoted ) unnest_expr = sge.Unnest( expressions=[column], @@ -553,19 +569,20 @@ def _explode_single_column( new_expr = _select_to_cte( self.expr, sge.to_identifier( - next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted + next(uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted ), ) # Use LEFT JOIN to preserve rows when unnesting empty arrays. new_expr = new_expr.select(selection, append=False).join( unnest_expr, join_type="LEFT" ) - return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) + return SQLGlotIR(expr=new_expr, uid_gen=uid_gen) def _explode_multiple_columns( self, column_names: tuple[str, ...], offsets_col: typing.Optional[str], + uid_gen: guid.SequentialUIDGenerator, ) -> SQLGlotIR: """Helper method to handle the case of exploding multiple columns.""" offset = ( @@ -588,7 +605,7 @@ def _explode_multiple_columns( sge.func("LEAST", *column_lengths), ) unnested_offset_alias = sge.to_identifier( - next(self.uid_gen.get_uid_stream("bfcol_")), quoted=self.quoted + next(uid_gen.get_uid_stream("bfcol_")), quoted=self.quoted ) unnest_expr = sge.Unnest( expressions=[generate_array], @@ -609,14 +626,14 @@ def _explode_multiple_columns( new_expr = _select_to_cte( self.expr, sge.to_identifier( - next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted + next(uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted ), ) # Use LEFT JOIN to preserve rows when unnesting empty arrays. new_expr = new_expr.select(selection, append=False).join( unnest_expr, join_type="LEFT" ) - return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) + return SQLGlotIR(expr=new_expr, uid_gen=uid_gen) def _select_to_cte(expr: sge.Select, cte_name: sge.Identifier) -> sge.Select: