Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feat: Adding optional name to Aggregation (#5994)
Signed-off-by: Nick Quinn <nicholas_quinn@apple.com>
  • Loading branch information
nickquinn408 committed Mar 9, 2026
commit 02df4b21f787d0ef6f33c709f17374e7a49a393f
1 change: 1 addition & 0 deletions protos/feast/core/Aggregation.proto
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ message Aggregation {
string function = 2;
google.protobuf.Duration time_window = 3;
google.protobuf.Duration slide_interval = 4;
string name = 5;
}
9 changes: 8 additions & 1 deletion sdk/python/feast/aggregation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,22 @@ class Aggregation:
function: str # Provided built in aggregations sum, max, min, count mean
time_window: timedelta # The time window for this aggregation.
slide_interval: timedelta # The sliding window for these aggregations
name: str # Optional override for the output feature name (defaults to {function}_{column})
"""

column: str
function: str
time_window: Optional[timedelta]
slide_interval: Optional[timedelta]
name: str

def __init__(
self,
column: Optional[str] = "",
function: Optional[str] = "",
time_window: Optional[timedelta] = None,
slide_interval: Optional[timedelta] = None,
name: Optional[str] = None,
):
self.column = column or ""
self.function = function or ""
Expand All @@ -42,6 +45,7 @@ def __init__(
self.slide_interval = self.time_window
else:
self.slide_interval = slide_interval
self.name = name or ""
Comment thread
nquinn408 marked this conversation as resolved.

def to_proto(self) -> AggregationProto:
window_duration = None
Expand All @@ -59,6 +63,7 @@ def to_proto(self) -> AggregationProto:
function=self.function,
time_window=window_duration,
slide_interval=slide_interval_duration,
name=self.name,
)

@classmethod
Expand All @@ -79,6 +84,7 @@ def from_proto(cls, agg_proto: AggregationProto):
function=agg_proto.function,
time_window=time_window,
slide_interval=slide_interval,
name=agg_proto.name or None,
)
return aggregation

Expand All @@ -91,6 +97,7 @@ def __eq__(self, other):
or self.function != other.function
or self.time_window != other.time_window
or self.slide_interval != other.slide_interval
or self.name != other.name
):
return False

Expand All @@ -106,7 +113,7 @@ def aggregation_specs_to_agg_ops(
for agg in agg_specs:
if getattr(agg, "time_window", None) is not None:
raise ValueError(time_window_unsupported_error_message)
alias = f"{agg.function}_{agg.column}"
alias = getattr(agg, "name", None) or f"{agg.function}_{agg.column}"
agg_ops[alias] = (agg.function, agg.column)
return agg_ops

Expand Down
4 changes: 2 additions & 2 deletions sdk/python/feast/protos/feast/core/Aggregation_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion sdk/python/feast/protos/feast/core/Aggregation_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ class Aggregation(google.protobuf.message.Message):
FUNCTION_FIELD_NUMBER: builtins.int
TIME_WINDOW_FIELD_NUMBER: builtins.int
SLIDE_INTERVAL_FIELD_NUMBER: builtins.int
NAME_FIELD_NUMBER: builtins.int
column: builtins.str
function: builtins.str
name: builtins.str
@property
def time_window(self) -> google.protobuf.duration_pb2.Duration: ...
@property
Expand All @@ -35,8 +37,9 @@ class Aggregation(google.protobuf.message.Message):
function: builtins.str = ...,
time_window: google.protobuf.duration_pb2.Duration | None = ...,
slide_interval: google.protobuf.duration_pb2.Duration | None = ...,
name: builtins.str = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["slide_interval", b"slide_interval", "time_window", b"time_window"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["column", b"column", "function", b"function", "slide_interval", b"slide_interval", "time_window", b"time_window"]) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["column", b"column", "function", b"function", "name", b"name", "slide_interval", b"slide_interval", "time_window", b"time_window"]) -> None: ...

global___Aggregation = Aggregation
54 changes: 52 additions & 2 deletions sdk/python/tests/unit/test_aggregation_ops.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from datetime import timedelta

import pytest

from feast.aggregation import aggregation_specs_to_agg_ops
from feast.aggregation import Aggregation, aggregation_specs_to_agg_ops


class DummyAggregation:
def __init__(self, *, function: str, column: str, time_window=None):
def __init__(self, *, function: str, column: str, time_window=None, name: str = ""):
self.function = function
self.column = column
self.time_window = time_window
self.name = name


def test_aggregation_specs_to_agg_ops_success():
Expand Down Expand Up @@ -42,3 +45,50 @@ def test_aggregation_specs_to_agg_ops_time_window_unsupported(error_message: str
agg_specs,
time_window_unsupported_error_message=error_message,
)


def test_aggregation_specs_to_agg_ops_custom_name():
agg_specs = [
DummyAggregation(function="sum", column="seconds_watched", name="sum_seconds_watched_per_ad_1d"),
]

agg_ops = aggregation_specs_to_agg_ops(
agg_specs,
time_window_unsupported_error_message="Time window aggregation is not supported.",
)

assert agg_ops == {
"sum_seconds_watched_per_ad_1d": ("sum", "seconds_watched"),
}


def test_aggregation_specs_to_agg_ops_mixed_names():
agg_specs = [
DummyAggregation(function="sum", column="trips", name="total_trips"),
DummyAggregation(function="mean", column="fare"),
]

agg_ops = aggregation_specs_to_agg_ops(
agg_specs,
time_window_unsupported_error_message="Time window aggregation is not supported.",
)

assert agg_ops == {
"total_trips": ("sum", "trips"),
"mean_fare": ("mean", "fare"),
}


def test_aggregation_round_trip_with_name():
agg = Aggregation(
column="seconds_watched",
function="sum",
time_window=timedelta(days=1),
name="sum_seconds_watched_per_ad_1d",
)
proto = agg.to_proto()
assert proto.name == "sum_seconds_watched_per_ad_1d"

restored = Aggregation.from_proto(proto)
assert restored.name == "sum_seconds_watched_per_ad_1d"
assert restored == agg
Loading