Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add doc
Signed-off-by: HaoXuAI <sduxuhao@gmail.com>
  • Loading branch information
HaoXuAI committed Apr 8, 2025
commit 1c9ae31bf73598a57345242afc08266afaf8438a
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Feature transformations can be executed by three types of "transformation engine

1. The Feast Feature Server
2. An Offline Store (e.g., Snowflake, BigQuery, DuckDB, Spark, etc.)
3. A Stream processor (e.g., Flink or Spark Streaming)
3. An Compute Engine (see more [here](../../reference/compute-engine/README.md))
Comment thread
HaoXuAI marked this conversation as resolved.
Outdated

The three transformation engines are coupled with the [communication pattern used for writes](write-patterns.md).

Expand Down
87 changes: 87 additions & 0 deletions docs/reference/compute-engine/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# 🧠 ComputeEngine (WIP)

The `ComputeEngine` is Feast’s pluggable abstraction for executing feature pipelines — including transformations, aggregations, joins, and materialization/get_historical_features — on a backend of your choice (e.g., Spark, PyArrow, Pandas, Ray).
Comment thread
HaoXuAI marked this conversation as resolved.
Outdated

It powers both:

- `materialize()` – for batch and stream generation of features to offline/online stores
- `get_historical_features()` – for point-in-time correct training dataset retrieval

This system builds and executes DAGs (Directed Acyclic Graphs) of typed operations, enabling modular and scalable workflows.

---

## 🧠 Core Concepts

| Component | Description |
|--------------------|--------------------------------------------------------------------|
| `ComputeEngine` | Interface for executing materialization and retrieval tasks |
| `DAGBuilder` | Constructs a DAG for a specific backend |
| `DAGNode` | Represents a logical operation (read, aggregate, join, etc.) |
| `ExecutionPlan` | Executes nodes in dependency order and stores intermediate outputs |
| `ExecutionContext` | Holds config, registry, stores, entity data, and node outputs |

---

## ✨ Available Engines

### 🔥 SparkComputeEngine

- Distributed DAG execution via Apache Spark
- Supports point-in-time joins and large-scale materialization
- Integrates with `SparkOfflineStore` and `SparkMaterializationJob`

### 🧪 LocalComputeEngine (WIP)

- Runs on Arrow + Pandas (or optionally DuckDB)
- Designed for local dev, testing, or lightweight feature generation

---

## 🛠️ Example DAG Flow
`Read → Aggregate → Join → Transform → Write`

Each step is implemented as a `DAGNode`. An `ExecutionPlan` executes these nodes in topological order, caching `DAGValue` outputs.

---

## 🧩 Implementing a Custom Compute Engine

To create your own compute engine:

1. **Implement the interface**

```python
class MyComputeEngine(ComputeEngine):
def materialize(self, task: MaterializationTask) -> MaterializationJob:
...

def get_historical_features(self, task: HistoricalRetrievalTask) -> pa.Table:
...
```

2. Create a DAGBuilder
```python
class MyDAGBuilder(DAGBuilder):
def build_source_node(self): ...
def build_aggregation_node(self, input_node): ...
def build_join_node(self, input_node): ...
def build_transformation_node(self, input_node): ...
def build_output_nodes(self, input_node): ...
```

3. Define DAGNode subclasses
* ReadNode, AggregationNode, JoinNode, WriteNode, etc.
* Each DAGNode.execute(context) -> DAGValue

4. Return an ExecutionPlan
* ExecutionPlan stores DAG nodes in topological order
* Automatically handles intermediate value caching

## 🚧 Roadmap
- [x] Modular, backend-agnostic DAG execution framework
- [x] Spark engine with native support for materialization + PIT joins
- [ ] PyArrow + Pandas engine for local compute
- [ ] Native multi-feature-view DAG optimization
- [ ] DAG validation, metrics, and debug output
- [ ] Scalable distributed backend via Ray or Polars
4 changes: 4 additions & 0 deletions sdk/python/feast/infra/compute_engines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ class HistoricalRetrievalTask:
class ComputeEngine(ABC):
"""
The interface that Feast uses to control the compute system that handles materialization and get_historical_features.
Each engine must implement:
- materialize(): to generate and persist features
- get_historical_features(): to perform point-in-time correct joins
Engines should use DAGBuilder and DAGNode abstractions to build modular, pluggable workflows.
"""

