Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update API
Signed-off-by: HaoXuAI <sduxuhao@gmail.com>
  • Loading branch information
HaoXuAI committed Apr 13, 2025
commit 3a5cf921005c5a69e6611d6d38e5304f6ec5d4e4
34 changes: 23 additions & 11 deletions sdk/python/feast/infra/compute_engines/feature_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,28 +55,40 @@ def build_output_nodes(self, input_node):
def build_validation_node(self, input_node):
raise

def _should_aggregate(self):
return (
hasattr(self.feature_view, "aggregations")
and self.feature_view.aggregations is not None
and len(self.feature_view.aggregations) > 0
)

def _should_transform(self):
return (
hasattr(self.feature_view, "feature_transformation")
and self.feature_view.feature_transformation
)

def _should_validate(self):
return getattr(self.feature_view, "enable_validation", False)

def build(self) -> ExecutionPlan:
last_node = self.build_source_node()

# PIT join entities to the feature data, and perform filtering
last_node = self.build_join_node(last_node)
if isinstance(self.task, HistoricalRetrievalTask):
last_node = self.build_join_node(last_node)

last_node = self.build_filter_node(last_node)

if (
hasattr(self.feature_view, "aggregations")
and self.feature_view.aggregations is not None
):
if self._should_aggregate():
last_node = self.build_aggregation_node(last_node)
else:
elif isinstance(self.task, HistoricalRetrievalTask):
last_node = self.build_dedup_node(last_node)

if (
hasattr(self.feature_view, "feature_transformation")
and self.feature_view.feature_transformation
):
if self._should_transform():
last_node = self.build_transformation_node(last_node)

if getattr(self.feature_view, "enable_validation", False):
if self._should_validate():
last_node = self.build_validation_node(last_node)

self.build_output_nodes(last_node)
Expand Down
1 change: 1 addition & 0 deletions sdk/python/feast/infra/compute_engines/spark/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def materialize(self, task: MaterializationTask) -> MaterializationJob:
)

except Exception as e:
raise e
# 🛑 Handle failure
return SparkMaterializationJob(
job_id=job_id, status=MaterializationJobStatus.ERROR, error=e
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def build_filter_node(self, input_node):
filter_expr = None
if hasattr(self.feature_view, "filter"):
filter_expr = self.feature_view.filter
node = SparkFilterNode("filter", input_node, self.feature_view, filter_expr)
node = SparkFilterNode(
"filter", self.spark_session, input_node, self.feature_view, filter_expr
)
self.nodes.append(node)
return node

Expand Down
49 changes: 33 additions & 16 deletions sdk/python/feast/infra/compute_engines/spark/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,29 @@
)
from feast.utils import _get_fields_with_aliases

ENTITY_TS_ALIAS = "__entity_event_timestamp"


# Rename entity_df event_timestamp_col to match feature_df
def rename_entity_ts_column(
spark_session: SparkSession, entity_df: DataFrame
) -> DataFrame:
# check if entity_ts_alias already exists
if ENTITY_TS_ALIAS in entity_df.columns:
return entity_df

entity_schema = _get_entity_schema(
spark_session=spark_session,
entity_df=entity_df,
)
event_timestamp_col = infer_event_timestamp_from_entity_df(
entity_schema=entity_schema,
)
if not isinstance(entity_df, DataFrame):
entity_df = spark_session.createDataFrame(entity_df)
entity_df = entity_df.withColumnRenamed(event_timestamp_col, ENTITY_TS_ALIAS)
return entity_df


@dataclass
class SparkJoinContext:
Expand Down Expand Up @@ -233,19 +256,12 @@ def execute(self, context: ExecutionContext) -> DAGValue:
join_keys, feature_cols, ts_col, created_ts_col = context.column_info

# Rename entity_df event_timestamp_col to match feature_df
entity_schema = _get_entity_schema(
entity_df = rename_entity_ts_column(
spark_session=self.spark_session,
entity_df=entity_df,
)
event_timestamp_col = infer_event_timestamp_from_entity_df(
entity_schema=entity_schema,
)
entity_ts_alias = "__entity_event_timestamp"
if not isinstance(entity_df, DataFrame):
entity_df = self.spark_session.createDataFrame(entity_df)
entity_df = entity_df.withColumnRenamed(event_timestamp_col, entity_ts_alias)

# Perform left join + event timestamp filtering
# Perform left join on entity df
joined = feature_df.join(entity_df, on=join_keys, how="left")

return DAGValue(
Expand All @@ -257,11 +273,13 @@ class SparkFilterNode(DAGNode):
def __init__(
self,
name: str,
spark_session: SparkSession,
input_node: DAGNode,
feature_view: Union[BatchFeatureView, StreamFeatureView],
filter_condition: Optional[str] = None,
):
super().__init__(name)
self.spark_session = spark_session
self.feature_view = feature_view
self.add_input(input_node)
self.filter_condition = filter_condition
Expand All @@ -274,22 +292,22 @@ def execute(self, context: ExecutionContext) -> DAGValue:
# Get timestamp fields from feature view
_, _, ts_col, _ = context.column_info

# Apply filter condition
entity_ts_alias = "__entity_event_timestamp"
# Optional filter: feature.ts <= entity.event_timestamp
filtered_df = input_df
filtered_df = filtered_df.filter(F.col(ts_col) <= F.col(entity_ts_alias))
if ENTITY_TS_ALIAS in input_df.columns:
filtered_df = filtered_df.filter(F.col(ts_col) <= F.col(ENTITY_TS_ALIAS))

# Optional TTL filter: feature.ts >= entity.event_timestamp - ttl
if self.feature_view.ttl:
ttl_seconds = int(self.feature_view.ttl.total_seconds())
lower_bound = F.col(entity_ts_alias) - F.expr(
lower_bound = F.col(ENTITY_TS_ALIAS) - F.expr(
f"INTERVAL {ttl_seconds} seconds"
)
filtered_df = filtered_df.filter(F.col(ts_col) >= lower_bound)

# Optional custom filter condition
if self.filter_condition:
filtered_df = input_df.filter(self.filter_condition)
filtered_df = filtered_df.filter(self.filter_condition)

return DAGValue(
data=filtered_df,
Expand Down Expand Up @@ -321,8 +339,7 @@ def execute(self, context: ExecutionContext) -> DAGValue:

# Dedup based on join keys and event timestamp
# Dedup with row_number
entity_ts_alias = "__entity_event_timestamp"
partition_cols = join_keys + [entity_ts_alias]
partition_cols = join_keys + [ENTITY_TS_ALIAS]
ordering = [F.col(ts_col).desc()]
if created_ts_col:
ordering.append(F.col(created_ts_col).desc())
Expand Down
1 change: 1 addition & 0 deletions sdk/python/feast/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def _convert_arrow_fv_to_proto(
if isinstance(table, pyarrow.Table):
table = table.to_batches()[0]

# TODO: This will break if the feature view has aggregations or transformations
columns = [
(field.name, field.dtype.to_value_type()) for field in feature_view.features
] + list(join_keys.items())
Expand Down

This file was deleted.

Loading
Loading