Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix testing
Signed-off-by: HaoXuAI <sduxuhao@gmail.com>
  • Loading branch information
HaoXuAI committed Jul 8, 2025
commit ea50f43da4b7cac5a58a8e4396db349729f83b9d
6 changes: 4 additions & 2 deletions sdk/python/feast/infra/compute_engines/feature_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def build_transformation_node(self, view, input_nodes):
raise NotImplementedError

@abstractmethod
def build_output_nodes(self, final_node):
def build_output_nodes(self, view, final_node):
raise NotImplementedError

@abstractmethod
Expand Down Expand Up @@ -131,7 +131,9 @@ def build(self) -> ExecutionPlan:
view_to_node[view.name] = dag_node

# Step 3: Build output node
final_node = self.build_output_nodes(view_to_node[self.feature_view.name])
final_node = self.build_output_nodes(
self.feature_view, view_to_node[self.feature_view.name]
)

# Step 4: Topo sort the final DAG from the output node (Physical DAG)
sorted_nodes = topo_sort(final_node)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def build_validation_node(self, view, input_node):
self.nodes.append(node)
return node

def build_output_nodes(self, input_node):
def build_output_nodes(self, view, input_node):
node = LocalOutputNode("output", self.dag_root.view, inputs=[input_node])
self.nodes.append(node)
return node
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/feast/infra/compute_engines/local/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
column_info: ColumnInfo,
backend: DataFrameBackend,
inputs: Optional[List["DAGNode"]] = None,
how: str = "left",
how: str = "inner",
):
super().__init__(name, inputs or [])
self.column_info = column_info
Expand Down
25 changes: 18 additions & 7 deletions sdk/python/feast/infra/compute_engines/spark/feature_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@ def build_source_node(self, view):
source = view.batch_source
column_info = self.get_column_info(view)
node = SparkReadNode(
"source", source, column_info, self.spark_session, start_time, end_time
f"{view.name}:source",
source,
column_info,
self.spark_session,
start_time,
end_time,
)
self.nodes.append(node)
return node
Expand All @@ -43,15 +48,19 @@ def build_aggregation_node(self, view, input_node):
group_by_keys = view.entities
timestamp_col = view.batch_source.timestamp_field
node = SparkAggregationNode(
"agg", agg_specs, group_by_keys, timestamp_col, inputs=[input_node]
f"{view.name}:agg",
agg_specs,
group_by_keys,
timestamp_col,
inputs=[input_node],
)
self.nodes.append(node)
return node

