Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Checkpoint
Signed-off-by: HaoXuAI <sduxuhao@gmail.com>
  • Loading branch information
HaoXuAI committed Jul 8, 2025
commit 8ef4f4f9bf4cd339832f2db66b197aefd73c5b63
8 changes: 5 additions & 3 deletions sdk/python/feast/batch_feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from feast.entity import Entity
from feast.feature_view import FeatureView
from feast.field import Field
from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto
from feast.transformation.base import Transformation
from feast.transformation.mode import TransformationMode

Expand Down Expand Up @@ -53,6 +52,7 @@ class BatchFeatureView(FeatureView):
entities: List[str]
ttl: Optional[timedelta]
source: DataSource
sink_source: Optional[DataSource] = None
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should just call it sink?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, will update it

schema: List[Field]
entity_columns: List[Field]
features: List[Field]
Expand All @@ -75,6 +75,7 @@ def __init__(
name: str,
mode: Union[TransformationMode, str] = TransformationMode.PYTHON,
source: Union[DataSource, "BatchFeatureView", List["BatchFeatureView"]],
sink_source: Optional[DataSource] = None,
entities: Optional[List[Entity]] = None,
ttl: Optional[timedelta] = None,
tags: Optional[Dict[str, str]] = None,
Expand Down Expand Up @@ -115,12 +116,13 @@ def __init__(
description=description,
owner=owner,
schema=schema,
source=source,
source=source, # type: ignore[arg-type]
sink_source=sink_source,
)

def get_feature_transformation(self) -> Optional[Transformation]:
if not self.udf:
return
return None
if self.mode in (
TransformationMode.PANDAS,
TransformationMode.PYTHON,
Expand Down
33 changes: 22 additions & 11 deletions sdk/python/feast/feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def __init__(
*,
name: str,
source: Union[DataSource, "FeatureView", List["FeatureView"]],
sink_source: Optional[DataSource] = None,
schema: Optional[List[Field]] = None,
entities: Optional[List[Entity]] = None,
ttl: Optional[timedelta] = timedelta(days=0),
Expand Down Expand Up @@ -146,34 +147,44 @@ def __init__(
schema = schema or []

# Normalize source
self.stream_source = None
self.data_source: Optional[DataSource] = None
self.source_views: List[FeatureView] = []

if isinstance(source, DataSource):
self.data_source = source
elif isinstance(source, FeatureView):
self.source_views = [source]
elif isinstance(source, list) and all(isinstance(sv, FeatureView) for sv in source):
elif isinstance(source, list) and all(
isinstance(sv, FeatureView) for sv in source
):
self.source_views = source
else:
raise TypeError("source must be a DataSource, a FeatureView, or a list of FeatureViews.")
raise TypeError(
"source must be a DataSource, a FeatureView, or a list of FeatureViews."
)

# Set up stream/batch sources
# Set up stream, batch and derived view sources
if (
isinstance(self.data_source, PushSource)
or isinstance(self.data_source, KafkaSource)
or isinstance(self.data_source, KinesisSource)
):
self.stream_source = source
# Stream source definition
self.stream_source = self.data_source
if not self.data_source.batch_source:
raise ValueError(
f"A batch_source needs to be specified for stream source `{source.name}`"
f"A batch_source needs to be specified for stream source `{self.data_source.name}`"
)
else:
self.batch_source = self.data_source.batch_source
else:
self.stream_source = None
self.batch_source = self.data_source.batch_source
elif self.data_source:
# Batch source definition
self.batch_source = self.data_source
else:
# Derived view source definition
if not sink_source:
raise ValueError("Derived FeatureView must specify `sink_source`.")
self.batch_source = sink_source

# Initialize features and entity columns.
features: List[Field] = []
Expand Down Expand Up @@ -215,7 +226,7 @@ def __init__(
)

# TODO(felixwang9817): Add more robust validation of features.
if source is not None:
if self.batch_source is not None:
cols = [field.name for field in schema]
for col in cols:
if (
Expand Down Expand Up @@ -451,7 +462,7 @@ def from_proto(cls, feature_view_proto: FeatureViewProto):
if feature_view_proto.spec.ttl.ToNanoseconds() == 0
else feature_view_proto.spec.ttl.ToTimedelta()
),
source=batch_source if batch_source else source_views
source=batch_source if batch_source else source_views,
)
if stream_source:
feature_view.stream_source = stream_source
Expand Down
27 changes: 7 additions & 20 deletions sdk/python/feast/infra/compute_engines/dag/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,29 +60,16 @@ def to_sql(self, context: ExecutionContext) -> str:
"""
raise NotImplementedError("SQL generation is not implemented yet.")

def to_dag(self):
def to_dag(self) -> str:
"""
Generate a textual DAG representation for debugging.

Returns:
str: A multi-line string showing the DAG structure.
Render the DAG as a multiline string with full node expansion (no visited shortcut).
"""
lines = []
seen = set()

def dfs(node: DAGNode, indent=0):
def walk(node: DAGNode, indent: int = 0) -> List[str]:
prefix = " " * indent
if node.name in seen:
lines.append(f"{prefix}- {node.name} (visited)")
return
seen.add(node.name)
lines.append(f"{prefix}- {node.name}")
lines = [f"{prefix}- {node.name}"]
for input_node in node.inputs:
dfs(input_node, indent + 1)

for node in self.nodes:
dfs(node)

return "\n".join(lines)

lines.extend(walk(input_node, indent + 1))
return lines

return "\n".join(walk(self.nodes[-1]))
6 changes: 3 additions & 3 deletions sdk/python/feast/infra/compute_engines/feature_builder.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from abc import ABC, abstractmethod
from typing import List, Optional, Union, Dict
from typing import Dict, List, Optional, Union

from feast import BatchFeatureView, FeatureView, StreamFeatureView
from feast.infra.common.materialization_job import MaterializationTask
from feast.infra.common.retrieval_task import HistoricalRetrievalTask
from feast.infra.compute_engines.algorithms.topo import topo_sort
from feast.infra.compute_engines.dag.context import ColumnInfo
from feast.infra.compute_engines.dag.node import DAGNode
from feast.infra.compute_engines.dag.plan import ExecutionPlan
Expand All @@ -11,7 +13,6 @@
)
from feast.infra.registry.base_registry import BaseRegistry
from feast.utils import _get_column_names
from feast.infra.compute_engines.algorithms.topo import topo_sort


class FeatureBuilder(ABC):
Expand Down Expand Up @@ -78,7 +79,6 @@ def _should_dedupe(self, view):
return isinstance(self.task, HistoricalRetrievalTask) or self.task.only_latest

def _build(self, view, input_nodes: Optional[List[DAGNode]]) -> DAGNode:

# Step 1: build source node
if view.data_source:
last_node = self.build_source_node(view)
Expand Down
19 changes: 11 additions & 8 deletions sdk/python/feast/infra/compute_engines/feature_resolver.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import List, Optional, Set

from feast.feature_view import FeatureView
from feast.infra.compute_engines.dag.node import DAGNode
from feast.infra.compute_engines.algorithms.topo import topo_sort
from feast.infra.compute_engines.dag.context import ExecutionContext
from feast.infra.compute_engines.dag.node import DAGNode
from feast.infra.compute_engines.dag.value import DAGValue


Expand All @@ -12,10 +12,12 @@ class FeatureViewNode(DAGNode):
Logical representation of a node in the FeatureView dependency DAG.
"""

def __init__(self, view: FeatureView):
def __init__(
self, view: FeatureView, inputs: Optional[List["FeatureViewNode"]] = None
):
super().__init__(name=view.name)
self.view: FeatureView = view
self.inputs: List["FeatureViewNode"] = []
self.inputs: List["FeatureViewNode"] = inputs or [] # type: ignore

def execute(self, context: ExecutionContext) -> DAGValue:
raise NotImplementedError(
Expand Down Expand Up @@ -68,15 +70,16 @@ def _walk(self, view: FeatureView):
self._node_cache[view.name] = node

self._resolution_path.append(view.name)
for upstream_view in view.source_views:
input_node = self._walk(upstream_view)
node.inputs.append(input_node)
if view.source_views:
for upstream_view in view.source_views:
input_node = self._walk(upstream_view)
node.inputs.append(input_node)
self._resolution_path.pop()

return node

def topo_sort(self, root: FeatureViewNode) -> List[FeatureViewNode]:
return topo_sort(root)
return topo_sort(root) # type: ignore

def debug_dag(self, node: FeatureViewNode, depth=0):
"""
Expand All @@ -89,4 +92,4 @@ def debug_dag(self, node: FeatureViewNode, depth=0):
indent = " " * depth
print(f"{indent}- {node.view.name}")
for input_node in node.inputs:
self.debug_dag(input_node, depth + 1)
self.debug_dag(input_node, depth + 1) # type: ignore
1 change: 1 addition & 0 deletions sdk/python/feast/infra/compute_engines/local/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def execute(self, context: ExecutionContext) -> ArrowTableValue:
context=context,
start_time=self.start_time,
end_time=self.end_time,
column_info=self.column_info,
)
arrow_table = retrieval_job.to_arrow()
if self.column_info.field_mapping:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@ def build_aggregation_node(self, view, input_node):
self.nodes.append(node)
return node

def build_join_node(self,
view,
input_nodes):
def build_join_node(self, view, input_nodes):
column_info = self.get_column_info(view)
node = SparkJoinNode(
name=f"{view.name}_join",
Expand Down
18 changes: 12 additions & 6 deletions sdk/python/feast/infra/compute_engines/spark/nodes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime, timedelta
from typing import List, Optional, Union, cast
from typing import Callable, List, Optional, Union, cast

from pyspark.sql import DataFrame, SparkSession, Window
from pyspark.sql import functions as F
Expand Down Expand Up @@ -156,7 +156,7 @@ def __init__(
column_info: ColumnInfo,
spark_session: SparkSession,
inputs: Optional[List[DAGNode]] = None,
how: str = "left"
how: str = "left",
):
super().__init__(name, inputs=inputs or [])
self.column_info = column_info
Expand All @@ -171,7 +171,9 @@ def execute(self, context: ExecutionContext) -> DAGValue:
# Join all input DataFrames on join_keys
joined_df = input_values[0].data
for dag_value in input_values[1:]:
joined_df = joined_df.join(dag_value.data, on=self.column_info.join_keys, how=self.how)
joined_df = joined_df.join(
dag_value.data, on=self.column_info.join_keys, how=self.how
)

# If entity_df is provided, join it in last
entity_df = context.entity_df
Expand All @@ -180,7 +182,9 @@ def execute(self, context: ExecutionContext) -> DAGValue:
spark_session=self.spark_session,
entity_df=entity_df,
)
joined_df = joined_df.join(entity_df, on=self.column_info.join_keys, how=self.how)
joined_df = joined_df.join(
entity_df, on=self.column_info.join_keys, how=self.how
)

return DAGValue(
data=joined_df,
Expand Down Expand Up @@ -332,7 +336,7 @@ def execute(self, context: ExecutionContext) -> DAGValue:


class SparkTransformationNode(DAGNode):
def __init__(self, name: str, udf: callable, inputs: List[DAGNode]):
def __init__(self, name: str, udf: Callable, inputs: List[DAGNode]):
super().__init__(name, inputs)
self.udf = udf

Expand All @@ -343,7 +347,9 @@ def execute(self, context: ExecutionContext) -> DAGValue:

input_dfs: List[DataFrame] = [val.data for val in input_values]

print(f"[SparkTransformationNode] Executing transform on {len(input_dfs)} input(s).")
print(
f"[SparkTransformationNode] Executing transform on {len(input_dfs)} input(s)."
)

transformed_df = self.udf(*input_dfs)

Expand Down
Loading
Loading