|
1 | 1 | from datetime import timedelta |
2 | 2 |
|
| 3 | +import pandas as pd |
3 | 4 | import pytest |
4 | 5 |
|
5 | 6 | from feast.aggregation import Aggregation, aggregation_specs_to_agg_ops |
| 7 | +from feast.aggregation.tiling.base import get_ir_metadata_for_aggregation |
6 | 8 |
|
7 | 9 |
|
8 | 10 | class DummyAggregation: |
@@ -96,3 +98,40 @@ def test_aggregation_round_trip_with_name(): |
96 | 98 | restored = Aggregation.from_proto(proto) |
97 | 99 | assert restored.name == "sum_seconds_watched_per_ad_1d" |
98 | 100 | assert restored == agg |
| 101 | + |
| 102 | + |
| 103 | +def test_count_distinct_agg_ops(): |
| 104 | + """aggregation_specs_to_agg_ops maps count_distinct to the nunique pandas function.""" |
| 105 | + agg_specs = [DummyAggregation(function="count_distinct", column="item_id")] |
| 106 | + |
| 107 | + agg_ops = aggregation_specs_to_agg_ops( |
| 108 | + agg_specs, |
| 109 | + time_window_unsupported_error_message="no windows", |
| 110 | + ) |
| 111 | + |
| 112 | + assert agg_ops == {"count_distinct_item_id": ("nunique", "item_id")} |
| 113 | + |
| 114 | + |
| 115 | +def test_count_distinct_result(): |
| 116 | + """count_distinct via nunique returns the number of unique values per group.""" |
| 117 | + from feast.infra.compute_engines.backends.pandas_backend import PandasBackend |
| 118 | + |
| 119 | + agg_specs = [DummyAggregation(function="count_distinct", column="item_id")] |
| 120 | + agg_ops = aggregation_specs_to_agg_ops( |
| 121 | + agg_specs, |
| 122 | + time_window_unsupported_error_message="no windows", |
| 123 | + ) |
| 124 | + |
| 125 | + df = pd.DataFrame({"user": ["A", "A", "B"], "item_id": [1, 2, 1]}) |
| 126 | + result = PandasBackend().groupby_agg(df, ["user"], agg_ops) |
| 127 | + result = result.set_index("user") |
| 128 | + |
| 129 | + assert result.loc["A", "count_distinct_item_id"] == 2 |
| 130 | + assert result.loc["B", "count_distinct_item_id"] == 1 |
| 131 | + |
| 132 | + |
| 133 | +def test_count_distinct_tiling_raises(): |
| 134 | + """get_ir_metadata_for_aggregation raises ValueError for count_distinct.""" |
| 135 | + agg = Aggregation(column="item_id", function="count_distinct") |
| 136 | + with pytest.raises(ValueError, match="count_distinct does not support tiling"): |
| 137 | + get_ir_metadata_for_aggregation(agg, "count_distinct_item_id") |
0 commit comments