Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update API
Signed-off-by: HaoXuAI <sduxuhao@gmail.com>
  • Loading branch information
HaoXuAI committed Apr 13, 2025
commit e9362de5623ca333e6db8cbc61003853be50dcc4
4 changes: 4 additions & 0 deletions sdk/python/feast/batch_feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import dill

from feast import flags_helper
from feast.aggregation import Aggregation
from feast.data_source import DataSource
from feast.entity import Entity
from feast.feature_view import FeatureView
Expand Down Expand Up @@ -65,6 +66,7 @@ class BatchFeatureView(FeatureView):
udf_string: Optional[str]
feature_transformation: Transformation
batch_engine: Optional[Field]
aggregations: Optional[List[Aggregation]]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm to be honest, I don't love putting the transformation in the FeatureView.

I think it'd be more intuitive to put Aggregation under Transformation and make the FeatureViews purely represent schemas available for online or offline.

Copy link
Copy Markdown
Collaborator Author

@HaoXuAI HaoXuAI Apr 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. This is only to be consistent with the current StreamFeatureView and Aggregation class. It makes sense to have Aggregation one type of Transformation (as well as for Filter), but that will need some work to refactor.
For Tecton, they put Aggregate in feature (what Feast called schema) as

features=[
        Aggregate(
            input_column=Field("transaction", Int64),
            function="count",
            time_window=TimeWindow(window_size=timedelta(days=1)),
        )
    ],

In some way that also make sense but for Feast it will needs much more work to refactor the schema API.

For Chrono they put it into a GroupBy API, which is similar to our FeatureView:

v1 = GroupBy(
    sources=[source],
    keys=["user_id"], # We are aggregating by user
    online=True,
    aggregations=[Aggregation(
            input_column="refund_amt",
            operation=Operation.SUM,
            windows=window_sizes
        ), # The sum of purchases prices in various windows
    ],
)

I can look into how to merge Aggregation with Transformation together in the next PR, and added it to a TODO for now.


