Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix linting
Signed-off-by: HaoXuAI <sduxuhao@gmail.com>
  • Loading branch information
HaoXuAI committed Apr 4, 2025
commit 2825ee4592e392a019dcffae045084a4b01f3490
36 changes: 20 additions & 16 deletions sdk/python/feast/infra/compute_engines/base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from abc import ABC
from dataclasses import dataclass
from typing import Union, List
from datetime import datetime
from typing import Union

import pandas as pd
import pyarrow as pa

from feast import RepoConfig, BatchFeatureView, StreamFeatureView
from feast.infra.materialization.batch_materialization_engine import MaterializationTask, MaterializationJob
from feast import BatchFeatureView, RepoConfig, StreamFeatureView
from feast.infra.materialization.batch_materialization_engine import (
MaterializationJob,
MaterializationTask,
)
from feast.infra.offline_stores.offline_store import OfflineStore
from feast.infra.online_stores.online_store import OnlineStore
from feast.infra.registry.registry import Registry
Expand All @@ -15,10 +19,12 @@
@dataclass
class HistoricalRetrievalTask:
entity_df: Union[pd.DataFrame, str]
feature_views: List[Union[BatchFeatureView, StreamFeatureView]]
full_feature_names: bool
feature_view: Union[BatchFeatureView, StreamFeatureView]
full_feature_name: bool
registry: Registry
config: RepoConfig
start_time: datetime
end_time: datetime


class ComputeEngine(ABC):
Expand All @@ -27,23 +33,21 @@ class ComputeEngine(ABC):
"""

def __init__(
self,
*,
registry: Registry,
repo_config: RepoConfig,
offline_store: OfflineStore,
online_store: OnlineStore,
**kwargs,
self,
*,
registry: Registry,
repo_config: RepoConfig,
offline_store: OfflineStore,
online_store: OnlineStore,
**kwargs,
):
self.registry = registry
self.repo_config = repo_config
self.offline_store = offline_store
self.online_store = online_store

def materialize(self,
task: MaterializationTask) -> MaterializationJob:
def materialize(self, task: MaterializationTask) -> MaterializationJob:
raise NotImplementedError

def get_historical_features(self,
task: HistoricalRetrievalTask) -> pa.Table:
def get_historical_features(self, task: HistoricalRetrievalTask) -> pa.Table:
raise NotImplementedError
39 changes: 18 additions & 21 deletions sdk/python/feast/infra/compute_engines/dag/builder.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,57 @@
from abc import ABC, abstractmethod
from typing import Union

from feast import BatchFeatureView, StreamFeatureView
from feast.infra.compute_engines.dag.plan import ExecutionPlan
from feast import BatchFeatureView, StreamFeatureView, FeatureView
from feast.infra.compute_engines.base import HistoricalRetrievalTask
from feast.infra.compute_engines.dag.plan import ExecutionPlan
from feast.infra.materialization.batch_materialization_engine import MaterializationTask
from feast.infra.compute_engines.dag.node import DAGNode


class DAGBuilder(ABC):
def __init__(self,
feature_view: Union[BatchFeatureView, StreamFeatureView],
task: Union[MaterializationTask, HistoricalRetrievalTask]
):
def __init__(
self,
feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView],
task: Union[MaterializationTask, HistoricalRetrievalTask],
):
self.feature_view = feature_view
self.task = task
self.nodes = []
self.nodes: list[DAGNode] = []

@abstractmethod
def build_source_node(self):
raise NotImplementedError

@abstractmethod
def build_aggregation_node(self,
input_node):
def build_aggregation_node(self, input_node):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should remove the 'node' suffix from all of there, as I think it's implicit.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah sounds good

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought, I think it's better to keep the 'node' suffix, as it makes the output more descriptive and clearly indicates that it's a node, which can then be chained in the build method

raise NotImplementedError

@abstractmethod
def build_join_node(self,
input_node):
def build_join_node(self, input_node):
raise NotImplementedError

@abstractmethod
def build_transformation_node(self,
input_node):
def build_transformation_node(self, input_node):
raise NotImplementedError

@abstractmethod
def build_output_nodes(self,
input_node):
def build_output_nodes(self, input_node):
raise NotImplementedError

@abstractmethod
def build_validation_node(self,
input_node):
def build_validation_node(self, input_node):
raise

def build(self) -> ExecutionPlan:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not: build_dag(self)

Copy link
Copy Markdown
Collaborator Author

@HaoXuAI HaoXuAI Apr 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest to keep build(), more consistent with the builder.build() pattern.

last_node = self.build_source_node()

if getattr(self.feature_view.transformation, "requires_aggregation", False):
if hasattr(self.feature_view, "aggregation") and self.feature_view.aggregation is not None:
last_node = self.build_aggregation_node(last_node)

if self._should_join():
last_node = self.build_join_node(last_node)

if self.feature_view.transformation:
if hasattr(self.feature_view, "feature_transformation") and self.feature_view.feature_transformation:
last_node = self.build_transformation_node(last_node)

if getattr(self.feature_view, "enable_validation", False):
Expand All @@ -65,6 +62,6 @@ def build(self) -> ExecutionPlan:

def _should_join(self):
return (
self.feature_view.compute_config.join_strategy == "engine"
or self.task.config.compute_engine.get("point_in_time_join") == "engine"
self.feature_view.compute_config.join_strategy == "engine"
or self.task.config.compute_engine.get("point_in_time_join") == "engine"
)
7 changes: 4 additions & 3 deletions sdk/python/feast/infra/compute_engines/dag/model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from enum import Enum
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Union

import pandas as pd
from typing import Union, List, Dict

from feast.entity import Entity
from feast.infra.compute_engines.dag.value import DAGValue
from feast.infra.offline_stores.offline_store import OfflineStore
from feast.infra.online_stores.online_store import OnlineStore
from feast.repo_config import RepoConfig
from feast.infra.compute_engines.dag.value import DAGValue


class DAGFormat(str, Enum):
Expand Down
13 changes: 5 additions & 8 deletions sdk/python/feast/infra/compute_engines/dag/node.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,26 @@
from abc import ABC, abstractmethod
from typing import List

from feast.infra.compute_engines.dag.model import ExecutionContext
from infra.compute_engines.dag.value import DAGValue

from feast.infra.compute_engines.dag.model import ExecutionContext


class DAGNode(ABC):
name: str
inputs: List["DAGNode"]
outputs: List["DAGNode"]

def __init__(self,
name: str):
def __init__(self, name: str):
self.name = name
self.inputs = []
self.outputs = []

def add_input(self,
node: "DAGNode"):
def add_input(self, node: "DAGNode"):
if node in self.inputs:
raise ValueError(f"Input node {node.name} already added to {self.name}")
self.inputs.append(node)
node.outputs.append(self)

@abstractmethod
def execute(self,
context: ExecutionContext) -> DAGValue:
...
def execute(self, context: ExecutionContext) -> DAGValue: ...
4 changes: 2 additions & 2 deletions sdk/python/feast/infra/compute_engines/dag/plan.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import List

from feast.infra.compute_engines.dag.model import ExecutionContext
from feast.infra.compute_engines.dag.node import DAGNode
from feast.infra.compute_engines.dag.value import DAGValue
from feast.infra.compute_engines.dag.model import ExecutionContext


class ExecutionPlan:
Expand Down Expand Up @@ -39,6 +39,7 @@ class ExecutionPlan:
This approach is inspired by execution DAGs in systems like Apache Spark,
Apache Beam, and Dask — but specialized for Feast’s feature computation domain.
"""

