diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 590005d8fbd..ce750d939c3 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -174,6 +174,7 @@ * [Snowflake](reference/compute-engine/snowflake.md) * [AWS Lambda (alpha)](reference/compute-engine/lambda.md) * [Spark (contrib)](reference/compute-engine/spark.md) + * [Apache Flink](reference/compute-engine/flink.md) * [Ray (contrib)](reference/compute-engine/ray.md) * [Feature repository](reference/feature-repository/README.md) * [feature\_store.yaml](reference/feature-repository/feature-store-yaml.md) diff --git a/docs/getting-started/components/compute-engine.md b/docs/getting-started/components/compute-engine.md index 60da1575932..d115ec5debb 100644 --- a/docs/getting-started/components/compute-engine.md +++ b/docs/getting-started/components/compute-engine.md @@ -24,7 +24,7 @@ engines. | SparkComputeEngine | Runs on Apache Spark, designed for large-scale distributed feature generation. | ✅ | | | SnowflakeComputeEngine | Runs on Snowflake, designed for scalable feature generation using Snowflake SQL. | ✅ | | | LambdaComputeEngine | Runs on AWS Lambda, designed for serverless feature generation. | ✅ | | -| FlinkComputeEngine | Runs on Apache Flink, designed for stream processing and real-time feature generation. | ❌ | | +| FlinkComputeEngine | Runs on Apache Flink, designed for distributed feature generation through PyFlink Table API. | ✅ | | | RayComputeEngine | Runs on Ray, designed for distributed feature generation and machine learning workloads. | ✅ | | ``` @@ -156,4 +156,4 @@ DAG nodes are defined as follows: +----------------+ +----------------+ | OnlineStoreWrite| OfflineStoreWrite| +----------------+ +----------------+ -``` \ No newline at end of file +``` diff --git a/docs/reference/compute-engine/README.md b/docs/reference/compute-engine/README.md index dad2ede75a6..920d5761d28 100644 --- a/docs/reference/compute-engine/README.md +++ b/docs/reference/compute-engine/README.md @@ -57,6 +57,14 @@ An example of built output from FeatureBuilder: - Supports point-in-time joins and large-scale materialization - Integrates with `SparkOfflineStore` and `SparkMaterializationJob` +### 🌊 FlinkComputeEngine + +{% page-ref page="flink.md" %} + +- Distributed DAG execution through Apache Flink's PyFlink Table API +- Supports materialization and historical retrieval with Feast offline stores +- Integrates with `FlinkMaterializationJob` and `FlinkDAGRetrievalJob` + ### ⚡ RayComputeEngine (contrib) - Distributed DAG execution via Ray diff --git a/docs/reference/compute-engine/flink.md b/docs/reference/compute-engine/flink.md new file mode 100644 index 00000000000..0f598996ffa --- /dev/null +++ b/docs/reference/compute-engine/flink.md @@ -0,0 +1,123 @@ +# Apache Flink + +## Description + +The Apache Flink compute engine provides a distributed execution engine for +feature pipelines through the PyFlink Table API. It implements Feast's unified +`ComputeEngine` interface and can be used for batch materialization operations +(`materialize` and `materialize-incremental`) and historical retrieval +(`get_historical_features`). + +The engine reads data through the configured Feast offline store and executes +the Feast DAG as PyFlink tables. Offline stores that expose a native +`to_flink_table(table_env)` retrieval job hand Flink tables directly to the +engine. The engine then uses Flink Table/SQL operations for join, filter, +aggregate, dedupe, and projection steps, and writes materialization results to +the configured online and/or offline store. + +## Configuration + +Install the Flink extra from a Feast source checkout with `uv` before using the +engine: + +```bash +uv sync --extra flink --no-dev +``` + +The `flink` extra installs PyFlink directly. PyFlink currently requires +`pyarrow<21`, while the default Feast install keeps `pyarrow>=21`; Feast's uv +lock resolves the Flink extra in a separate dependency fork so normal Feast +installs do not downgrade Arrow. + +Configure the engine in `feature_store.yaml`: + +```yaml +project: my_project +registry: data/registry.db +provider: local +offline_store: + type: file +online_store: + type: sqlite + path: data/online_store.db +batch_engine: + type: flink.engine + execution_mode: batch + parallelism: 4 + table_config: + pipeline.name: "Feast Flink Compute Engine" + pandas_split_num: 4 +``` + +## Configuration Options + +| Option | Type | Default | Description | +| --- | --- | --- | --- | +| `type` | string | `flink.engine` | Must be `flink.engine`. | +| `execution_mode` | string | `batch` | PyFlink execution mode: `batch` or `streaming`. | +| `parallelism` | integer | `null` | Default Flink parallelism for jobs created by the engine. | +| `table_config` | map | `null` | Additional PyFlink table configuration entries. | +| `pandas_split_num` | integer | `1` | Number of PyFlink Arrow source splits when converting pandas entity DataFrames into Flink tables. | + +## Flink Transformations + +Use `mode="flink"` when a `BatchFeatureView` transformation should receive and +return PyFlink table objects: + +```python +from feast import BatchFeatureView, Field +from feast.types import Float32 + + +def double_rates(table): + # In production this can use PyFlink Table API operations and return a table. + return table + + +driver_stats = BatchFeatureView( + name="driver_stats", + entities=[driver], + mode="flink", + udf=double_rates, + schema=[Field(name="conv_rate", dtype=Float32)], + source=driver_stats_source, + online=True, +) +``` + +Flink transformations must return PyFlink table objects. pandas-returning UDFs +are not accepted by the Flink compute engine. + +## DAG Support + +The Flink engine implements Feast's compute DAG with Flink-specific nodes: + +- Source reads from Feast offline stores, preferring native Flink tables when a + retrieval job supports `to_flink_table(table_env)`. +- Transform nodes pass PyFlink tables to `mode="flink"` UDFs and preserve native + Flink table outputs. +- Join nodes use Flink SQL temporary views for feature joins and entity joins. +- Filter nodes apply point-in-time, TTL, and custom filter expressions in Flink + SQL. +- Aggregate nodes support non-windowed Feast aggregations using Flink SQL + aggregate functions. +- Dedupe nodes use `ROW_NUMBER()` over entity keys or internal entity-row ids so + historical retrieval keeps one latest feature row per entity row. +- Validation nodes check required output columns. JSON value validation must be + handled upstream in Flink SQL. +- Output nodes write only for materialization tasks; historical retrieval is + read-only. +- Historical retrieval accepts pandas entity DataFrames and SQL-string entity + DataFrames. SQL strings are interpreted as Flink SQL queries against the + configured TableEnvironment/catalog and must select an `event_timestamp` + column. + +## Current Limitations + +- Windowed aggregations are not yet implemented in the Flink compute engine. Use + non-windowed Feast aggregations or pre-window upstream in Flink. +- Offline store retrieval jobs must implement `to_flink_table(table_env)`. + Arrow/pandas-only retrieval jobs are rejected instead of converted. +- JSON value validation is not implemented inside the Flink compute engine + because the engine does not collect intermediate data out of Flink for + validation. diff --git a/pyproject.toml b/pyproject.toml index 21c48dd09e1..32049cd05ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,8 @@ dependencies = [ "mmh3", "numpy>=2.0.0,<3", "pandas>=1.4.3,<3", - "pyarrow>=21.0.0", + "pyarrow>=21.0.0; extra != 'flink'", + "pyarrow>=16.1.0,<21.0.0; extra == 'flink'", "pydantic>=2.10.6", "pygments>=2.12.0,<3", "PyYAML>=5.4.0,<7", @@ -63,6 +64,7 @@ docling = ["docling==2.27.0"] duckdb = ["ibis-framework[duckdb]>=10.0.0"] elasticsearch = ["elasticsearch>=8.13.0"] faiss = ["faiss-cpu>=1.7.0,<=1.10.0"] +flink = ["apache-flink>=2.2.1,<3"] gcp = [ "google-api-core>=1.23.0,<3", "googleapis-common-protos>=1.52.0,<2", @@ -278,6 +280,26 @@ dev = [ "pytest-xdist>=3.8.0", ] +[tool.uv] +conflicts = [ + [ + { extra = "flink" }, + { extra = "ge" }, + ], + [ + { extra = "flink" }, + { extra = "ci" }, + ], + [ + { extra = "flink" }, + { extra = "dev" }, + ], + [ + { extra = "flink" }, + { extra = "docs" }, + ], +] + # Pixi configuration [tool.pixi.workspace] channels = ["conda-forge"] diff --git a/sdk/python/feast/batch_feature_view.py b/sdk/python/feast/batch_feature_view.py index 95385da1d91..0bfbdf9d936 100644 --- a/sdk/python/feast/batch_feature_view.py +++ b/sdk/python/feast/batch_feature_view.py @@ -169,13 +169,14 @@ def get_feature_transformation(self) -> Optional[Transformation]: TransformationMode.PYTHON, TransformationMode.SQL, TransformationMode.RAY, - ) or self.mode in ("pandas", "python", "sql", "ray"): + TransformationMode.FLINK, + ) or self.mode in ("pandas", "python", "sql", "ray", "flink"): return Transformation( mode=self.mode, udf=self.udf, udf_string=self.udf_string or "" ) else: raise ValueError( - f"Unsupported transformation mode: {self.mode} for StreamFeatureView" + f"Unsupported transformation mode: {self.mode} for BatchFeatureView" ) diff --git a/sdk/python/feast/infra/compute_engines/dag/model.py b/sdk/python/feast/infra/compute_engines/dag/model.py index 5990eea6141..263c5029f4f 100644 --- a/sdk/python/feast/infra/compute_engines/dag/model.py +++ b/sdk/python/feast/infra/compute_engines/dag/model.py @@ -6,3 +6,4 @@ class DAGFormat(str, Enum): PANDAS = "pandas" ARROW = "arrow" RAY = "ray" + FLINK = "flink" diff --git a/sdk/python/feast/infra/compute_engines/feature_builder.py b/sdk/python/feast/infra/compute_engines/feature_builder.py index 43f17ee2986..2a102bf9f2f 100644 --- a/sdk/python/feast/infra/compute_engines/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/feature_builder.py @@ -158,10 +158,13 @@ def get_column_info( # we need to read ALL source columns, not just the output feature columns. # This is specifically for transformations that create new columns or need raw data. mode = getattr(getattr(view, "feature_transformation", None), "mode", None) - if mode in ("ray", "pandas", "python") or getattr(mode, "value", None) in ( + if mode in ("ray", "pandas", "python", "flink") or getattr( + mode, "value", None + ) in ( "ray", "pandas", "python", + "flink", ): # Signal to read all columns by passing empty list for feature_cols. # "python" (BatchFeatureView) transformations need all raw source columns — the diff --git a/sdk/python/feast/infra/compute_engines/flink/__init__.py b/sdk/python/feast/infra/compute_engines/flink/__init__.py new file mode 100644 index 00000000000..678bae5afc6 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/flink/__init__.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from feast.infra.compute_engines.flink.compute import ( + FlinkComputeEngine, + FlinkComputeEngineConfig, +) + +__all__ = [ + "FlinkComputeEngine", + "FlinkComputeEngineConfig", +] diff --git a/sdk/python/feast/infra/compute_engines/flink/compute.py b/sdk/python/feast/infra/compute_engines/flink/compute.py new file mode 100644 index 00000000000..4018427374e --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/flink/compute.py @@ -0,0 +1,154 @@ +from __future__ import annotations + +import logging +from typing import Any, Dict, Literal, Optional, Sequence, Union + +from feast import ( + BatchFeatureView, + Entity, + FeatureView, + OnDemandFeatureView, + StreamFeatureView, +) +from feast.infra.common.materialization_job import ( + MaterializationJob, + MaterializationJobStatus, + MaterializationTask, +) +from feast.infra.common.retrieval_task import HistoricalRetrievalTask +from feast.infra.compute_engines.base import ComputeEngine +from feast.infra.compute_engines.flink.feature_builder import FlinkFeatureBuilder +from feast.infra.compute_engines.flink.job import ( + FlinkDAGRetrievalJob, + FlinkMaterializationJob, +) +from feast.infra.compute_engines.flink.utils import create_flink_table_environment +from feast.infra.offline_stores.offline_store import OfflineStore, RetrievalJob +from feast.infra.online_stores.online_store import OnlineStore +from feast.infra.registry.base_registry import BaseRegistry +from feast.repo_config import FeastConfigBaseModel, RepoConfig + +logger = logging.getLogger(__name__) + + +class FlinkComputeEngineConfig(FeastConfigBaseModel): + """Configuration for the Apache Flink compute engine.""" + + type: Literal["flink.engine"] = "flink.engine" + """Flink compute engine type selector.""" + + execution_mode: Literal["batch", "streaming"] = "batch" + """PyFlink TableEnvironment execution mode.""" + + parallelism: Optional[int] = None + """Default Flink parallelism for jobs created by this engine.""" + + table_config: Optional[Dict[str, str]] = None + """Additional PyFlink table configuration entries.""" + + pandas_split_num: int = 1 + """Number of PyFlink Arrow source splits for pandas entity DataFrames.""" + + +class FlinkComputeEngine(ComputeEngine): + def __init__( + self, + *, + repo_config: RepoConfig, + offline_store: OfflineStore, + online_store: OnlineStore, + table_environment: Optional[Any] = None, + **kwargs, + ) -> None: + super().__init__( + repo_config=repo_config, + offline_store=offline_store, + online_store=online_store, + **kwargs, + ) + self.config = repo_config.batch_engine + assert isinstance(self.config, FlinkComputeEngineConfig) + self.table_env = table_environment or create_flink_table_environment( + self.config + ) + + def update( + self, + project: str, + views_to_delete: Sequence[ + Union[BatchFeatureView, StreamFeatureView, FeatureView] + ], + views_to_keep: Sequence[ + Union[BatchFeatureView, StreamFeatureView, FeatureView, OnDemandFeatureView] + ], + entities_to_delete: Sequence[Entity], + entities_to_keep: Sequence[Entity], + ) -> None: + """Flink compute engine does not provision Feast-managed infrastructure.""" + pass + + def teardown_infra( + self, + project: str, + fvs: Sequence[Union[BatchFeatureView, StreamFeatureView, FeatureView]], + entities: Sequence[Entity], + ) -> None: + """Flink compute engine does not tear down Feast-managed infrastructure.""" + pass + + def _materialize_one( + self, registry: BaseRegistry, task: MaterializationTask, **kwargs + ) -> MaterializationJob: + job_id = f"{task.feature_view.name}-{task.start_time}-{task.end_time}" + context = self.get_execution_context(registry, task) + + try: + builder = FlinkFeatureBuilder( + registry=registry, + table_env=self.table_env, + task=task, + split_num=self.config.pandas_split_num, + ) + plan = builder.build() + plan.execute(context) + return FlinkMaterializationJob( + job_id=job_id, + status=MaterializationJobStatus.SUCCEEDED, + ) + except Exception as exc: + logger.error("Flink materialization failed for %s: %s", job_id, exc) + return FlinkMaterializationJob( + job_id=job_id, + status=MaterializationJobStatus.ERROR, + error=exc, + ) + + def get_historical_features( + self, registry: BaseRegistry, task: HistoricalRetrievalTask + ) -> RetrievalJob: + context = self.get_execution_context(registry, task) + try: + builder = FlinkFeatureBuilder( + registry=registry, + table_env=self.table_env, + task=task, + split_num=self.config.pandas_split_num, + ) + plan = builder.build() + return FlinkDAGRetrievalJob( + plan=plan, + context=context, + full_feature_names=task.full_feature_name, + ) + except Exception as exc: + logger.error( + "Flink historical retrieval setup failed for %s: %s", + task.feature_view.name, + exc, + ) + return FlinkDAGRetrievalJob( + plan=None, + context=context, + full_feature_names=task.full_feature_name, + error=exc, + ) diff --git a/sdk/python/feast/infra/compute_engines/flink/feature_builder.py b/sdk/python/feast/infra/compute_engines/flink/feature_builder.py new file mode 100644 index 00000000000..4f4abe7bea1 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/flink/feature_builder.py @@ -0,0 +1,217 @@ +from __future__ import annotations + +import logging +from typing import Any, Union + +import pandas as pd + +from feast.infra.common.materialization_job import MaterializationTask +from feast.infra.common.retrieval_task import HistoricalRetrievalTask +from feast.infra.compute_engines.dag.node import DAGNode +from feast.infra.compute_engines.feature_builder import FeatureBuilder +from feast.infra.compute_engines.flink.nodes import ( + FlinkAggregationNode, + FlinkDedupNode, + FlinkFilterNode, + FlinkJoinNode, + FlinkOutputNode, + FlinkSourceReadNode, + FlinkTransformationNode, + FlinkValidationNode, +) +from feast.infra.registry.base_registry import BaseRegistry +from feast.types import PrimitiveFeastType, from_feast_to_pyarrow_type + +logger = logging.getLogger(__name__) + + +class FlinkFeatureBuilder(FeatureBuilder): + def __init__( + self, + registry: BaseRegistry, + table_env: Any, + task: Union[MaterializationTask, HistoricalRetrievalTask], + split_num: int, + ) -> None: + super().__init__(registry, task.feature_view, task) + self.table_env = table_env + self.split_num = split_num + + def _should_join_entity_df(self) -> bool: + return isinstance(self.task, HistoricalRetrievalTask) and ( + ( + isinstance(self.task.entity_df, pd.DataFrame) + and not self.task.entity_df.empty + ) + or ( + isinstance(self.task.entity_df, str) + and bool(self.task.entity_df.strip()) + ) + ) + + def _build(self, view: Any, input_nodes: list[DAGNode] | None) -> DAGNode: + if view.data_source: + last_node = self.build_source_node(view) + + if self._should_transform(view): + last_node = self.build_transformation_node(view, [last_node]) + + if self._should_join_entity_df(): + last_node = self.build_join_node(view, [last_node]) + + elif input_nodes: + if self._should_transform(view): + last_node = self.build_transformation_node(view, input_nodes) + else: + last_node = self.build_join_node(view, input_nodes) + else: + raise ValueError(f"FeatureView {view.name} has no valid source or inputs") + + last_node = self.build_filter_node(view, last_node) + + if self._should_aggregate(view): + last_node = self.build_aggregation_node(view, last_node) + elif self._should_dedupe(view): + last_node = self.build_dedup_node(view, last_node) + + if self._should_validate(view): + last_node = self.build_validation_node(view, last_node) + + return last_node + + def build_source_node(self, view: Any) -> FlinkSourceReadNode: + source = view.batch_source + column_info = self.get_column_info(view) + node = FlinkSourceReadNode( + f"{view.name}:source", + source, + column_info, + self.table_env, + self.split_num, + self.task.start_time, + self.task.end_time, + ) + self.nodes.append(node) + return node + + def build_aggregation_node( + self, view: Any, input_node: DAGNode + ) -> FlinkAggregationNode: + column_info = self.get_column_info(view) + node = FlinkAggregationNode( + f"{view.name}:agg", + column_info.join_keys_columns, + view.aggregations, + self.table_env, + self.split_num, + inputs=[input_node], + ) + self.nodes.append(node) + return node + + def build_join_node(self, view: Any, input_nodes: list[DAGNode]) -> FlinkJoinNode: + column_info = self.get_column_info(view) + node = FlinkJoinNode( + f"{view.name}:join", + column_info, + self.table_env, + self.split_num, + inputs=input_nodes, + ) + self.nodes.append(node) + return node + + def build_filter_node(self, view: Any, input_node: DAGNode) -> FlinkFilterNode: + filter_expr = getattr(view, "filter", None) + ttl = getattr(view, "ttl", None) + column_info = self.get_column_info(view) + node = FlinkFilterNode( + f"{view.name}:filter", + column_info, + self.table_env, + self.split_num, + filter_expr, + ttl, + inputs=[input_node], + ) + self.nodes.append(node) + return node + + def build_dedup_node(self, view: Any, input_node: DAGNode) -> FlinkDedupNode: + column_info = self.get_column_info(view) + node = FlinkDedupNode( + f"{view.name}:dedup", + column_info, + self.table_env, + self.split_num, + inputs=[input_node], + ) + self.nodes.append(node) + return node + + def build_transformation_node( + self, view: Any, input_nodes: list[DAGNode] + ) -> FlinkTransformationNode: + transform_config = view.feature_transformation + transformation_fn = ( + transform_config.udf + if hasattr(transform_config, "udf") + else transform_config + ) + node = FlinkTransformationNode( + f"{view.name}:transform", + transformation_fn, + self.table_env, + self.split_num, + inputs=input_nodes, + ) + self.nodes.append(node) + return node + + def build_output_nodes(self, view: Any, input_node: DAGNode) -> FlinkOutputNode: + node = FlinkOutputNode( + f"{view.name}:output", + self.dag_root.view, + self.table_env, + self.split_num, + isinstance(self.task, MaterializationTask), + [input_node], + ) + self.nodes.append(node) + return node + + def build_validation_node( + self, view: Any, input_node: DAGNode + ) -> FlinkValidationNode: + expected_columns = {} + json_columns: set[str] = set() + if hasattr(view, "features"): + for feature in view.features: + try: + expected_columns[feature.name] = from_feast_to_pyarrow_type( + feature.dtype + ) + except (ValueError, KeyError): + logger.debug( + "Could not resolve PyArrow type for feature '%s' " + "(dtype=%s), skipping type check for this column.", + feature.name, + feature.dtype, + ) + expected_columns[feature.name] = None + if ( + isinstance(feature.dtype, PrimitiveFeastType) + and feature.dtype.name == "JSON" + ): + json_columns.add(feature.name) + + node = FlinkValidationNode( + f"{view.name}:validate", + expected_columns, + json_columns, + self.table_env, + self.split_num, + inputs=[input_node], + ) + self.nodes.append(node) + return node diff --git a/sdk/python/feast/infra/compute_engines/flink/job.py b/sdk/python/feast/infra/compute_engines/flink/job.py new file mode 100644 index 00000000000..44d9afd1824 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/flink/job.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import List, Optional + +import pandas as pd +import pyarrow as pa + +from feast.infra.common.materialization_job import ( + MaterializationJob, + MaterializationJobStatus, +) +from feast.infra.compute_engines.dag.context import ExecutionContext +from feast.infra.compute_engines.dag.plan import ExecutionPlan +from feast.infra.compute_engines.flink.utils import flink_table_to_arrow +from feast.infra.offline_stores.offline_store import RetrievalJob, RetrievalMetadata +from feast.on_demand_feature_view import OnDemandFeatureView +from feast.saved_dataset import SavedDatasetStorage + + +class FlinkDAGRetrievalJob(RetrievalJob): + def __init__( + self, + plan: Optional[ExecutionPlan], + context: ExecutionContext, + full_feature_names: bool, + on_demand_feature_views: Optional[List[OnDemandFeatureView]] = None, + metadata: Optional[RetrievalMetadata] = None, + error: Optional[BaseException] = None, + ) -> None: + self._plan = plan + self._context = context + self._full_feature_names = full_feature_names + self._on_demand_feature_views = on_demand_feature_views or [] + self._metadata = metadata + self._error = error + self._arrow_table: Optional[pa.Table] = None + + def error(self) -> Optional[BaseException]: + return self._error + + def _ensure_executed(self) -> None: + if self._arrow_table is None: + if self._error is not None: + raise self._error + if self._plan is None: + raise RuntimeError("Execution plan is not set") + result = self._plan.execute(self._context) + self._arrow_table = flink_table_to_arrow(result.data) + + def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame: + self._ensure_executed() + assert self._arrow_table is not None + return self._arrow_table.to_pandas() + + def _to_arrow_internal(self, timeout: Optional[int] = None) -> pa.Table: + self._ensure_executed() + assert self._arrow_table is not None + return self._arrow_table + + @property + def full_feature_names(self) -> bool: + return self._full_feature_names + + @property + def on_demand_feature_views(self) -> List[OnDemandFeatureView]: + return self._on_demand_feature_views + + @property + def metadata(self) -> Optional[RetrievalMetadata]: + return self._metadata + + def persist( + self, + storage: SavedDatasetStorage, + allow_overwrite: bool = False, + timeout: Optional[int] = None, + ) -> None: + raise NotImplementedError("Persisting Flink retrieval jobs is not supported.") + + def to_remote_storage(self) -> List[str]: + raise NotImplementedError( + "Remote storage is not supported in FlinkDAGRetrievalJob." + ) + + def to_sql(self) -> str: + raise NotImplementedError("SQL generation is not supported for Flink DAGs.") + + +@dataclass +class FlinkMaterializationJob(MaterializationJob): + def __init__( + self, + job_id: str, + status: MaterializationJobStatus, + error: Optional[BaseException] = None, + ) -> None: + super().__init__() + self._job_id = job_id + self._status = status + self._error = error + + def status(self) -> MaterializationJobStatus: + return self._status + + def error(self) -> Optional[BaseException]: + return self._error + + def should_be_retried(self) -> bool: + return False + + def job_id(self) -> str: + return self._job_id + + def url(self) -> Optional[str]: + return None diff --git a/sdk/python/feast/infra/compute_engines/flink/nodes.py b/sdk/python/feast/infra/compute_engines/flink/nodes.py new file mode 100644 index 00000000000..818efd0b1ef --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/flink/nodes.py @@ -0,0 +1,768 @@ +from __future__ import annotations + +import logging +import uuid +from datetime import datetime, timedelta +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Union + +import pandas as pd +import pyarrow as pa + +from feast import BatchFeatureView, StreamFeatureView +from feast.aggregation import Aggregation, aggregation_specs_to_agg_ops +from feast.data_source import DataSource +from feast.infra.compute_engines.dag.context import ColumnInfo, ExecutionContext +from feast.infra.compute_engines.dag.model import DAGFormat +from feast.infra.compute_engines.dag.node import DAGNode +from feast.infra.compute_engines.dag.value import DAGValue +from feast.infra.compute_engines.flink.utils import ( + flink_table_to_pandas, + pandas_to_flink_table, +) +from feast.infra.compute_engines.utils import create_offline_store_retrieval_job +from feast.infra.offline_stores.offline_utils import ( + infer_event_timestamp_from_entity_df, +) +from feast.utils import _convert_arrow_to_proto + +logger = logging.getLogger(__name__) + +ENTITY_TS_ALIAS = "__entity_event_timestamp" +ENTITY_ROW_ID = "__feast_entity_row_id" +DEDUP_ROW_NUMBER = "__feast_row_number" + + +def _quote_identifier(identifier: str) -> str: + return f"`{identifier.replace('`', '``')}`" + + +def _qualified_column(alias: str, column: str) -> str: + return f"{alias}.{_quote_identifier(column)}" + + +def _select_column(alias: str, column: str, output_name: Optional[str] = None) -> str: + expr = _qualified_column(alias, column) + if output_name and output_name != column: + return f"{expr} AS {_quote_identifier(output_name)}" + return expr + + +def _flink_interval_literal(value: timedelta) -> str: + total_seconds = int(value.total_seconds()) + if total_seconds <= 0: + return "INTERVAL '0' SECOND" + + days, remainder = divmod(total_seconds, 24 * 60 * 60) + hours, remainder = divmod(remainder, 60 * 60) + minutes, seconds = divmod(remainder, 60) + parts = [] + if days: + parts.append(f"INTERVAL '{days}' DAY") + if hours: + parts.append(f"INTERVAL '{hours}' HOUR") + if minutes: + parts.append(f"INTERVAL '{minutes}' MINUTE") + if seconds: + parts.append(f"INTERVAL '{seconds}' SECOND") + return " + ".join(parts) + + +def _get_columns_from_schema(table: Any) -> Optional[List[str]]: + if not hasattr(table, "get_schema"): + return None + schema = table.get_schema() + if hasattr(schema, "get_field_names"): + return list(schema.get_field_names()) + if hasattr(schema, "get_field_count") and hasattr(schema, "get_field_name"): + return [schema.get_field_name(i) for i in range(schema.get_field_count())] + return None + + +def _get_columns(value: DAGValue) -> List[str]: + metadata_columns = value.metadata.get("columns") if value.metadata else None + if metadata_columns: + return list(metadata_columns) + schema_columns = _get_columns_from_schema(value.data) + if schema_columns: + return schema_columns + raise ValueError( + "Could not infer columns for Flink DAG value from metadata or PyFlink schema." + ) + + +def _can_use_sql(table_env: Any) -> bool: + return hasattr(table_env, "create_temporary_view") and hasattr( + table_env, "sql_query" + ) + + +def _require_sql(table_env: Any, node_name: str) -> None: + if not _can_use_sql(table_env): + raise RuntimeError( + f"Flink node '{node_name}' requires a PyFlink TableEnvironment with " + "create_temporary_view() and sql_query()." + ) + + +def _register_table(table_env: Any, table: Any, prefix: str) -> str: + view_name = f"__feast_{prefix}_{uuid.uuid4().hex}" + table_env.create_temporary_view(view_name, table) + return view_name + + +def _sql_value( + table_env: Any, + query: str, + columns: Iterable[str], + metadata: Optional[dict] = None, +) -> DAGValue: + return DAGValue( + data=table_env.sql_query(query), + format=DAGFormat.FLINK, + metadata={**(metadata or {}), "columns": list(columns), "native_sql": query}, + ) + + +def _entity_timestamp_column_from_columns(columns: List[str]) -> str: + if ENTITY_TS_ALIAS in columns: + return ENTITY_TS_ALIAS + if "event_timestamp" in columns: + return "event_timestamp" + raise ValueError( + "SQL-based entity_df for FlinkComputeEngine must select an " + "`event_timestamp` column." + ) + + +def _entity_value_from_dataframe( + table_env: Any, + entity_df: pd.DataFrame, + split_num: int, +) -> tuple[Any, List[str], str]: + entity_df = entity_df.copy() + entity_df[ENTITY_ROW_ID] = range(len(entity_df)) + entity_schema = dict(zip(entity_df.columns, entity_df.dtypes)) + entity_ts_col = infer_event_timestamp_from_entity_df(entity_schema) + if entity_ts_col != ENTITY_TS_ALIAS: + entity_df = entity_df.rename(columns={entity_ts_col: ENTITY_TS_ALIAS}) + return ( + pandas_to_flink_table(table_env, entity_df, split_num), + list(entity_df.columns), + entity_ts_col, + ) + + +def _entity_value_from_sql( + table_env: Any, + entity_sql: str, + join_keys: List[str], +) -> tuple[Any, List[str], str]: + _require_sql(table_env, "entity_df") + entity_table = table_env.sql_query(entity_sql) + entity_columns = _get_columns_from_schema(entity_table) + if entity_columns is None: + raise ValueError("Could not infer columns for SQL-based entity_df.") + + entity_ts_col = _entity_timestamp_column_from_columns(entity_columns) + entity_view = _register_table(table_env, entity_table, "entity_sql") + output_columns = [ + ENTITY_TS_ALIAS if column == entity_ts_col else column + for column in entity_columns + ] + select_exprs = [ + _select_column( + "entity_src", + column, + ENTITY_TS_ALIAS if column == entity_ts_col else column, + ) + for column in entity_columns + ] + order_columns = [ + column for column in [entity_ts_col, *join_keys] if column in entity_columns + ] + order_expr = ", ".join( + _qualified_column("entity_src", col) for col in order_columns + ) + if not order_expr: + order_expr = _qualified_column("entity_src", entity_columns[0]) + select_exprs.append( + f"ROW_NUMBER() OVER (ORDER BY {order_expr}) - 1 AS " + f"{_quote_identifier(ENTITY_ROW_ID)}" + ) + query = ( + f"SELECT {', '.join(select_exprs)} " + f"FROM {_quote_identifier(entity_view)} AS entity_src" + ) + output_columns.append(ENTITY_ROW_ID) + value = _sql_value( + table_env, + query, + output_columns, + metadata={"entity_timestamp_column": entity_ts_col}, + ) + return value.data, output_columns, entity_ts_col + + +def _entity_value_from_context( + table_env: Any, + context: ExecutionContext, + split_num: int, + join_keys: List[str], +) -> tuple[Any, List[str], str]: + if isinstance(context.entity_df, pd.DataFrame): + return _entity_value_from_dataframe(table_env, context.entity_df, split_num) + if isinstance(context.entity_df, str): + return _entity_value_from_sql(table_env, context.entity_df, join_keys) + raise TypeError( + "FlinkComputeEngine entity_df must be a pandas DataFrame, SQL string, or None." + ) + + +class FlinkSourceReadNode(DAGNode): + def __init__( + self, + name: str, + source: DataSource, + column_info: ColumnInfo, + table_env: Any, + split_num: int, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + ) -> None: + super().__init__(name) + self.source = source + self.column_info = column_info + self.table_env = table_env + self.split_num = split_num + self.start_time = start_time + self.end_time = end_time + + def execute(self, context: ExecutionContext) -> DAGValue: + retrieval_job = create_offline_store_retrieval_job( + data_source=self.source, + column_info=self.column_info, + context=context, + start_time=self.start_time, + end_time=self.end_time, + ) + if not hasattr(retrieval_job, "to_flink_table"): + raise TypeError( + "FlinkComputeEngine source reads require RetrievalJob.to_flink_table(" + "table_env). Configure an offline store retrieval job that returns " + "native PyFlink tables instead of Arrow/pandas results." + ) + + flink_table = retrieval_job.to_flink_table(self.table_env) + columns = _get_columns_from_schema(flink_table) + if columns is None: + raise ValueError( + "Could not infer columns for source Flink table returned by " + "RetrievalJob.to_flink_table(table_env)." + ) + + if self.column_info.field_mapping: + view_name = _register_table(self.table_env, flink_table, "source_read") + select_exprs = [ + _select_column( + "src", + col, + self.column_info.field_mapping.get(col, col), + ) + for col in columns + ] + renamed_columns = [ + self.column_info.field_mapping.get(col, col) for col in columns + ] + query = ( + f"SELECT {', '.join(select_exprs)} " + f"FROM {_quote_identifier(view_name)} AS src" + ) + return _sql_value( + self.table_env, + query, + renamed_columns, + metadata={ + "source": "feature_view_batch_source", + "timestamp_field": self.column_info.timestamp_column, + "created_timestamp_column": ( + self.column_info.created_timestamp_column + ), + "start_date": self.start_time, + "end_date": self.end_time, + }, + ) + + return DAGValue( + data=flink_table, + format=DAGFormat.FLINK, + metadata={ + "source": "feature_view_batch_source", + "timestamp_field": self.column_info.timestamp_column, + "created_timestamp_column": (self.column_info.created_timestamp_column), + "start_date": self.start_time, + "end_date": self.end_time, + "columns": columns, + }, + ) + + +class FlinkJoinNode(DAGNode): + def __init__( + self, + name: str, + column_info: ColumnInfo, + table_env: Any, + split_num: int, + inputs: Optional[List[DAGNode]] = None, + how: str = "left", + ) -> None: + super().__init__(name, inputs=inputs or []) + self.column_info = column_info + self.table_env = table_env + self.split_num = split_num + self.how = how + + def execute(self, context: ExecutionContext) -> DAGValue: + input_values = self.get_input_values(context) + for value in input_values: + value.assert_format(DAGFormat.FLINK) + if not input_values: + raise RuntimeError(f"FlinkJoinNode '{self.name}' requires inputs") + + _require_sql(self.table_env, self.name) + return self._execute_sql_join(input_values, context) + + def _execute_sql_join( + self, input_values: List[DAGValue], context: ExecutionContext + ) -> DAGValue: + join_keys = self.column_info.join_keys_columns + view_names = [ + _register_table(self.table_env, value.data, f"join_{index}") + for index, value in enumerate(input_values) + ] + columns_by_input = [_get_columns(value) for value in input_values] + output_columns = list(columns_by_input[0]) + seen_columns = set(output_columns) + select_exprs = [_select_column("t0", column) for column in columns_by_input[0]] + + joins = [] + for index, view_name in enumerate(view_names[1:], start=1): + alias = f"t{index}" + on_clause = " AND ".join( + f"{_qualified_column('t0', key)} = {_qualified_column(alias, key)}" + for key in join_keys + ) + joins.append( + f"{self.how.upper()} JOIN {_quote_identifier(view_name)} AS {alias} " + f"ON {on_clause}" + ) + for column in columns_by_input[index]: + if column in join_keys or column in seen_columns: + continue + output_columns.append(column) + seen_columns.add(column) + select_exprs.append(_select_column(alias, column)) + + query = ( + f"SELECT {', '.join(select_exprs)} " + f"FROM {_quote_identifier(view_names[0])} AS t0 " + f"{' '.join(joins)}" + ) + joined_value = _sql_value( + self.table_env, + query, + output_columns, + metadata={"joined_on": join_keys, "join_type": self.how}, + ) + + if context.entity_df is None: + return joined_value + + entity_table, entity_columns, entity_ts_col = _entity_value_from_context( + self.table_env, context, self.split_num, join_keys + ) + entity_view = _register_table(self.table_env, entity_table, "entity") + feature_view = _register_table(self.table_env, joined_value.data, "features") + feature_columns = [ + column + for column in output_columns + if column not in join_keys and column not in entity_columns + ] + select_entity = [_select_column("e", column) for column in entity_columns] + select_features = [_select_column("f", column) for column in feature_columns] + on_clause = " AND ".join( + f"{_qualified_column('e', key)} = {_qualified_column('f', key)}" + for key in join_keys + ) + entity_join_query = ( + f"SELECT {', '.join(select_entity + select_features)} " + f"FROM {_quote_identifier(entity_view)} AS e " + f"LEFT JOIN {_quote_identifier(feature_view)} AS f ON {on_clause}" + ) + return _sql_value( + self.table_env, + entity_join_query, + entity_columns + feature_columns, + metadata={ + "joined_on": join_keys, + "join_type": "left", + "entity_timestamp_column": entity_ts_col, + }, + ) + + +class FlinkFilterNode(DAGNode): + def __init__( + self, + name: str, + column_info: ColumnInfo, + table_env: Any, + split_num: int, + filter_expr: Optional[str] = None, + ttl: Optional[timedelta] = None, + inputs: Optional[List[DAGNode]] = None, + ) -> None: + super().__init__(name, inputs=inputs) + self.column_info = column_info + self.table_env = table_env + self.split_num = split_num + self.filter_expr = filter_expr + self.ttl = ttl + + def execute(self, context: ExecutionContext) -> DAGValue: + input_value = self.get_single_input_value(context) + input_value.assert_format(DAGFormat.FLINK) + + _require_sql(self.table_env, self.name) + return self._execute_sql_filter(input_value) + + def _execute_sql_filter(self, input_value: DAGValue) -> DAGValue: + columns = _get_columns(input_value) + timestamp_column = self.column_info.timestamp_column + conditions = [] + + if ENTITY_TS_ALIAS in columns and timestamp_column in columns: + conditions.append( + f"{_quote_identifier(timestamp_column)} <= " + f"{_quote_identifier(ENTITY_TS_ALIAS)}" + ) + if self.ttl: + ttl_interval = _flink_interval_literal(self.ttl) + conditions.append( + f"{_quote_identifier(timestamp_column)} >= " + f"{_quote_identifier(ENTITY_TS_ALIAS)} - " + f"({ttl_interval})" + ) + + if self.filter_expr: + conditions.append(f"({self.filter_expr})") + + if not conditions: + return input_value + + view_name = _register_table(self.table_env, input_value.data, "filter") + query = ( + f"SELECT * FROM {_quote_identifier(view_name)} " + f"WHERE {' AND '.join(conditions)}" + ) + return _sql_value( + self.table_env, + query, + columns, + metadata={**(input_value.metadata or {}), "filter_applied": True}, + ) + + +class FlinkAggregationNode(DAGNode): + def __init__( + self, + name: str, + group_keys: List[str], + aggregations: List[Aggregation], + table_env: Any, + split_num: int, + inputs: Optional[List[DAGNode]] = None, + ) -> None: + super().__init__(name, inputs=inputs) + self.group_keys = group_keys + self.aggregations = aggregations + self.table_env = table_env + self.split_num = split_num + + def execute(self, context: ExecutionContext) -> DAGValue: + agg_ops = aggregation_specs_to_agg_ops( + self.aggregations, + time_window_unsupported_error_message=( + "Time window aggregation is not yet supported in the Flink compute " + "engine. Use non-windowed aggregations or pre-window upstream in Flink." + ), + ) + input_value = self.get_single_input_value(context) + input_value.assert_format(DAGFormat.FLINK) + + _require_sql(self.table_env, self.name) + return self._execute_sql_aggregation(input_value, agg_ops) + + def _execute_sql_aggregation( + self, input_value: DAGValue, agg_ops: Dict[str, tuple[str, str]] + ) -> DAGValue: + view_name = _register_table(self.table_env, input_value.data, "aggregate") + select_exprs = [_quote_identifier(key) for key in self.group_keys] + for alias, (function, column) in agg_ops.items(): + sql_function = { + "mean": "AVG", + "avg": "AVG", + "sum": "SUM", + "min": "MIN", + "max": "MAX", + "count": "COUNT", + "nunique": "COUNT_DISTINCT", + "std": "STDDEV_SAMP", + "var": "VAR_SAMP", + }.get(function, function.upper()) + if sql_function == "COUNT_DISTINCT": + expr = ( + f"COUNT(DISTINCT {_quote_identifier(column)}) " + f"AS {_quote_identifier(alias)}" + ) + else: + expr = ( + f"{sql_function}({_quote_identifier(column)}) " + f"AS {_quote_identifier(alias)}" + ) + select_exprs.append(expr) + + query = ( + f"SELECT {', '.join(select_exprs)} " + f"FROM {_quote_identifier(view_name)} " + f"GROUP BY {', '.join(_quote_identifier(key) for key in self.group_keys)}" + ) + return _sql_value( + self.table_env, + query, + [*self.group_keys, *agg_ops.keys()], + metadata={"aggregated": True}, + ) + + +class FlinkDedupNode(DAGNode): + def __init__( + self, + name: str, + column_info: ColumnInfo, + table_env: Any, + split_num: int, + inputs: Optional[List[DAGNode]] = None, + ) -> None: + super().__init__(name, inputs=inputs) + self.column_info = column_info + self.table_env = table_env + self.split_num = split_num + + def execute(self, context: ExecutionContext) -> DAGValue: + input_value = self.get_single_input_value(context) + input_value.assert_format(DAGFormat.FLINK) + + _require_sql(self.table_env, self.name) + return self._execute_sql_dedup(input_value) + + def _execute_sql_dedup(self, input_value: DAGValue) -> DAGValue: + columns = _get_columns(input_value) + dedup_keys = ( + [ENTITY_ROW_ID] + if ENTITY_ROW_ID in columns + else self.column_info.join_keys_columns + ) + dedup_keys = [key for key in dedup_keys if key in columns] + if not dedup_keys: + return input_value + + order_columns = [ + self.column_info.timestamp_column, + self.column_info.created_timestamp_column, + ] + order_exprs = [ + f"{_quote_identifier(column)} DESC" + for column in order_columns + if column and column in columns + ] + if not order_exprs: + order_exprs = [f"{_quote_identifier(dedup_keys[0])} ASC"] + + view_name = _register_table(self.table_env, input_value.data, "dedup") + select_columns = ", ".join(_quote_identifier(column) for column in columns) + query = ( + f"SELECT {select_columns} FROM (" + f"SELECT *, ROW_NUMBER() OVER (" + f"PARTITION BY {', '.join(_quote_identifier(key) for key in dedup_keys)} " + f"ORDER BY {', '.join(order_exprs)}" + f") AS {_quote_identifier(DEDUP_ROW_NUMBER)} " + f"FROM {_quote_identifier(view_name)}" + f") WHERE {_quote_identifier(DEDUP_ROW_NUMBER)} = 1" + ) + return _sql_value( + self.table_env, + query, + columns, + metadata={**(input_value.metadata or {}), "deduped": True}, + ) + + +class FlinkTransformationNode(DAGNode): + def __init__( + self, + name: str, + transformation_fn: Callable[..., Any], + table_env: Any, + split_num: int, + inputs: Optional[List[DAGNode]] = None, + ) -> None: + super().__init__(name, inputs=inputs) + self.transformation_fn = transformation_fn + self.table_env = table_env + self.split_num = split_num + + def execute(self, context: ExecutionContext) -> DAGValue: + input_values = self.get_input_values(context) + for value in input_values: + value.assert_format(DAGFormat.FLINK) + + input_tables = [value.data for value in input_values] + transformed = self.transformation_fn(*input_tables) + + columns = _get_columns_from_schema(transformed) + if columns is None: + raise TypeError( + "Flink transformations must return a PyFlink Table with a schema." + ) + + return DAGValue( + data=transformed, + format=DAGFormat.FLINK, + metadata={"transformed": True, "columns": columns or []}, + ) + + +class FlinkValidationNode(DAGNode): + def __init__( + self, + name: str, + expected_columns: dict[str, Optional[pa.DataType]], + json_columns: Optional[Set[str]], + table_env: Any, + split_num: int, + inputs: Optional[List[DAGNode]] = None, + ) -> None: + super().__init__(name, inputs=inputs) + self.expected_columns = expected_columns + self.json_columns = json_columns or set() + self.table_env = table_env + self.split_num = split_num + + def execute(self, context: ExecutionContext) -> DAGValue: + input_value = self.get_single_input_value(context) + input_value.assert_format(DAGFormat.FLINK) + + columns = _get_columns(input_value) + missing = set(self.expected_columns.keys()) - set(columns) + if missing: + raise ValueError( + f"[Validation: {self.name}] Missing expected columns: {missing}. " + f"Actual columns: {sorted(columns)}" + ) + if not self.json_columns: + return DAGValue( + data=input_value.data, + format=DAGFormat.FLINK, + metadata={**(input_value.metadata or {}), "validated": True}, + ) + + raise NotImplementedError( + "JSON value validation is not supported by FlinkComputeEngine without " + "collecting data out of Flink. Validate JSON upstream in Flink SQL or " + "disable JSON validation for this FeatureView." + ) + + +class FlinkOutputNode(DAGNode): + def __init__( + self, + name: str, + feature_view: Union[BatchFeatureView, StreamFeatureView], + table_env: Any, + split_num: int, + write_output: bool, + inputs: Optional[List[DAGNode]] = None, + ) -> None: + super().__init__(name, inputs=inputs) + self.feature_view = feature_view + self.table_env = table_env + self.split_num = split_num + self.write_output = write_output + + def execute(self, context: ExecutionContext) -> DAGValue: + input_value = self.get_single_input_value(context) + input_value.assert_format(DAGFormat.FLINK) + output_value = self._drop_internal_columns(input_value) + output_table = output_value.data + if not self.write_output: + return output_value + + output_df = flink_table_to_pandas(output_table) + output_arrow = pa.Table.from_pandas(output_df) + + if output_arrow.num_rows == 0: + return output_value + + if self.feature_view.online: + join_key_to_value_type = { + entity.name: entity.dtype.to_value_type() + for entity in self.feature_view.entity_columns + } + batch_size = ( + context.repo_config.materialization_config.online_write_batch_size + ) + batches = ( + [output_arrow] + if batch_size is None + else output_arrow.to_batches(max_chunksize=batch_size) + ) + for batch in batches: + rows_to_write = _convert_arrow_to_proto( + batch, self.feature_view, join_key_to_value_type + ) + context.online_store.online_write_batch( + config=context.repo_config, + table=self.feature_view, + data=rows_to_write, + progress=lambda x: None, + ) + + if self.feature_view.offline: + context.offline_store.offline_write_batch( + config=context.repo_config, + feature_view=self.feature_view, + table=output_arrow, + progress=lambda x: None, + ) + + return output_value + + def _drop_internal_columns(self, input_value: DAGValue) -> DAGValue: + columns = _get_columns(input_value) + output_columns = [column for column in columns if column != ENTITY_ROW_ID] + if output_columns == columns: + return input_value + + _require_sql(self.table_env, self.name) + view_name = _register_table(self.table_env, input_value.data, "output") + query = ( + f"SELECT {', '.join(_quote_identifier(column) for column in output_columns)} " + f"FROM {_quote_identifier(view_name)}" + ) + return _sql_value( + self.table_env, + query, + output_columns, + metadata={**(input_value.metadata or {}), "output_cleaned": True}, + ) diff --git a/sdk/python/feast/infra/compute_engines/flink/utils.py b/sdk/python/feast/infra/compute_engines/flink/utils.py new file mode 100644 index 00000000000..dc330d45d2a --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/flink/utils.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import pandas as pd +import pyarrow as pa + +if TYPE_CHECKING: + from feast.infra.compute_engines.flink.compute import FlinkComputeEngineConfig + + +def create_flink_table_environment(config: FlinkComputeEngineConfig) -> Any: + """Create a PyFlink TableEnvironment from Feast engine config.""" + try: + from pyflink.common import Configuration + from pyflink.table import EnvironmentSettings, TableEnvironment + except ImportError as exc: + raise ImportError( + "FlinkComputeEngine requires PyFlink. Install the `flink` extra with " + "uv from a Feast source checkout, or otherwise make the `pyflink` " + "package available to Feast." + ) from exc + + flink_conf = Configuration() + for key, value in (config.table_config or {}).items(): + flink_conf.set_string(key, value) + if config.parallelism is not None: + flink_conf.set_string("parallelism.default", str(config.parallelism)) + + builder = EnvironmentSettings.new_instance().with_configuration(flink_conf) + if config.execution_mode == "streaming": + builder = builder.in_streaming_mode() + else: + builder = builder.in_batch_mode() + return TableEnvironment.create(builder.build()) + + +def pandas_to_flink_table(table_env: Any, df: pd.DataFrame, split_num: int = 1) -> Any: + """Convert a pandas DataFrame to a PyFlink table.""" + schema = list(df.columns) + return table_env.from_pandas(df, schema=schema, splits_num=split_num) + + +def flink_table_to_pandas(table: Any) -> pd.DataFrame: + """Collect a PyFlink table into pandas.""" + if hasattr(table, "to_pandas"): + return table.to_pandas() + raise TypeError(f"Expected a PyFlink table, got {type(table)}") + + +def flink_table_to_arrow(table: Any) -> pa.Table: + """Collect a PyFlink table into Arrow.""" + value = flink_table_to_pandas(table) + return pa.Table.from_pandas(value) diff --git a/sdk/python/feast/repo_config.py b/sdk/python/feast/repo_config.py index 7518f613788..0f832ea0d69 100644 --- a/sdk/python/feast/repo_config.py +++ b/sdk/python/feast/repo_config.py @@ -50,6 +50,7 @@ "k8s": "feast.infra.compute_engines.kubernetes.k8s_engine.KubernetesComputeEngine", "spark.engine": "feast.infra.compute_engines.spark.compute.SparkComputeEngine", "ray.engine": "feast.infra.compute_engines.ray.compute.RayComputeEngine", + "flink.engine": "feast.infra.compute_engines.flink.compute.FlinkComputeEngine", } LEGACY_ONLINE_STORE_CLASS_FOR_TYPE = { diff --git a/sdk/python/feast/stream_feature_view.py b/sdk/python/feast/stream_feature_view.py index 9ee07e6a199..c2b4625214a 100644 --- a/sdk/python/feast/stream_feature_view.py +++ b/sdk/python/feast/stream_feature_view.py @@ -207,7 +207,8 @@ def get_feature_transformation(self) -> Optional[Transformation]: TransformationMode.PYTHON, TransformationMode.SPARK_SQL, TransformationMode.SPARK, - ) or self.mode in ("pandas", "python", "spark_sql", "spark"): + TransformationMode.FLINK, + ) or self.mode in ("pandas", "python", "spark_sql", "spark", "flink"): return Transformation( mode=self.mode, udf=self.udf, udf_string=self.udf_string or "" ) diff --git a/sdk/python/feast/transformation/factory.py b/sdk/python/feast/transformation/factory.py index 16d7a7570d5..a181b7dea69 100644 --- a/sdk/python/feast/transformation/factory.py +++ b/sdk/python/feast/transformation/factory.py @@ -7,6 +7,7 @@ "sql": "feast.transformation.sql_transformation.SQLTransformation", "spark_sql": "feast.transformation.spark_transformation.SparkTransformation", "spark": "feast.transformation.spark_transformation.SparkTransformation", + "flink": "feast.transformation.flink_transformation.FlinkTransformation", "ray": "feast.transformation.ray_transformation.RayTransformation", } diff --git a/sdk/python/feast/transformation/flink_transformation.py b/sdk/python/feast/transformation/flink_transformation.py new file mode 100644 index 00000000000..83b929c4c96 --- /dev/null +++ b/sdk/python/feast/transformation/flink_transformation.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from typing import Any, Callable, Optional, cast + +from feast.transformation.base import Transformation +from feast.transformation.mode import TransformationMode + + +class FlinkTransformation(Transformation): + """Transformation wrapper for Flink compute-engine UDFs. + + The UDF is expected to accept PyFlink Table objects and return a PyFlink + Table. + """ + + def __new__( + cls, + udf: Optional[Callable[..., Any]] = None, + udf_string: Optional[str] = None, + name: Optional[str] = None, + tags: Optional[dict[str, str]] = None, + description: str = "", + owner: str = "", + *args, + **kwargs, + ) -> "FlinkTransformation": + if udf is None and udf_string is None: + return cast("FlinkTransformation", object.__new__(cls)) + if udf is None: + raise ValueError("udf parameter cannot be None") + if udf_string is None: + raise ValueError("udf_string parameter cannot be None") + return cast( + "FlinkTransformation", + super(FlinkTransformation, cls).__new__( + cls, + mode=TransformationMode.FLINK, + udf=udf, + name=name, + udf_string=udf_string, + tags=tags, + description=description, + owner=owner, + ), + ) + + def __init__( + self, + udf: Optional[Callable[..., Any]] = None, + udf_string: Optional[str] = None, + name: Optional[str] = None, + tags: Optional[dict[str, str]] = None, + description: str = "", + owner: str = "", + *args, + **kwargs, + ) -> None: + if udf is None and udf_string is None: + return + if udf is None: + raise ValueError("udf parameter cannot be None") + if udf_string is None: + raise ValueError("udf_string parameter cannot be None") + super().__init__( + mode=TransformationMode.FLINK, + udf=udf, + name=name, + udf_string=udf_string, + tags=tags, + description=description, + owner=owner, + ) + + def transform(self, *inputs: Any) -> Any: + return self.udf(*inputs) + + def infer_features(self, *args, **kwargs) -> Any: + pass diff --git a/sdk/python/feast/transformation/mode.py b/sdk/python/feast/transformation/mode.py index 44d38d8e99c..bd6fdf22424 100644 --- a/sdk/python/feast/transformation/mode.py +++ b/sdk/python/feast/transformation/mode.py @@ -6,6 +6,7 @@ class TransformationMode(Enum): PANDAS = "pandas" SPARK_SQL = "spark_sql" SPARK = "spark" + FLINK = "flink" RAY = "ray" SQL = "sql" SUBSTRAIT = "substrait" diff --git a/sdk/python/tests/unit/infra/compute_engines/flink/__init__.py b/sdk/python/tests/unit/infra/compute_engines/flink/__init__.py new file mode 100644 index 00000000000..8b137891791 --- /dev/null +++ b/sdk/python/tests/unit/infra/compute_engines/flink/__init__.py @@ -0,0 +1 @@ + diff --git a/sdk/python/tests/unit/infra/compute_engines/flink/test_flink_compute_engine.py b/sdk/python/tests/unit/infra/compute_engines/flink/test_flink_compute_engine.py new file mode 100644 index 00000000000..f1d7ad2e8b6 --- /dev/null +++ b/sdk/python/tests/unit/infra/compute_engines/flink/test_flink_compute_engine.py @@ -0,0 +1,984 @@ +from __future__ import annotations + +import re +from datetime import datetime, timedelta +from pathlib import Path +from typing import Any, List, Optional +from unittest.mock import MagicMock + +import pandas as pd +import pyarrow as pa +import pytest +import toml # type: ignore[import-untyped] + +from feast import BatchFeatureView, Entity, Field, FileSource +from feast.aggregation import Aggregation +from feast.infra.common.materialization_job import ( + MaterializationJobStatus, + MaterializationTask, +) +from feast.infra.common.retrieval_task import HistoricalRetrievalTask +from feast.infra.compute_engines.dag.context import ColumnInfo, ExecutionContext +from feast.infra.compute_engines.dag.model import DAGFormat +from feast.infra.compute_engines.dag.node import DAGNode +from feast.infra.compute_engines.dag.value import DAGValue +from feast.infra.compute_engines.flink.compute import ( + FlinkComputeEngine, + FlinkComputeEngineConfig, +) +from feast.infra.compute_engines.flink.nodes import ( + ENTITY_ROW_ID, + ENTITY_TS_ALIAS, + FlinkAggregationNode, + FlinkDedupNode, + FlinkFilterNode, + FlinkJoinNode, + FlinkSourceReadNode, + FlinkTransformationNode, + FlinkValidationNode, + _flink_interval_literal, +) +from feast.infra.offline_stores.offline_store import RetrievalJob, RetrievalMetadata +from feast.on_demand_feature_view import OnDemandFeatureView +from feast.repo_config import RepoConfig +from feast.saved_dataset import SavedDatasetStorage +from feast.types import Float32 +from feast.value_type import ValueType + + +def test_flink_extra_does_not_downgrade_default_pyarrow_dependency() -> None: + pyproject_path = Path(__file__).resolve().parents[7] / "pyproject.toml" + pyproject = toml.loads(pyproject_path.read_text()) + + dependencies = pyproject["project"]["dependencies"] + assert "pyarrow>=21.0.0; extra != 'flink'" in dependencies + assert "pyarrow>=16.1.0,<21.0.0; extra == 'flink'" in dependencies + assert "pyarrow>=16.1.0" not in dependencies + assert ( + "apache-flink>=2.2.1,<3" + in pyproject["project"]["optional-dependencies"]["flink"] + ) + + +class FakeFlinkTable: + def __init__(self, df: pd.DataFrame) -> None: + self._df = df.copy() + + def to_pandas(self) -> pd.DataFrame: + return self._df.copy() + + def get_schema(self) -> FakeFlinkSchema: + return FakeFlinkSchema(list(self._df.columns)) + + +class FakeTableEnvironment: + def __init__(self) -> None: + self.created_tables: List[pd.DataFrame] = [] + self.split_nums: List[int] = [] + self.views: dict[str, object] = {} + self.queries: List[str] = [] + + def from_pandas( + self, + df: pd.DataFrame, + schema: object = None, + splits_num: int = 1, + split_num: Optional[int] = None, + ) -> FakeFlinkTable: + self.created_tables.append(df.copy()) + self.split_nums.append(split_num if split_num is not None else splits_num) + return FakeFlinkTable(df) + + def create_temporary_view( + self, view_path: str, table_or_data_stream: object, *args: object + ) -> None: + self.views[view_path] = table_or_data_stream + + def sql_query(self, query: str) -> Any: + self.queries.append(query) + return FakeFlinkTable(self._evaluate_sql(query)) + + def _view_df(self, view_name: str) -> pd.DataFrame: + table = self.views[view_name] + if isinstance(table, FakeFlinkTable): + return table.to_pandas() + if isinstance(table, FakeNativeFlinkTable): + return pd.DataFrame(columns=table.get_schema().get_field_names()) + raise TypeError(f"Unsupported fake Flink table type: {type(table)}") + + def _evaluate_sql(self, query: str) -> pd.DataFrame: + if "ROW_NUMBER() OVER" in query: + return self._evaluate_row_number_query(query) + if " GROUP BY " in query: + return self._evaluate_group_by_query(query) + if " JOIN " in query: + return self._evaluate_join_query(query) + if " WHERE " in query: + return self._evaluate_where_query(query) + return self._evaluate_select_query(query) + + def _extract_views(self, query: str) -> List[str]: + return re.findall(r"(?:FROM|JOIN)\s+`([^`]+)`", query) + + def _evaluate_select_query(self, query: str) -> pd.DataFrame: + views = self._extract_views(query) + if not views: + if "FROM entities" in query: + return self._view_df("entities")[["driver_id", "event_timestamp"]] + raise ValueError(f"Could not infer source view from query: {query}") + source_df = self._view_df(views[-1]) + select_clause = query.split(" FROM ", 1)[0].removeprefix("SELECT ") + if select_clause == "*": + return source_df + result = pd.DataFrame() + for column_expr in select_clause.split(", "): + parts = re.findall(r"`([^`]+)`", column_expr) + if not parts: + continue + source_column = parts[0] + output_column = parts[-1] + result[output_column] = source_df[source_column] + return result + + def _evaluate_where_query(self, query: str) -> pd.DataFrame: + view_name = self._extract_views(query)[0] + df = self._view_df(view_name) + if "`event_timestamp` <= `__entity_event_timestamp`" in query: + df = df[df["event_timestamp"] <= df[ENTITY_TS_ALIAS]] + if "conv_rate > 0.15" in query: + df = df[df["conv_rate"] > 0.15] + return df.reset_index(drop=True) + + def _evaluate_group_by_query(self, query: str) -> pd.DataFrame: + view_name = self._extract_views(query)[0] + df = self._view_df(view_name) + return ( + df.groupby("driver_id") + .agg(sum_conv_rate=pd.NamedAgg(column="conv_rate", aggfunc="sum")) + .reset_index() + ) + + def _evaluate_row_number_query(self, query: str) -> pd.DataFrame: + view_name = self._extract_views(query)[-1] + df = self._view_df(view_name) + if " - 1 AS " in query: + df = df.copy() + df[ENTITY_ROW_ID] = range(len(df)) + if " AS `__entity_event_timestamp`" in query: + df = df.rename(columns={"event_timestamp": ENTITY_TS_ALIAS}) + return df.reset_index(drop=True) + + dedup_keys = [ENTITY_ROW_ID] if ENTITY_ROW_ID in df.columns else ["driver_id"] + sort_keys = [ + column for column in ["event_timestamp", "created"] if column in df + ] + return ( + df.sort_values(by=sort_keys, ascending=False) + .drop_duplicates(subset=dedup_keys) + .reset_index(drop=True) + ) + + def _evaluate_join_query(self, query: str) -> pd.DataFrame: + views = self._extract_views(query) + if " AS e LEFT JOIN " in query: + entity_df = self._view_df(views[-2]) + feature_df = self._view_df(views[-1]) + feature_columns = [ + column + for column in feature_df.columns + if column not in entity_df.columns and column != "driver_id" + ] + return entity_df.merge( + feature_df[["driver_id", *feature_columns]], + on="driver_id", + how="left", + ) + + joined_df = self._view_df(views[0]) + for view_name in views[1:]: + joined_df = joined_df.merge( + self._view_df(view_name), on="driver_id", how="left" + ) + return joined_df + + +class FakeFlinkSchema: + def __init__(self, columns: List[str]) -> None: + self._columns = columns + + def get_field_names(self) -> List[str]: + return list(self._columns) + + +class FakeNativeFlinkTable: + def __init__(self, columns: List[str]) -> None: + self._columns = columns + + def get_schema(self) -> FakeFlinkSchema: + return FakeFlinkSchema(self._columns) + + +class RecordingTableEnvironment(FakeTableEnvironment): + def __init__(self) -> None: + super().__init__() + + def create_temporary_view( + self, view_path: str, table_or_data_stream: object, *args: object + ) -> None: + self.views[view_path] = table_or_data_stream + + def sql_query(self, query: str) -> FakeNativeFlinkTable: + self.queries.append(query) + return FakeNativeFlinkTable([]) + + +class InputNode(DAGNode): + def execute(self, context: ExecutionContext) -> DAGValue: + return context.node_outputs[self.name] + + +class FakeRetrievalJob(RetrievalJob): + def __init__(self, table: pa.Table) -> None: + self._table = table + + def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame: + return self._table.to_pandas() + + def _to_arrow_internal(self, timeout: Optional[int] = None) -> pa.Table: + return self._table + + @property + def full_feature_names(self) -> bool: + return False + + @property + def on_demand_feature_views(self) -> List[OnDemandFeatureView]: + return [] + + @property + def metadata(self) -> Optional[RetrievalMetadata]: + return None + + def persist( + self, + storage: SavedDatasetStorage, + allow_overwrite: bool = False, + timeout: Optional[int] = None, + ) -> None: + raise NotImplementedError + + def to_remote_storage(self) -> List[str]: + raise NotImplementedError + + def to_sql(self) -> str: + raise NotImplementedError + + +class FakeFlinkRetrievalJob: + def __init__(self, df: pd.DataFrame) -> None: + self._table = FakeFlinkTable(df) + + def to_flink_table(self, table_env: object) -> FakeFlinkTable: + return self._table + + +def _repo_config(tmp_path: Path, batch_engine: dict[str, object]) -> RepoConfig: + return RepoConfig( + project="test_project", + registry=str(tmp_path / "registry.db"), + provider="local", + offline_store={"type": "file"}, + online_store={"type": "sqlite", "path": str(tmp_path / "online.db")}, + batch_engine=batch_engine, + ) + + +def _driver() -> Entity: + return Entity(name="driver_id", value_type=ValueType.INT64) + + +def _source() -> FileSource: + return FileSource( + name="driver_stats_source", + path="unused.parquet", + timestamp_field="event_timestamp", + created_timestamp_column="created", + ) + + +def _feature_view(source: FileSource, **kwargs: Any) -> BatchFeatureView: + return BatchFeatureView( + name="driver_stats", + entities=[_driver()], + ttl=timedelta(days=2), + schema=[Field(name="conv_rate", dtype=Float32)], + source=source, + **kwargs, + ) + + +def _feature_data() -> pd.DataFrame: + return pd.DataFrame( + { + "driver_id": [1, 1, 2], + "event_timestamp": [ + datetime(2024, 1, 1, 9, 0, 0), + datetime(2024, 1, 1, 10, 0, 0), + datetime(2024, 1, 1, 10, 0, 0), + ], + "created": [ + datetime(2024, 1, 1, 9, 1, 0), + datetime(2024, 1, 1, 10, 1, 0), + datetime(2024, 1, 1, 10, 1, 0), + ], + "conv_rate": [0.1, 0.2, 0.3], + } + ) + + +def _offline_store(df: pd.DataFrame) -> MagicMock: + store = MagicMock() + store.pull_all_from_table_or_query.return_value = FakeFlinkRetrievalJob(df) + store.pull_latest_from_table_or_query.return_value = FakeFlinkRetrievalJob(df) + return store + + +def _registry(entity: Entity) -> MagicMock: + registry = MagicMock() + registry.get_entity.return_value = entity + return registry + + +def _column_info() -> ColumnInfo: + return ColumnInfo( + join_keys=["driver_id"], + feature_cols=["conv_rate"], + ts_col="event_timestamp", + created_ts_col="created", + ) + + +def _execution_context( + tmp_path: Path, node_outputs: dict[str, DAGValue] +) -> ExecutionContext: + return ExecutionContext( + project="test_project", + repo_config=_repo_config(tmp_path, {"type": "flink.engine"}), + offline_store=MagicMock(), + online_store=MagicMock(), + entity_defs=[_driver()], + node_outputs=node_outputs, + ) + + +def _flink_value(df: pd.DataFrame) -> DAGValue: + return DAGValue( + data=FakeFlinkTable(df), + format=DAGFormat.FLINK, + metadata={"columns": list(df.columns)}, + ) + + +def _native_flink_value(columns: List[str]) -> DAGValue: + return DAGValue( + data=FakeNativeFlinkTable(columns), + format=DAGFormat.FLINK, + metadata={"columns": columns}, + ) + + +def test_repo_config_loads_flink_batch_engine_config(tmp_path: Path) -> None: + config = _repo_config( + tmp_path, + { + "type": "flink.engine", + "execution_mode": "streaming", + "parallelism": 3, + "table_config": {"pipeline.name": "feast-flink-test"}, + "pandas_split_num": 2, + }, + ) + + assert isinstance(config.batch_engine, FlinkComputeEngineConfig) + assert config.batch_engine.execution_mode == "streaming" + assert config.batch_engine.parallelism == 3 + assert config.batch_engine.table_config == {"pipeline.name": "feast-flink-test"} + assert config.batch_engine.pandas_split_num == 2 + + +def test_flink_source_read_node_rejects_arrow_retrieval_jobs(tmp_path: Path) -> None: + offline_store = MagicMock() + offline_store.pull_all_from_table_or_query.return_value = FakeRetrievalJob( + pa.Table.from_pandas(_feature_data()) + ) + context = _execution_context(tmp_path, {}) + context.offline_store = offline_store + node = FlinkSourceReadNode( + "source", + _source(), + _column_info(), + FakeTableEnvironment(), + split_num=1, + ) + + with pytest.raises(TypeError, match="to_flink_table"): + node.execute(context) + + +def test_flink_historical_retrieval_executes_dag_with_transformation( + tmp_path: Path, +) -> None: + entity = _driver() + source = _source() + + def double_conv_rate(table: FakeFlinkTable) -> FakeFlinkTable: + df = table.to_pandas() + df["conv_rate"] = df["conv_rate"] * 2 + return FakeFlinkTable(df) + + feature_view = _feature_view( + source, + mode="flink", + udf=double_conv_rate, + udf_string="double_conv_rate", + online=False, + offline=False, + ) + config = _repo_config( + tmp_path, + {"type": "flink.engine", "pandas_split_num": 4}, + ) + table_env = FakeTableEnvironment() + engine = FlinkComputeEngine( + repo_config=config, + offline_store=_offline_store(_feature_data()), + online_store=MagicMock(), + table_environment=table_env, + ) + task = HistoricalRetrievalTask( + project=config.project, + entity_df=pd.DataFrame(), + feature_view=feature_view, + full_feature_name=False, + registry=_registry(entity), + ) + + job = engine.get_historical_features(_registry(entity), task) + result = job.to_df().sort_values("driver_id").reset_index(drop=True) + + assert job.error() is None + assert result["driver_id"].tolist() == [1, 2] + assert result["conv_rate"].tolist() == [0.4, 0.6] + + +def test_flink_historical_retrieval_is_read_only_and_dedupes_per_entity_row( + tmp_path: Path, +) -> None: + entity = _driver() + source = _source() + feature_view = _feature_view(source, online=True, offline=True) + config = _repo_config(tmp_path, {"type": "flink.engine", "pandas_split_num": 4}) + feature_data = pd.DataFrame( + { + "driver_id": [1, 1], + "event_timestamp": [ + datetime(2024, 1, 1, 9, 0, 0), + datetime(2024, 1, 1, 10, 0, 0), + ], + "created": [ + datetime(2024, 1, 1, 9, 1, 0), + datetime(2024, 1, 1, 10, 1, 0), + ], + "conv_rate": [0.1, 0.2], + } + ) + offline_store = _offline_store(feature_data) + online_store = MagicMock() + table_env = FakeTableEnvironment() + engine = FlinkComputeEngine( + repo_config=config, + offline_store=offline_store, + online_store=online_store, + table_environment=table_env, + ) + task = HistoricalRetrievalTask( + project=config.project, + entity_df=pd.DataFrame( + { + "driver_id": [1, 1], + "event_timestamp": [ + datetime(2024, 1, 1, 9, 30, 0), + datetime(2024, 1, 1, 10, 30, 0), + ], + } + ), + feature_view=feature_view, + full_feature_name=False, + registry=_registry(entity), + ) + + result = engine.get_historical_features(_registry(entity), task).to_df() + result = result.sort_values(ENTITY_TS_ALIAS).reset_index(drop=True) + + assert result["conv_rate"].tolist() == [0.1, 0.2] + assert table_env.split_nums == [4] + assert ENTITY_ROW_ID not in result.columns + online_store.online_write_batch.assert_not_called() + offline_store.offline_write_batch.assert_not_called() + + +def test_flink_historical_retrieval_supports_sql_entity_df(tmp_path: Path) -> None: + entity = _driver() + source = _source() + feature_view = _feature_view(source, online=False, offline=False) + config = _repo_config(tmp_path, {"type": "flink.engine"}) + table_env = FakeTableEnvironment() + table_env.create_temporary_view( + "entities", + FakeFlinkTable( + pd.DataFrame( + { + "driver_id": [1, 1], + "event_timestamp": [ + datetime(2024, 1, 1, 9, 30, 0), + datetime(2024, 1, 1, 10, 30, 0), + ], + } + ) + ), + ) + engine = FlinkComputeEngine( + repo_config=config, + offline_store=_offline_store(_feature_data()), + online_store=MagicMock(), + table_environment=table_env, + ) + task = HistoricalRetrievalTask( + project=config.project, + entity_df="SELECT driver_id, event_timestamp FROM entities", + feature_view=feature_view, + full_feature_name=False, + registry=_registry(entity), + ) + + job = engine.get_historical_features(_registry(entity), task) + result = job.to_df().sort_values(ENTITY_TS_ALIAS).reset_index(drop=True) + + assert job.error() is None + assert result["conv_rate"].tolist() == [0.1, 0.2] + assert any( + "SELECT driver_id, event_timestamp FROM entities" in query + for query in table_env.queries + ) + + +def test_flink_materialize_writes_online_and_offline(tmp_path: Path) -> None: + entity = _driver() + source = _source() + feature_view = _feature_view(source, online=True, offline=True) + config = _repo_config(tmp_path, {"type": "flink.engine"}) + offline_store = _offline_store(_feature_data().head(1)) + online_store = MagicMock() + engine = FlinkComputeEngine( + repo_config=config, + offline_store=offline_store, + online_store=online_store, + table_environment=FakeTableEnvironment(), + ) + task = MaterializationTask( + project=config.project, + feature_view=feature_view, + start_time=datetime(2024, 1, 1), + end_time=datetime(2024, 1, 2), + ) + + jobs = engine.materialize(_registry(entity), task) + + assert len(jobs) == 1 + assert jobs[0].status() == MaterializationJobStatus.SUCCEEDED + assert jobs[0].error() is None + online_store.online_write_batch.assert_called_once() + offline_store.offline_write_batch.assert_called_once() + + +def test_flink_engine_reports_materialization_errors(tmp_path: Path) -> None: + entity = _driver() + source = _source() + feature_view = _feature_view(source, online=False, offline=False) + offline_store = MagicMock() + offline_store.pull_all_from_table_or_query.side_effect = RuntimeError("boom") + config = _repo_config(tmp_path, {"type": "flink.engine"}) + engine = FlinkComputeEngine( + repo_config=config, + offline_store=offline_store, + online_store=MagicMock(), + table_environment=FakeTableEnvironment(), + ) + task = MaterializationTask( + project=config.project, + feature_view=feature_view, + start_time=datetime(2024, 1, 1), + end_time=datetime(2024, 1, 2), + ) + + jobs = engine.materialize(_registry(entity), task) + + assert jobs[0].status() == MaterializationJobStatus.ERROR + assert isinstance(jobs[0].error(), RuntimeError) + + +def test_flink_join_node_merges_input_tables(tmp_path: Path) -> None: + left = InputNode("left") + right = InputNode("right") + node = FlinkJoinNode( + "join", + _column_info(), + FakeTableEnvironment(), + split_num=1, + inputs=[left, right], + ) + context = _execution_context( + tmp_path, + { + "left": _flink_value( + pd.DataFrame({"driver_id": [1, 2], "conv_rate": [0.1, 0.2]}) + ), + "right": _flink_value( + pd.DataFrame({"driver_id": [1, 2], "acc_rate": [0.3, 0.4]}) + ), + }, + ) + + result = node.execute(context).data.to_pandas().sort_values("driver_id") + + assert result["conv_rate"].tolist() == [0.1, 0.2] + assert result["acc_rate"].tolist() == [0.3, 0.4] + + +def test_flink_join_node_uses_native_sql_when_available(tmp_path: Path) -> None: + left = InputNode("left") + right = InputNode("right") + table_env = RecordingTableEnvironment() + node = FlinkJoinNode( + "join", + _column_info(), + table_env, + split_num=1, + inputs=[left, right], + ) + context = _execution_context( + tmp_path, + { + "left": _native_flink_value(["driver_id", "conv_rate"]), + "right": _native_flink_value(["driver_id", "acc_rate"]), + }, + ) + + result = node.execute(context) + + assert result.format == DAGFormat.FLINK + assert result.metadata["columns"] == ["driver_id", "conv_rate", "acc_rate"] + assert any("JOIN" in query for query in table_env.queries) + + +def test_flink_filter_node_applies_filter_expression(tmp_path: Path) -> None: + input_node = InputNode("input") + node = FlinkFilterNode( + "filter", + _column_info(), + FakeTableEnvironment(), + split_num=1, + filter_expr="conv_rate > 0.15", + inputs=[input_node], + ) + context = _execution_context( + tmp_path, + { + "input": _flink_value( + pd.DataFrame({"driver_id": [1, 2], "conv_rate": [0.1, 0.2]}) + ) + }, + ) + + result = node.execute(context).data.to_pandas() + + assert result["driver_id"].tolist() == [2] + + +def test_flink_filter_node_uses_native_sql_when_available(tmp_path: Path) -> None: + input_node = InputNode("input") + table_env = RecordingTableEnvironment() + node = FlinkFilterNode( + "filter", + _column_info(), + table_env, + split_num=1, + filter_expr="conv_rate > 0.15", + inputs=[input_node], + ) + context = _execution_context( + tmp_path, + {"input": _native_flink_value(["driver_id", "conv_rate"])}, + ) + + result = node.execute(context) + + assert result.format == DAGFormat.FLINK + assert any( + "WHERE" in query and "conv_rate > 0.15" in query for query in table_env.queries + ) + + +def test_flink_filter_node_renders_ttl_as_valid_flink_interval( + tmp_path: Path, +) -> None: + input_node = InputNode("input") + table_env = RecordingTableEnvironment() + node = FlinkFilterNode( + "filter", + _column_info(), + table_env, + split_num=1, + ttl=timedelta(days=2, hours=3, minutes=4, seconds=5), + inputs=[input_node], + ) + context = _execution_context( + tmp_path, + { + "input": _native_flink_value( + ["driver_id", "conv_rate", "event_timestamp", ENTITY_TS_ALIAS] + ) + }, + ) + + node.execute(context) + + assert _flink_interval_literal( + timedelta(days=2, hours=3, minutes=4, seconds=5) + ) == ( + "INTERVAL '2' DAY + INTERVAL '3' HOUR + " + "INTERVAL '4' MINUTE + INTERVAL '5' SECOND" + ) + assert any("INTERVAL '2' DAY" in query for query in table_env.queries) + + +def test_flink_aggregation_node_groups_features(tmp_path: Path) -> None: + input_node = InputNode("input") + node = FlinkAggregationNode( + "agg", + ["driver_id"], + aggregations=[Aggregation(column="conv_rate", function="sum")], + table_env=FakeTableEnvironment(), + split_num=1, + inputs=[input_node], + ) + context = _execution_context( + tmp_path, + { + "input": _flink_value( + pd.DataFrame({"driver_id": [1, 1, 2], "conv_rate": [0.1, 0.2, 0.3]}) + ) + }, + ) + + result = node.execute(context).data.to_pandas().sort_values("driver_id") + + assert result["sum_conv_rate"].tolist() == pytest.approx([0.3, 0.3]) + + +def test_flink_aggregation_node_uses_native_sql_when_available(tmp_path: Path) -> None: + input_node = InputNode("input") + table_env = RecordingTableEnvironment() + node = FlinkAggregationNode( + "agg", + ["driver_id"], + aggregations=[Aggregation(column="conv_rate", function="sum")], + table_env=table_env, + split_num=1, + inputs=[input_node], + ) + context = _execution_context( + tmp_path, + {"input": _native_flink_value(["driver_id", "conv_rate"])}, + ) + + result = node.execute(context) + + assert result.format == DAGFormat.FLINK + assert result.metadata["columns"] == ["driver_id", "sum_conv_rate"] + assert any("GROUP BY" in query and "SUM" in query for query in table_env.queries) + + +def test_flink_dedup_node_uses_entity_row_id_for_historical_retrieval( + tmp_path: Path, +) -> None: + input_node = InputNode("input") + node = FlinkDedupNode( + "dedup", + _column_info(), + FakeTableEnvironment(), + split_num=1, + inputs=[input_node], + ) + context = _execution_context( + tmp_path, + { + "input": _flink_value( + pd.DataFrame( + { + ENTITY_ROW_ID: [0, 0, 1], + "driver_id": [1, 1, 1], + "event_timestamp": [ + datetime(2024, 1, 1, 9, 0, 0), + datetime(2024, 1, 1, 10, 0, 0), + datetime(2024, 1, 1, 10, 0, 0), + ], + "created": [ + datetime(2024, 1, 1, 9, 1, 0), + datetime(2024, 1, 1, 10, 1, 0), + datetime(2024, 1, 1, 10, 1, 0), + ], + "conv_rate": [0.1, 0.2, 0.3], + } + ) + ) + }, + ) + + result = node.execute(context).data.to_pandas().sort_values(ENTITY_ROW_ID) + + assert result["conv_rate"].tolist() == [0.2, 0.3] + + +def test_flink_dedup_node_uses_native_row_number_when_available( + tmp_path: Path, +) -> None: + input_node = InputNode("input") + table_env = RecordingTableEnvironment() + node = FlinkDedupNode( + "dedup", + _column_info(), + table_env, + split_num=1, + inputs=[input_node], + ) + context = _execution_context( + tmp_path, + { + "input": _native_flink_value( + [ENTITY_ROW_ID, "driver_id", "event_timestamp", "created", "conv_rate"] + ) + }, + ) + + result = node.execute(context) + + assert result.format == DAGFormat.FLINK + assert any("ROW_NUMBER() OVER" in query for query in table_env.queries) + assert ENTITY_ROW_ID in result.metadata["columns"] + + +def test_flink_transformation_node_keeps_native_flink_table(tmp_path: Path) -> None: + input_node = InputNode("input") + native_result = FakeNativeFlinkTable(["driver_id", "conv_rate"]) + + def native_udf(table: object) -> FakeNativeFlinkTable: + return native_result + + node = FlinkTransformationNode( + "transform", + native_udf, + RecordingTableEnvironment(), + split_num=1, + inputs=[input_node], + ) + context = _execution_context( + tmp_path, + {"input": _native_flink_value(["driver_id", "conv_rate"])}, + ) + + result = node.execute(context) + + assert result.data is native_result + assert result.metadata["columns"] == ["driver_id", "conv_rate"] + + +def test_flink_validation_node_raises_for_missing_columns(tmp_path: Path) -> None: + input_node = InputNode("input") + node = FlinkValidationNode( + "validate", + expected_columns={"missing_feature": pa.float32()}, + json_columns=set(), + table_env=FakeTableEnvironment(), + split_num=1, + inputs=[input_node], + ) + context = _execution_context( + tmp_path, + {"input": _flink_value(pd.DataFrame({"driver_id": [1]}))}, + ) + + with pytest.raises(ValueError, match="Missing expected columns"): + node.execute(context) + + +@pytest.mark.integration +@pytest.mark.slow +def test_flink_compute_engine_executes_with_real_pyflink_when_installed( + tmp_path: Path, +) -> None: + pyflink_table = pytest.importorskip( + "pyflink.table", reason="PyFlink is required for this runtime smoke test" + ) + entity = _driver() + source = _source() + feature_view = _feature_view(source, online=True, offline=True) + config = _repo_config(tmp_path, {"type": "flink.engine"}) + offline_store = _offline_store(_feature_data()) + online_store = MagicMock() + table_env = pyflink_table.TableEnvironment.create( + pyflink_table.EnvironmentSettings.new_instance().in_batch_mode().build() + ) + engine = FlinkComputeEngine( + repo_config=config, + offline_store=offline_store, + online_store=online_store, + table_environment=table_env, + ) + task = HistoricalRetrievalTask( + project=config.project, + entity_df=pd.DataFrame( + { + "driver_id": [1, 1, 2], + "event_timestamp": [ + datetime(2024, 1, 1, 9, 30, 0), + datetime(2024, 1, 1, 10, 30, 0), + datetime(2024, 1, 1, 10, 30, 0), + ], + } + ), + feature_view=feature_view, + full_feature_name=False, + registry=_registry(entity), + ) + + result = engine.get_historical_features(_registry(entity), task).to_df() + result = result.sort_values(["driver_id", ENTITY_TS_ALIAS]).reset_index(drop=True) + + assert result["conv_rate"].tolist() == [0.1, 0.2, 0.3] + assert ENTITY_ROW_ID not in result.columns + online_store.online_write_batch.assert_not_called() + + materialization_task = MaterializationTask( + project=config.project, + feature_view=feature_view, + start_time=datetime(2024, 1, 1), + end_time=datetime(2024, 1, 2), + ) + + jobs = engine.materialize(_registry(entity), materialization_task) + + assert jobs[0].status() == MaterializationJobStatus.SUCCEEDED + assert jobs[0].error() is None + online_store.online_write_batch.assert_called_once() + offline_store.offline_write_batch.assert_called_once()