@@ -995,7 +995,7 @@ def filter(self, predicate: scalars.Expression):
995995
996996 def aggregate_all_and_stack (
997997 self ,
998- operation : agg_ops .UnaryAggregateOp ,
998+ operation : typing . Union [ agg_ops .UnaryAggregateOp , agg_ops . NullaryAggregateOp ] ,
999999 * ,
10001000 axis : int | str = 0 ,
10011001 value_col_id : str = "values" ,
@@ -1004,7 +1004,12 @@ def aggregate_all_and_stack(
10041004 axis_n = utils .get_axis_number (axis )
10051005 if axis_n == 0 :
10061006 aggregations = [
1007- (ex .UnaryAggregation (operation , ex .free_var (col_id )), col_id )
1007+ (
1008+ ex .UnaryAggregation (operation , ex .free_var (col_id ))
1009+ if isinstance (operation , agg_ops .UnaryAggregateOp )
1010+ else ex .NullaryAggregation (operation ),
1011+ col_id ,
1012+ )
10081013 for col_id in self .value_columns
10091014 ]
10101015 index_id = guid .generate_guid ()
@@ -1033,6 +1038,11 @@ def aggregate_all_and_stack(
10331038 (ex .UnaryAggregation (agg_ops .AnyValueOp (), ex .free_var (col_id )), col_id )
10341039 for col_id in [* self .index_columns ]
10351040 ]
1041+ # TODO: may need add NullaryAggregation in main_aggregation
1042+ # when agg add support for axis=1, needed for agg("size", axis=1)
1043+ assert isinstance (
1044+ operation , agg_ops .UnaryAggregateOp
1045+ ), f"Expected a unary operation, but got { operation } . Please report this error and how you got here to the BigQuery DataFrames team (bit.ly/bigframes-feedback)."
10361046 main_aggregation = (
10371047 ex .UnaryAggregation (operation , ex .free_var (value_col_id )),
10381048 value_col_id ,
@@ -1125,7 +1135,11 @@ def remap_f(x):
11251135 def aggregate (
11261136 self ,
11271137 by_column_ids : typing .Sequence [str ] = (),
1128- aggregations : typing .Sequence [typing .Tuple [str , agg_ops .UnaryAggregateOp ]] = (),
1138+ aggregations : typing .Sequence [
1139+ typing .Tuple [
1140+ str , typing .Union [agg_ops .UnaryAggregateOp , agg_ops .NullaryAggregateOp ]
1141+ ]
1142+ ] = (),
11291143 * ,
11301144 dropna : bool = True ,
11311145 ) -> typing .Tuple [Block , typing .Sequence [str ]]:
@@ -1139,7 +1153,9 @@ def aggregate(
11391153 """
11401154 agg_specs = [
11411155 (
1142- ex .UnaryAggregation (operation , ex .free_var (input_id )),
1156+ ex .UnaryAggregation (operation , ex .free_var (input_id ))
1157+ if isinstance (operation , agg_ops .UnaryAggregateOp )
1158+ else ex .NullaryAggregation (operation ),
11431159 guid .generate_guid (),
11441160 )
11451161 for input_id , operation in aggregations
@@ -1175,18 +1191,32 @@ def aggregate(
11751191 output_col_ids ,
11761192 )
11771193
1178- def get_stat (self , column_id : str , stat : agg_ops .UnaryAggregateOp ):
1194+ def get_stat (
1195+ self ,
1196+ column_id : str ,
1197+ stat : typing .Union [agg_ops .UnaryAggregateOp , agg_ops .NullaryAggregateOp ],
1198+ ):
11791199 """Gets aggregates immediately, and caches it"""
11801200 if stat .name in self ._stats_cache [column_id ]:
11811201 return self ._stats_cache [column_id ][stat .name ]
11821202
11831203 # TODO: Convert nonstandard stats into standard stats where possible (popvar, etc.)
11841204 # if getting a standard stat, just go get the rest of them
1185- standard_stats = self ._standard_stats (column_id )
1205+ standard_stats = typing .cast (
1206+ typing .Sequence [
1207+ typing .Union [agg_ops .UnaryAggregateOp , agg_ops .NullaryAggregateOp ]
1208+ ],
1209+ self ._standard_stats (column_id ),
1210+ )
11861211 stats_to_fetch = standard_stats if stat in standard_stats else [stat ]
11871212
11881213 aggregations = [
1189- (ex .UnaryAggregation (stat , ex .free_var (column_id )), stat .name )
1214+ (
1215+ ex .UnaryAggregation (stat , ex .free_var (column_id ))
1216+ if isinstance (stat , agg_ops .UnaryAggregateOp )
1217+ else ex .NullaryAggregation (stat ),
1218+ stat .name ,
1219+ )
11901220 for stat in stats_to_fetch
11911221 ]
11921222 expr = self .expr .aggregate (aggregations )
@@ -1231,13 +1261,20 @@ def get_binary_stat(
12311261 def summarize (
12321262 self ,
12331263 column_ids : typing .Sequence [str ],
1234- stats : typing .Sequence [agg_ops .UnaryAggregateOp ],
1264+ stats : typing .Sequence [
1265+ typing .Union [agg_ops .UnaryAggregateOp , agg_ops .NullaryAggregateOp ]
1266+ ],
12351267 ):
12361268 """Get a list of stats as a deferred block object."""
12371269 label_col_id = guid .generate_guid ()
12381270 labels = [stat .name for stat in stats ]
12391271 aggregations = [
1240- (ex .UnaryAggregation (stat , ex .free_var (col_id )), f"{ col_id } -{ stat .name } " )
1272+ (
1273+ ex .UnaryAggregation (stat , ex .free_var (col_id ))
1274+ if isinstance (stat , agg_ops .UnaryAggregateOp )
1275+ else ex .NullaryAggregation (stat ),
1276+ f"{ col_id } -{ stat .name } " ,
1277+ )
12411278 for stat in stats
12421279 for col_id in column_ids
12431280 ]
0 commit comments