def __init__(self, nodes: List[DAGNode]):
self.nodes = nodes

Expand All @@ -60,4 +61,3 @@ def execute(self, context: ExecutionContext) -> DAGValue:

# Return output of final node
return node_outputs[self.nodes[-1].name]

8 changes: 2 additions & 6 deletions sdk/python/feast/infra/compute_engines/dag/value.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,11 @@


class DAGValue:
def __init__(self,
data: Any,
format: DAGFormat,
metadata: Optional[dict] = None):
def __init__(self, data: Any, format: DAGFormat, metadata: Optional[dict] = None):
self.data = data
self.format = format
self.metadata = metadata or {}

def assert_format(self,
expected: DAGFormat):
def assert_format(self, expected: DAGFormat):
if self.format != expected:
raise ValueError(f"Expected format {expected}, but got {self.format}")
30 changes: 15 additions & 15 deletions sdk/python/feast/infra/compute_engines/spark/compute.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import pyarrow as pa

from feast.infra.compute_engines.base import ComputeEngine, HistoricalRetrievalTask
from feast.infra.compute_engines.spark.spark_dag_builder import SparkDAGBuilder
from feast.infra.materialization.batch_materialization_engine import MaterializationTask, MaterializationJob, \
MaterializationJobStatus
from feast.infra.materialization.contrib.spark.spark_materialization_engine import SparkMaterializationJob
from feast.infra.compute_engines.dag.model import ExecutionContext
from feast.infra.compute_engines.spark.spark_dag_builder import SparkDAGBuilder
from feast.infra.materialization.batch_materialization_engine import (
MaterializationJob,
MaterializationJobStatus,
MaterializationTask,
)
from feast.infra.materialization.contrib.spark.spark_materialization_engine import (
SparkMaterializationJob,
)


class SparkComputeEngine(ComputeEngine):
Expand All @@ -22,7 +28,7 @@ def materialize(self, task: MaterializationTask) -> MaterializationJob:
repo_config=self.repo_config,
offline_store=self.offline_store,
online_store=self.online_store,
entity_defs=entities
entity_defs=entities,
)

# ✅ 2. Construct DAG and run it
Expand All @@ -35,30 +41,24 @@ def materialize(self, task: MaterializationTask) -> MaterializationJob:

# ✅ 3. Report success
return SparkMaterializationJob(
job_id=job_id,
status=MaterializationJobStatus.SUCCEEDED
job_id=job_id, status=MaterializationJobStatus.SUCCEEDED
)

except Exception as e:
# 🛑 Handle failure
return SparkMaterializationJob(
job_id=job_id,
status=MaterializationJobStatus.ERROR,
error=e
job_id=job_id, status=MaterializationJobStatus.ERROR, error=e
)

def get_historical_features(self, task: HistoricalRetrievalTask) -> pa.Table:
# ✅ 1. Validate input
assert len(task.feature_views) == 1, "Multi-view support not yet implemented"
feature_view = task.feature_views[0]

if isinstance(task.entity_df, str):
raise NotImplementedError("SQL-based entity_df is not yet supported in DAG")

# ✅ 2. Build typed execution context
entity_defs = [
task.registry.get_entity(name, task.config.project)
for name in feature_view.entities
for name in task.feature_view.entities
]

context = ExecutionContext(
Expand All @@ -71,7 +71,7 @@ def get_historical_features(self, task: HistoricalRetrievalTask) -> pa.Table:
)

# ✅ 3. Construct and execute DAG
builder = SparkDAGBuilder(feature_view=feature_view, task=task)
builder = SparkDAGBuilder(feature_view=task.feature_view, task=task)
plan = builder.build()

result = plan.execute(context=context)
Expand Down
Loading