def build_join_node(self, view, input_nodes):
column_info = self.get_column_info(view)
node = SparkJoinNode(
name=f"{view.name}_join",
name=f"{view.name}:join",
column_info=column_info,
spark_session=self.spark_session,
inputs=input_nodes,
Expand All @@ -65,7 +74,7 @@ def build_filter_node(self, view, input_node):
ttl = getattr(view, "ttl", None)
column_info = self.get_column_info(view)
node = SparkFilterNode(
"filter",
f"{view.name}:filter",
column_info,
self.spark_session,
ttl,
Expand All @@ -78,7 +87,7 @@ def build_filter_node(self, view, input_node):
def build_dedup_node(self, view, input_node):
column_info = self.get_column_info(view)
node = SparkDedupNode(
"dedup", column_info, self.spark_session, inputs=[input_node]
f"{view.name}:dedup", column_info, self.spark_session, inputs=[input_node]
)
self.nodes.append(node)
return node
Expand All @@ -90,8 +99,10 @@ def build_transformation_node(self, view, input_nodes):
self.nodes.append(node)
return node

def build_output_nodes(self, input_node):
node = SparkWriteNode("output", self.dag_root.view, inputs=[input_node])
def build_output_nodes(self, view, input_node):
node = SparkWriteNode(
f"{view.name}:output", self.dag_root.view, inputs=[input_node]
)
self.nodes.append(node)
return node

Expand Down
36 changes: 28 additions & 8 deletions sdk/python/feast/infra/compute_engines/spark/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def __init__(
column_info: ColumnInfo,
spark_session: SparkSession,
inputs: Optional[List[DAGNode]] = None,
how: str = "left",
how: str = "inner",
):
super().__init__(name, inputs=inputs or [])
self.column_info = column_info
Expand All @@ -169,11 +169,28 @@ def execute(self, context: ExecutionContext) -> DAGValue:
val.assert_format(DAGFormat.SPARK)

# Join all input DataFrames on join_keys
joined_df = input_values[0].data
for dag_value in input_values[1:]:
joined_df = joined_df.join(
dag_value.data, on=self.column_info.join_keys, how=self.how
)
joined_df = None
for i, dag_value in enumerate(input_values):
df = dag_value.data

# Use original FeatureView name if available
fv_name = self.inputs[i].name.split(":")[0]
prefix = fv_name + "__"

# Skip renaming join keys to preserve join compatibility
renamed_cols = [
F.col(c).alias(f"{prefix}{c}")
if c not in self.column_info.join_keys
else F.col(c)
for c in df.columns
]
df = df.select(*renamed_cols)
if joined_df is None:
joined_df = df
else:
joined_df = joined_df.join(
df, on=self.column_info.join_keys, how=self.how
)

# If entity_df is provided, join it in last
entity_df = context.entity_df
Expand All @@ -182,8 +199,11 @@ def execute(self, context: ExecutionContext) -> DAGValue:
spark_session=self.spark_session,
entity_df=entity_df,
)
joined_df = joined_df.join(
entity_df, on=self.column_info.join_keys, how=self.how
if joined_df is None:
raise RuntimeError("No input features available to join with entity_df")

joined_df = entity_df.join(
joined_df, on=self.column_info.join_keys, how="left"
)

return DAGValue(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from tqdm import tqdm

from feast import BatchFeatureView, Field
from feast.aggregation import Aggregation
from feast.infra.common.materialization_job import (
MaterializationJobStatus,
MaterializationTask,
Expand Down Expand Up @@ -45,6 +46,26 @@ def create_base_feature_view(source):
)


def create_agg_feature_view(source):
return BatchFeatureView(
name="agg_hourly_driver_stats",
entities=[driver],
schema=[
Field(name="conv_rate", dtype=Float32),
Field(name="acc_rate", dtype=Float32),
Field(name="avg_daily_trips", dtype=Int64),
Field(name="driver_id", dtype=Int32),
],
online=True,
offline=True,
source=source,
aggregations=[
Aggregation(column="conv_rate", function="sum"),
Aggregation(column="acc_rate", function="avg"),
],
)


def create_chained_feature_view(base_fv: BatchFeatureView):
def transform_feature(df: DataFrame) -> DataFrame:
df = df.withColumn("conv_rate", df["conv_rate"] * 2)
Expand Down Expand Up @@ -126,3 +147,78 @@ def tqdm_builder(length):
)
finally:
spark_env.teardown()


@pytest.mark.integration
def test_spark_dag_materialize_multi_views():
spark_env = create_spark_environment()
fs = spark_env.feature_store
registry = fs.registry
source = create_feature_dataset(spark_env)

base_fv = create_base_feature_view(source)
chained_fv = create_chained_feature_view(base_fv)

multi_view = BatchFeatureView(
name="multi_view",
entities=[driver],
schema=[
Field(name="driver_id", dtype=Int32),
Field(name="daily_driver_stats__conv_rate", dtype=Float32),
Field(name="daily_driver_stats__acc_rate", dtype=Float32),
],
online=True,
offline=True,
source=[base_fv, chained_fv],
sink_source=SparkSource(
name="multi_view_sink",
path="/tmp/multi_view_sink",
file_format="parquet",
timestamp_field="daily_driver_stats__event_timestamp",
created_timestamp_column="daily_driver_stats__created",
),
)

def tqdm_builder(length):
return tqdm(total=length, ncols=100)

try:
fs.apply([driver, base_fv, chained_fv, multi_view])

# 🧪 Materialize multi-view
task = MaterializationTask(
project=fs.project,
feature_view=multi_view,
start_time=now - timedelta(days=2),
end_time=now,
tqdm_builder=tqdm_builder,
)

engine = SparkComputeEngine(
repo_config=spark_env.config,
offline_store=SparkOfflineStore(),
online_store=MagicMock(),
registry=registry,
)

jobs = engine.materialize(registry, task)

# ✅ Validate jobs ran
assert len(jobs) == 1
assert jobs[0].status() == MaterializationJobStatus.SUCCEEDED

_check_online_features(
fs=fs,
driver_id=1001,
feature="multi_view:daily_driver_stats__conv_rate",
expected_value=1.6,
full_feature_names=True,
)

entity_df = create_entity_df()

_check_offline_features(
fs=fs, feature="hourly_driver_stats:conv_rate", entity_df=entity_df, size=2
)
finally:
spark_env.teardown()
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def build_transformation_node(self, view, input_nodes):
def build_validation_node(self, view, input_node):
return MockDAGNode(f"Validate({view.name})", inputs=[input_node])

def build_output_nodes(self, final_node):
def build_output_nodes(self, view, final_node):
output_node = MockDAGNode(f"Output({final_node.name})", inputs=[final_node])
self.nodes.append(output_node)
return output_node
Expand Down
Loading