Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions protos/feast/core/Aggregation.proto
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ message Aggregation {
string function = 2;
google.protobuf.Duration time_window = 3;
google.protobuf.Duration slide_interval = 4;
string name = 5;
}
9 changes: 8 additions & 1 deletion sdk/python/feast/aggregation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,22 @@ class Aggregation:
function: str # Provided built in aggregations sum, max, min, count mean
time_window: timedelta # The time window for this aggregation.
slide_interval: timedelta # The sliding window for these aggregations
name: str # Optional override for the output feature name (defaults to {function}_{column})
"""

column: str
function: str
time_window: Optional[timedelta]
slide_interval: Optional[timedelta]
name: str

def __init__(
self,
column: Optional[str] = "",
function: Optional[str] = "",
time_window: Optional[timedelta] = None,
slide_interval: Optional[timedelta] = None,
name: Optional[str] = None,
):
self.column = column or ""
self.function = function or ""
Expand All @@ -42,6 +45,7 @@ def __init__(
self.slide_interval = self.time_window
else:
self.slide_interval = slide_interval
self.name = name or ""
Comment thread
nquinn408 marked this conversation as resolved.

def to_proto(self) -> AggregationProto:
window_duration = None
Expand All @@ -59,6 +63,7 @@ def to_proto(self) -> AggregationProto:
function=self.function,
time_window=window_duration,
slide_interval=slide_interval_duration,
name=self.name,
)

@classmethod
Expand All @@ -79,6 +84,7 @@ def from_proto(cls, agg_proto: AggregationProto):
function=agg_proto.function,
time_window=time_window,
slide_interval=slide_interval,
name=agg_proto.name or None,
)
return aggregation

Expand All @@ -91,6 +97,7 @@ def __eq__(self, other):
or self.function != other.function
or self.time_window != other.time_window
or self.slide_interval != other.slide_interval
or self.name != other.name
):
return False

Expand All @@ -106,7 +113,7 @@ def aggregation_specs_to_agg_ops(
for agg in agg_specs:
if getattr(agg, "time_window", None) is not None:
raise ValueError(time_window_unsupported_error_message)
alias = f"{agg.function}_{agg.column}"
alias = getattr(agg, "name", None) or f"{agg.function}_{agg.column}"
agg_ops[alias] = (agg.function, agg.column)
return agg_ops

Expand Down
4 changes: 2 additions & 2 deletions sdk/python/feast/protos/feast/core/Aggregation_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion sdk/python/feast/protos/feast/core/Aggregation_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ class Aggregation(google.protobuf.message.Message):
FUNCTION_FIELD_NUMBER: builtins.int
TIME_WINDOW_FIELD_NUMBER: builtins.int
SLIDE_INTERVAL_FIELD_NUMBER: builtins.int
NAME_FIELD_NUMBER: builtins.int
column: builtins.str
function: builtins.str
name: builtins.str
@property
def time_window(self) -> google.protobuf.duration_pb2.Duration: ...
@property
Expand All @@ -35,8 +37,9 @@ class Aggregation(google.protobuf.message.Message):
function: builtins.str = ...,
time_window: google.protobuf.duration_pb2.Duration | None = ...,
slide_interval: google.protobuf.duration_pb2.Duration | None = ...,
name: builtins.str = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["slide_interval", b"slide_interval", "time_window", b"time_window"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["column", b"column", "function", b"function", "slide_interval", b"slide_interval", "time_window", b"time_window"]) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["column", b"column", "function", b"function", "name", b"name", "slide_interval", b"slide_interval", "time_window", b"time_window"]) -> None: ...

global___Aggregation = Aggregation
54 changes: 52 additions & 2 deletions sdk/python/tests/unit/test_aggregation_ops.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from datetime import timedelta

import pytest

from feast.aggregation import aggregation_specs_to_agg_ops
from feast.aggregation import Aggregation, aggregation_specs_to_agg_ops


class DummyAggregation:
def __init__(self, *, function: str, column: str, time_window=None):
def __init__(self, *, function: str, column: str, time_window=None, name: str = ""):
self.function = function
self.column = column
self.time_window = time_window
self.name = name


def test_aggregation_specs_to_agg_ops_success():
Expand Down Expand Up @@ -42,3 +45,50 @@ def test_aggregation_specs_to_agg_ops_time_window_unsupported(error_message: str
agg_specs,
time_window_unsupported_error_message=error_message,
)


def test_aggregation_specs_to_agg_ops_custom_name():
agg_specs = [
DummyAggregation(function="sum", column="seconds_watched", name="sum_seconds_watched_per_ad_1d"),
]

agg_ops = aggregation_specs_to_agg_ops(
agg_specs,
time_window_unsupported_error_message="Time window aggregation is not supported.",
)

assert agg_ops == {
"sum_seconds_watched_per_ad_1d": ("sum", "seconds_watched"),
}


def test_aggregation_specs_to_agg_ops_mixed_names():
agg_specs = [
DummyAggregation(function="sum", column="trips", name="total_trips"),
DummyAggregation(function="mean", column="fare"),
]

agg_ops = aggregation_specs_to_agg_ops(
agg_specs,
time_window_unsupported_error_message="Time window aggregation is not supported.",
)

assert agg_ops == {
"total_trips": ("sum", "trips"),
"mean_fare": ("mean", "fare"),
}


def test_aggregation_round_trip_with_name():
agg = Aggregation(
column="seconds_watched",
function="sum",
time_window=timedelta(days=1),
name="sum_seconds_watched_per_ad_1d",
)
proto = agg.to_proto()
assert proto.name == "sum_seconds_watched_per_ad_1d"

restored = Aggregation.from_proto(proto)
assert restored.name == "sum_seconds_watched_per_ad_1d"
assert restored == agg
Loading