def __init__(
Expand Down
2 changes: 2 additions & 0 deletions sdk/python/feast/infra/compute_engines/dag/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@


class DAGBuilder(ABC):
""" """

def __init__(
self,
feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView],
Expand Down
7 changes: 7 additions & 0 deletions sdk/python/feast/infra/compute_engines/dag/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,10 @@ def execute(self, context: ExecutionContext) -> DAGValue:

# Return output of final node
return node_outputs[self.nodes[-1].name]

def to_sql(self, context: ExecutionContext) -> str:
"""
Generate SQL query for the entire execution plan.
This is a placeholder and should be implemented in subclasses.
"""
raise NotImplementedError("SQL generation is not implemented yet.")
43 changes: 34 additions & 9 deletions sdk/python/feast/infra/compute_engines/spark/compute.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import pyarrow as pa

from feast.infra.compute_engines.base import ComputeEngine, HistoricalRetrievalTask
from feast.infra.compute_engines.dag.context import ExecutionContext
from feast.infra.compute_engines.spark.job import SparkDAGRetrievalJob
from feast.infra.compute_engines.spark.spark_dag_builder import SparkDAGBuilder
from feast.infra.compute_engines.spark.utils import get_or_create_new_spark_session
from feast.infra.materialization.batch_materialization_engine import (
MaterializationJob,
MaterializationJobStatus,
Expand All @@ -11,9 +11,27 @@
from feast.infra.materialization.contrib.spark.spark_materialization_engine import (
SparkMaterializationJob,
)
from feast.infra.offline_stores.offline_store import RetrievalJob


class SparkComputeEngine(ComputeEngine):
def __init__(
self,
offline_store,
online_store,
registry,
repo_config,
**kwargs,
):
super().__init__(
offline_store=offline_store,
online_store=online_store,
registry=registry,
repo_config=repo_config,
**kwargs,
)
self.spark_session = get_or_create_new_spark_session()

def materialize(self, task: MaterializationTask) -> MaterializationJob:
job_id = f"{task.feature_view.name}-{task.start_time}-{task.end_time}"

Expand All @@ -33,6 +51,7 @@ def materialize(self, task: MaterializationTask) -> MaterializationJob:

# ✅ 2. Construct DAG and run it
builder = SparkDAGBuilder(
spark_session=self.spark_session,
feature_view=task.feature_view,
task=task,
)
Expand All @@ -50,7 +69,7 @@ def materialize(self, task: MaterializationTask) -> MaterializationJob:
job_id=job_id, status=MaterializationJobStatus.ERROR, error=e
)

def get_historical_features(self, task: HistoricalRetrievalTask) -> pa.Table:
def get_historical_features(self, task: HistoricalRetrievalTask) -> RetrievalJob:
if isinstance(task.entity_df, str):
raise NotImplementedError("SQL-based entity_df is not yet supported in DAG")

Expand All @@ -70,11 +89,17 @@ def get_historical_features(self, task: HistoricalRetrievalTask) -> pa.Table:
)

# ✅ 3. Construct and execute DAG
builder = SparkDAGBuilder(feature_view=task.feature_view, task=task)
builder = SparkDAGBuilder(
spark_session=self.spark_session,
feature_view=task.feature_view,
task=task,
)
plan = builder.build()

result = plan.execute(context=context)
spark_df = result.data # should be a Spark DataFrame

# ✅ 4. Return as Arrow
return spark_df.toPandas().to_arrow()
return SparkDAGRetrievalJob(
plan=plan,
spark_session=self.spark_session,
context=context,
config=task.config,
full_feature_names=task.full_feature_name,
)
51 changes: 51 additions & 0 deletions sdk/python/feast/infra/compute_engines/spark/job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import List, Optional

import pyspark
from pyspark.sql import SparkSession

from feast import OnDemandFeatureView, RepoConfig
from feast.infra.compute_engines.dag.context import ExecutionContext
from feast.infra.compute_engines.dag.plan import ExecutionPlan
from feast.infra.offline_stores.contrib.spark_offline_store.spark import (
SparkRetrievalJob,
)
from feast.infra.offline_stores.offline_store import RetrievalMetadata


