Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: Support distinct count aggregation [#6116]
Signed-off-by: Nick Quinn <nicholas_quinn@apple.com>
  • Loading branch information
nquinn408 authored and nickquinn408 committed Mar 17, 2026
commit 04d55a03eca8a4a1f3b9ace50d60570b4e51aca0
10 changes: 8 additions & 2 deletions sdk/python/feast/aggregation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class Aggregation:

Attributes:
column: str # Column name of the feature we are aggregating.
function: str # Provided built in aggregations sum, max, min, count mean
function: str # Provided built in aggregations sum, max, min, count, mean, count_distinct
time_window: timedelta # The time window for this aggregation.
slide_interval: timedelta # The sliding window for these aggregations
name: str # Optional override for the output feature name (defaults to {function}_{column})
Expand Down Expand Up @@ -118,6 +118,11 @@ def resolved_name(self, time_window: Optional[timedelta] = None) -> str:
return base


_FUNCTION_ALIASES: Dict[str, str] = {
"count_distinct": "nunique",
}


def aggregation_specs_to_agg_ops(
agg_specs: Iterable[Any],
*,
Expand All @@ -128,7 +133,8 @@ def aggregation_specs_to_agg_ops(
if getattr(agg, "time_window", None) is not None:
raise ValueError(time_window_unsupported_error_message)
alias = getattr(agg, "name", None) or f"{agg.function}_{agg.column}"
agg_ops[alias] = (agg.function, agg.column)
func_name = _FUNCTION_ALIASES.get(agg.function, agg.function)
agg_ops[alias] = (func_name, agg.column)
return agg_ops


Expand Down
6 changes: 6 additions & 0 deletions sdk/python/feast/aggregation/tiling/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ def get_ir_metadata_for_aggregation(
),
)

elif agg_type == "count_distinct":
raise ValueError(
"count_distinct does not support tiling. "
"Use enable_tiling=False or choose an algebraic aggregation (sum, count, min, max)."
)

else:
# Unknown aggregation: treat as algebraic
return ([], IRMetadata(type="algebraic"))
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@


class PolarsBackend(DataFrameBackend):
_POLARS_FUNC_MAP = {"nunique": "n_unique"}

def columns(self, df: pl.DataFrame) -> list[str]:
return df.columns

Expand All @@ -21,7 +23,7 @@ def join(self, left: pl.DataFrame, right: pl.DataFrame, on, how) -> pl.DataFrame

def groupby_agg(self, df: pl.DataFrame, group_keys, agg_ops) -> pl.DataFrame:
agg_exprs = [
getattr(pl.col(col), func)().alias(alias)
getattr(pl.col(col), self._POLARS_FUNC_MAP.get(func, func))().alias(alias)
for alias, (func, col) in agg_ops.items()
]
return df.groupby(group_keys).agg(agg_exprs)
Expand Down
4 changes: 4 additions & 0 deletions sdk/python/feast/infra/compute_engines/ray/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,8 @@ def _execute_standard_aggregation(self, dataset: Dataset) -> DAGValue:
agg_dict[feature_name] = (agg.column, "std")
elif agg.function == "var":
agg_dict[feature_name] = (agg.column, "var")
elif agg.function == "count_distinct":
agg_dict[feature_name] = (agg.column, "nunique")
else:
raise ValueError(f"Unknown aggregation function: {agg.function}.")

Expand Down Expand Up @@ -531,6 +533,8 @@ def _fallback_pandas_aggregation(self, dataset: Dataset, agg_dict: dict) -> Data
result = grouped[column].std()
elif function == "var":
result = grouped[column].var()
elif function == "nunique":
result = grouped[column].nunique()
else:
raise ValueError(f"Unknown aggregation function: {function}.")

Expand Down
7 changes: 5 additions & 2 deletions sdk/python/feast/infra/compute_engines/spark/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,11 @@ def _execute_standard_aggregation(self, input_df: DataFrame) -> DAGValue:
"""Execute standard Spark aggregation (existing logic)."""
agg_exprs = []
for agg in self.aggregations:
func = getattr(F, agg.function)
expr = func(agg.column).alias(agg.resolved_name(agg.time_window))
if agg.function == "count_distinct":
func_expr = F.countDistinct(agg.column)
else:
func_expr = getattr(F, agg.function)(agg.column)
expr = func_expr.alias(agg.resolved_name(agg.time_window))
agg_exprs.append(expr)

if any(agg.time_window for agg in self.aggregations):
Expand Down
39 changes: 39 additions & 0 deletions sdk/python/tests/unit/test_aggregation_ops.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from datetime import timedelta

import pandas as pd
import pytest

from feast.aggregation import Aggregation, aggregation_specs_to_agg_ops
from feast.aggregation.tiling.base import get_ir_metadata_for_aggregation


class DummyAggregation:
Expand Down Expand Up @@ -96,3 +98,40 @@ def test_aggregation_round_trip_with_name():
restored = Aggregation.from_proto(proto)
assert restored.name == "sum_seconds_watched_per_ad_1d"
assert restored == agg


def test_count_distinct_agg_ops():
"""aggregation_specs_to_agg_ops maps count_distinct to the nunique pandas function."""
agg_specs = [DummyAggregation(function="count_distinct", column="item_id")]

agg_ops = aggregation_specs_to_agg_ops(
agg_specs,
time_window_unsupported_error_message="no windows",
)

assert agg_ops == {"count_distinct_item_id": ("nunique", "item_id")}


def test_count_distinct_result():
"""count_distinct via nunique returns the number of unique values per group."""
from feast.infra.compute_engines.backends.pandas_backend import PandasBackend

agg_specs = [DummyAggregation(function="count_distinct", column="item_id")]
agg_ops = aggregation_specs_to_agg_ops(
agg_specs,
time_window_unsupported_error_message="no windows",
)

df = pd.DataFrame({"user": ["A", "A", "B"], "item_id": [1, 2, 1]})
result = PandasBackend().groupby_agg(df, ["user"], agg_ops)
result = result.set_index("user")

assert result.loc["A", "count_distinct_item_id"] == 2
assert result.loc["B", "count_distinct_item_id"] == 1


def test_count_distinct_tiling_raises():
"""get_ir_metadata_for_aggregation raises ValueError for count_distinct."""
agg = Aggregation(column="item_id", function="count_distinct")
with pytest.raises(ValueError, match="count_distinct does not support tiling"):
get_ir_metadata_for_aggregation(agg, "count_distinct_item_id")
Loading