Skip to content

Commit 3639570

Browse files
authored
feat: Support distinct count aggregation [#6116]
1 parent 1e5b60f commit 3639570

File tree

6 files changed

+65
-5
lines changed

6 files changed

+65
-5
lines changed

sdk/python/feast/aggregation/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class Aggregation:
1818
1919
Attributes:
2020
column: str # Column name of the feature we are aggregating.
21-
function: str # Provided built in aggregations sum, max, min, count mean
21+
function: str # Provided built in aggregations sum, max, min, count, mean, count_distinct
2222
time_window: timedelta # The time window for this aggregation.
2323
slide_interval: timedelta # The sliding window for these aggregations
2424
name: str # Optional override for the output feature name (defaults to {function}_{column})
@@ -118,6 +118,11 @@ def resolved_name(self, time_window: Optional[timedelta] = None) -> str:
118118
return base
119119

120120

121+
_FUNCTION_ALIASES: Dict[str, str] = {
122+
"count_distinct": "nunique",
123+
}
124+
125+
121126
def aggregation_specs_to_agg_ops(
122127
agg_specs: Iterable[Any],
123128
*,
@@ -128,7 +133,8 @@ def aggregation_specs_to_agg_ops(
128133
if getattr(agg, "time_window", None) is not None:
129134
raise ValueError(time_window_unsupported_error_message)
130135
alias = getattr(agg, "name", None) or f"{agg.function}_{agg.column}"
131-
agg_ops[alias] = (agg.function, agg.column)
136+
func_name = _FUNCTION_ALIASES.get(agg.function, agg.function)
137+
agg_ops[alias] = (func_name, agg.column)
132138
return agg_ops
133139

134140

sdk/python/feast/aggregation/tiling/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,12 @@ def get_ir_metadata_for_aggregation(
8686
),
8787
)
8888

89+
elif agg_type == "count_distinct":
90+
raise ValueError(
91+
"count_distinct does not support tiling. "
92+
"Use enable_tiling=False or choose an algebraic aggregation (sum, count, min, max)."
93+
)
94+
8995
else:
9096
# Unknown aggregation: treat as algebraic
9197
return ([], IRMetadata(type="algebraic"))

sdk/python/feast/infra/compute_engines/backends/polars_backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88

99
class PolarsBackend(DataFrameBackend):
10+
_POLARS_FUNC_MAP = {"nunique": "n_unique"}
11+
1012
def columns(self, df: pl.DataFrame) -> list[str]:
1113
return df.columns
1214

@@ -21,7 +23,7 @@ def join(self, left: pl.DataFrame, right: pl.DataFrame, on, how) -> pl.DataFrame
2123

2224
def groupby_agg(self, df: pl.DataFrame, group_keys, agg_ops) -> pl.DataFrame:
2325
agg_exprs = [
24-
getattr(pl.col(col), func)().alias(alias)
26+
getattr(pl.col(col), self._POLARS_FUNC_MAP.get(func, func))().alias(alias)
2527
for alias, (func, col) in agg_ops.items()
2628
]
2729
return df.groupby(group_keys).agg(agg_exprs)

sdk/python/feast/infra/compute_engines/ray/nodes.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,8 @@ def _execute_standard_aggregation(self, dataset: Dataset) -> DAGValue:
477477
agg_dict[feature_name] = (agg.column, "std")
478478
elif agg.function == "var":
479479
agg_dict[feature_name] = (agg.column, "var")
480+
elif agg.function == "count_distinct":
481+
agg_dict[feature_name] = (agg.column, "nunique")
480482
else:
481483
raise ValueError(f"Unknown aggregation function: {agg.function}.")
482484

@@ -531,6 +533,8 @@ def _fallback_pandas_aggregation(self, dataset: Dataset, agg_dict: dict) -> Data
531533
result = grouped[column].std()
532534
elif function == "var":
533535
result = grouped[column].var()
536+
elif function == "nunique":
537+
result = grouped[column].nunique()
534538
else:
535539
raise ValueError(f"Unknown aggregation function: {function}.")
536540

sdk/python/feast/infra/compute_engines/spark/nodes.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -371,8 +371,11 @@ def _execute_standard_aggregation(self, input_df: DataFrame) -> DAGValue:
371371
"""Execute standard Spark aggregation (existing logic)."""
372372
agg_exprs = []
373373
for agg in self.aggregations:
374-
func = getattr(F, agg.function)
375-
expr = func(agg.column).alias(agg.resolved_name(agg.time_window))
374+
if agg.function == "count_distinct":
375+
func_expr = F.countDistinct(agg.column)
376+
else:
377+
func_expr = getattr(F, agg.function)(agg.column)
378+
expr = func_expr.alias(agg.resolved_name(agg.time_window))
376379
agg_exprs.append(expr)
377380

378381
if any(agg.time_window for agg in self.aggregations):

sdk/python/tests/unit/test_aggregation_ops.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from datetime import timedelta
22

3+
import pandas as pd
34
import pytest
45

56
from feast.aggregation import Aggregation, aggregation_specs_to_agg_ops
7+
from feast.aggregation.tiling.base import get_ir_metadata_for_aggregation
68

79

810
class DummyAggregation:
@@ -96,3 +98,40 @@ def test_aggregation_round_trip_with_name():
9698
restored = Aggregation.from_proto(proto)
9799
assert restored.name == "sum_seconds_watched_per_ad_1d"
98100
assert restored == agg
101+
102+
103+
def test_count_distinct_agg_ops():
104+
"""aggregation_specs_to_agg_ops maps count_distinct to the nunique pandas function."""
105+
agg_specs = [DummyAggregation(function="count_distinct", column="item_id")]
106+
107+
agg_ops = aggregation_specs_to_agg_ops(
108+
agg_specs,
109+
time_window_unsupported_error_message="no windows",
110+
)
111+
112+
assert agg_ops == {"count_distinct_item_id": ("nunique", "item_id")}
113+
114+
115+
def test_count_distinct_result():
116+
"""count_distinct via nunique returns the number of unique values per group."""
117+
from feast.infra.compute_engines.backends.pandas_backend import PandasBackend
118+
119+
agg_specs = [DummyAggregation(function="count_distinct", column="item_id")]
120+
agg_ops = aggregation_specs_to_agg_ops(
121+
agg_specs,
122+
time_window_unsupported_error_message="no windows",
123+
)
124+
125+
df = pd.DataFrame({"user": ["A", "A", "B"], "item_id": [1, 2, 1]})
126+
result = PandasBackend().groupby_agg(df, ["user"], agg_ops)
127+
result = result.set_index("user")
128+
129+
assert result.loc["A", "count_distinct_item_id"] == 2
130+
assert result.loc["B", "count_distinct_item_id"] == 1
131+
132+
133+
def test_count_distinct_tiling_raises():
134+
"""get_ir_metadata_for_aggregation raises ValueError for count_distinct."""
135+
agg = Aggregation(column="item_id", function="count_distinct")
136+
with pytest.raises(ValueError, match="count_distinct does not support tiling"):
137+
get_ir_metadata_for_aggregation(agg, "count_distinct_item_id")

0 commit comments

Comments
 (0)