class SparkDAGRetrievalJob(SparkRetrievalJob):
def __init__(
self,
spark_session: SparkSession,
plan: ExecutionPlan,
context: ExecutionContext,
full_feature_names: bool,
config: RepoConfig,
on_demand_feature_views: Optional[List[OnDemandFeatureView]] = None,
metadata: Optional[RetrievalMetadata] = None,
):
super().__init__(
spark_session=spark_session,
query="",
full_feature_names=full_feature_names,
config=config,
on_demand_feature_views=on_demand_feature_views,
metadata=metadata,
)
self._plan = plan
self._context = context
self._metadata = metadata
self._spark_df = None # Will be populated on first access

def _ensure_executed(self):
if self._spark_df is None:
result = self._plan.execute(self._context)
self._spark_df = result.data

def to_spark_df(self) -> pyspark.sql.DataFrame:
self._ensure_executed()
assert self._spark_df is not None, "Execution plan did not produce a DataFrame"
return self._spark_df

def to_sql(self) -> str:
self._ensure_executed()
return self._plan.to_sql(self._context)
78 changes: 75 additions & 3 deletions sdk/python/feast/infra/compute_engines/spark/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
from datetime import datetime
from typing import Dict, List, Optional, Union, cast

from pyspark.sql import DataFrame, Window
from infra.compute_engines.dag.context import ExecutionContext
from pyspark.sql import DataFrame, SparkSession, Window
from pyspark.sql import functions as F

from feast import BatchFeatureView, StreamFeatureView
from feast.aggregation import Aggregation
from feast.infra.compute_engines.base import HistoricalRetrievalTask
from feast.infra.compute_engines.dag.context import ExecutionContext
from feast.infra.compute_engines.dag.model import DAGFormat
from feast.infra.compute_engines.dag.node import DAGNode
from feast.infra.compute_engines.dag.value import DAGValue
Expand All @@ -19,6 +19,11 @@
)
from feast.infra.offline_stores.contrib.spark_offline_store.spark import (
SparkRetrievalJob,
_get_entity_df_event_timestamp_range,
_get_entity_schema,
)
from feast.infra.offline_stores.offline_utils import (
infer_event_timestamp_from_entity_df,
)
from feast.utils import _get_column_names

Expand All @@ -37,7 +42,7 @@ class SparkJoinContext:
full_feature_names: bool = False # apply feature view name prefix


class SparkReadNode(DAGNode):
class SparkMaterializationReadNode(DAGNode):
def __init__(
self, name: str, task: Union[MaterializationTask, HistoricalRetrievalTask]
):
Expand Down Expand Up @@ -82,6 +87,73 @@ def execute(self, context: ExecutionContext) -> DAGValue:
)


class SparkHistoricalRetrievalReadNode(DAGNode):
def __init__(
self, name: str, task: HistoricalRetrievalTask, spark_session: SparkSession
):
super().__init__(name)
self.task = task
self.spark_session = spark_session

def execute(self, context: ExecutionContext) -> DAGValue:
"""
Read data from the offline store on the Spark engine.
TODO: Some functionality is duplicated with SparkMaterializationReadNode and spark get_historical_features.
Args:
context: SparkExecutionContext
Returns: DAGValue
"""
offline_store = context.offline_store
fv = self.task.feature_view
entity_df = context.entity_df
source = fv.batch_source
entities = context.entity_defs

(
join_key_columns,
feature_name_columns,
timestamp_field,
_,
) = _get_column_names(fv, entities)

entity_schema = _get_entity_schema(
spark_session=self.spark_session,
entity_df=entity_df,
)
event_timestamp_col = infer_event_timestamp_from_entity_df(
entity_schema=entity_schema,
)
entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range(
entity_df,
event_timestamp_col,
self.spark_session,
)
min_ts = entity_df_event_timestamp_range[0]
max_ts = entity_df_event_timestamp_range[1]

retrieval_job = offline_store.pull_all_from_table_or_query(
config=context.repo_config,
data_source=source,
join_key_columns=join_key_columns,
feature_name_columns=feature_name_columns,
timestamp_field=timestamp_field,
start_date=min_ts,
end_date=max_ts,
)
spark_df = cast(SparkRetrievalJob, retrieval_job).to_spark_df()

return DAGValue(
data=spark_df,
format=DAGFormat.SPARK,
metadata={
"source": "feature_view_batch_source",
"timestamp_field": timestamp_field,
"start_date": min_ts,
"end_date": max_ts,
},
)


class SparkAggregationNode(DAGNode):
def __init__(
self,
Expand Down
Loading