From ce6b022f5e4e9e82da47d2373f77d36146c78fc4 Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Wed, 16 Apr 2025 12:11:25 -0700 Subject: [PATCH 1/9] Create Local Compute Engine Signed-off-by: HaoXuAI --- .../infra/compute_engines/feature_builder.py | 4 +- .../infra/compute_engines/local/__init__.py | 0 .../local/arrow_table_value.py | 12 ++ .../local/backends/__init__.py | 0 .../compute_engines/local/backends/base.py | 29 +++ .../compute_engines/local/backends/factory.py | 44 +++++ .../local/backends/pandas_backend.py | 34 ++++ .../local/backends/polars_backend.py | 44 +++++ .../infra/compute_engines/local/compute.py | 72 +++++++ .../infra/compute_engines/local/config.py | 20 ++ .../compute_engines/local/feature_builder.py | 113 +++++++++++ .../feast/infra/compute_engines/local/job.py | 77 ++++++++ .../infra/compute_engines/local/local_node.py | 14 ++ .../feast/infra/compute_engines/local/node.py | 176 ++++++++++++++++++ .../infra/compute_engines/spark/compute.py | 2 - .../compute_engines/spark/feature_builder.py | 26 ++- .../feast/infra/compute_engines/spark/node.py | 17 +- .../infra/compute_engines/spark/test_nodes.py | 21 +-- 18 files changed, 655 insertions(+), 50 deletions(-) create mode 100644 sdk/python/feast/infra/compute_engines/local/__init__.py create mode 100644 sdk/python/feast/infra/compute_engines/local/arrow_table_value.py create mode 100644 sdk/python/feast/infra/compute_engines/local/backends/__init__.py create mode 100644 sdk/python/feast/infra/compute_engines/local/backends/base.py create mode 100644 sdk/python/feast/infra/compute_engines/local/backends/factory.py create mode 100644 sdk/python/feast/infra/compute_engines/local/backends/pandas_backend.py create mode 100644 sdk/python/feast/infra/compute_engines/local/backends/polars_backend.py create mode 100644 sdk/python/feast/infra/compute_engines/local/compute.py create mode 100644 sdk/python/feast/infra/compute_engines/local/config.py create mode 100644 sdk/python/feast/infra/compute_engines/local/feature_builder.py create mode 100644 sdk/python/feast/infra/compute_engines/local/job.py create mode 100644 sdk/python/feast/infra/compute_engines/local/local_node.py create mode 100644 sdk/python/feast/infra/compute_engines/local/node.py diff --git a/sdk/python/feast/infra/compute_engines/feature_builder.py b/sdk/python/feast/infra/compute_engines/feature_builder.py index cab32d47d26..927d4daf2a4 100644 --- a/sdk/python/feast/infra/compute_engines/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/feature_builder.py @@ -1,7 +1,6 @@ from abc import ABC, abstractmethod from typing import Union -from feast import BatchFeatureView, FeatureView, StreamFeatureView from feast.infra.compute_engines.dag.node import DAGNode from feast.infra.compute_engines.dag.plan import ExecutionPlan from feast.infra.compute_engines.tasks import HistoricalRetrievalTask @@ -16,10 +15,9 @@ class FeatureBuilder(ABC): def __init__( self, - feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView], task: Union[MaterializationTask, HistoricalRetrievalTask], ): - self.feature_view = feature_view + self.feature_view = task.feature_view self.task = task self.nodes: list[DAGNode] = [] diff --git a/sdk/python/feast/infra/compute_engines/local/__init__.py b/sdk/python/feast/infra/compute_engines/local/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/sdk/python/feast/infra/compute_engines/local/arrow_table_value.py b/sdk/python/feast/infra/compute_engines/local/arrow_table_value.py new file mode 100644 index 00000000000..52315ac7d4b --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/local/arrow_table_value.py @@ -0,0 +1,12 @@ +import pyarrow as pa +from infra.compute_engines.dag.model import DAGFormat + +from feast.infra.compute_engines.dag.value import DAGValue + + +class ArrowTableValue(DAGValue): + def __init__(self, data: pa.Table): + super().__init__(data, DAGFormat.ARROW) + + def __repr__(self): + return f"ArrowTableValue(schema={self.data.schema}, rows={self.data.num_rows})" diff --git a/sdk/python/feast/infra/compute_engines/local/backends/__init__.py b/sdk/python/feast/infra/compute_engines/local/backends/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/sdk/python/feast/infra/compute_engines/local/backends/base.py b/sdk/python/feast/infra/compute_engines/local/backends/base.py new file mode 100644 index 00000000000..6fde7c8b0d7 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/local/backends/base.py @@ -0,0 +1,29 @@ +from abc import ABC, abstractmethod +from datetime import timedelta + + +class DataFrameBackend(ABC): + @abstractmethod + def columns(self, df): ... + + @abstractmethod + def from_arrow(self, table): ... + + @abstractmethod + def join(self, left, right, on, how): ... + + @abstractmethod + def groupby_agg(self, df, group_keys, agg_ops): ... + + @abstractmethod + def filter(self, df, expr): ... + + @abstractmethod + def to_arrow(self, df): ... + + @abstractmethod + def to_timedelta_value(self, delta: timedelta): ... + + @abstractmethod + def drop_duplicates(self, df, keys, sort_by, ascending: bool = False): + pass diff --git a/sdk/python/feast/infra/compute_engines/local/backends/factory.py b/sdk/python/feast/infra/compute_engines/local/backends/factory.py new file mode 100644 index 00000000000..c34d64237ea --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/local/backends/factory.py @@ -0,0 +1,44 @@ +from typing import Optional + +import pandas as pd +import pyarrow + +from feast.infra.compute_engines.local.backends.base import DataFrameBackend +from feast.infra.compute_engines.local.backends.pandas_backend import PandasBackend + + +class BackendFactory: + @staticmethod + def from_name(name: str) -> DataFrameBackend: + if name == "pandas": + return PandasBackend() + if name == "polars": + return BackendFactory._get_polars_backend() + raise ValueError(f"Unsupported backend name: {name}") + + @staticmethod + def infer_from_entity_df(entity_df) -> Optional[DataFrameBackend]: + if isinstance(entity_df, pyarrow.Table) or isinstance(entity_df, pd.DataFrame): + return PandasBackend() + + if BackendFactory._is_polars(entity_df): + return BackendFactory._get_polars_backend() + return None + + @staticmethod + def _is_polars(entity_df) -> bool: + try: + import polars as pl + except ImportError: + raise ImportError( + "Polars is not installed. Please install it to use Polars backend." + ) + return isinstance(entity_df, pl.DataFrame) + + @staticmethod + def _get_polars_backend(): + from feast.infra.compute_engines.local.backends.polars_backend import ( + PolarsBackend, + ) + + return PolarsBackend() diff --git a/sdk/python/feast/infra/compute_engines/local/backends/pandas_backend.py b/sdk/python/feast/infra/compute_engines/local/backends/pandas_backend.py new file mode 100644 index 00000000000..cf67d46e70e --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/local/backends/pandas_backend.py @@ -0,0 +1,34 @@ +from datetime import timedelta + +import pandas as pd +import pyarrow as pa + +from feast.infra.compute_engines.local.backends.base import DataFrameBackend + + +class PandasBackend(DataFrameBackend): + def columns(self, df): + return df.columns.tolist() + + def from_arrow(self, table): + return table.to_pandas() + + def join(self, left, right, on, how): + return left.merge(right, on=on, how=how) + + def groupby_agg(self, df, group_keys, agg_ops): + return df.groupby(group_keys).agg(agg_ops).reset_index() + + def filter(self, df, expr): + return df.query(expr) + + def to_arrow(self, df): + return pa.Table.from_pandas(df) + + def to_timedelta_value(self, delta: timedelta): + return pd.to_timedelta(delta) + + def drop_duplicates(self, df, keys, sort_by, ascending: bool = False): + return df.sort_values(by=sort_by, ascending=ascending).drop_duplicates( + subset=keys + ) diff --git a/sdk/python/feast/infra/compute_engines/local/backends/polars_backend.py b/sdk/python/feast/infra/compute_engines/local/backends/polars_backend.py new file mode 100644 index 00000000000..bac780d7d47 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/local/backends/polars_backend.py @@ -0,0 +1,44 @@ +from datetime import timedelta + +import polars as pl +import pyarrow as pa + +from feast.infra.compute_engines.local.backends.base import DataFrameBackend + + +class PolarsBackend(DataFrameBackend): + def columns(self, df): + pass + + def from_arrow(self, table: pa.Table) -> pl.DataFrame: + return pl.from_arrow(table) + + def to_arrow(self, df: pl.DataFrame) -> pa.Table: + return df.to_arrow() + + def join(self, left: pl.DataFrame, right: pl.DataFrame, on, how) -> pl.DataFrame: + return left.join(right, on=on, how=how) + + def groupby_agg(self, df: pl.DataFrame, group_keys, agg_ops) -> pl.DataFrame: + agg_exprs = [ + getattr(pl.col(col), func)().alias(alias) + for alias, (func, col) in agg_ops.items() + ] + return df.groupby(group_keys).agg(agg_exprs) + + def filter(self, df: pl.DataFrame, expr: str) -> pl.DataFrame: + return df.filter(pl.sql_expr(expr)) + + def to_timedelta_value(self, delta: timedelta): + return pl.duration(milliseconds=delta.total_seconds() * 1000) + + def drop_duplicates( + self, + df: pl.DataFrame, + keys: list[str], + sort_by: list[str], + ascending: bool = False, + ) -> pl.DataFrame: + return df.sort(by=sort_by, descending=not ascending).unique( + subset=keys, keep="first" + ) diff --git a/sdk/python/feast/infra/compute_engines/local/compute.py b/sdk/python/feast/infra/compute_engines/local/compute.py new file mode 100644 index 00000000000..787439c84ea --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/local/compute.py @@ -0,0 +1,72 @@ +from typing import Optional + +from feast.infra.compute_engines.base import ComputeEngine +from feast.infra.compute_engines.dag.context import ExecutionContext +from feast.infra.compute_engines.local.backends.base import DataFrameBackend +from feast.infra.compute_engines.local.backends.factory import BackendFactory +from feast.infra.compute_engines.local.feature_builder import LocalFeatureBuilder +from feast.infra.compute_engines.local.job import LocalRetrievalJob +from feast.infra.compute_engines.tasks import HistoricalRetrievalTask +from feast.infra.materialization.batch_materialization_engine import ( + MaterializationJobStatus, + MaterializationTask, +) +from feast.infra.materialization.local_engine import LocalMaterializationJob + + +class LocalComputeEngine(ComputeEngine): + def __init__(self, backend: Optional[str] = None, **kwargs): + super().__init__(**kwargs) + self.backend_name = backend + self._backend = BackendFactory.from_name(backend) if backend else None + + def _get_backend(self, context: ExecutionContext) -> DataFrameBackend: + if self._backend: + return self._backend + backend = BackendFactory.infer_from_entity_df(context.entity_df) + if backend is not None: + return backend + raise ValueError("Could not infer backend from context.entity_df") + + def materialize(self, task: MaterializationTask) -> LocalMaterializationJob: + job_id = f"{task.feature_view.name}-{task.start_time}-{task.end_time}" + context = self.get_execution_context(task) + backend = self._get_backend(context) + + try: + builder = LocalFeatureBuilder(task, backend=backend) + plan = builder.build() + plan.execute(context) + return LocalMaterializationJob( + job_id=job_id, + status=MaterializationJobStatus.SUCCEEDED, + ) + + except Exception as e: + return LocalMaterializationJob( + job_id=job_id, + status=MaterializationJobStatus.ERROR, + error=e, + ) + + def get_historical_features( + self, task: HistoricalRetrievalTask + ) -> LocalRetrievalJob: + context = self.get_execution_context(task) + backend = self._get_backend(context) + + try: + builder = LocalFeatureBuilder(task=task, backend=backend) + plan = builder.build() + return LocalRetrievalJob( + plan=plan, + context=context, + full_feature_names=task.full_feature_name, + ) + except Exception as e: + return LocalRetrievalJob( + plan=plan, + context=context, + full_feature_names=task.full_feature_name, + error=e, + ) diff --git a/sdk/python/feast/infra/compute_engines/local/config.py b/sdk/python/feast/infra/compute_engines/local/config.py new file mode 100644 index 00000000000..070cf204dce --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/local/config.py @@ -0,0 +1,20 @@ +from typing import Dict, Optional + +from pydantic import StrictStr + +from feast.repo_config import FeastConfigBaseModel + + +class SparkComputeConfig(FeastConfigBaseModel): + type: StrictStr = "spark" + """ Spark Compute type selector""" + + spark_conf: Optional[Dict[str, str]] = None + """ Configuration overlay for the spark session """ + # sparksession is not serializable and we dont want to pass it around as an argument + + staging_location: Optional[StrictStr] = None + """ Remote path for batch materialization jobs""" + + region: Optional[StrictStr] = None + """ AWS Region if applicable for s3-based staging locations""" diff --git a/sdk/python/feast/infra/compute_engines/local/feature_builder.py b/sdk/python/feast/infra/compute_engines/local/feature_builder.py new file mode 100644 index 00000000000..5006e97163a --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/local/feature_builder.py @@ -0,0 +1,113 @@ +from typing import Union + +from feast.infra.compute_engines.dag.plan import ExecutionPlan +from feast.infra.compute_engines.feature_builder import FeatureBuilder +from feast.infra.compute_engines.local.backends.base import DataFrameBackend +from feast.infra.compute_engines.local.node import ( + LocalAggregationNode, + LocalDedupNode, + LocalFilterNode, + LocalJoinNode, + LocalOutputNode, + LocalSourceReadNode, + LocalTransformationNode, + LocalValidationNode, +) +from feast.infra.compute_engines.tasks import HistoricalRetrievalTask +from feast.infra.materialization.batch_materialization_engine import MaterializationTask + + +class LocalFeatureBuilder(FeatureBuilder): + def __init__( + self, + task: Union[MaterializationTask, HistoricalRetrievalTask], + backend: DataFrameBackend, + ): + super().__init__(task) + self.backend = backend + + def build_source_node(self): + node = LocalSourceReadNode("source", self.feature_view, self.task) + self.nodes.append(node) + return node + + def build_join_node(self, input_node): + node = LocalJoinNode("join", self.backend) + node.add_input(input_node) + self.nodes.append(node) + return node + + def build_filter_node(self, input_node): + filter_expr = None + if hasattr(self.feature_view, "filter"): + filter_expr = self.feature_view.filter + ttl = self.feature_view.ttl + node = LocalFilterNode("filter", self.backend, filter_expr, ttl) + node.add_input(input_node) + self.nodes.append(node) + return node + + def build_aggregation_node(self, input_node): + agg_specs = self.feature_view.aggregations + agg_ops = {} + for agg in agg_specs: + if agg.time_window is not None: + raise ValueError( + "Time window aggregation is not supported in local compute engine. Please use a different compute engine." + ) + alias = f"{agg.function}_{agg.column}" + agg_ops[alias] = (agg.function, agg.column) + group_by_keys = self.feature_view.entities + node = LocalAggregationNode("agg", group_by_keys, agg_ops, self.backend) + node.add_input(input_node) + self.nodes.append(node) + return node + + def build_dedup_node(self, input_node): + node = LocalDedupNode("dedup", self.backend) + node.add_input(input_node) + self.nodes.append(node) + return node + + def build_transformation_node(self, input_node): + node = LocalTransformationNode( + "transform", self.feature_view.feature_transformation, self.backend + ) + node.add_input(input_node) + self.nodes.append(node) + return node + + def build_validation_node(self, input_node): + node = LocalValidationNode( + "validate", self.feature_view.validation_config, self.backend + ) + node.add_input(input_node) + self.nodes.append(node) + return node + + def build_output_nodes(self, input_node): + node = LocalOutputNode("output") + node.add_input(input_node) + self.nodes.append(node) + + def build(self) -> ExecutionPlan: + last_node = self.build_source_node() + + if isinstance(self.task, HistoricalRetrievalTask): + last_node = self.build_join_node(last_node) + + last_node = self.build_filter_node(last_node) + + if self._should_aggregate(): + last_node = self.build_aggregation_node(last_node) + elif isinstance(self.task, HistoricalRetrievalTask): + last_node = self.build_dedup_node(last_node) + + if self._should_transform(): + last_node = self.build_transformation_node(last_node) + + if self._should_validate(): + last_node = self.build_validation_node(last_node) + + self.build_output_nodes(last_node) + return ExecutionPlan(self.nodes) diff --git a/sdk/python/feast/infra/compute_engines/local/job.py b/sdk/python/feast/infra/compute_engines/local/job.py new file mode 100644 index 00000000000..530bee8d59b --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/local/job.py @@ -0,0 +1,77 @@ +from typing import List, Optional, cast + +import pandas as pd +import pyarrow + +from feast import OnDemandFeatureView +from feast.infra.compute_engines.dag.context import ExecutionContext +from feast.infra.compute_engines.dag.plan import ExecutionPlan +from feast.infra.compute_engines.local.arrow_table_value import ArrowTableValue +from feast.infra.offline_stores.offline_store import RetrievalJob, RetrievalMetadata +from feast.saved_dataset import SavedDatasetStorage + + +class LocalRetrievalJob(RetrievalJob): + def __init__( + self, + plan: Optional[ExecutionPlan], + context: ExecutionContext, + full_feature_names: bool = True, + on_demand_feature_views: Optional[List[OnDemandFeatureView]] = None, + metadata: Optional[RetrievalMetadata] = None, + error: Optional[BaseException] = None, + ): + self._plan = plan + self._context = context + self._arrow_table = None + self._error = error + self._metadata = metadata + self._full_feature_names = full_feature_names + self._on_demand_feature_views = on_demand_feature_views or [] + + def error(self) -> Optional[BaseException]: + return self._error + + def _ensure_executed(self): + if self._arrow_table is None: + result = cast(ArrowTableValue, self._plan.execute(self._context)) + self._arrow_table = result.data + + def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame: + self._ensure_executed() + assert self._arrow_table is not None + return self._arrow_table.to_pandas() + + def _to_arrow_internal(self, timeout: Optional[int] = None) -> pyarrow.Table: + self._ensure_executed() + return self._arrow_table + + @property + def full_feature_names(self) -> bool: + return self._full_feature_names + + @property + def on_demand_feature_views(self) -> List[OnDemandFeatureView]: + return self._on_demand_feature_views + + def persist( + self, + storage: SavedDatasetStorage, + allow_overwrite: bool = False, + timeout: Optional[int] = None, + ): + pass + + @property + def metadata(self) -> Optional[RetrievalMetadata]: + return self._metadata + + def to_remote_storage(self) -> List[str]: + raise NotImplementedError( + "Remote storage is not supported in LocalRetrievalJob" + ) + + def to_sql(self) -> str: + raise NotImplementedError( + "SQL generation is not supported in LocalRetrievalJob" + ) diff --git a/sdk/python/feast/infra/compute_engines/local/local_node.py b/sdk/python/feast/infra/compute_engines/local/local_node.py new file mode 100644 index 00000000000..c9bf9aa5668 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/local/local_node.py @@ -0,0 +1,14 @@ +from abc import ABC +from typing import List, cast + +from feast.infra.compute_engines.dag.context import ExecutionContext +from feast.infra.compute_engines.dag.node import DAGNode +from feast.infra.compute_engines.local.arrow_table_value import ArrowTableValue + + +class LocalNode(DAGNode, ABC): + def get_single_table(self, context: ExecutionContext) -> ArrowTableValue: + return cast(ArrowTableValue, self.get_single_input_value(context)) + + def get_input_tables(self, context: ExecutionContext) -> List[ArrowTableValue]: + return [cast(ArrowTableValue, val) for val in self.get_input_values(context)] diff --git a/sdk/python/feast/infra/compute_engines/local/node.py b/sdk/python/feast/infra/compute_engines/local/node.py new file mode 100644 index 00000000000..0da4358e9bc --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/local/node.py @@ -0,0 +1,176 @@ +from datetime import timedelta +from typing import Optional + +import pyarrow as pa + +from feast.infra.compute_engines.dag.context import ExecutionContext +from feast.infra.compute_engines.local.arrow_table_value import ArrowTableValue +from feast.infra.compute_engines.local.backends.base import DataFrameBackend +from feast.infra.compute_engines.local.local_node import LocalNode + +ENTITY_TS_ALIAS = "__entity_event_timestamp" + + +class LocalSourceReadNode(LocalNode): + def __init__(self, name: str, feature_view, task): + super().__init__(name) + self.feature_view = feature_view + self.task = task + + def execute(self, context: ExecutionContext) -> ArrowTableValue: + df = self.feature_view.source.to_pandas() + table = pa.Table.from_pandas(df) + output = ArrowTableValue(table) + context.node_outputs[self.name] = output + return output + + +class LocalJoinNode(LocalNode): + def __init__(self, name: str, backend: DataFrameBackend): + super().__init__(name) + self.backend = backend + + def execute(self, context: ExecutionContext) -> ArrowTableValue: + feature_table = self.get_single_table(context).data + entity_table = pa.Table(context.entity_df) + feature_df = self.backend.from_arrow(feature_table) + entity_df = self.backend.from_arrow(entity_table) + + join_keys, feature_cols, ts_col, created_ts_col = context.column_info + + # Rename entity timestamp if needed + if ENTITY_TS_ALIAS in entity_df.columns and ENTITY_TS_ALIAS != ts_col: + entity_df = entity_df.rename(columns={ENTITY_TS_ALIAS: ts_col}) + joined_df = self.backend.join(feature_df, entity_df, on=join_keys, how="left") + result = self.backend.to_arrow(joined_df) + output = ArrowTableValue(result) + context.node_outputs[self.name] = output + return output + + +class LocalFilterNode(LocalNode): + def __init__( + self, + name: str, + backend: DataFrameBackend, + filter_expr: Optional[str] = None, + ttl: Optional[timedelta] = None, + ): + super().__init__(name) + self.backend = backend + self.filter_expr = filter_expr + self.ttl = ttl # in seconds + + def execute(self, context: ExecutionContext) -> ArrowTableValue: + input_table = self.get_single_table(context).data + df = self.backend.from_arrow(input_table) + + _, _, ts_col, _ = context.column_info + + if ENTITY_TS_ALIAS in self.backend.columns(df): + # filter where feature.ts <= entity.event_timestamp + df = df[df[ts_col] <= df[ENTITY_TS_ALIAS]] + + # TTL: feature.ts >= entity.event_timestamp - ttl + if self.ttl: + lower_bound = df[ENTITY_TS_ALIAS] - self.backend.to_timedelta_value( + self.ttl + ) + df = df[df[ts_col] >= lower_bound] + + # Optional user-defined filter expression (e.g., "value > 0") + if self.filter_expr: + df = self.backend.filter(df, self.filter_expr) + + result = self.backend.to_arrow(df) + output = ArrowTableValue(result) + context.node_outputs[self.name] = output + return output + + +class LocalAggregationNode(LocalNode): + def __init__(self, name: str, group_keys: list[str], agg_ops: dict, backend): + super().__init__(name) + self.group_keys = group_keys + self.agg_ops = agg_ops + self.backend = backend + + def execute(self, context: ExecutionContext) -> ArrowTableValue: + input_table = self.get_single_table(context).data + df = self.backend.from_arrow(input_table) + grouped_df = self.backend.groupby_agg(df, self.group_keys, self.agg_ops) + result = self.backend.to_arrow(grouped_df) + output = ArrowTableValue(result) + context.node_outputs[self.name] = output + return output + + +class LocalDedupNode(LocalNode): + def __init__(self, name: str, backend: DataFrameBackend): + super().__init__(name) + self.backend = backend + + def execute(self, context: ExecutionContext) -> ArrowTableValue: + input_table = self.get_single_table(context).data + df = self.backend.from_arrow(input_table) + + # Extract join_keys, timestamp, and created_ts from context + join_keys, _, ts_col, created_ts_col = context.column_info + + # Dedup strategy: sort and drop_duplicates + sort_keys = [ts_col] + if created_ts_col: + sort_keys.append(created_ts_col) + + dedup_keys = join_keys + [ENTITY_TS_ALIAS] + df = self.backend.drop_duplicates( + df, keys=dedup_keys, sort_by=sort_keys, ascending=False + ) + result = self.backend.to_arrow(df) + output = ArrowTableValue(result) + context.node_outputs[self.name] = output + return output + + +class LocalTransformationNode(LocalNode): + def __init__(self, name: str, transformation_fn, backend): + super().__init__(name) + self.transformation_fn = transformation_fn + self.backend = backend + + def execute(self, context: ExecutionContext) -> ArrowTableValue: + input_table = self.get_single_table(context).data + df = self.backend.from_arrow(input_table) + transformed_df = self.transformation_fn(df) + result = self.backend.to_arrow(transformed_df) + output = ArrowTableValue(result) + context.node_outputs[self.name] = output + return output + + +class LocalValidationNode(LocalNode): + def __init__(self, name: str, validation_config, backend): + super().__init__(name) + self.validation_config = validation_config + self.backend = backend + + def execute(self, context: ExecutionContext) -> ArrowTableValue: + input_table = self.get_single_table(context).data + df = self.backend.from_arrow(input_table) + # Placeholder for actual validation logic + if self.validation_config: + print(f"[Validation: {self.name}] Passed.") + result = self.backend.to_arrow(df) + output = ArrowTableValue(result) + context.node_outputs[self.name] = output + return output + + +class LocalOutputNode(LocalNode): + def __init__(self, name: str): + super().__init__(name) + + def execute(self, context: ExecutionContext) -> ArrowTableValue: + input_table = self.get_single_table(context).data + context.node_outputs[self.name] = input_table + return input_table diff --git a/sdk/python/feast/infra/compute_engines/spark/compute.py b/sdk/python/feast/infra/compute_engines/spark/compute.py index e6e6cc52971..73256078efe 100644 --- a/sdk/python/feast/infra/compute_engines/spark/compute.py +++ b/sdk/python/feast/infra/compute_engines/spark/compute.py @@ -42,7 +42,6 @@ def materialize(self, task: MaterializationTask) -> MaterializationJob: # ✅ 2. Construct Feature Builder and run it builder = SparkFeatureBuilder( spark_session=self.spark_session, - feature_view=task.feature_view, task=task, ) plan = builder.build() @@ -70,7 +69,6 @@ def get_historical_features(self, task: HistoricalRetrievalTask) -> RetrievalJob # ✅ 2. Construct Feature Builder and run it builder = SparkFeatureBuilder( spark_session=self.spark_session, - feature_view=task.feature_view, task=task, ) plan = builder.build() diff --git a/sdk/python/feast/infra/compute_engines/spark/feature_builder.py b/sdk/python/feast/infra/compute_engines/spark/feature_builder.py index e7efbfe1195..f9f0ac50023 100644 --- a/sdk/python/feast/infra/compute_engines/spark/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/spark/feature_builder.py @@ -2,7 +2,6 @@ from pyspark.sql import SparkSession -from feast import BatchFeatureView, FeatureView, StreamFeatureView from feast.infra.compute_engines.base import HistoricalRetrievalTask from feast.infra.compute_engines.feature_builder import FeatureBuilder from feast.infra.compute_engines.spark.node import ( @@ -22,10 +21,9 @@ class SparkFeatureBuilder(FeatureBuilder): def __init__( self, spark_session: SparkSession, - feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView], task: Union[MaterializationTask, HistoricalRetrievalTask], ): - super().__init__(feature_view, task) + super().__init__(task) self.spark_session = spark_session def build_source_node(self): @@ -42,17 +40,14 @@ def build_aggregation_node(self, input_node): agg_specs = self.feature_view.aggregations group_by_keys = self.feature_view.entities timestamp_col = self.feature_view.batch_source.timestamp_field - node = SparkAggregationNode( - "agg", input_node, agg_specs, group_by_keys, timestamp_col - ) + node = SparkAggregationNode("agg", agg_specs, group_by_keys, timestamp_col) + node.add_input(input_node) self.nodes.append(node) return node def build_join_node(self, input_node): - join_keys = self.feature_view.entities - node = SparkJoinNode( - "join", input_node, join_keys, self.feature_view, self.spark_session - ) + node = SparkJoinNode("join", self.spark_session) + node.add_input(input_node) self.nodes.append(node) return node @@ -61,22 +56,23 @@ def build_filter_node(self, input_node): if hasattr(self.feature_view, "filter"): filter_expr = self.feature_view.filter node = SparkFilterNode( - "filter", self.spark_session, input_node, self.feature_view, filter_expr + "filter", self.spark_session, self.feature_view, filter_expr ) + node.add_input(input_node) self.nodes.append(node) return node def build_dedup_node(self, input_node): - node = SparkDedupNode( - "dedup", input_node, self.feature_view, self.spark_session - ) + node = SparkDedupNode("dedup", self.spark_session) + node.add_input(input_node) self.nodes.append(node) return node def build_transformation_node(self, input_node): udf_name = self.feature_view.feature_transformation.name udf = self.feature_view.feature_transformation.udf - node = SparkTransformationNode(udf_name, input_node, udf) + node = SparkTransformationNode(udf_name, udf) + node.add_input(input_node) self.nodes.append(node) return node diff --git a/sdk/python/feast/infra/compute_engines/spark/node.py b/sdk/python/feast/infra/compute_engines/spark/node.py index e3f737a4fa6..63232a5ccad 100644 --- a/sdk/python/feast/infra/compute_engines/spark/node.py +++ b/sdk/python/feast/infra/compute_engines/spark/node.py @@ -179,13 +179,11 @@ class SparkAggregationNode(DAGNode): def __init__( self, name: str, - input_node: DAGNode, aggregations: List[Aggregation], group_by_keys: List[str], timestamp_col: str, ): super().__init__(name) - self.add_input(input_node) self.aggregations = aggregations self.group_by_keys = group_by_keys self.timestamp_col = timestamp_col @@ -233,15 +231,9 @@ class SparkJoinNode(DAGNode): def __init__( self, name: str, - feature_node: DAGNode, - join_keys: List[str], - feature_view: Union[BatchFeatureView, StreamFeatureView], spark_session: SparkSession, ): super().__init__(name) - self.join_keys = join_keys - self.add_input(feature_node) - self.feature_view = feature_view self.spark_session = spark_session def execute(self, context: ExecutionContext) -> DAGValue: @@ -274,14 +266,12 @@ def __init__( self, name: str, spark_session: SparkSession, - input_node: DAGNode, feature_view: Union[BatchFeatureView, StreamFeatureView], filter_condition: Optional[str] = None, ): super().__init__(name) self.spark_session = spark_session self.feature_view = feature_view - self.add_input(input_node) self.filter_condition = filter_condition def execute(self, context: ExecutionContext) -> DAGValue: @@ -320,13 +310,9 @@ class SparkDedupNode(DAGNode): def __init__( self, name: str, - input_node: DAGNode, - feature_view: Union[BatchFeatureView, StreamFeatureView], spark_session: SparkSession, ): super().__init__(name) - self.add_input(input_node) - self.feature_view = feature_view self.spark_session = spark_session def execute(self, context: ExecutionContext) -> DAGValue: @@ -398,9 +384,8 @@ def execute(self, context: ExecutionContext) -> DAGValue: class SparkTransformationNode(DAGNode): - def __init__(self, name: str, input_node: DAGNode, udf): + def __init__(self, name: str, udf): super().__init__(name) - self.add_input(input_node) self.udf = udf def execute(self, context: ExecutionContext) -> DAGValue: diff --git a/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py b/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py index afeea82008a..5dfc110499c 100644 --- a/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py +++ b/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py @@ -16,7 +16,6 @@ ) from tests.example_repos.example_feature_repo_with_bfvs import ( driver, - driver_hourly_stats_view, ) @@ -70,9 +69,9 @@ def strip_extra_spaces(df): # Create and run the node node = SparkTransformationNode( - "transform", input_node=MagicMock(), udf=strip_extra_spaces + "transform", udf=strip_extra_spaces ) - + node.add_input(MagicMock()) node.inputs[0].name = "source" result = node.execute(context) @@ -119,11 +118,11 @@ def test_spark_aggregation_node_executes_correctly(spark_session): # Create and configure node node = SparkAggregationNode( name="agg", - input_node=MagicMock(), aggregations=agg_specs, group_by_keys=["user_id"], timestamp_col="", ) + node.add_input(MagicMock()) node.inputs[0].name = "source" # Execute @@ -182,9 +181,6 @@ def test_spark_join_node_executes_point_in_time_join(spark_session): # Wrap as DAGValues feature_val = DAGValue(data=feature_df, format=DAGFormat.SPARK) - # Setup FeatureView mock with batch_source metadata - feature_view = driver_hourly_stats_view - # Set up context context = ExecutionContext( project="test_project", @@ -207,12 +203,10 @@ def test_spark_join_node_executes_point_in_time_join(spark_session): # Create the node and add input join_node = SparkJoinNode( name="join", - feature_node=MagicMock(name="feature_node"), - join_keys=["user_id"], - feature_view=feature_view, spark_session=spark_session, ) - join_node.inputs[0].name = "feature_node" # must match key in node_outputs + join_node.add_input(MagicMock()) + join_node.inputs[0].name = "feature_node" # Execute the node output = join_node.execute(context) @@ -220,11 +214,10 @@ def test_spark_join_node_executes_point_in_time_join(spark_session): dedup_node = SparkDedupNode( name="dedup", - input_node=join_node, - feature_view=feature_view, spark_session=spark_session, ) - dedup_node.inputs[0].name = "join" # must match key in node_outputs + dedup_node.add_input(MagicMock()) + dedup_node.inputs[0].name = "join" dedup_output = dedup_node.execute(context) result_df = dedup_output.data.orderBy("driver_id").collect() From daa8f55111849b5caef40fc9f18452e04d447be1 Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Wed, 16 Apr 2025 12:18:39 -0700 Subject: [PATCH 2/9] format code Signed-off-by: HaoXuAI --- sdk/python/feast/infra/compute_engines/local/node.py | 7 ++----- .../tests/unit/infra/compute_engines/spark/test_nodes.py | 4 +--- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/sdk/python/feast/infra/compute_engines/local/node.py b/sdk/python/feast/infra/compute_engines/local/node.py index 0da4358e9bc..fa23facccf4 100644 --- a/sdk/python/feast/infra/compute_engines/local/node.py +++ b/sdk/python/feast/infra/compute_engines/local/node.py @@ -18,11 +18,8 @@ def __init__(self, name: str, feature_view, task): self.task = task def execute(self, context: ExecutionContext) -> ArrowTableValue: - df = self.feature_view.source.to_pandas() - table = pa.Table.from_pandas(df) - output = ArrowTableValue(table) - context.node_outputs[self.name] = output - return output + # TODO : Implement the logic to read from offline store + return ArrowTableValue(data=pa.Table.from_pandas(context.entity_df)) class LocalJoinNode(LocalNode): diff --git a/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py b/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py index 5dfc110499c..ae69c0a6fcd 100644 --- a/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py +++ b/sdk/python/tests/unit/infra/compute_engines/spark/test_nodes.py @@ -68,9 +68,7 @@ def strip_extra_spaces(df): ) # Create and run the node - node = SparkTransformationNode( - "transform", udf=strip_extra_spaces - ) + node = SparkTransformationNode("transform", udf=strip_extra_spaces) node.add_input(MagicMock()) node.inputs[0].name = "source" result = node.execute(context) From d5f847a26e8cb3463179568584098a7d261f0d73 Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Wed, 16 Apr 2025 12:40:02 -0700 Subject: [PATCH 3/9] format code Signed-off-by: HaoXuAI --- sdk/python/feast/infra/common/__init__.py | 0 .../feast/infra/common/materialization_job.py | 57 ++++++++++++++++++ .../tasks.py => common/retrieval_task.py} | 0 .../infra/compute_engines/local/compute.py | 10 ++-- .../compute_engines/local/feature_builder.py | 4 +- .../infra/compute_engines/spark/compute.py | 12 ++-- .../compute_engines/spark/feature_builder.py | 4 +- .../aws_lambda/lambda_engine.py | 6 +- .../batch_materialization_engine.py | 59 ++----------------- .../spark/spark_materialization_engine.py | 6 +- .../kubernetes/k8s_materialization_engine.py | 6 +- .../kubernetes/k8s_materialization_job.py | 2 +- .../infra/materialization/local_engine.py | 8 ++- .../infra/materialization/snowflake_engine.py | 6 +- .../feast/infra/passthrough_provider.py | 6 +- .../compute_engines/spark/test_compute.py | 8 +-- 16 files changed, 107 insertions(+), 87 deletions(-) create mode 100644 sdk/python/feast/infra/common/__init__.py create mode 100644 sdk/python/feast/infra/common/materialization_job.py rename sdk/python/feast/infra/{compute_engines/tasks.py => common/retrieval_task.py} (100%) diff --git a/sdk/python/feast/infra/common/__init__.py b/sdk/python/feast/infra/common/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/sdk/python/feast/infra/common/materialization_job.py b/sdk/python/feast/infra/common/materialization_job.py new file mode 100644 index 00000000000..2d3105d22d8 --- /dev/null +++ b/sdk/python/feast/infra/common/materialization_job.py @@ -0,0 +1,57 @@ +import enum +from abc import ABC, abstractmethod +from dataclasses import dataclass +from datetime import datetime +from typing import Callable, Optional, Union + +from tqdm import tqdm + +from feast import BatchFeatureView, FeatureView, StreamFeatureView + + +@dataclass +class MaterializationTask: + """ + A MaterializationTask represents a unit of data that needs to be materialized from an + offline store to an online store. + """ + + project: str + feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView] + start_time: datetime + end_time: datetime + tqdm_builder: Callable[[int], tqdm] + + +class MaterializationJobStatus(enum.Enum): + WAITING = 1 + RUNNING = 2 + AVAILABLE = 3 + ERROR = 4 + CANCELLING = 5 + CANCELLED = 6 + SUCCEEDED = 7 + + +class MaterializationJob(ABC): + """ + A MaterializationJob represents an ongoing or executed process that materializes data as per the + definition of a materialization task. + """ + + task: MaterializationTask + + @abstractmethod + def status(self) -> MaterializationJobStatus: ... + + @abstractmethod + def error(self) -> Optional[BaseException]: ... + + @abstractmethod + def should_be_retried(self) -> bool: ... + + @abstractmethod + def job_id(self) -> str: ... + + @abstractmethod + def url(self) -> Optional[str]: ... diff --git a/sdk/python/feast/infra/compute_engines/tasks.py b/sdk/python/feast/infra/common/retrieval_task.py similarity index 100% rename from sdk/python/feast/infra/compute_engines/tasks.py rename to sdk/python/feast/infra/common/retrieval_task.py diff --git a/sdk/python/feast/infra/compute_engines/local/compute.py b/sdk/python/feast/infra/compute_engines/local/compute.py index 787439c84ea..5b5fa7c06ab 100644 --- a/sdk/python/feast/infra/compute_engines/local/compute.py +++ b/sdk/python/feast/infra/compute_engines/local/compute.py @@ -1,16 +1,16 @@ from typing import Optional +from feast.infra.common.materialization_job import ( + MaterializationJobStatus, + MaterializationTask, +) +from feast.infra.common.retrieval_task import HistoricalRetrievalTask from feast.infra.compute_engines.base import ComputeEngine from feast.infra.compute_engines.dag.context import ExecutionContext from feast.infra.compute_engines.local.backends.base import DataFrameBackend from feast.infra.compute_engines.local.backends.factory import BackendFactory from feast.infra.compute_engines.local.feature_builder import LocalFeatureBuilder from feast.infra.compute_engines.local.job import LocalRetrievalJob -from feast.infra.compute_engines.tasks import HistoricalRetrievalTask -from feast.infra.materialization.batch_materialization_engine import ( - MaterializationJobStatus, - MaterializationTask, -) from feast.infra.materialization.local_engine import LocalMaterializationJob diff --git a/sdk/python/feast/infra/compute_engines/local/feature_builder.py b/sdk/python/feast/infra/compute_engines/local/feature_builder.py index 5006e97163a..0cf22735c04 100644 --- a/sdk/python/feast/infra/compute_engines/local/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/local/feature_builder.py @@ -1,5 +1,7 @@ from typing import Union +from feast.infra.common.materialization_job import MaterializationTask +from feast.infra.common.retrieval_job import HistoricalRetrievalTask from feast.infra.compute_engines.dag.plan import ExecutionPlan from feast.infra.compute_engines.feature_builder import FeatureBuilder from feast.infra.compute_engines.local.backends.base import DataFrameBackend @@ -13,8 +15,6 @@ LocalTransformationNode, LocalValidationNode, ) -from feast.infra.compute_engines.tasks import HistoricalRetrievalTask -from feast.infra.materialization.batch_materialization_engine import MaterializationTask class LocalFeatureBuilder(FeatureBuilder): diff --git a/sdk/python/feast/infra/compute_engines/spark/compute.py b/sdk/python/feast/infra/compute_engines/spark/compute.py index 73256078efe..67ff535d8ee 100644 --- a/sdk/python/feast/infra/compute_engines/spark/compute.py +++ b/sdk/python/feast/infra/compute_engines/spark/compute.py @@ -1,13 +1,13 @@ -from feast.infra.compute_engines.base import ComputeEngine -from feast.infra.compute_engines.spark.feature_builder import SparkFeatureBuilder -from feast.infra.compute_engines.spark.job import SparkDAGRetrievalJob -from feast.infra.compute_engines.spark.utils import get_or_create_new_spark_session -from feast.infra.compute_engines.tasks import HistoricalRetrievalTask -from feast.infra.materialization.batch_materialization_engine import ( +from feast.infra.common.materialization_job import ( MaterializationJob, MaterializationJobStatus, MaterializationTask, ) +from feast.infra.common.retrieval_job import HistoricalRetrievalTask +from feast.infra.compute_engines.base import ComputeEngine +from feast.infra.compute_engines.spark.feature_builder import SparkFeatureBuilder +from feast.infra.compute_engines.spark.job import SparkDAGRetrievalJob +from feast.infra.compute_engines.spark.utils import get_or_create_new_spark_session from feast.infra.materialization.contrib.spark.spark_materialization_engine import ( SparkMaterializationJob, ) diff --git a/sdk/python/feast/infra/compute_engines/spark/feature_builder.py b/sdk/python/feast/infra/compute_engines/spark/feature_builder.py index f9f0ac50023..d3f04c38a5d 100644 --- a/sdk/python/feast/infra/compute_engines/spark/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/spark/feature_builder.py @@ -2,7 +2,8 @@ from pyspark.sql import SparkSession -from feast.infra.compute_engines.base import HistoricalRetrievalTask +from feast.infra.common.materialization_job import MaterializationTask +from feast.infra.common.retrieval_job import HistoricalRetrievalTask from feast.infra.compute_engines.feature_builder import FeatureBuilder from feast.infra.compute_engines.spark.node import ( SparkAggregationNode, @@ -14,7 +15,6 @@ SparkTransformationNode, SparkWriteNode, ) -from feast.infra.materialization.batch_materialization_engine import MaterializationTask class SparkFeatureBuilder(FeatureBuilder): diff --git a/sdk/python/feast/infra/materialization/aws_lambda/lambda_engine.py b/sdk/python/feast/infra/materialization/aws_lambda/lambda_engine.py index 2864012055b..d686ba99394 100644 --- a/sdk/python/feast/infra/materialization/aws_lambda/lambda_engine.py +++ b/sdk/python/feast/infra/materialization/aws_lambda/lambda_engine.py @@ -15,12 +15,14 @@ from feast.constants import FEATURE_STORE_YAML_ENV_NAME from feast.entity import Entity from feast.feature_view import FeatureView -from feast.infra.materialization.batch_materialization_engine import ( - BatchMaterializationEngine, +from feast.infra.common.materialization_job import ( MaterializationJob, MaterializationJobStatus, MaterializationTask, ) +from feast.infra.materialization.batch_materialization_engine import ( + BatchMaterializationEngine, +) from feast.infra.offline_stores.offline_store import OfflineStore from feast.infra.online_stores.online_store import OnlineStore from feast.infra.registry.base_registry import BaseRegistry diff --git a/sdk/python/feast/infra/materialization/batch_materialization_engine.py b/sdk/python/feast/infra/materialization/batch_materialization_engine.py index af92b95d175..17bc6134cdb 100644 --- a/sdk/python/feast/infra/materialization/batch_materialization_engine.py +++ b/sdk/python/feast/infra/materialization/batch_materialization_engine.py @@ -1,14 +1,13 @@ -import enum from abc import ABC, abstractmethod -from dataclasses import dataclass -from datetime import datetime -from typing import Callable, List, Optional, Sequence, Union - -from tqdm import tqdm +from typing import List, Sequence, Union from feast.batch_feature_view import BatchFeatureView from feast.entity import Entity from feast.feature_view import FeatureView +from feast.infra.common.materialization_job import ( + MaterializationJob, + MaterializationTask, +) from feast.infra.offline_stores.offline_store import OfflineStore from feast.infra.online_stores.online_store import OnlineStore from feast.infra.registry.base_registry import BaseRegistry @@ -17,54 +16,6 @@ from feast.stream_feature_view import StreamFeatureView -@dataclass -class MaterializationTask: - """ - A MaterializationTask represents a unit of data that needs to be materialized from an - offline store to an online store. - """ - - project: str - feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView] - start_time: datetime - end_time: datetime - tqdm_builder: Callable[[int], tqdm] - - -class MaterializationJobStatus(enum.Enum): - WAITING = 1 - RUNNING = 2 - AVAILABLE = 3 - ERROR = 4 - CANCELLING = 5 - CANCELLED = 6 - SUCCEEDED = 7 - - -class MaterializationJob(ABC): - """ - A MaterializationJob represents an ongoing or executed process that materializes data as per the - definition of a materialization task. - """ - - task: MaterializationTask - - @abstractmethod - def status(self) -> MaterializationJobStatus: ... - - @abstractmethod - def error(self) -> Optional[BaseException]: ... - - @abstractmethod - def should_be_retried(self) -> bool: ... - - @abstractmethod - def job_id(self) -> str: ... - - @abstractmethod - def url(self) -> Optional[str]: ... - - class BatchMaterializationEngine(ABC): """ The interface that Feast uses to control the compute system that handles batch materialization. diff --git a/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py b/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py index 53b29cdfc0f..a885dffe48c 100644 --- a/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py +++ b/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py @@ -10,11 +10,13 @@ from feast.batch_feature_view import BatchFeatureView from feast.entity import Entity from feast.feature_view import FeatureView +from feast.infra.common.materialization_job import ( + MaterializationJobStatus, + MaterializationTask, +) from feast.infra.materialization.batch_materialization_engine import ( BatchMaterializationEngine, MaterializationJob, - MaterializationJobStatus, - MaterializationTask, ) from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( SparkOfflineStore, diff --git a/sdk/python/feast/infra/materialization/kubernetes/k8s_materialization_engine.py b/sdk/python/feast/infra/materialization/kubernetes/k8s_materialization_engine.py index 96064409459..adf14eaf419 100644 --- a/sdk/python/feast/infra/materialization/kubernetes/k8s_materialization_engine.py +++ b/sdk/python/feast/infra/materialization/kubernetes/k8s_materialization_engine.py @@ -15,12 +15,14 @@ from feast import FeatureView, RepoConfig from feast.batch_feature_view import BatchFeatureView from feast.entity import Entity -from feast.infra.materialization.batch_materialization_engine import ( - BatchMaterializationEngine, +from feast.infra.common.materialization_job import ( MaterializationJob, MaterializationJobStatus, MaterializationTask, ) +from feast.infra.materialization.batch_materialization_engine import ( + BatchMaterializationEngine, +) from feast.infra.offline_stores.offline_store import OfflineStore from feast.infra.online_stores.online_store import OnlineStore from feast.infra.registry.base_registry import BaseRegistry diff --git a/sdk/python/feast/infra/materialization/kubernetes/k8s_materialization_job.py b/sdk/python/feast/infra/materialization/kubernetes/k8s_materialization_job.py index 612b20155d4..01cd26ccfc7 100644 --- a/sdk/python/feast/infra/materialization/kubernetes/k8s_materialization_job.py +++ b/sdk/python/feast/infra/materialization/kubernetes/k8s_materialization_job.py @@ -2,9 +2,9 @@ from kubernetes import client +from feast.infra.common.materialization_job import MaterializationJobStatus from feast.infra.materialization.batch_materialization_engine import ( MaterializationJob, - MaterializationJobStatus, ) diff --git a/sdk/python/feast/infra/materialization/local_engine.py b/sdk/python/feast/infra/materialization/local_engine.py index fa60950f298..ed71d11586d 100644 --- a/sdk/python/feast/infra/materialization/local_engine.py +++ b/sdk/python/feast/infra/materialization/local_engine.py @@ -7,6 +7,11 @@ from feast.batch_feature_view import BatchFeatureView from feast.entity import Entity from feast.feature_view import FeatureView +from feast.infra.common.materialization_job import ( + MaterializationJob, + MaterializationJobStatus, + MaterializationTask, +) from feast.infra.offline_stores.offline_store import OfflineStore from feast.infra.online_stores.online_store import OnlineStore from feast.infra.registry.base_registry import BaseRegistry @@ -21,9 +26,6 @@ from .batch_materialization_engine import ( BatchMaterializationEngine, - MaterializationJob, - MaterializationJobStatus, - MaterializationTask, ) DEFAULT_BATCH_SIZE = 10_000 diff --git a/sdk/python/feast/infra/materialization/snowflake_engine.py b/sdk/python/feast/infra/materialization/snowflake_engine.py index 2b18515ae44..5c4f30ec206 100644 --- a/sdk/python/feast/infra/materialization/snowflake_engine.py +++ b/sdk/python/feast/infra/materialization/snowflake_engine.py @@ -14,11 +14,13 @@ from feast.batch_feature_view import BatchFeatureView from feast.entity import Entity from feast.feature_view import DUMMY_ENTITY_ID, FeatureView +from feast.infra.common.materialization_job import ( + MaterializationJobStatus, + MaterializationTask, +) from feast.infra.materialization.batch_materialization_engine import ( BatchMaterializationEngine, MaterializationJob, - MaterializationJobStatus, - MaterializationTask, ) from feast.infra.offline_stores.offline_store import OfflineStore from feast.infra.online_stores.online_store import OnlineStore diff --git a/sdk/python/feast/infra/passthrough_provider.py b/sdk/python/feast/infra/passthrough_provider.py index f5df0f2eb1a..b30e695de52 100644 --- a/sdk/python/feast/infra/passthrough_provider.py +++ b/sdk/python/feast/infra/passthrough_provider.py @@ -24,11 +24,13 @@ from feast.feature_logging import FeatureServiceLoggingSource from feast.feature_service import FeatureService from feast.feature_view import FeatureView +from feast.infra.common.materialization_job import ( + MaterializationJobStatus, + MaterializationTask, +) from feast.infra.infra_object import Infra, InfraObject from feast.infra.materialization.batch_materialization_engine import ( BatchMaterializationEngine, - MaterializationJobStatus, - MaterializationTask, ) from feast.infra.offline_stores.offline_store import RetrievalJob from feast.infra.offline_stores.offline_utils import get_offline_store_from_config diff --git a/sdk/python/tests/integration/compute_engines/spark/test_compute.py b/sdk/python/tests/integration/compute_engines/spark/test_compute.py index b8046c12296..f3134395c3b 100644 --- a/sdk/python/tests/integration/compute_engines/spark/test_compute.py +++ b/sdk/python/tests/integration/compute_engines/spark/test_compute.py @@ -10,13 +10,13 @@ from feast import BatchFeatureView, Entity, Field from feast.aggregation import Aggregation from feast.data_source import DataSource -from feast.infra.compute_engines.spark.compute import SparkComputeEngine -from feast.infra.compute_engines.spark.job import SparkDAGRetrievalJob -from feast.infra.compute_engines.tasks import HistoricalRetrievalTask -from feast.infra.materialization.batch_materialization_engine import ( +from feast.infra.common.materialization_job import ( MaterializationJobStatus, MaterializationTask, ) +from feast.infra.common.retrieval_job import HistoricalRetrievalTask +from feast.infra.compute_engines.spark.compute import SparkComputeEngine +from feast.infra.compute_engines.spark.job import SparkDAGRetrievalJob from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( SparkOfflineStore, ) From 107cc33dfab52605269f7fc4507050c929cbc062 Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Wed, 16 Apr 2025 12:46:55 -0700 Subject: [PATCH 4/9] format code Signed-off-by: HaoXuAI --- .../feast/infra/compute_engines/local/feature_builder.py | 2 +- sdk/python/feast/infra/compute_engines/spark/compute.py | 2 +- .../feast/infra/compute_engines/spark/feature_builder.py | 2 +- sdk/python/feast/infra/compute_engines/spark/node.py | 4 ++-- .../contrib/spark/spark_materialization_engine.py | 2 +- .../materialization/kubernetes/k8s_materialization_job.py | 4 ++-- .../materialization/kubernetes/k8s_materialization_task.py | 2 +- sdk/python/feast/infra/materialization/snowflake_engine.py | 2 +- 8 files changed, 10 insertions(+), 10 deletions(-) diff --git a/sdk/python/feast/infra/compute_engines/local/feature_builder.py b/sdk/python/feast/infra/compute_engines/local/feature_builder.py index 0cf22735c04..582a7e2ba96 100644 --- a/sdk/python/feast/infra/compute_engines/local/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/local/feature_builder.py @@ -1,7 +1,7 @@ from typing import Union from feast.infra.common.materialization_job import MaterializationTask -from feast.infra.common.retrieval_job import HistoricalRetrievalTask +from feast.infra.common.retrieval_task import HistoricalRetrievalTask from feast.infra.compute_engines.dag.plan import ExecutionPlan from feast.infra.compute_engines.feature_builder import FeatureBuilder from feast.infra.compute_engines.local.backends.base import DataFrameBackend diff --git a/sdk/python/feast/infra/compute_engines/spark/compute.py b/sdk/python/feast/infra/compute_engines/spark/compute.py index 67ff535d8ee..981e786cf7f 100644 --- a/sdk/python/feast/infra/compute_engines/spark/compute.py +++ b/sdk/python/feast/infra/compute_engines/spark/compute.py @@ -3,7 +3,7 @@ MaterializationJobStatus, MaterializationTask, ) -from feast.infra.common.retrieval_job import HistoricalRetrievalTask +from feast.infra.common.retrieval_task import HistoricalRetrievalTask from feast.infra.compute_engines.base import ComputeEngine from feast.infra.compute_engines.spark.feature_builder import SparkFeatureBuilder from feast.infra.compute_engines.spark.job import SparkDAGRetrievalJob diff --git a/sdk/python/feast/infra/compute_engines/spark/feature_builder.py b/sdk/python/feast/infra/compute_engines/spark/feature_builder.py index d3f04c38a5d..b6b4f2164d1 100644 --- a/sdk/python/feast/infra/compute_engines/spark/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/spark/feature_builder.py @@ -3,7 +3,7 @@ from pyspark.sql import SparkSession from feast.infra.common.materialization_job import MaterializationTask -from feast.infra.common.retrieval_job import HistoricalRetrievalTask +from feast.infra.common.retrieval_task import HistoricalRetrievalTask from feast.infra.compute_engines.feature_builder import FeatureBuilder from feast.infra.compute_engines.spark.node import ( SparkAggregationNode, diff --git a/sdk/python/feast/infra/compute_engines/spark/node.py b/sdk/python/feast/infra/compute_engines/spark/node.py index 63232a5ccad..7b0f2b2b574 100644 --- a/sdk/python/feast/infra/compute_engines/spark/node.py +++ b/sdk/python/feast/infra/compute_engines/spark/node.py @@ -7,12 +7,12 @@ from feast import BatchFeatureView, StreamFeatureView from feast.aggregation import Aggregation +from feast.infra.common.materialization_job import MaterializationTask +from feast.infra.common.retrieval_task import HistoricalRetrievalTask from feast.infra.compute_engines.dag.context import ExecutionContext from feast.infra.compute_engines.dag.model import DAGFormat from feast.infra.compute_engines.dag.node import DAGNode from feast.infra.compute_engines.dag.value import DAGValue -from feast.infra.compute_engines.tasks import HistoricalRetrievalTask -from feast.infra.materialization.batch_materialization_engine import MaterializationTask from feast.infra.materialization.contrib.spark.spark_materialization_engine import ( _map_by_partition, _SparkSerializedArtifacts, diff --git a/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py b/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py index a885dffe48c..246297cc1d5 100644 --- a/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py +++ b/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py @@ -11,12 +11,12 @@ from feast.entity import Entity from feast.feature_view import FeatureView from feast.infra.common.materialization_job import ( + MaterializationJob, MaterializationJobStatus, MaterializationTask, ) from feast.infra.materialization.batch_materialization_engine import ( BatchMaterializationEngine, - MaterializationJob, ) from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( SparkOfflineStore, diff --git a/sdk/python/feast/infra/materialization/kubernetes/k8s_materialization_job.py b/sdk/python/feast/infra/materialization/kubernetes/k8s_materialization_job.py index 01cd26ccfc7..2e46d2ad49d 100644 --- a/sdk/python/feast/infra/materialization/kubernetes/k8s_materialization_job.py +++ b/sdk/python/feast/infra/materialization/kubernetes/k8s_materialization_job.py @@ -2,9 +2,9 @@ from kubernetes import client -from feast.infra.common.materialization_job import MaterializationJobStatus -from feast.infra.materialization.batch_materialization_engine import ( +from feast.infra.common.materialization_job import ( MaterializationJob, + MaterializationJobStatus, ) diff --git a/sdk/python/feast/infra/materialization/kubernetes/k8s_materialization_task.py b/sdk/python/feast/infra/materialization/kubernetes/k8s_materialization_task.py index 607dcb5b260..0372f162f02 100644 --- a/sdk/python/feast/infra/materialization/kubernetes/k8s_materialization_task.py +++ b/sdk/python/feast/infra/materialization/kubernetes/k8s_materialization_task.py @@ -1,4 +1,4 @@ -from feast.infra.materialization.batch_materialization_engine import MaterializationTask +from feast.infra.common.materialization_job import MaterializationTask class KubernetesMaterializationTask(MaterializationTask): diff --git a/sdk/python/feast/infra/materialization/snowflake_engine.py b/sdk/python/feast/infra/materialization/snowflake_engine.py index 5c4f30ec206..9c535c334e3 100644 --- a/sdk/python/feast/infra/materialization/snowflake_engine.py +++ b/sdk/python/feast/infra/materialization/snowflake_engine.py @@ -15,12 +15,12 @@ from feast.entity import Entity from feast.feature_view import DUMMY_ENTITY_ID, FeatureView from feast.infra.common.materialization_job import ( + MaterializationJob, MaterializationJobStatus, MaterializationTask, ) from feast.infra.materialization.batch_materialization_engine import ( BatchMaterializationEngine, - MaterializationJob, ) from feast.infra.offline_stores.offline_store import OfflineStore from feast.infra.online_stores.online_store import OnlineStore From 2dd267b3d0eae23565a7b57913d483fae6d316ec Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Wed, 16 Apr 2025 21:17:52 -0700 Subject: [PATCH 5/9] update backend Signed-off-by: HaoXuAI --- .../local/arrow_table_value.py | 2 +- .../compute_engines/local/backends/base.py | 3 + .../local/backends/pandas_backend.py | 14 +- .../local/backends/polars_backend.py | 3 + .../compute_engines/local/feature_builder.py | 4 +- .../local/{node.py => nodes.py} | 31 ++- .../infra/compute_engines/local/__init__.py | 0 .../infra/compute_engines/local/test_nodes.py | 201 ++++++++++++++++++ 8 files changed, 247 insertions(+), 11 deletions(-) rename sdk/python/feast/infra/compute_engines/local/{node.py => nodes.py} (88%) create mode 100644 sdk/python/tests/unit/infra/compute_engines/local/__init__.py create mode 100644 sdk/python/tests/unit/infra/compute_engines/local/test_nodes.py diff --git a/sdk/python/feast/infra/compute_engines/local/arrow_table_value.py b/sdk/python/feast/infra/compute_engines/local/arrow_table_value.py index 52315ac7d4b..cbd0c9f37ec 100644 --- a/sdk/python/feast/infra/compute_engines/local/arrow_table_value.py +++ b/sdk/python/feast/infra/compute_engines/local/arrow_table_value.py @@ -1,6 +1,6 @@ import pyarrow as pa -from infra.compute_engines.dag.model import DAGFormat +from feast.infra.compute_engines.dag.model import DAGFormat from feast.infra.compute_engines.dag.value import DAGValue diff --git a/sdk/python/feast/infra/compute_engines/local/backends/base.py b/sdk/python/feast/infra/compute_engines/local/backends/base.py index 6fde7c8b0d7..279a434b577 100644 --- a/sdk/python/feast/infra/compute_engines/local/backends/base.py +++ b/sdk/python/feast/infra/compute_engines/local/backends/base.py @@ -27,3 +27,6 @@ def to_timedelta_value(self, delta: timedelta): ... @abstractmethod def drop_duplicates(self, df, keys, sort_by, ascending: bool = False): pass + + @abstractmethod + def rename_columns(self, df, columns: dict[str, str]): ... diff --git a/sdk/python/feast/infra/compute_engines/local/backends/pandas_backend.py b/sdk/python/feast/infra/compute_engines/local/backends/pandas_backend.py index cf67d46e70e..76ddd688424 100644 --- a/sdk/python/feast/infra/compute_engines/local/backends/pandas_backend.py +++ b/sdk/python/feast/infra/compute_engines/local/backends/pandas_backend.py @@ -17,7 +17,16 @@ def join(self, left, right, on, how): return left.merge(right, on=on, how=how) def groupby_agg(self, df, group_keys, agg_ops): - return df.groupby(group_keys).agg(agg_ops).reset_index() + return ( + df.groupby(group_keys) + .agg( + **{ + alias: pd.NamedAgg(column=col, aggfunc=func) + for alias, (func, col) in agg_ops.items() + } + ) + .reset_index() + ) def filter(self, df, expr): return df.query(expr) @@ -32,3 +41,6 @@ def drop_duplicates(self, df, keys, sort_by, ascending: bool = False): return df.sort_values(by=sort_by, ascending=ascending).drop_duplicates( subset=keys ) + + def rename_columns(self, df, columns: dict[str, str]): + return df.rename(columns=columns) diff --git a/sdk/python/feast/infra/compute_engines/local/backends/polars_backend.py b/sdk/python/feast/infra/compute_engines/local/backends/polars_backend.py index bac780d7d47..5c874aa4433 100644 --- a/sdk/python/feast/infra/compute_engines/local/backends/polars_backend.py +++ b/sdk/python/feast/infra/compute_engines/local/backends/polars_backend.py @@ -42,3 +42,6 @@ def drop_duplicates( return df.sort(by=sort_by, descending=not ascending).unique( subset=keys, keep="first" ) + + def rename_columns(self, df: pl.DataFrame, columns: dict[str, str]) -> pl.DataFrame: + return df.rename(columns) diff --git a/sdk/python/feast/infra/compute_engines/local/feature_builder.py b/sdk/python/feast/infra/compute_engines/local/feature_builder.py index 582a7e2ba96..74f1b248222 100644 --- a/sdk/python/feast/infra/compute_engines/local/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/local/feature_builder.py @@ -5,7 +5,7 @@ from feast.infra.compute_engines.dag.plan import ExecutionPlan from feast.infra.compute_engines.feature_builder import FeatureBuilder from feast.infra.compute_engines.local.backends.base import DataFrameBackend -from feast.infra.compute_engines.local.node import ( +from feast.infra.compute_engines.local.nodes import ( LocalAggregationNode, LocalDedupNode, LocalFilterNode, @@ -58,7 +58,7 @@ def build_aggregation_node(self, input_node): alias = f"{agg.function}_{agg.column}" agg_ops[alias] = (agg.function, agg.column) group_by_keys = self.feature_view.entities - node = LocalAggregationNode("agg", group_by_keys, agg_ops, self.backend) + node = LocalAggregationNode("agg", self.backend, group_by_keys, agg_ops) node.add_input(input_node) self.nodes.append(node) return node diff --git a/sdk/python/feast/infra/compute_engines/local/node.py b/sdk/python/feast/infra/compute_engines/local/nodes.py similarity index 88% rename from sdk/python/feast/infra/compute_engines/local/node.py rename to sdk/python/feast/infra/compute_engines/local/nodes.py index fa23facccf4..4e1d2c3362f 100644 --- a/sdk/python/feast/infra/compute_engines/local/node.py +++ b/sdk/python/feast/infra/compute_engines/local/nodes.py @@ -7,6 +7,9 @@ from feast.infra.compute_engines.local.arrow_table_value import ArrowTableValue from feast.infra.compute_engines.local.backends.base import DataFrameBackend from feast.infra.compute_engines.local.local_node import LocalNode +from feast.infra.offline_stores.offline_utils import ( + infer_event_timestamp_from_entity_df, +) ENTITY_TS_ALIAS = "__entity_event_timestamp" @@ -29,15 +32,26 @@ def __init__(self, name: str, backend: DataFrameBackend): def execute(self, context: ExecutionContext) -> ArrowTableValue: feature_table = self.get_single_table(context).data - entity_table = pa.Table(context.entity_df) + + if context.entity_df is None: + context.node_outputs[self.name] = feature_table + return feature_table + + entity_table = pa.Table.from_pandas(context.entity_df) feature_df = self.backend.from_arrow(feature_table) entity_df = self.backend.from_arrow(entity_table) + entity_schema = dict(zip(entity_df.columns, entity_df.dtypes)) + entity_df_event_timestamp_col = infer_event_timestamp_from_entity_df( + entity_schema + ) + join_keys, feature_cols, ts_col, created_ts_col = context.column_info - # Rename entity timestamp if needed - if ENTITY_TS_ALIAS in entity_df.columns and ENTITY_TS_ALIAS != ts_col: - entity_df = entity_df.rename(columns={ENTITY_TS_ALIAS: ts_col}) + entity_df = self.backend.rename_columns( + entity_df, {entity_df_event_timestamp_col: ENTITY_TS_ALIAS} + ) + joined_df = self.backend.join(feature_df, entity_df, on=join_keys, how="left") result = self.backend.to_arrow(joined_df) output = ArrowTableValue(result) @@ -56,7 +70,7 @@ def __init__( super().__init__(name) self.backend = backend self.filter_expr = filter_expr - self.ttl = ttl # in seconds + self.ttl = ttl def execute(self, context: ExecutionContext) -> ArrowTableValue: input_table = self.get_single_table(context).data @@ -86,11 +100,13 @@ def execute(self, context: ExecutionContext) -> ArrowTableValue: class LocalAggregationNode(LocalNode): - def __init__(self, name: str, group_keys: list[str], agg_ops: dict, backend): + def __init__( + self, name: str, backend: DataFrameBackend, group_keys: list[str], agg_ops: dict + ): super().__init__(name) + self.backend = backend self.group_keys = group_keys self.agg_ops = agg_ops - self.backend = backend def execute(self, context: ExecutionContext) -> ArrowTableValue: input_table = self.get_single_table(context).data @@ -170,4 +186,5 @@ def __init__(self, name: str): def execute(self, context: ExecutionContext) -> ArrowTableValue: input_table = self.get_single_table(context).data context.node_outputs[self.name] = input_table + # TODO: implement the logic to write to offline store return input_table diff --git a/sdk/python/tests/unit/infra/compute_engines/local/__init__.py b/sdk/python/tests/unit/infra/compute_engines/local/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/sdk/python/tests/unit/infra/compute_engines/local/test_nodes.py b/sdk/python/tests/unit/infra/compute_engines/local/test_nodes.py new file mode 100644 index 00000000000..d5f6085a67c --- /dev/null +++ b/sdk/python/tests/unit/infra/compute_engines/local/test_nodes.py @@ -0,0 +1,201 @@ +from datetime import timedelta +from unittest.mock import MagicMock + +import pandas as pd +import pyarrow as pa + +from feast.infra.compute_engines.dag.context import ColumnInfo, ExecutionContext +from feast.infra.compute_engines.local.arrow_table_value import ArrowTableValue +from feast.infra.compute_engines.local.backends.pandas_backend import PandasBackend +from feast.infra.compute_engines.local.nodes import ( + LocalAggregationNode, + LocalDedupNode, + LocalFilterNode, + LocalJoinNode, + LocalOutputNode, + LocalTransformationNode, +) + +backend = PandasBackend() +now = pd.Timestamp.utcnow() + +sample_df = pd.DataFrame( + { + "entity_id": [1, 1, 2, 2], + "value": [10, 20, 30, 40], + "event_timestamp": [ + now, + now - timedelta(minutes=1), + now, + now - timedelta(minutes=5), + ], + } +) + +entity_df = pd.DataFrame({"entity_id": [1, 2], "event_timestamp": [now, now]}) + + +def create_context(node_outputs): + # Setup execution context + return ExecutionContext( + project="test_proj", + repo_config=MagicMock(), + offline_store=MagicMock(), + online_store=MagicMock(), + entity_defs=MagicMock(), + entity_df=entity_df, + node_outputs=node_outputs, + column_info=ColumnInfo( + join_keys=["entity_id"], + feature_cols=["value"], + ts_col="event_timestamp", + created_ts_col=None, + ), + ) + + +def test_local_filter_node(): + context = create_context( + node_outputs={"source": ArrowTableValue(pa.Table.from_pandas(sample_df))} + ) + + # Create filter node and connect input + filter_node = LocalFilterNode( + name="filter", + backend=backend, + filter_expr="value > 15", + ) + filter_node.add_input(MagicMock()) + filter_node.inputs[0].name = "source" + + # Execute and validate + result = filter_node.execute(context) + assert isinstance(result, ArrowTableValue) + assert result.data.num_rows == 3 + + +def test_local_aggregation_node(): + context = create_context( + node_outputs={"source": ArrowTableValue(pa.Table.from_pandas(sample_df))} + ) + + # Create aggregation node and connect input + agg_ops = { + "sum_value": ("sum", "value"), + } + agg_node = LocalAggregationNode( + name="agg", + backend=backend, + group_keys=["entity_id"], + agg_ops=agg_ops, + ) + agg_node.add_input(MagicMock()) + agg_node.inputs[0].name = "source" + + # Execute and validate + result = agg_node.execute(context) + assert isinstance(result, ArrowTableValue) + assert result.data.num_rows == 2 + result_df = result.data.to_pandas() + assert result_df["sum_value"].iloc[0] == 30 + assert result_df["sum_value"].iloc[1] == 70 + + +def test_local_join_node(): + context = create_context( + node_outputs={"source": ArrowTableValue(pa.Table.from_pandas(sample_df))} + ) + + # Create join node and connect input + join_node = LocalJoinNode( + name="join", + backend=backend, + ) + join_node.add_input(MagicMock()) + join_node.inputs[0].name = "source" + + # Execute and validate + result = join_node.execute(context) + assert isinstance(result, ArrowTableValue) + assert result.data.num_rows == 4 + result_df = result.data.to_pandas() + assert all(result_df["entity_id"].isin([1, 2])) + assert "__entity_event_timestamp" in result_df.columns + + +def test_local_dedup_node(): + # Duplicate rows for each entity with different event and created timestamps + df = pd.DataFrame( + { + "entity_id": [1, 1, 2, 2], + "value": [100, 200, 300, 400], + "event_timestamp": [ + now - timedelta(seconds=1), + now, + now - timedelta(seconds=1), + now, + ], + "created_ts": [ + now - timedelta(seconds=1), + now, + now, + now - timedelta(seconds=2), + ], + "__entity_event_timestamp": [ + now, + now, + now - timedelta(seconds=1), + now - timedelta(seconds=1), + ], + } + ) + + # Register DataFrame in context + table = pa.Table.from_pandas(df) + context = create_context(node_outputs={"source": ArrowTableValue(table)}) + context.entity_timestamp_col = "event_timestamp" + + # Build node + node = LocalDedupNode(name="dedup", backend=backend) + node.add_input(MagicMock()) + node.inputs[0].name = "source" + + result = node.execute(context) + + # Validate: only latest row per entity remains + df_result = result.data.to_pandas() + assert df_result.shape[0] == 2 + assert set(df_result["entity_id"]) == {1, 2} + + +def test_local_transformation_node(): + context = create_context( + node_outputs={"source": ArrowTableValue(pa.Table.from_pandas(sample_df))} + ) + + # Create transformation node and connect input + transform_node = LocalTransformationNode( + name="transform", + backend=backend, + transformation_fn=lambda df: df.assign(value=df["value"] * 2), + ) + transform_node.add_input(MagicMock()) + transform_node.inputs[0].name = "source" + + # Execute and validate + result = transform_node.execute(context) + assert isinstance(result, ArrowTableValue) + assert result.data.num_rows == 4 + result_df = result.data.to_pandas() + assert all(result_df["value"] == sample_df["value"] * 2) + + +def test_local_output_node(): + context = create_context( + node_outputs={"source": ArrowTableValue(pa.Table.from_pandas(sample_df))} + ) + node = LocalOutputNode("output") + node.add_input(MagicMock()) + node.inputs[0].name = "source" + result = node.execute(context) + assert result.num_rows == 4 From 4b801ba032eb3c8f09b3801760f7ee4bf7fe86ba Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Wed, 16 Apr 2025 21:33:26 -0700 Subject: [PATCH 6/9] update backend Signed-off-by: HaoXuAI --- docs/reference/compute-engine/README.md | 5 +- .../feast/infra/compute_engines/base.py | 6 +-- .../infra/compute_engines/feature_builder.py | 4 +- .../compute_engines/local/backends/base.py | 47 +++++++++++++++++++ .../compute_engines/local/backends/factory.py | 5 ++ .../infra/compute_engines/local/config.py | 20 -------- .../compute_engines/spark/test_compute.py | 2 +- 7 files changed, 61 insertions(+), 28 deletions(-) delete mode 100644 sdk/python/feast/infra/compute_engines/local/config.py diff --git a/docs/reference/compute-engine/README.md b/docs/reference/compute-engine/README.md index 50aaa5befab..75f29890046 100644 --- a/docs/reference/compute-engine/README.md +++ b/docs/reference/compute-engine/README.md @@ -31,10 +31,11 @@ This system builds and executes DAGs (Directed Acyclic Graphs) of typed operatio - Supports point-in-time joins and large-scale materialization - Integrates with `SparkOfflineStore` and `SparkMaterializationJob` -### 🧪 LocalComputeEngine (WIP) +### 🧪 LocalComputeEngine -- Runs on Arrow + Pandas (or optionally DuckDB) +- Runs on Arrow + Specified backend (e.g., Pandas, Polars) - Designed for local dev, testing, or lightweight feature generation +- Supports `LocalMaterializationJob` and `LocalHistoricalRetrievalJob` --- diff --git a/sdk/python/feast/infra/compute_engines/base.py b/sdk/python/feast/infra/compute_engines/base.py index d5372d246aa..6e1a90f45b8 100644 --- a/sdk/python/feast/infra/compute_engines/base.py +++ b/sdk/python/feast/infra/compute_engines/base.py @@ -4,12 +4,12 @@ import pyarrow as pa from feast import RepoConfig -from feast.infra.compute_engines.dag.context import ColumnInfo, ExecutionContext -from feast.infra.compute_engines.tasks import HistoricalRetrievalTask -from feast.infra.materialization.batch_materialization_engine import ( +from feast.infra.common.materialization_job import ( MaterializationJob, MaterializationTask, ) +from feast.infra.common.retrieval_task import HistoricalRetrievalTask +from feast.infra.compute_engines.dag.context import ColumnInfo, ExecutionContext from feast.infra.offline_stores.offline_store import OfflineStore from feast.infra.online_stores.online_store import OnlineStore from feast.infra.registry.registry import Registry diff --git a/sdk/python/feast/infra/compute_engines/feature_builder.py b/sdk/python/feast/infra/compute_engines/feature_builder.py index 927d4daf2a4..324f82e7500 100644 --- a/sdk/python/feast/infra/compute_engines/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/feature_builder.py @@ -1,10 +1,10 @@ from abc import ABC, abstractmethod from typing import Union +from feast.infra.common.materialization_job import MaterializationTask +from feast.infra.common.retrieval_task import HistoricalRetrievalTask from feast.infra.compute_engines.dag.node import DAGNode from feast.infra.compute_engines.dag.plan import ExecutionPlan -from feast.infra.compute_engines.tasks import HistoricalRetrievalTask -from feast.infra.materialization.batch_materialization_engine import MaterializationTask class FeatureBuilder(ABC): diff --git a/sdk/python/feast/infra/compute_engines/local/backends/base.py b/sdk/python/feast/infra/compute_engines/local/backends/base.py index 279a434b577..3c8d25abe00 100644 --- a/sdk/python/feast/infra/compute_engines/local/backends/base.py +++ b/sdk/python/feast/infra/compute_engines/local/backends/base.py @@ -3,6 +3,53 @@ class DataFrameBackend(ABC): + """ + Abstract interface for DataFrame operations used by the LocalComputeEngine. + + This interface defines the contract for implementing pluggable DataFrame backends + such as Pandas, Polars, or DuckDB. Each backend must support core table operations + such as joins, filtering, aggregation, conversion to/from Arrow, and deduplication. + + The purpose of this abstraction is to allow seamless swapping of execution backends + without changing DAGNode or ComputeEngine logic. All nodes operate on pyarrow.Table + as the standard input/output format, while the backend defines how the computation + is actually performed. + + Expected implementations include: + - PandasBackend + - PolarsBackend + - DuckDBBackend (future) + + Methods + ------- + from_arrow(table: pa.Table) -> Any + Convert a pyarrow.Table to the backend-native DataFrame format. + + to_arrow(df: Any) -> pa.Table + Convert a backend-native DataFrame to pyarrow.Table. + + join(left: Any, right: Any, on: List[str], how: str) -> Any + Join two dataframes on specified keys with given join type. + + groupby_agg(df: Any, group_keys: List[str], agg_ops: Dict[str, Tuple[str, str]]) -> Any + Group and aggregate the dataframe. `agg_ops` maps output column names + to (aggregation function, source column name) pairs. + + filter(df: Any, expr: str) -> Any + Apply a filter expression (string-based) to the DataFrame. + + to_timedelta_value(delta: timedelta) -> Any + Convert a Python timedelta object to a backend-compatible value + that can be subtracted from a timestamp column. + + drop_duplicates(df: Any, keys: List[str], sort_by: List[str], ascending: bool = False) -> Any + Deduplicate the DataFrame by key columns, keeping the first row + by descending or ascending sort order. + + rename_columns(df: Any, columns: Dict[str, str]) -> Any + Rename columns in the DataFrame according to the provided mapping. + """ + @abstractmethod def columns(self, df): ... diff --git a/sdk/python/feast/infra/compute_engines/local/backends/factory.py b/sdk/python/feast/infra/compute_engines/local/backends/factory.py index c34d64237ea..0a5f40cccf2 100644 --- a/sdk/python/feast/infra/compute_engines/local/backends/factory.py +++ b/sdk/python/feast/infra/compute_engines/local/backends/factory.py @@ -8,6 +8,11 @@ class BackendFactory: + """ + Factory class for constructing DataFrameBackend implementations based on backend name + or runtime entity_df type. + """ + @staticmethod def from_name(name: str) -> DataFrameBackend: if name == "pandas": diff --git a/sdk/python/feast/infra/compute_engines/local/config.py b/sdk/python/feast/infra/compute_engines/local/config.py deleted file mode 100644 index 070cf204dce..00000000000 --- a/sdk/python/feast/infra/compute_engines/local/config.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import Dict, Optional - -from pydantic import StrictStr - -from feast.repo_config import FeastConfigBaseModel - - -class SparkComputeConfig(FeastConfigBaseModel): - type: StrictStr = "spark" - """ Spark Compute type selector""" - - spark_conf: Optional[Dict[str, str]] = None - """ Configuration overlay for the spark session """ - # sparksession is not serializable and we dont want to pass it around as an argument - - staging_location: Optional[StrictStr] = None - """ Remote path for batch materialization jobs""" - - region: Optional[StrictStr] = None - """ AWS Region if applicable for s3-based staging locations""" diff --git a/sdk/python/tests/integration/compute_engines/spark/test_compute.py b/sdk/python/tests/integration/compute_engines/spark/test_compute.py index f3134395c3b..c6aef9e5701 100644 --- a/sdk/python/tests/integration/compute_engines/spark/test_compute.py +++ b/sdk/python/tests/integration/compute_engines/spark/test_compute.py @@ -14,7 +14,7 @@ MaterializationJobStatus, MaterializationTask, ) -from feast.infra.common.retrieval_job import HistoricalRetrievalTask +from feast.infra.common.retrieval_task import HistoricalRetrievalTask from feast.infra.compute_engines.spark.compute import SparkComputeEngine from feast.infra.compute_engines.spark.job import SparkDAGRetrievalJob from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( From f254b807873f2398599f9acee7f43ce186d371e3 Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Wed, 16 Apr 2025 21:37:39 -0700 Subject: [PATCH 7/9] update status Signed-off-by: HaoXuAI --- sdk/python/feast/infra/common/materialization_job.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sdk/python/feast/infra/common/materialization_job.py b/sdk/python/feast/infra/common/materialization_job.py index 2d3105d22d8..60ded6735a6 100644 --- a/sdk/python/feast/infra/common/materialization_job.py +++ b/sdk/python/feast/infra/common/materialization_job.py @@ -31,6 +31,8 @@ class MaterializationJobStatus(enum.Enum): CANCELLING = 5 CANCELLED = 6 SUCCEEDED = 7 + PAUSED = 8 + RETRYING = 9 class MaterializationJob(ABC): From f89ebb16ac61e8f32508315541228685dc852622 Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Wed, 16 Apr 2025 21:53:46 -0700 Subject: [PATCH 8/9] update doc Signed-off-by: HaoXuAI --- .../compute_engines/local/feature_builder.py | 37 +++++++++++++------ .../compute_engines/spark/feature_builder.py | 3 +- .../feast/infra/compute_engines/spark/node.py | 27 +++----------- 3 files changed, 33 insertions(+), 34 deletions(-) diff --git a/sdk/python/feast/infra/compute_engines/local/feature_builder.py b/sdk/python/feast/infra/compute_engines/local/feature_builder.py index 74f1b248222..b00623db978 100644 --- a/sdk/python/feast/infra/compute_engines/local/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/local/feature_builder.py @@ -19,9 +19,9 @@ class LocalFeatureBuilder(FeatureBuilder): def __init__( - self, - task: Union[MaterializationTask, HistoricalRetrievalTask], - backend: DataFrameBackend, + self, + task: Union[MaterializationTask, HistoricalRetrievalTask], + backend: DataFrameBackend, ): super().__init__(task) self.backend = backend @@ -31,13 +31,15 @@ def build_source_node(self): self.nodes.append(node) return node - def build_join_node(self, input_node): + def build_join_node(self, + input_node): node = LocalJoinNode("join", self.backend) node.add_input(input_node) self.nodes.append(node) return node - def build_filter_node(self, input_node): + def build_filter_node(self, + input_node): filter_expr = None if hasattr(self.feature_view, "filter"): filter_expr = self.feature_view.filter @@ -47,29 +49,38 @@ def build_filter_node(self, input_node): self.nodes.append(node) return node - def build_aggregation_node(self, input_node): - agg_specs = self.feature_view.aggregations + @staticmethod + def _get_aggregate_operations(agg_specs): agg_ops = {} for agg in agg_specs: if agg.time_window is not None: raise ValueError( - "Time window aggregation is not supported in local compute engine. Please use a different compute engine." + "Time window aggregation is not supported in local compute engine. Please use a different compute " + "engine." ) alias = f"{agg.function}_{agg.column}" agg_ops[alias] = (agg.function, agg.column) + return agg_ops + + def build_aggregation_node(self, + input_node): + agg_specs = self.feature_view.aggregations + agg_ops = self._get_aggregate_operations(agg_specs) group_by_keys = self.feature_view.entities node = LocalAggregationNode("agg", self.backend, group_by_keys, agg_ops) node.add_input(input_node) self.nodes.append(node) return node - def build_dedup_node(self, input_node): + def build_dedup_node(self, + input_node): node = LocalDedupNode("dedup", self.backend) node.add_input(input_node) self.nodes.append(node) return node - def build_transformation_node(self, input_node): + def build_transformation_node(self, + input_node): node = LocalTransformationNode( "transform", self.feature_view.feature_transformation, self.backend ) @@ -77,7 +88,8 @@ def build_transformation_node(self, input_node): self.nodes.append(node) return node - def build_validation_node(self, input_node): + def build_validation_node(self, + input_node): node = LocalValidationNode( "validate", self.feature_view.validation_config, self.backend ) @@ -85,7 +97,8 @@ def build_validation_node(self, input_node): self.nodes.append(node) return node - def build_output_nodes(self, input_node): + def build_output_nodes(self, + input_node): node = LocalOutputNode("output") node.add_input(input_node) self.nodes.append(node) diff --git a/sdk/python/feast/infra/compute_engines/spark/feature_builder.py b/sdk/python/feast/infra/compute_engines/spark/feature_builder.py index b6b4f2164d1..59b4ebafca3 100644 --- a/sdk/python/feast/infra/compute_engines/spark/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/spark/feature_builder.py @@ -55,8 +55,9 @@ def build_filter_node(self, input_node): filter_expr = None if hasattr(self.feature_view, "filter"): filter_expr = self.feature_view.filter + ttl = self.feature_view.ttl node = SparkFilterNode( - "filter", self.spark_session, self.feature_view, filter_expr + "filter", self.spark_session, ttl, filter_expr ) node.add_input(input_node) self.nodes.append(node) diff --git a/sdk/python/feast/infra/compute_engines/spark/node.py b/sdk/python/feast/infra/compute_engines/spark/node.py index 7b0f2b2b574..e0215081bcf 100644 --- a/sdk/python/feast/infra/compute_engines/spark/node.py +++ b/sdk/python/feast/infra/compute_engines/spark/node.py @@ -1,6 +1,5 @@ -from dataclasses import dataclass -from datetime import datetime -from typing import Dict, List, Optional, Union, cast +from datetime import timedelta +from typing import List, Optional, Union, cast from pyspark.sql import DataFrame, SparkSession, Window from pyspark.sql import functions as F @@ -50,20 +49,6 @@ def rename_entity_ts_column( return entity_df -@dataclass -class SparkJoinContext: - name: str # feature view name or alias - join_keys: List[str] - feature_columns: List[str] - timestamp_field: str - created_timestamp_column: Optional[str] - ttl_seconds: Optional[int] - min_event_timestamp: Optional[datetime] - max_event_timestamp: Optional[datetime] - field_mapping: Dict[str, str] # original_column_name -> renamed_column - full_feature_names: bool = False # apply feature view name prefix - - class SparkMaterializationReadNode(DAGNode): def __init__( self, name: str, task: Union[MaterializationTask, HistoricalRetrievalTask] @@ -266,12 +251,12 @@ def __init__( self, name: str, spark_session: SparkSession, - feature_view: Union[BatchFeatureView, StreamFeatureView], + ttl: Optional[timedelta] = None, filter_condition: Optional[str] = None, ): super().__init__(name) self.spark_session = spark_session - self.feature_view = feature_view + self.ttl = ttl self.filter_condition = filter_condition def execute(self, context: ExecutionContext) -> DAGValue: @@ -288,8 +273,8 @@ def execute(self, context: ExecutionContext) -> DAGValue: filtered_df = filtered_df.filter(F.col(ts_col) <= F.col(ENTITY_TS_ALIAS)) # Optional TTL filter: feature.ts >= entity.event_timestamp - ttl - if self.feature_view.ttl: - ttl_seconds = int(self.feature_view.ttl.total_seconds()) + if self.ttl: + ttl_seconds = int(self.ttl.total_seconds()) lower_bound = F.col(ENTITY_TS_ALIAS) - F.expr( f"INTERVAL {ttl_seconds} seconds" ) From d6f0ed3506a4fb9a76f9cb710c3541fc7c82dce6 Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Wed, 16 Apr 2025 21:54:17 -0700 Subject: [PATCH 9/9] format Signed-off-by: HaoXuAI --- .../compute_engines/local/feature_builder.py | 27 +++++++------------ .../compute_engines/spark/feature_builder.py | 4 +-- 2 files changed, 11 insertions(+), 20 deletions(-) diff --git a/sdk/python/feast/infra/compute_engines/local/feature_builder.py b/sdk/python/feast/infra/compute_engines/local/feature_builder.py index b00623db978..bf755ed96d0 100644 --- a/sdk/python/feast/infra/compute_engines/local/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/local/feature_builder.py @@ -19,9 +19,9 @@ class LocalFeatureBuilder(FeatureBuilder): def __init__( - self, - task: Union[MaterializationTask, HistoricalRetrievalTask], - backend: DataFrameBackend, + self, + task: Union[MaterializationTask, HistoricalRetrievalTask], + backend: DataFrameBackend, ): super().__init__(task) self.backend = backend @@ -31,15 +31,13 @@ def build_source_node(self): self.nodes.append(node) return node - def build_join_node(self, - input_node): + def build_join_node(self, input_node): node = LocalJoinNode("join", self.backend) node.add_input(input_node) self.nodes.append(node) return node - def build_filter_node(self, - input_node): + def build_filter_node(self, input_node): filter_expr = None if hasattr(self.feature_view, "filter"): filter_expr = self.feature_view.filter @@ -62,8 +60,7 @@ def _get_aggregate_operations(agg_specs): agg_ops[alias] = (agg.function, agg.column) return agg_ops - def build_aggregation_node(self, - input_node): + def build_aggregation_node(self, input_node): agg_specs = self.feature_view.aggregations agg_ops = self._get_aggregate_operations(agg_specs) group_by_keys = self.feature_view.entities @@ -72,15 +69,13 @@ def build_aggregation_node(self, self.nodes.append(node) return node - def build_dedup_node(self, - input_node): + def build_dedup_node(self, input_node): node = LocalDedupNode("dedup", self.backend) node.add_input(input_node) self.nodes.append(node) return node - def build_transformation_node(self, - input_node): + def build_transformation_node(self, input_node): node = LocalTransformationNode( "transform", self.feature_view.feature_transformation, self.backend ) @@ -88,8 +83,7 @@ def build_transformation_node(self, self.nodes.append(node) return node - def build_validation_node(self, - input_node): + def build_validation_node(self, input_node): node = LocalValidationNode( "validate", self.feature_view.validation_config, self.backend ) @@ -97,8 +91,7 @@ def build_validation_node(self, self.nodes.append(node) return node - def build_output_nodes(self, - input_node): + def build_output_nodes(self, input_node): node = LocalOutputNode("output") node.add_input(input_node) self.nodes.append(node) diff --git a/sdk/python/feast/infra/compute_engines/spark/feature_builder.py b/sdk/python/feast/infra/compute_engines/spark/feature_builder.py index 59b4ebafca3..453cee7fda5 100644 --- a/sdk/python/feast/infra/compute_engines/spark/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/spark/feature_builder.py @@ -56,9 +56,7 @@ def build_filter_node(self, input_node): if hasattr(self.feature_view, "filter"): filter_expr = self.feature_view.filter ttl = self.feature_view.ttl - node = SparkFilterNode( - "filter", self.spark_session, ttl, filter_expr - ) + node = SparkFilterNode("filter", self.spark_session, ttl, filter_expr) node.add_input(input_node) self.nodes.append(node) return node