def __init__(
self,
Expand All @@ -84,6 +86,7 @@ def __init__(
udf_string: Optional[str] = "",
feature_transformation: Optional[Transformation] = None,
batch_engine: Optional[Field] = None,
aggregations: Optional[List[Aggregation]] = None,
):
if not flags_helper.is_test():
warnings.warn(
Expand All @@ -108,6 +111,7 @@ def __init__(
feature_transformation or self.get_feature_transformation()
)
self.batch_engine = batch_engine
self.aggregations = aggregations or []

super().__init__(
name=name,
Expand Down
40 changes: 40 additions & 0 deletions sdk/python/feast/infra/compute_engines/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from abc import ABC
from typing import Union

import pyarrow as pa

from feast import RepoConfig
from feast.infra.compute_engines.dag.context import ColumnInfo, ExecutionContext
from feast.infra.compute_engines.tasks import HistoricalRetrievalTask
from feast.infra.materialization.batch_materialization_engine import (
MaterializationJob,
Expand All @@ -11,6 +13,7 @@
from feast.infra.offline_stores.offline_store import OfflineStore
from feast.infra.online_stores.online_store import OnlineStore
from feast.infra.registry.registry import Registry
from feast.utils import _get_column_names


class ComputeEngine(ABC):
Expand Down Expand Up @@ -41,3 +44,40 @@ def materialize(self, task: MaterializationTask) -> MaterializationJob:

def get_historical_features(self, task: HistoricalRetrievalTask) -> pa.Table:
raise NotImplementedError

def get_execution_context(
self,
task: Union[MaterializationTask, HistoricalRetrievalTask],
) -> ExecutionContext:
entity_defs = [
self.registry.get_entity(name, task.project)
for name in task.feature_view.entities
]
entity_df = None
if task.entity_df is not None:
entity_df = task.entity_df

column_info = self.get_column_info(task)
return ExecutionContext(
project=task.project,
repo_config=self.repo_config,
offline_store=self.offline_store,
online_store=self.online_store,
entity_defs=entity_defs,
column_info=column_info,
entity_df=entity_df,
)

def get_column_info(
self,
task: Union[MaterializationTask, HistoricalRetrievalTask],
) -> ColumnInfo:
join_keys, feature_cols, ts_col, created_ts_col = _get_column_names(
task.feature_view, self.registry.list_entities(task.project)
)
return ColumnInfo(
join_keys=join_keys,
feature_cols=feature_cols,
ts_col=ts_col,
created_ts_col=created_ts_col,
)
17 changes: 16 additions & 1 deletion sdk/python/feast/infra/compute_engines/dag/context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Dict, List, Union
from typing import Dict, List, Optional, Union

import pandas as pd

Expand All @@ -10,6 +10,20 @@
from feast.repo_config import RepoConfig


@dataclass
class ColumnInfo:
join_keys: List[str]
feature_cols: List[str]
ts_col: str
created_ts_col: Optional[str]

def __iter__(self):
yield self.join_keys
yield self.feature_cols
yield self.ts_col
yield self.created_ts_col


@dataclass
class ExecutionContext:
"""
Expand Down Expand Up @@ -47,6 +61,7 @@ class ExecutionContext:
repo_config: RepoConfig
offline_store: OfflineStore
online_store: OnlineStore
column_info: ColumnInfo
entity_defs: List[Entity]
entity_df: Union[pd.DataFrame, None] = None
node_outputs: Dict[str, DAGValue] = field(default_factory=dict)
20 changes: 16 additions & 4 deletions sdk/python/feast/infra/compute_engines/feature_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@ def build_aggregation_node(self, input_node):
def build_join_node(self, input_node):
raise NotImplementedError

@abstractmethod
def build_filter_node(self, input_node):
raise NotImplementedError

@abstractmethod
def build_dedup_node(self, input_node):
raise NotImplementedError

@abstractmethod
def build_transformation_node(self, input_node):
raise NotImplementedError
Expand All @@ -50,13 +58,17 @@ def build_validation_node(self, input_node):
def build(self) -> ExecutionPlan:
last_node = self.build_source_node()

# PIT join entities to the feature data, and perform filtering
last_node = self.build_join_node(last_node)
last_node = self.build_filter_node(last_node)

if (
hasattr(self.feature_view, "aggregation")
and self.feature_view.aggregation is not None
hasattr(self.feature_view, "aggregations")
and self.feature_view.aggregations is not None
):
last_node = self.build_aggregation_node(last_node)

last_node = self.build_join_node(last_node)
else:
last_node = self.build_dedup_node(last_node)

if (
hasattr(self.feature_view, "feature_transformation")
Expand Down
73 changes: 31 additions & 42 deletions sdk/python/feast/infra/compute_engines/spark/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,11 @@ def __init__(
def materialize(self, task: MaterializationTask) -> MaterializationJob:
job_id = f"{task.feature_view.name}-{task.start_time}-{task.end_time}"

try:
# ✅ 1. Build typed execution context
entities = []
for entity_name in task.feature_view.entities:
entities.append(self.registry.get_entity(entity_name, task.project))

context = ExecutionContext(
project=task.project,
repo_config=self.repo_config,
offline_store=self.offline_store,
online_store=self.online_store,
entity_defs=entities,
)
# ✅ 1. Build typed execution context
context = self.get_execution_context(task)

# ✅ 2. Construct DAG and run it
try:
# ✅ 2. Construct Feature Builder and run it
builder = SparkFeatureBuilder(
spark_session=self.spark_session,
feature_view=task.feature_view,
Expand All @@ -74,33 +64,32 @@ def get_historical_features(self, task: HistoricalRetrievalTask) -> RetrievalJob
if isinstance(task.entity_df, str):
raise NotImplementedError("SQL-based entity_df is not yet supported in DAG")

# ✅ 2. Build typed execution context
entity_defs = [
task.registry.get_entity(name, task.config.project)
for name in task.feature_view.entities
]

context = ExecutionContext(
project=task.config.project,
repo_config=task.config,
offline_store=self.offline_store,
online_store=self.online_store,
entity_defs=entity_defs,
entity_df=task.entity_df,
)
# ✅ 1. Build typed execution context
context = self.get_execution_context(task)

# ✅ 3. Construct and execute DAG
builder = SparkFeatureBuilder(
spark_session=self.spark_session,
feature_view=task.feature_view,
task=task,
)
plan = builder.build()
try:
# ✅ 2. Construct Feature Builder and run it
builder = SparkFeatureBuilder(
spark_session=self.spark_session,
feature_view=task.feature_view,
task=task,
)
plan = builder.build()

return SparkDAGRetrievalJob(
plan=plan,
spark_session=self.spark_session,
context=context,
config=task.config,
full_feature_names=task.full_feature_name,
)
return SparkDAGRetrievalJob(
plan=plan,
spark_session=self.spark_session,
context=context,
config=self.repo_config,
full_feature_names=task.full_feature_name,
)
except Exception as e:
# 🛑 Handle failure
return SparkDAGRetrievalJob(
plan=None,
spark_session=self.spark_session,
context=context,
config=self.repo_config,
full_feature_names=task.full_feature_name,
error=e,
)
17 changes: 17 additions & 0 deletions sdk/python/feast/infra/compute_engines/spark/feature_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from feast.infra.compute_engines.feature_builder import FeatureBuilder
from feast.infra.compute_engines.spark.node import (
SparkAggregationNode,
SparkDedupNode,
SparkFilterNode,
SparkHistoricalRetrievalReadNode,
SparkJoinNode,
SparkMaterializationReadNode,
Expand Down Expand Up @@ -54,6 +56,21 @@ def build_join_node(self, input_node):
self.nodes.append(node)
return node

def build_filter_node(self, input_node):
filter_expr = None
if hasattr(self.feature_view, "filter"):
filter_expr = self.feature_view.filter
node = SparkFilterNode("filter", input_node, self.feature_view, filter_expr)
self.nodes.append(node)
return node

def build_dedup_node(self, input_node):
node = SparkDedupNode(
"dedup", input_node, self.feature_view, self.spark_session
)
self.nodes.append(node)
return node

def build_transformation_node(self, input_node):
udf_name = self.feature_view.feature_transformation.name
udf = self.feature_view.feature_transformation.udf
Expand Down
9 changes: 7 additions & 2 deletions sdk/python/feast/infra/compute_engines/spark/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ class SparkDAGRetrievalJob(SparkRetrievalJob):
def __init__(
self,
spark_session: SparkSession,
plan: ExecutionPlan,
context: ExecutionContext,
full_feature_names: bool,
config: RepoConfig,
plan: Optional[ExecutionPlan] = None,
on_demand_feature_views: Optional[List[OnDemandFeatureView]] = None,
metadata: Optional[RetrievalMetadata] = None,
error: Optional[BaseException] = None,
):
super().__init__(
spark_session=spark_session,
Expand All @@ -34,7 +35,11 @@ def __init__(
self._plan = plan
self._context = context
self._metadata = metadata
self._spark_df = None # Will be populated on first access
self._spark_df = None
self._error = error

def error(self) -> Optional[BaseException]:
return self._error

def _ensure_executed(self):
if self._spark_df is None:
Expand Down
Loading