diff --git a/docs/reference/compute-engine/README.md b/docs/reference/compute-engine/README.md index 50aaa5befab..75f29890046 100644 --- a/docs/reference/compute-engine/README.md +++ b/docs/reference/compute-engine/README.md @@ -31,10 +31,11 @@ This system builds and executes DAGs (Directed Acyclic Graphs) of typed operatio - Supports point-in-time joins and large-scale materialization - Integrates with `SparkOfflineStore` and `SparkMaterializationJob` -### 🧪 LocalComputeEngine (WIP) +### 🧪 LocalComputeEngine -- Runs on Arrow + Pandas (or optionally DuckDB) +- Runs on Arrow + Specified backend (e.g., Pandas, Polars) - Designed for local dev, testing, or lightweight feature generation +- Supports `LocalMaterializationJob` and `LocalHistoricalRetrievalJob` --- diff --git a/sdk/python/feast/infra/common/__init__.py b/sdk/python/feast/infra/common/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/sdk/python/feast/infra/common/materialization_job.py b/sdk/python/feast/infra/common/materialization_job.py new file mode 100644 index 00000000000..60ded6735a6 --- /dev/null +++ b/sdk/python/feast/infra/common/materialization_job.py @@ -0,0 +1,59 @@ +import enum +from abc import ABC, abstractmethod +from dataclasses import dataclass +from datetime import datetime +from typing import Callable, Optional, Union + +from tqdm import tqdm + +from feast import BatchFeatureView, FeatureView, StreamFeatureView + + +@dataclass +class MaterializationTask: + """ + A MaterializationTask represents a unit of data that needs to be materialized from an + offline store to an online store. + """ + + project: str + feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView] + start_time: datetime + end_time: datetime + tqdm_builder: Callable[[int], tqdm] + + +class MaterializationJobStatus(enum.Enum): + WAITING = 1 + RUNNING = 2 + AVAILABLE = 3 + ERROR = 4 + CANCELLING = 5 + CANCELLED = 6 + SUCCEEDED = 7 + PAUSED = 8 + RETRYING = 9 + + +class MaterializationJob(ABC): + """ + A MaterializationJob represents an ongoing or executed process that materializes data as per the + definition of a materialization task. + """ + + task: MaterializationTask + + @abstractmethod + def status(self) -> MaterializationJobStatus: ... + + @abstractmethod + def error(self) -> Optional[BaseException]: ... + + @abstractmethod + def should_be_retried(self) -> bool: ... + + @abstractmethod + def job_id(self) -> str: ... + + @abstractmethod + def url(self) -> Optional[str]: ... diff --git a/sdk/python/feast/infra/compute_engines/tasks.py b/sdk/python/feast/infra/common/retrieval_task.py similarity index 100% rename from sdk/python/feast/infra/compute_engines/tasks.py rename to sdk/python/feast/infra/common/retrieval_task.py diff --git a/sdk/python/feast/infra/compute_engines/base.py b/sdk/python/feast/infra/compute_engines/base.py index d5372d246aa..6e1a90f45b8 100644 --- a/sdk/python/feast/infra/compute_engines/base.py +++ b/sdk/python/feast/infra/compute_engines/base.py @@ -4,12 +4,12 @@ import pyarrow as pa from feast import RepoConfig -from feast.infra.compute_engines.dag.context import ColumnInfo, ExecutionContext -from feast.infra.compute_engines.tasks import HistoricalRetrievalTask -from feast.infra.materialization.batch_materialization_engine import ( +from feast.infra.common.materialization_job import ( MaterializationJob, MaterializationTask, ) +from feast.infra.common.retrieval_task import HistoricalRetrievalTask +from feast.infra.compute_engines.dag.context import ColumnInfo, ExecutionContext from feast.infra.offline_stores.offline_store import OfflineStore from feast.infra.online_stores.online_store import OnlineStore from feast.infra.registry.registry import Registry diff --git a/sdk/python/feast/infra/compute_engines/feature_builder.py b/sdk/python/feast/infra/compute_engines/feature_builder.py index cab32d47d26..324f82e7500 100644 --- a/sdk/python/feast/infra/compute_engines/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/feature_builder.py @@ -1,11 +1,10 @@ from abc import ABC, abstractmethod from typing import Union -from feast import BatchFeatureView, FeatureView, StreamFeatureView +from feast.infra.common.materialization_job import MaterializationTask +from feast.infra.common.retrieval_task import HistoricalRetrievalTask from feast.infra.compute_engines.dag.node import DAGNode from feast.infra.compute_engines.dag.plan import ExecutionPlan -from feast.infra.compute_engines.tasks import HistoricalRetrievalTask -from feast.infra.materialization.batch_materialization_engine import MaterializationTask class FeatureBuilder(ABC): @@ -16,10 +15,9 @@ class FeatureBuilder(ABC): def __init__( self, - feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView], task: Union[MaterializationTask, HistoricalRetrievalTask], ): - self.feature_view = feature_view + self.feature_view = task.feature_view self.task = task self.nodes: list[DAGNode] = [] diff --git a/sdk/python/feast/infra/compute_engines/local/__init__.py b/sdk/python/feast/infra/compute_engines/local/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/sdk/python/feast/infra/compute_engines/local/arrow_table_value.py b/sdk/python/feast/infra/compute_engines/local/arrow_table_value.py new file mode 100644 index 00000000000..cbd0c9f37ec --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/local/arrow_table_value.py @@ -0,0 +1,12 @@ +import pyarrow as pa + +from feast.infra.compute_engines.dag.model import DAGFormat +from feast.infra.compute_engines.dag.value import DAGValue + + +class ArrowTableValue(DAGValue): + def __init__(self, data: pa.Table): + super().__init__(data, DAGFormat.ARROW) + + def __repr__(self): + return f"ArrowTableValue(schema={self.data.schema}, rows={self.data.num_rows})" diff --git a/sdk/python/feast/infra/compute_engines/local/backends/__init__.py b/sdk/python/feast/infra/compute_engines/local/backends/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/sdk/python/feast/infra/compute_engines/local/backends/base.py b/sdk/python/feast/infra/compute_engines/local/backends/base.py new file mode 100644 index 00000000000..3c8d25abe00 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/local/backends/base.py @@ -0,0 +1,79 @@ +from abc import ABC, abstractmethod +from datetime import timedelta + + +class DataFrameBackend(ABC): + """ + Abstract interface for DataFrame operations used by the LocalComputeEngine. + + This interface defines the contract for implementing pluggable DataFrame backends + such as Pandas, Polars, or DuckDB. Each backend must support core table operations + such as joins, filtering, aggregation, conversion to/from Arrow, and deduplication. + + The purpose of this abstraction is to allow seamless swapping of execution backends + without changing DAGNode or ComputeEngine logic. All nodes operate on pyarrow.Table + as the standard input/output format, while the backend defines how the computation + is actually performed. + + Expected implementations include: + - PandasBackend + - PolarsBackend + - DuckDBBackend (future) + + Methods + ------- + from_arrow(table: pa.Table) -> Any + Convert a pyarrow.Table to the backend-native DataFrame format. + + to_arrow(df: Any) -> pa.Table + Convert a backend-native DataFrame to pyarrow.Table. + + join(left: Any, right: Any, on: List[str], how: str) -> Any + Join two dataframes on specified keys with given join type. + + groupby_agg(df: Any, group_keys: List[str], agg_ops: Dict[str, Tuple[str, str]]) -> Any + Group and aggregate the dataframe. `agg_ops` maps output column names + to (aggregation function, source column name) pairs. + + filter(df: Any, expr: str) -> Any + Apply a filter expression (string-based) to the DataFrame. + + to_timedelta_value(delta: timedelta) -> Any + Convert a Python timedelta object to a backend-compatible value + that can be subtracted from a timestamp column. + + drop_duplicates(df: Any, keys: List[str], sort_by: List[str], ascending: bool = False) -> Any + Deduplicate the DataFrame by key columns, keeping the first row + by descending or ascending sort order. + + rename_columns(df: Any, columns: Dict[str, str]) -> Any + Rename columns in the DataFrame according to the provided mapping. + """ + + @abstractmethod + def columns(self, df): ... + + @abstractmethod + def from_arrow(self, table): ... + + @abstractmethod + def join(self, left, right, on, how): ... + + @abstractmethod + def groupby_agg(self, df, group_keys, agg_ops): ... + + @abstractmethod + def filter(self, df, expr): ... + + @abstractmethod + def to_arrow(self, df): ... + + @abstractmethod + def to_timedelta_value(self, delta: timedelta): ... + + @abstractmethod + def drop_duplicates(self, df, keys, sort_by, ascending: bool = False): + pass + + @abstractmethod + def rename_columns(self, df, columns: dict[str, str]): ... diff --git a/sdk/python/feast/infra/compute_engines/local/backends/factory.py b/sdk/python/feast/infra/compute_engines/local/backends/factory.py new file mode 100644 index 00000000000..0a5f40cccf2 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/local/backends/factory.py @@ -0,0 +1,49 @@ +from typing import Optional + +import pandas as pd +import pyarrow + +from feast.infra.compute_engines.local.backends.base import DataFrameBackend +from feast.infra.compute_engines.local.backends.pandas_backend import PandasBackend + + +class BackendFactory: + """ + Factory class for constructing DataFrameBackend implementations based on backend name + or runtime entity_df type. + """ + + @staticmethod + def from_name(name: str) -> DataFrameBackend: + if name == "pandas": + return PandasBackend() + if name == "polars": + return BackendFactory._get_polars_backend() + raise ValueError(f"Unsupported backend name: {name}") + + @staticmethod + def infer_from_entity_df(entity_df) -> Optional[DataFrameBackend]: + if isinstance(entity_df, pyarrow.Table) or isinstance(entity_df, pd.DataFrame): + return PandasBackend() + + if BackendFactory._is_polars(entity_df): + return BackendFactory._get_polars_backend() + return None + + @staticmethod + def _is_polars(entity_df) -> bool: + try: + import polars as pl + except ImportError: + raise ImportError( + "Polars is not installed. Please install it to use Polars backend." + ) + return isinstance(entity_df, pl.DataFrame) + + @staticmethod + def _get_polars_backend(): + from feast.infra.compute_engines.local.backends.polars_backend import ( + PolarsBackend, + ) + + return PolarsBackend() diff --git a/sdk/python/feast/infra/compute_engines/local/backends/pandas_backend.py b/sdk/python/feast/infra/compute_engines/local/backends/pandas_backend.py new file mode 100644 index 00000000000..76ddd688424 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/local/backends/pandas_backend.py @@ -0,0 +1,46 @@ +from datetime import timedelta + +import pandas as pd +import pyarrow as pa + +from feast.infra.compute_engines.local.backends.base import DataFrameBackend + + +class PandasBackend(DataFrameBackend): + def columns(self, df): + return df.columns.tolist() + + def from_arrow(self, table): + return table.to_pandas() + + def join(self, left, right, on, how): + return left.merge(right, on=on, how=how) + + def groupby_agg(self, df, group_keys, agg_ops): + return ( + df.groupby(group_keys) + .agg( + **{ + alias: pd.NamedAgg(column=col, aggfunc=func) + for alias, (func, col) in agg_ops.items() + } + ) + .reset_index() + ) + + def filter(self, df, expr): + return df.query(expr) + + def to_arrow(self, df): + return pa.Table.from_pandas(df) + + def to_timedelta_value(self, delta: timedelta): + return pd.to_timedelta(delta) + + def drop_duplicates(self, df, keys, sort_by, ascending: bool = False): + return df.sort_values(by=sort_by, ascending=ascending).drop_duplicates( + subset=keys + ) + + def rename_columns(self, df, columns: dict[str, str]): + return df.rename(columns=columns) diff --git a/sdk/python/feast/infra/compute_engines/local/backends/polars_backend.py b/sdk/python/feast/infra/compute_engines/local/backends/polars_backend.py new file mode 100644 index 00000000000..5c874aa4433 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/local/backends/polars_backend.py @@ -0,0 +1,47 @@ +from datetime import timedelta + +import polars as pl +import pyarrow as pa + +from feast.infra.compute_engines.local.backends.base import DataFrameBackend + + +class PolarsBackend(DataFrameBackend): + def columns(self, df): + pass + + def from_arrow(self, table: pa.Table) -> pl.DataFrame: + return pl.from_arrow(table) + + def to_arrow(self, df: pl.DataFrame) -> pa.Table: + return df.to_arrow() + + def join(self, left: pl.DataFrame, right: pl.DataFrame, on, how) -> pl.DataFrame: + return left.join(right, on=on, how=how) + + def groupby_agg(self, df: pl.DataFrame, group_keys, agg_ops) -> pl.DataFrame: + agg_exprs = [ + getattr(pl.col(col), func)().alias(alias) + for alias, (func, col) in agg_ops.items() + ] + return df.groupby(group_keys).agg(agg_exprs) + + def filter(self, df: pl.DataFrame, expr: str) -> pl.DataFrame: + return df.filter(pl.sql_expr(expr)) + + def to_timedelta_value(self, delta: timedelta): + return pl.duration(milliseconds=delta.total_seconds() * 1000) + + def drop_duplicates( + self, + df: pl.DataFrame, + keys: list[str], + sort_by: list[str], + ascending: bool = False, + ) -> pl.DataFrame: + return df.sort(by=sort_by, descending=not ascending).unique( + subset=keys, keep="first" + ) + + def rename_columns(self, df: pl.DataFrame, columns: dict[str, str]) -> pl.DataFrame: + return df.rename(columns) diff --git a/sdk/python/feast/infra/compute_engines/local/compute.py b/sdk/python/feast/infra/compute_engines/local/compute.py new file mode 100644 index 00000000000..5b5fa7c06ab --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/local/compute.py @@ -0,0 +1,72 @@ +from typing import Optional + +from feast.infra.common.materialization_job import ( + MaterializationJobStatus, + MaterializationTask, +) +from feast.infra.common.retrieval_task import HistoricalRetrievalTask +from feast.infra.compute_engines.base import ComputeEngine +from feast.infra.compute_engines.dag.context import ExecutionContext +from feast.infra.compute_engines.local.backends.base import DataFrameBackend +from feast.infra.compute_engines.local.backends.factory import BackendFactory +from feast.infra.compute_engines.local.feature_builder import LocalFeatureBuilder +from feast.infra.compute_engines.local.job import LocalRetrievalJob +from feast.infra.materialization.local_engine import LocalMaterializationJob + + +class LocalComputeEngine(ComputeEngine): + def __init__(self, backend: Optional[str] = None, **kwargs): + super().__init__(**kwargs) + self.backend_name = backend + self._backend = BackendFactory.from_name(backend) if backend else None + + def _get_backend(self, context: ExecutionContext) -> DataFrameBackend: + if self._backend: + return self._backend + backend = BackendFactory.infer_from_entity_df(context.entity_df) + if backend is not None: + return backend + raise ValueError("Could not infer backend from context.entity_df") + + def materialize(self, task: MaterializationTask) -> LocalMaterializationJob: + job_id = f"{task.feature_view.name}-{task.start_time}-{task.end_time}" + context = self.get_execution_context(task) + backend = self._get_backend(context) + + try: + builder = LocalFeatureBuilder(task, backend=backend) + plan = builder.build() + plan.execute(context) + return LocalMaterializationJob( + job_id=job_id, + status=MaterializationJobStatus.SUCCEEDED, + ) + + except Exception as e: + return LocalMaterializationJob( + job_id=job_id, + status=MaterializationJobStatus.ERROR, + error=e, + ) + + def get_historical_features( + self, task: HistoricalRetrievalTask + ) -> LocalRetrievalJob: + context = self.get_execution_context(task) + backend = self._get_backend(context) + + try: + builder = LocalFeatureBuilder(task=task, backend=backend) + plan = builder.build() + return LocalRetrievalJob( + plan=plan, + context=context, + full_feature_names=task.full_feature_name, + ) + except Exception as e: + return LocalRetrievalJob( + plan=plan, + context=context, + full_feature_names=task.full_feature_name, + error=e, + ) diff --git a/sdk/python/feast/infra/compute_engines/local/feature_builder.py b/sdk/python/feast/infra/compute_engines/local/feature_builder.py new file mode 100644 index 00000000000..bf755ed96d0 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/local/feature_builder.py @@ -0,0 +1,119 @@ +from typing import Union + +from feast.infra.common.materialization_job import MaterializationTask +from feast.infra.common.retrieval_task import HistoricalRetrievalTask +from feast.infra.compute_engines.dag.plan import ExecutionPlan +from feast.infra.compute_engines.feature_builder import FeatureBuilder +from feast.infra.compute_engines.local.backends.base import DataFrameBackend +from feast.infra.compute_engines.local.nodes import ( + LocalAggregationNode, + LocalDedupNode, + LocalFilterNode, + LocalJoinNode, + LocalOutputNode, + LocalSourceReadNode, + LocalTransformationNode, + LocalValidationNode, +) + + +class LocalFeatureBuilder(FeatureBuilder): + def __init__( + self, + task: Union[MaterializationTask, HistoricalRetrievalTask], + backend: DataFrameBackend, + ): + super().__init__(task) + self.backend = backend + + def build_source_node(self): + node = LocalSourceReadNode("source", self.feature_view, self.task) + self.nodes.append(node) + return node + + def build_join_node(self, input_node): + node = LocalJoinNode("join", self.backend) + node.add_input(input_node) + self.nodes.append(node) + return node + + def build_filter_node(self, input_node): + filter_expr = None + if hasattr(self.feature_view, "filter"): + filter_expr = self.feature_view.filter + ttl = self.feature_view.ttl + node = LocalFilterNode("filter", self.backend, filter_expr, ttl) + node.add_input(input_node) + self.nodes.append(node) + return node + + @staticmethod + def _get_aggregate_operations(agg_specs): + agg_ops = {} + for agg in agg_specs: + if agg.time_window is not None: + raise ValueError( + "Time window aggregation is not supported in local compute engine. Please use a different compute " + "engine." + ) + alias = f"{agg.function}_{agg.column}" + agg_ops[alias] = (agg.function, agg.column) + return agg_ops + + def build_aggregation_node(self, input_node): + agg_specs = self.feature_view.aggregations + agg_ops = self._get_aggregate_operations(agg_specs) + group_by_keys = self.feature_view.entities + node = LocalAggregationNode("agg", self.backend, group_by_keys, agg_ops) + node.add_input(input_node) + self.nodes.append(node) + return node + + def build_dedup_node(self, input_node): + node = LocalDedupNode("dedup", self.backend) + node.add_input(input_node) + self.nodes.append(node) + return node + + def build_transformation_node(self, input_node): + node = LocalTransformationNode( + "transform", self.feature_view.feature_transformation, self.backend + ) + node.add_input(input_node) + self.nodes.append(node) + return node + + def build_validation_node(self, input_node): + node = LocalValidationNode( + "validate", self.feature_view.validation_config, self.backend + ) + node.add_input(input_node) + self.nodes.append(node) + return node + + def build_output_nodes(self, input_node): + node = LocalOutputNode("output") + node.add_input(input_node) + self.nodes.append(node) + + def build(self) -> ExecutionPlan: + last_node = self.build_source_node() + + if isinstance(self.task, HistoricalRetrievalTask): + last_node = self.build_join_node(last_node) + + last_node = self.build_filter_node(last_node) + + if self._should_aggregate(): + last_node = self.build_aggregation_node(last_node) + elif isinstance(self.task, HistoricalRetrievalTask): + last_node = self.build_dedup_node(last_node) + + if self._should_transform(): + last_node = self.build_transformation_node(last_node) + + if self._should_validate(): + last_node = self.build_validation_node(last_node) + + self.build_output_nodes(last_node) + return ExecutionPlan(self.nodes) diff --git a/sdk/python/feast/infra/compute_engines/local/job.py b/sdk/python/feast/infra/compute_engines/local/job.py new file mode 100644 index 00000000000..530bee8d59b --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/local/job.py @@ -0,0 +1,77 @@ +from typing import List, Optional, cast + +import pandas as pd +import pyarrow + +from feast import OnDemandFeatureView +from feast.infra.compute_engines.dag.context import ExecutionContext +from feast.infra.compute_engines.dag.plan import ExecutionPlan +from feast.infra.compute_engines.local.arrow_table_value import ArrowTableValue +from feast.infra.offline_stores.offline_store import RetrievalJob, RetrievalMetadata +from feast.saved_dataset import SavedDatasetStorage + + +class LocalRetrievalJob(RetrievalJob): + def __init__( + self, + plan: Optional[ExecutionPlan], + context: ExecutionContext, + full_feature_names: bool = True, + on_demand_feature_views: Optional[List[OnDemandFeatureView]] = None, + metadata: Optional[RetrievalMetadata] = None, + error: Optional[BaseException] = None, + ): + self._plan = plan + self._context = context + self._arrow_table = None + self._error = error + self._metadata = metadata + self._full_feature_names = full_feature_names + self._on_demand_feature_views = on_demand_feature_views or [] + + def error(self) -> Optional[BaseException]: + return self._error + + def _ensure_executed(self): + if self._arrow_table is None: + result = cast(ArrowTableValue, self._plan.execute(self._context)) + self._arrow_table = result.data + + def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame: + self._ensure_executed() + assert self._arrow_table is not None + return self._arrow_table.to_pandas() + + def _to_arrow_internal(self, timeout: Optional[int] = None) -> pyarrow.Table: + self._ensure_executed() + return self._arrow_table + + @property + def full_feature_names(self) -> bool: + return self._full_feature_names + + @property + def on_demand_feature_views(self) -> List[OnDemandFeatureView]: + return self._on_demand_feature_views + + def persist( + self, + storage: SavedDatasetStorage, + allow_overwrite: bool = False, + timeout: Optional[int] = None, + ): + pass + + @property + def metadata(self) -> Optional[RetrievalMetadata]: + return self._metadata + + def to_remote_storage(self) -> List[str]: + raise NotImplementedError( + "Remote storage is not supported in LocalRetrievalJob" + ) + + def to_sql(self) -> str: + raise NotImplementedError( + "SQL generation is not supported in LocalRetrievalJob" + ) diff --git a/sdk/python/feast/infra/compute_engines/local/local_node.py b/sdk/python/feast/infra/compute_engines/local/local_node.py new file mode 100644 index 00000000000..c9bf9aa5668 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/local/local_node.py @@ -0,0 +1,14 @@ +from abc import ABC +from typing import List, cast + +from feast.infra.compute_engines.dag.context import ExecutionContext +from feast.infra.compute_engines.dag.node import DAGNode +from feast.infra.compute_engines.local.arrow_table_value import ArrowTableValue + + +class LocalNode(DAGNode, ABC): + def get_single_table(self, context: ExecutionContext) -> ArrowTableValue: + return cast(ArrowTableValue, self.get_single_input_value(context)) + + def get_input_tables(self, context: ExecutionContext) -> List[ArrowTableValue]: + return [cast(ArrowTableValue, val) for val in self.get_input_values(context)] diff --git a/sdk/python/feast/infra/compute_engines/local/nodes.py b/sdk/python/feast/infra/compute_engines/local/nodes.py new file mode 100644 index 00000000000..4e1d2c3362f --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/local/nodes.py @@ -0,0 +1,190 @@ +from datetime import timedelta +from typing import Optional + +import pyarrow as pa + +from feast.infra.compute_engines.dag.context import ExecutionContext +from feast.infra.compute_engines.local.arrow_table_value import ArrowTableValue +from feast.infra.compute_engines.local.backends.base import DataFrameBackend +from feast.infra.compute_engines.local.local_node import LocalNode +from feast.infra.offline_stores.offline_utils import ( + infer_event_timestamp_from_entity_df, +) + +ENTITY_TS_ALIAS = "__entity_event_timestamp" + + +class LocalSourceReadNode(LocalNode): + def __init__(self, name: str, feature_view, task): + super().__init__(name) + self.feature_view = feature_view + self.task = task + + def execute(self, context: ExecutionContext) -> ArrowTableValue: + # TODO : Implement the logic to read from offline store + return ArrowTableValue(data=pa.Table.from_pandas(context.entity_df)) + + +class LocalJoinNode(LocalNode): + def __init__(self, name: str, backend: DataFrameBackend): + super().__init__(name) + self.backend = backend + + def execute(self, context: ExecutionContext) -> ArrowTableValue: + feature_table = self.get_single_table(context).data + + if context.entity_df is None: + context.node_outputs[self.name] = feature_table + return feature_table + + entity_table = pa.Table.from_pandas(context.entity_df) + feature_df = self.backend.from_arrow(feature_table) + entity_df = self.backend.from_arrow(entity_table) + + entity_schema = dict(zip(entity_df.columns, entity_df.dtypes)) + entity_df_event_timestamp_col = infer_event_timestamp_from_entity_df( + entity_schema + ) + + join_keys, feature_cols, ts_col, created_ts_col = context.column_info + + entity_df = self.backend.rename_columns( + entity_df, {entity_df_event_timestamp_col: ENTITY_TS_ALIAS} + ) + + joined_df = self.backend.join(feature_df, entity_df, on=join_keys, how="left") + result = self.backend.to_arrow(joined_df) + output = ArrowTableValue(result) + context.node_outputs[self.name] = output + return output + + +class LocalFilterNode(LocalNode): + def __init__( + self, + name: str, + backend: DataFrameBackend, + filter_expr: Optional[str] = None, + ttl: Optional[timedelta] = None, + ): + super().__init__(name) + self.backend = backend + self.filter_expr = filter_expr + self.ttl = ttl + + def execute(self, context: ExecutionContext) -> ArrowTableValue: + input_table = self.get_single_table(context).data + df = self.backend.from_arrow(input_table) + + _, _, ts_col, _ = context.column_info + + if ENTITY_TS_ALIAS in self.backend.columns(df): + # filter where feature.ts <= entity.event_timestamp + df = df[df[ts_col] <= df[ENTITY_TS_ALIAS]] + + # TTL: feature.ts >= entity.event_timestamp - ttl + if self.ttl: + lower_bound = df[ENTITY_TS_ALIAS] - self.backend.to_timedelta_value( + self.ttl + ) + df = df[df[ts_col] >= lower_bound] + + # Optional user-defined filter expression (e.g., "value > 0") + if self.filter_expr: + df = self.backend.filter(df, self.filter_expr) + + result = self.backend.to_arrow(df) + output = ArrowTableValue(result) + context.node_outputs[self.name] = output + return output + + +class LocalAggregationNode(LocalNode): + def __init__( + self, name: str, backend: DataFrameBackend, group_keys: list[str], agg_ops: dict + ): + super().__init__(name) + self.backend = backend + self.group_keys = group_keys + self.agg_ops = agg_ops + + def execute(self, context: ExecutionContext) -> ArrowTableValue: + input_table = self.get_single_table(context).data + df = self.backend.from_arrow(input_table) + grouped_df = self.backend.groupby_agg(df, self.group_keys, self.agg_ops) + result = self.backend.to_arrow(grouped_df) + output = ArrowTableValue(result) + context.node_outputs[self.name] = output + return output + + +class LocalDedupNode(LocalNode): + def __init__(self, name: str, backend: DataFrameBackend): + super().__init__(name) + self.backend = backend + + def execute(self, context: ExecutionContext) -> ArrowTableValue: + input_table = self.get_single_table(context).data + df = self.backend.from_arrow(input_table) + + # Extract join_keys, timestamp, and created_ts from context + join_keys, _, ts_col, created_ts_col = context.column_info + + # Dedup strategy: sort and drop_duplicates + sort_keys = [ts_col] + if created_ts_col: + sort_keys.append(created_ts_col) + + dedup_keys = join_keys + [ENTITY_TS_ALIAS] + df = self.backend.drop_duplicates( + df, keys=dedup_keys, sort_by=sort_keys, ascending=False + ) + result = self.backend.to_arrow(df) + output = ArrowTableValue(result) + context.node_outputs[self.name] = output + return output + + +class LocalTransformationNode(LocalNode): + def __init__(self, name: str, transformation_fn, backend): + super().__init__(name) + self.transformation_fn = transformation_fn + self.backend = backend + + def execute(self, context: ExecutionContext) -> ArrowTableValue: + input_table = self.get_single_table(context).data + df = self.backend.from_arrow(input_table) + transformed_df = self.transformation_fn(df) + result = self.backend.to_arrow(transformed_df) + output = ArrowTableValue(result) + context.node_outputs[self.name] = output + return output + + +class LocalValidationNode(LocalNode): + def __init__(self, name: str, validation_config, backend): + super().__init__(name) + self.validation_config = validation_config + self.backend = backend + + def execute(self, context: ExecutionContext) -> ArrowTableValue: + input_table = self.get_single_table(context).data + df = self.backend.from_arrow(input_table) + # Placeholder for actual validation logic + if self.validation_config: + print(f"[Validation: {self.name}] Passed.") + result = self.backend.to_arrow(df) + output = ArrowTableValue(result) + context.node_outputs[self.name] = output + return output + + +class LocalOutputNode(LocalNode): + def __init__(self, name: str): + super().__init__(name) + + def execute(self, context: ExecutionContext) -> ArrowTableValue: + input_table = self.get_single_table(context).data + context.node_outputs[self.name] = input_table + # TODO: implement the logic to write to offline store + return input_table diff --git a/sdk/python/feast/infra/compute_engines/spark/compute.py b/sdk/python/feast/infra/compute_engines/spark/compute.py index e6e6cc52971..981e786cf7f 100644 --- a/sdk/python/feast/infra/compute_engines/spark/compute.py +++ b/sdk/python/feast/infra/compute_engines/spark/compute.py @@ -1,13 +1,13 @@ -from feast.infra.compute_engines.base import ComputeEngine -from feast.infra.compute_engines.spark.feature_builder import SparkFeatureBuilder -from feast.infra.compute_engines.spark.job import SparkDAGRetrievalJob -from feast.infra.compute_engines.spark.utils import get_or_create_new_spark_session -from feast.infra.compute_engines.tasks import HistoricalRetrievalTask -from feast.infra.materialization.batch_materialization_engine import ( +from feast.infra.common.materialization_job import ( MaterializationJob, MaterializationJobStatus, MaterializationTask, ) +from feast.infra.common.retrieval_task import HistoricalRetrievalTask +from feast.infra.compute_engines.base import ComputeEngine +from feast.infra.compute_engines.spark.feature_builder import SparkFeatureBuilder +from feast.infra.compute_engines.spark.job import SparkDAGRetrievalJob +from feast.infra.compute_engines.spark.utils import get_or_create_new_spark_session from feast.infra.materialization.contrib.spark.spark_materialization_engine import ( SparkMaterializationJob, ) @@ -42,7 +42,6 @@ def materialize(self, task: MaterializationTask) -> MaterializationJob: # ✅ 2. Construct Feature Builder and run it builder = SparkFeatureBuilder( spark_session=self.spark_session, - feature_view=task.feature_view, task=task, ) plan = builder.build() @@ -70,7 +69,6 @@ def get_historical_features(self, task: HistoricalRetrievalTask) -> RetrievalJob # ✅ 2. Construct Feature Builder and run it builder = SparkFeatureBuilder( spark_session=self.spark_session, - feature_view=task.feature_view, task=task, ) plan = builder.build() diff --git a/sdk/python/feast/infra/compute_engines/spark/feature_builder.py b/sdk/python/feast/infra/compute_engines/spark/feature_builder.py index e7efbfe1195..453cee7fda5 100644 --- a/sdk/python/feast/infra/compute_engines/spark/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/spark/feature_builder.py @@ -2,8 +2,8 @@ from pyspark.sql import SparkSession -from feast import BatchFeatureView, FeatureView, StreamFeatureView -from feast.infra.compute_engines.base import HistoricalRetrievalTask +from feast.infra.common.materialization_job import MaterializationTask +from feast.infra.common.retrieval_task import HistoricalRetrievalTask from feast.infra.compute_engines.feature_builder import FeatureBuilder from feast.infra.compute_engines.spark.node import ( SparkAggregationNode, @@ -15,17 +15,15 @@ SparkTransformationNode, SparkWriteNode, ) -from feast.infra.materialization.batch_materialization_engine import MaterializationTask class SparkFeatureBuilder(FeatureBuilder): def __init__( self, spark_session: SparkSession, - feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView], task: Union[MaterializationTask, HistoricalRetrievalTask], ): - super().__init__(feature_view, task) + super().__init__(task) self.spark_session = spark_session def build_source_node(self): @@ -42,17 +40,14 @@ def build_aggregation_node(self, input_node): agg_specs = self.feature_view.aggregations group_by_keys = self.feature_view.entities timestamp_col = self.feature_view.batch_source.timestamp_field - node = SparkAggregationNode( - "agg", input_node, agg_specs, group_by_keys, timestamp_col - ) + node = SparkAggregationNode("agg", agg_specs, group_by_keys, timestamp_col) + node.add_input(input_node) self.nodes.append(node) return node def build_join_node(self, input_node): - join_keys = self.feature_view.entities - node = SparkJoinNode( - "join", input_node, join_keys, self.feature_view, self.spark_session - ) + node = SparkJoinNode("join", self.spark_session) + node.add_input(input_node) self.nodes.append(node) return node @@ -60,23 +55,23 @@ def build_filter_node(self, input_node): filter_expr = None if hasattr(self.feature_view, "filter"): filter_expr = self.feature_view.filter - node = SparkFilterNode( - "filter", self.spark_session, input_node, self.feature_view, filter_expr - ) + ttl = self.feature_view.ttl + node = SparkFilterNode("filter", self.spark_session, ttl, filter_expr) + node.add_input(input_node) self.nodes.append(node) return node def build_dedup_node(self, input_node): - node = SparkDedupNode( - "dedup", input_node, self.feature_view, self.spark_session - ) + node = SparkDedupNode("dedup", self.spark_session) + node.add_input(input_node) self.nodes.append(node) return node def build_transformation_node(self, input_node): udf_name = self.feature_view.feature_transformation.name udf = self.feature_view.feature_transformation.udf - node = SparkTransformationNode(udf_name, input_node, udf) + node = SparkTransformationNode(udf_name, udf) + node.add_input(input_node) self.nodes.append(node) return node diff --git a/sdk/python/feast/infra/compute_engines/spark/node.py b/sdk/python/feast/infra/compute_engines/spark/node.py index e3f737a4fa6..e0215081bcf 100644 --- a/sdk/python/feast/infra/compute_engines/spark/node.py +++ b/sdk/python/feast/infra/compute_engines/spark/node.py @@ -1,18 +1,17 @@ -from dataclasses import dataclass -from datetime import datetime -from typing import Dict, List, Optional, Union, cast +from datetime import timedelta +from typing import List, Optional, Union, cast from pyspark.sql import DataFrame, SparkSession, Window from pyspark.sql import functions as F from feast import BatchFeatureView, StreamFeatureView from feast.aggregation import Aggregation +from feast.infra.common.materialization_job import MaterializationTask +from feast.infra.common.retrieval_task import HistoricalRetrievalTask from feast.infra.compute_engines.dag.context import ExecutionContext from feast.infra.compute_engines.dag.model import DAGFormat from feast.infra.compute_engines.dag.node import DAGNode from feast.infra.compute_engines.dag.value import DAGValue -from feast.infra.compute_engines.tasks import HistoricalRetrievalTask -from feast.infra.materialization.batch_materialization_engine import MaterializationTask from feast.infra.materialization.contrib.spark.spark_materialization_engine import ( _map_by_partition, _SparkSerializedArtifacts, @@ -50,20 +49,6 @@ def rename_entity_ts_column( return entity_df -@dataclass -class SparkJoinContext: - name: str # feature view name or alias - join_keys: List[str] - feature_columns: List[str] - timestamp_field: str - created_timestamp_column: Optional[str] - ttl_seconds: Optional[int] - min_event_timestamp: Optional[datetime] - max_event_timestamp: Optional[datetime] - field_mapping: Dict[str, str] # original_column_name -> renamed_column - full_feature_names: bool = False # apply feature view name prefix - - class SparkMaterializationReadNode(DAGNode): def __init__( self, name: str, task: Union[MaterializationTask, HistoricalRetrievalTask] @@ -179,13 +164,11 @@ class SparkAggregationNode(DAGNode): def __init__( self, name: str, - input_node: DAGNode, aggregations: List[Aggregation], group_by_keys: List[str], timestamp_col: str, ): super().__init__(name) - self.add_input(input_node) self.aggregations = aggregations self.group_by_keys = group_by_keys self.timestamp_col = timestamp_col @@ -233,15 +216,9 @@ class SparkJoinNode(DAGNode): def __init__( self, name: str, - feature_node: DAGNode, - join_keys: List[str], - feature_view: Union[BatchFeatureView, StreamFeatureView], spark_session: SparkSession, ): super().__init__(name) - self.join_keys = join_keys - self.add_input(feature_node) - self.feature_view = feature_view self.spark_session = spark_session def execute(self, context: ExecutionContext) -> DAGValue: @@ -274,14 +251,12 @@ def __init__( self, name: str, spark_session: SparkSession, - input_node: DAGNode, - feature_view: Union[BatchFeatureView, StreamFeatureView], + ttl: Optional[timedelta] = None, filter_condition: Optional[str] = None, ): super().__init__(name) self.spark_session = spark_session - self.feature_view = feature_view - self.add_input(input_node) + self.ttl = ttl self.filter_condition = filter_condition def execute(self, context: ExecutionContext) -> DAGValue: @@ -298,8 +273,8 @@ def execute(self, context: ExecutionContext) -> DAGValue: filtered_df = filtered_df.filter(F.col(ts_col) <= F.col(ENTITY_TS_ALIAS)) # Optional TTL filter: feature.ts >= entity.event_timestamp - ttl - if self.feature_view.ttl: - ttl_seconds = int(self.feature_view.ttl.total_seconds()) + if self.ttl: + ttl_seconds = int(self.ttl.total_seconds()) lower_bound = F.col(ENTITY_TS_ALIAS) - F.expr( f"INTERVAL {ttl_seconds} seconds" ) @@ -320,13 +295,9 @@ class SparkDedupNode(DAGNode): def __init__( self, name: str, - input_node: DAGNode, - feature_view: Union[BatchFeatureView, StreamFeatureView], spark_session: SparkSession, ): super().__init__(name) - self.add_input(input_node) - self.feature_view = feature_view self.spark_session = spark_session def execute(self, context: ExecutionContext) -> DAGValue: @@ -398,9 +369,8 @@ def execute(self, context: ExecutionContext) -> DAGValue: class SparkTransformationNode(DAGNode): - def __init__(self, name: str, input_node: DAGNode, udf): + def __init__(self, name: str, udf): super().__init__(name) - self.add_input(input_node) self.udf = udf def execute(self, context: ExecutionContext) -> DAGValue: diff --git a/sdk/python/feast/infra/materialization/aws_lambda/lambda_engine.py b/sdk/python/feast/infra/materialization/aws_lambda/lambda_engine.py index 2864012055b..d686ba99394 100644 --- a/sdk/python/feast/infra/materialization/aws_lambda/lambda_engine.py +++ b/sdk/python/feast/infra/materialization/aws_lambda/lambda_engine.py @@ -15,12 +15,14 @@ from feast.constants import FEATURE_STORE_YAML_ENV_NAME from feast.entity import Entity from feast.feature_view import FeatureView -from feast.infra.materialization.batch_materialization_engine import ( - BatchMaterializationEngine, +from feast.infra.common.materialization_job import ( MaterializationJob, MaterializationJobStatus, MaterializationTask, ) +from feast.infra.materialization.batch_materialization_engine import ( + BatchMaterializationEngine, +) from feast.infra.offline_stores.offline_store import OfflineStore from feast.infra.online_stores.online_store import OnlineStore from feast.infra.registry.base_registry import BaseRegistry diff --git a/sdk/python/feast/infra/materialization/batch_materialization_engine.py b/sdk/python/feast/infra/materialization/batch_materialization_engine.py index af92b95d175..17bc6134cdb 100644 --- a/sdk/python/feast/infra/materialization/batch_materialization_engine.py +++ b/sdk/python/feast/infra/materialization/batch_materialization_engine.py @@ -1,14 +1,13 @@ -import enum from abc import ABC, abstractmethod -from dataclasses import dataclass -from datetime import datetime -from typing import Callable, List, Optional, Sequence, Union - -from tqdm import tqdm +from typing import List, Sequence, Union from feast.batch_feature_view import BatchFeatureView from feast.entity import Entity from feast.feature_view import FeatureView +from feast.infra.common.materialization_job import ( + MaterializationJob, + MaterializationTask, +) from feast.infra.offline_stores.offline_store import OfflineStore from feast.infra.online_stores.online_store import OnlineStore from feast.infra.registry.base_registry import BaseRegistry @@ -17,54 +16,6 @@ from feast.stream_feature_view import StreamFeatureView -@dataclass -class MaterializationTask: - """ - A MaterializationTask represents a unit of data that needs to be materialized from an - offline store to an online store. - """ - - project: str - feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView] - start_time: datetime - end_time: datetime - tqdm_builder: Callable[[int], tqdm] - - -class MaterializationJobStatus(enum.Enum): - WAITING = 1 - RUNNING = 2 - AVAILABLE = 3 - ERROR = 4 - CANCELLING = 5 - CANCELLED = 6 - SUCCEEDED = 7 - - -class MaterializationJob(ABC): - """ - A MaterializationJob represents an ongoing or executed process that materializes data as per the - definition of a materialization task. - """ - - task: MaterializationTask - - @abstractmethod - def status(self) -> MaterializationJobStatus: ... - - @abstractmethod - def error(self) -> Optional[BaseException]: ... - - @abstractmethod - def should_be_retried(self) -> bool: ... - - @abstractmethod - def job_id(self) -> str: ... - - @abstractmethod - def url(self) -> Optional[str]: ... - - class BatchMaterializationEngine(ABC): """ The interface that Feast uses to control the compute system that handles batch materialization. diff --git a/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py b/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py index 53b29cdfc0f..246297cc1d5 100644 --- a/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py +++ b/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py @@ -10,12 +10,14 @@ from feast.batch_feature_view import BatchFeatureView from feast.entity import Entity from feast.feature_view import FeatureView -from feast.infra.materialization.batch_materialization_engine import ( - BatchMaterializationEngine, +from feast.infra.common.materialization_job import ( MaterializationJob, MaterializationJobStatus, MaterializationTask, ) +from feast.infra.materialization.batch_materialization_engine import ( + BatchMaterializationEngine, +) from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( SparkOfflineStore, SparkRetrievalJob, diff --git a/sdk/python/feast/infra/materialization/kubernetes/k8s_materialization_engine.py b/sdk/python/feast/infra/materialization/kubernetes/k8s_materialization_engine.py index 96064409459..adf14eaf419 100644 --- a/sdk/python/feast/infra/materialization/kubernetes/k8s_materialization_engine.py +++ b/sdk/python/feast/infra/materialization/kubernetes/k8s_materialization_engine.py @@ -15,12 +15,14 @@ from feast import FeatureView, RepoConfig from feast.batch_feature_view import BatchFeatureView from feast.entity import Entity -from feast.infra.materialization.batch_materialization_engine import ( - BatchMaterializationEngine, +from feast.infra.common.materialization_job import ( MaterializationJob, MaterializationJobStatus, MaterializationTask, ) +from feast.infra.materialization.batch_materialization_engine import ( + BatchMaterializationEngine, +) from feast.infra.offline_stores.offline_store import OfflineStore from feast.infra.online_stores.online_store import OnlineStore from feast.infra.registry.base_registry import BaseRegistry diff --git a/sdk/python/feast/infra/materialization/kubernetes/k8s_materialization_job.py b/sdk/python/feast/infra/materialization/kubernetes/k8s_materialization_job.py index 612b20155d4..2e46d2ad49d 100644 --- a/sdk/python/feast/infra/materialization/kubernetes/k8s_materialization_job.py +++ b/sdk/python/feast/infra/materialization/kubernetes/k8s_materialization_job.py @@ -2,7 +2,7 @@ from kubernetes import client -from feast.infra.materialization.batch_materialization_engine import ( +from feast.infra.common.materialization_job import ( MaterializationJob, MaterializationJobStatus, ) diff --git a/sdk/python/feast/infra/materialization/kubernetes/k8s_materialization_task.py b/sdk/python/feast/infra/materialization/kubernetes/k8s_materialization_task.py index 607dcb5b260..0372f162f02 100644 --- a/sdk/python/feast/infra/materialization/kubernetes/k8s_materialization_task.py +++ b/sdk/python/feast/infra/materialization/kubernetes/k8s_materialization_task.py @@ -1,4 +1,4 @@ -from feast.infra.materialization.batch_materialization_engine import MaterializationTask +from feast.infra.common.materialization_job import MaterializationTask class KubernetesMaterializationTask(MaterializationTask): diff --git a/sdk/python/feast/infra/materialization/local_engine.py b/sdk/python/feast/infra/materialization/local_engine.py index fa60950f298..ed71d11586d 100644 --- a/sdk/python/feast/infra/materialization/local_engine.py +++ b/sdk/python/feast/infra/materialization/local_engine.py @@ -7,6 +7,11 @@ from feast.batch_feature_view import BatchFeatureView from feast.entity import Entity from feast.feature_view import FeatureView +from feast.infra.common.materialization_job import ( + MaterializationJob, + MaterializationJobStatus, + MaterializationTask, +) from feast.infra.offline_stores.offline_store import OfflineStore from feast.infra.online_stores.online_store import OnlineStore from feast.infra.registry.base_registry import BaseRegistry @@ -21,9 +26,6 @@ from .batch_materialization_engine import ( BatchMaterializationEngine, - MaterializationJob, - MaterializationJobStatus, - MaterializationTask, ) DEFAULT_BATCH_SIZE = 10_000 diff --git a/sdk/python/feast/infra/materialization/snowflake_engine.py b/sdk/python/feast/infra/materialization/snowflake_engine.py index 2b18515ae44..9c535c334e3 100644 --- a/sdk/python/feast/infra/materialization/snowflake_engine.py +++ b/sdk/python/feast/infra/materialization/snowflake_engine.py @@ -14,12 +14,14 @@ from feast.batch_feature_view import BatchFeatureView from feast.entity import Entity from feast.feature_view import DUMMY_ENTITY_ID, FeatureView -from feast.infra.materialization.batch_materialization_engine import ( - BatchMaterializationEngine, +from feast.infra.common.materialization_job import ( MaterializationJob, MaterializationJobStatus, MaterializationTask, ) +from feast.infra.materialization.batch_materialization_engine import ( + BatchMaterializationEngine, +) from feast.infra.offline_stores.offline_store import OfflineStore from feast.infra.online_stores.online_store import OnlineStore from feast.infra.registry.base_registry import BaseRegistry diff --git a/sdk/python/feast/infra/passthrough_provider.py b/sdk/python/feast/infra/passthrough_provider.py index f5df0f2eb1a..b30e695de52 100644 --- a/sdk/python/feast/infra/passthrough_provider.py +++ b/sdk/python/feast/infra/passthrough_provider.py @@ -24,11 +24,13 @@ from feast.feature_logging import FeatureServiceLoggingSource from feast.feature_service import FeatureService from feast.feature_view import FeatureView +from feast.infra.common.materialization_job import ( + MaterializationJobStatus, + MaterializationTask, +) from feast.infra.infra_object import Infra, InfraObject from feast.infra.materialization.batch_materialization_engine import ( BatchMaterializationEngine, - MaterializationJobStatus, - MaterializationTask, ) from feast.infra.offline_stores.offline_store import RetrievalJob from feast.infra.offline_stores.offline_utils import get_offline_store_from_config diff --git a/sdk/python/tests/integration/compute_engines/spark/test_compute.py b/sdk/python/tests/integration/compute_engines/spark/test_compute.py index b8046c12296..c6aef9e5701 100644 --- a/sdk/python/tests/integration/compute_engines/spark/test_compute.py +++ b/sdk/python/tests/integration/compute_engines/spark/test_compute.py @@ -10,13 +10,13 @@ from feast import BatchFeatureView, Entity, Field from feast.aggregation import Aggregation from feast.data_source import DataSource -from feast.infra.compute_engines.spark.compute import SparkComputeEngine -from feast.infra.compute_engines.spark.job import SparkDAGRetrievalJob -from feast.infra.compute_engines.tasks import HistoricalRetrievalTask -from feast.infra.materialization.batch_materialization_engine import ( +from feast.infra.common.materialization_job import ( MaterializationJobStatus, MaterializationTask, ) +from feast.infra.common.retrieval_task import HistoricalRetrievalTask +from feast.infra.compute_engines.spark.compute import SparkComputeEngine +from feast.infra.compute_engines.spark.job import SparkDAGRetrievalJob from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( SparkOfflineStore, ) diff --git a/sdk/python/tests/unit/infra/compute_engines/local/__init__.py b/sdk/python/tests/unit/infra/compute_engines/local/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/sdk/python/tests/unit/infra/compute_engines/local/test_nodes.py b/sdk/python/tests/unit/infra/compute_engines/local/test_nodes.py new file mode 100644 index 00000000000..d5f6085a67c --- /dev/null +++ b/sdk/python/tests/unit/infra/compute_engines/local/test_nodes.py @@ -0,0 +1,201 @@ +from datetime import timedelta +from unittest.mock import MagicMock + +import pandas as pd +import pyarrow as pa + +from feast.infra.compute_engines.dag.context import ColumnInfo, ExecutionContext +from feast.infra.compute_engines.local.arrow_table_value import ArrowTableValue +from feast.infra.compute_engines.local.backends.pandas_backend import PandasBackend +from feast.infra.compute_engines.local.nodes import ( + LocalAggregationNode, + LocalDedupNode, + LocalFilterNode, + LocalJoinNode, + LocalOutputNode, + LocalTransformationNode, +) + +backend = PandasBackend() +now = pd.Timestamp.utcnow() + +sample_df = pd.DataFrame( + { + "entity_id": [1, 1, 2, 2], + "value": [10, 20, 30, 40], + "event_timestamp": [ + now, + now - timedelta(minutes=1), + now, + now - timedelta(minutes=5), + ], + } +) + +entity_df = pd.DataFrame({"entity_id": [1, 2], "event_timestamp": [now, now]}) + + +def create_context(node_outputs): + # Setup execution context + return ExecutionContext( + project="test_proj", + repo_config=MagicMock(), + offline_store=MagicMock(), + online_store=MagicMock(), + entity_defs=MagicMock(), + entity_df=entity_df, + node_outputs=node_outputs, + column_info=ColumnInfo( + join_keys=["entity_id"], + feature_cols=["value"], + ts_col="event_timestamp", + created_ts_col=None, + ), + ) + + +def test_local_filter_node(): + context = create_context( + node_outputs={"source": ArrowTableValue(pa.Table.from_pandas(sample_df))} + ) + + # Create filter node and connect input + filter_node = LocalFilterNode( + name="filter", + backend=backend, + filter_expr="value > 15", + ) + filter_node.add_input(MagicMock()) + filter_node.inputs[0].name = "source" + + # Execute and validate + result = filter_node.execute(context) + assert isinstance(result, ArrowTableValue) + assert result.data.num_rows == 3 + + +def test_local_aggregation_node(): + context = create_context( + node_outputs={"source": ArrowTableValue(pa.Table.from_pandas(sample_df))} + ) + + # Create aggregation node and connect input + agg_ops = { + "sum_value": ("sum", "value"), + } + agg_node = LocalAggregationNode( + name="agg", + backend=backend, + group_keys=["entity_id"], + agg_ops=agg_ops, + ) + agg_node.add_input(MagicMock()) + agg_node.inputs[0].name = "source" + + # Execute and validate + result = agg_node.execute(context) + assert isinstance(result, ArrowTableValue) + assert result.data.num_rows == 2 + result_df = result.data.to_pandas() + assert result_df["sum_value"].iloc[0] == 30 + assert result_df["sum_value"].iloc[1] == 70 + + +def test_local_join_node(): + context = create_context( + node_outputs={"source": ArrowTableValue(pa.Table.from_pandas(sample_df))} + ) + + # Create join node and connect input + join_node = LocalJoinNode( + name="join", + backend=backend, + ) + join_node.add_input(MagicMock()) + join_node.inputs[0].name = "source" + + # Execute and validate + result = join_node.execute(context) + assert isinstance(result, ArrowTableValue) + assert result.data.num_rows == 4 + result_df = result.data.to_pandas() + assert all(result_df["entity_id"].isin([1, 2])) + assert "__entity_event_timestamp" in result_df.columns + + +def test_local_dedup_node(): + # Duplicate rows for each entity with different event and created timestamps + df = pd.DataFrame( + { + "entity_id": [1, 1, 2, 2], + "value": [100, 200, 300, 400], + "event_timestamp": [ + now - timedelta(seconds=1), + now, + now - timedelta(seconds=1), + now, + ], + "created_ts": [ + now - timedelta(seconds=1), + now, + now, + now - timedelta(seconds=2), + ], + "__entity_event_timestamp": [ + now, + now, + now - timedelta(seconds=1), + now - timedelta(seconds=1), + ], + } + ) + + # Register DataFrame in context + table = pa.Table.from_pandas(df) + context = create_context(node_outputs={"source": ArrowTableValue(table)}) + context.entity_timestamp_col = "event_timestamp" + + # Build node + node = LocalDedupNode(name="dedup", backend=backend) + node.add_input(MagicMock()) + node.inputs[0].name = "source" + + result = node.execute(context) + + # Validate: only latest row per entity remains + df_result = result.data.to_pandas() + assert df_result.shape[0] == 2 + assert set(df_result["entity_id"]) == {1, 2} + + +def test_local_transformation_node(): + context = create_context( + node_outputs={"source": ArrowTableValue(pa.Table.from_pandas(sample_df))} + ) + + # Create transformation node and connect input + transform_node = LocalTransformationNode( + name="transform", + backend=backend, + transformation_fn=lambda df: df.assign(value=df["value"] * 2), + ) + transform_node.add_input(MagicMock()) + transform_node.inputs[0].name = "source" + + # Execute and validate + result = transform_node.execute(context) + assert isinstance(result, ArrowTableValue) + assert result.data.num_rows == 4 + result_df = result.data.to_pandas() + assert all(result_df["value"] == sample_df["value"] * 2) + + +def test_local_output_node(): + context = create_context( + node_outputs={"source": ArrowTableValue(pa.Table.from_pandas(sample_df))} + ) + node = LocalOutputNode("output") + node.add_input(MagicMock()) + node.inputs[0].name = "source" + result = node.execute(context) + assert result.num_rows == 4 diff --git a/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py b/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py index afeea82008a..ae69c0a6fcd 100644 --- a/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py +++ b/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py @@ -16,7 +16,6 @@ ) from tests.example_repos.example_feature_repo_with_bfvs import ( driver, - driver_hourly_stats_view, ) @@ -69,10 +68,8 @@ def strip_extra_spaces(df): ) # Create and run the node - node = SparkTransformationNode( - "transform", input_node=MagicMock(), udf=strip_extra_spaces - ) - + node = SparkTransformationNode("transform", udf=strip_extra_spaces) + node.add_input(MagicMock()) node.inputs[0].name = "source" result = node.execute(context) @@ -119,11 +116,11 @@ def test_spark_aggregation_node_executes_correctly(spark_session): # Create and configure node node = SparkAggregationNode( name="agg", - input_node=MagicMock(), aggregations=agg_specs, group_by_keys=["user_id"], timestamp_col="", ) + node.add_input(MagicMock()) node.inputs[0].name = "source" # Execute @@ -182,9 +179,6 @@ def test_spark_join_node_executes_point_in_time_join(spark_session): # Wrap as DAGValues feature_val = DAGValue(data=feature_df, format=DAGFormat.SPARK) - # Setup FeatureView mock with batch_source metadata - feature_view = driver_hourly_stats_view - # Set up context context = ExecutionContext( project="test_project", @@ -207,12 +201,10 @@ def test_spark_join_node_executes_point_in_time_join(spark_session): # Create the node and add input join_node = SparkJoinNode( name="join", - feature_node=MagicMock(name="feature_node"), - join_keys=["user_id"], - feature_view=feature_view, spark_session=spark_session, ) - join_node.inputs[0].name = "feature_node" # must match key in node_outputs + join_node.add_input(MagicMock()) + join_node.inputs[0].name = "feature_node" # Execute the node output = join_node.execute(context) @@ -220,11 +212,10 @@ def test_spark_join_node_executes_point_in_time_join(spark_session): dedup_node = SparkDedupNode( name="dedup", - input_node=join_node, - feature_view=feature_view, spark_session=spark_session, ) - dedup_node.inputs[0].name = "join" # must match key in node_outputs + dedup_node.add_input(MagicMock()) + dedup_node.inputs[0].name = "join" dedup_output = dedup_node.execute(context) result_df = dedup_output.data.orderBy("driver_id").collect()