From 2cee78eecb72be73854fc1bf8750177cccb1a43c Mon Sep 17 00:00:00 2001 From: abhijeet-dhumal Date: Tue, 5 May 2026 19:42:28 +0530 Subject: [PATCH 1/7] fix(spark): replace mapInArrow with foreachPartition and fix session config forwarding for vector store materialization Signed-off-by: abhijeet-dhumal --- .../infra/compute_engines/spark/compute.py | 3 +- .../infra/compute_engines/spark/nodes.py | 18 +- .../infra/compute_engines/spark/utils.py | 190 ++++++++++++++++-- .../contrib/spark_offline_store/spark.py | 61 ------ .../tests/component/spark/test_compute.py | 153 ++++++++++++-- 5 files changed, 315 insertions(+), 110 deletions(-) diff --git a/sdk/python/feast/infra/compute_engines/spark/compute.py b/sdk/python/feast/infra/compute_engines/spark/compute.py index b6c7dc30d55..b32df0a2560 100644 --- a/sdk/python/feast/infra/compute_engines/spark/compute.py +++ b/sdk/python/feast/infra/compute_engines/spark/compute.py @@ -81,7 +81,8 @@ def teardown_infra( def _get_feature_view_spark_session( self, feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView] ) -> SparkSession: - spark_conf = self._get_feature_view_engine_config(feature_view) + config = self._get_feature_view_engine_config(feature_view) + spark_conf = config.get("spark_conf", config) return get_or_create_new_spark_session(spark_conf) def _materialize_one( diff --git a/sdk/python/feast/infra/compute_engines/spark/nodes.py b/sdk/python/feast/infra/compute_engines/spark/nodes.py index 5a8c4368fc5..1391585ad58 100644 --- a/sdk/python/feast/infra/compute_engines/spark/nodes.py +++ b/sdk/python/feast/infra/compute_engines/spark/nodes.py @@ -32,7 +32,10 @@ from feast.infra.compute_engines.dag.model import DAGFormat from feast.infra.compute_engines.dag.node import DAGNode from feast.infra.compute_engines.dag.value import DAGValue -from feast.infra.compute_engines.spark.utils import map_in_arrow +from feast.infra.compute_engines.spark.utils import ( + write_to_offline_store, + write_to_online_store, +) from feast.infra.compute_engines.utils import ( create_offline_store_retrieval_job, ) @@ -572,21 +575,12 @@ def execute(self, context: ExecutionContext) -> DAGValue: feature_view=self.feature_view, repo_config=context.repo_config ) - # โœ… 1. Write to online store if online enabled if self.feature_view.online: - spark_df.mapInArrow( - lambda x: map_in_arrow(x, serialized_artifacts, mode="online"), - spark_df.schema, - ).count() + write_to_online_store(spark_df, serialized_artifacts) - # โœ… 2. Write to offline store if offline enabled if self.feature_view.offline: if not isinstance(self.feature_view.batch_source, SparkSource): - spark_df.mapInArrow( - lambda x: map_in_arrow(x, serialized_artifacts, mode="offline"), - spark_df.schema, - ).count() - # Directly write spark df to spark offline store without using mapInArrow + write_to_offline_store(spark_df, serialized_artifacts) else: dest_path = self.feature_view.batch_source.path file_format = self.feature_view.batch_source.file_format diff --git a/sdk/python/feast/infra/compute_engines/spark/utils.py b/sdk/python/feast/infra/compute_engines/spark/utils.py index 8c84c9f17a6..970747623aa 100644 --- a/sdk/python/feast/infra/compute_engines/spark/utils.py +++ b/sdk/python/feast/infra/compute_engines/spark/utils.py @@ -1,6 +1,6 @@ import logging import os -from typing import Dict, Iterable, Literal, Optional +from typing import TYPE_CHECKING, Dict, Iterable, Literal, Optional import pandas as pd import pyarrow @@ -18,21 +18,16 @@ boto3 = None # type: ignore[assignment] BotoConfig = None # type: ignore[assignment,misc] +if TYPE_CHECKING: + from pyspark.sql import DataFrame + logger = logging.getLogger(__name__) def _ensure_s3a_event_log_dir(spark_config: Dict[str, str]) -> None: - """Pre-create the S3A event log prefix before SparkContext initialisation. - - Spark's EventLogFileWriter.requireLogBaseDirAsDirectory() is called inside - SparkContext.__init__ and crashes if the S3A path doesn't exist yet (S3 has no - real directories, so an empty prefix returns a 404). This function writes a - zero-byte placeholder so the prefix exists before SparkContext is built. + """Pre-create an S3A event-log prefix so SparkContext.__init__ doesn't 404. - This is only attempted when: - - spark.eventLog.enabled == "true" - - spark.eventLog.dir starts with "s3a://" - Failures are non-fatal: Spark will surface its own error if the dir is still missing. + Only acts when eventLog is enabled with an s3a:// path. Non-fatal on failure. """ if spark_config.get("spark.eventLog.enabled", "false").lower() != "true": return @@ -121,6 +116,18 @@ def get_or_create_new_spark_session( ) spark_session = spark_builder.getOrCreate() + + # getOrCreate() silently drops new configs on a reused session. + # Re-apply spark.sql.* and spark.hadoop.* which are safe to set post-creation. + if spark_config: + _RUNTIME_PREFIXES = ("spark.sql.", "spark.hadoop.") + for k, v in spark_config.items(): + if any(k.startswith(p) for p in _RUNTIME_PREFIXES): + try: + spark_session.conf.set(k, v) + except Exception as e: + logger.debug("Could not set runtime config %s: %s", k, e) + spark_session.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true") return spark_session @@ -146,7 +153,9 @@ def map_in_arrow( for entity in feature_view.entity_columns } - batch_size = repo_config.materialization_config.online_write_batch_size + batch_size = getattr( + repo_config.materialization_config, "online_write_batch_size", None + ) # Single batch if None (backward compatible), otherwise use configured batch_size sub_batches = ( [table] @@ -202,7 +211,9 @@ def map_in_pandas(iterator, serialized_artifacts: SerializedArtifacts): for entity in feature_view.entity_columns } - batch_size = repo_config.materialization_config.online_write_batch_size + batch_size = getattr( + repo_config.materialization_config, "online_write_batch_size", None + ) # Single batch if None (backward compatible), otherwise use configured batch_size sub_batches = ( [table] @@ -220,6 +231,153 @@ def map_in_pandas(iterator, serialized_artifacts: SerializedArtifacts): lambda x: None, ) - yield pd.DataFrame( - [pd.Series(range(1, 2))] - ) # dummy result because mapInPandas needs to return something + yield pd.DataFrame({"status": [0]}) + + +def write_to_online_store( + spark_df: "DataFrame", + serialized_artifacts: SerializedArtifacts, +) -> None: + """Write a Spark DataFrame to the online store via foreachPartition. + + Uses foreachPartition instead of mapInArrow to avoid a Spark 3.5 + serialiser mismatch (ArrowStreamPandasUDFSerializer vs ArrowStreamUDFSerializer) + when WindowGroupLimitExec precedes MapInArrowExec. + """ + from pyspark.sql.pandas.types import to_arrow_schema + + df_schema = spark_df.schema + + def _write_partition(rows): # type: ignore[type-arg] + rows_list = list(rows) + if not rows_list: + return + + import pyarrow as pa + + from feast.utils import _convert_arrow_to_proto + + pdf = pd.DataFrame([r.asDict(recursive=True) for r in rows_list]) + table = pa.Table.from_pandas( + pdf, schema=to_arrow_schema(df_schema), preserve_index=False + ) + + ( + feature_view, + online_store, + _, + repo_config, + ) = serialized_artifacts.unserialize() + + join_key_to_value_type = { + entity.name: entity.dtype.to_value_type() + for entity in feature_view.entity_columns + } + + batch_size = getattr( + repo_config.materialization_config, "online_write_batch_size", None + ) + if batch_size is None: + sub_tables = [table] + else: + sub_tables = [ + table.slice(offset, min(batch_size, len(table) - offset)) + for offset in range(0, len(table), batch_size) + ] + + for sub_table in sub_tables: + online_store.online_write_batch( + config=repo_config, + table=feature_view, + data=_convert_arrow_to_proto( + sub_table, feature_view, join_key_to_value_type + ), + progress=lambda x: None, + ) + + spark_df.foreachPartition(_write_partition) + + +def write_to_offline_store( + spark_df: "DataFrame", + serialized_artifacts: SerializedArtifacts, +) -> None: + """Write a Spark DataFrame to the offline store via foreachPartition. + + Same Spark 3.5 serialiser workaround as ``write_to_online_store``. + """ + from pyspark.sql.pandas.types import to_arrow_schema + + df_schema = spark_df.schema + + def _write_partition(rows): # type: ignore[type-arg] + rows_list = list(rows) + if not rows_list: + return + + import pyarrow as pa + + pdf = pd.DataFrame([r.asDict(recursive=True) for r in rows_list]) + table = pa.Table.from_pandas( + pdf, schema=to_arrow_schema(df_schema), preserve_index=False + ) + + ( + feature_view, + _, + offline_store, + repo_config, + ) = serialized_artifacts.unserialize() + + offline_store.offline_write_batch( + config=repo_config, + feature_view=feature_view, + table=table, + progress=lambda x: None, + ) + + spark_df.foreachPartition(_write_partition) + + +_FEAST_EMBED_MODEL_CACHE: Dict[tuple, object] = {} + + +def spark_embed( + df: "DataFrame", + text_col: str, + model: str = "sentence-transformers/all-MiniLM-L6-v2", + output_col: str = "embedding", + batch_size: int = 64, +) -> "DataFrame": + """Append an embedding column to *df* using a sentence-transformer. + + Intended for ``@batch_feature_view`` with ``TransformationMode.PYTHON``. + Uses ``localCheckpoint(eager=True)`` to sever Python lineage and avoid + downstream Arrow serialiser mismatches. Model is cached per executor. + """ + import pyspark.sql.functions as F + import pyspark.sql.types as T + from pyspark.sql.functions import pandas_udf + + model_id = model + bs = batch_size + _cache = _FEAST_EMBED_MODEL_CACHE + + @pandas_udf(T.ArrayType(T.FloatType())) + def _embed_udf(texts: pd.Series) -> pd.Series: + import torch + from sentence_transformers import SentenceTransformer + + device = "cuda" if torch.cuda.is_available() else "cpu" + cache_key = (model_id, device) + if cache_key not in _cache: + _cache[cache_key] = SentenceTransformer(model_id, device=device) + sent_model = _cache[cache_key] + + embeddings = sent_model.encode( + texts.tolist(), batch_size=bs, show_progress_bar=False + ) + return pd.Series([e.astype("float32").tolist() for e in embeddings]) + + embedded = df.withColumn(output_col, _embed_udf(F.col(text_col))) + return embedded.localCheckpoint(eager=True) diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py index 3fc675ea402..bede2a6f44c 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py @@ -33,7 +33,6 @@ from pyspark.sql import SparkSession from feast import FeatureView, OnDemandFeatureView -from feast.batch_feature_view import BatchFeatureView from feast.data_source import DataSource from feast.dataframe import DataFrameEngine, FeastDataFrame from feast.errors import EntitySQLEmptyResults, InvalidEntityType @@ -261,10 +260,6 @@ def get_historical_features( entity_df_event_timestamp_range, ) - query_context = _apply_bfv_transformations( - spark_session, feature_views, query_context - ) - spark_query_context = [ SparkFeatureViewQueryContext( **asdict(context), @@ -718,62 +713,6 @@ def _entity_schema_keys_from( ) -def _apply_bfv_transformations( - spark_session: SparkSession, - feature_views: List[FeatureView], - query_contexts: List[offline_utils.FeatureViewQueryContext], -) -> List[offline_utils.FeatureViewQueryContext]: - """ - For BatchFeatureViews with a UDF, read the raw source into a Spark DataFrame, - invoke the transformation, register the result as a temp view, and replace the - table_subquery in the query context so the PIT join reads transformed data. - """ - from dataclasses import replace - - from feast.feature_view_utils import ( - get_transformation_function, - has_transformation, - resolve_feature_view_source_with_fallback, - ) - - fv_by_name = {fv.projection.name_to_use(): fv for fv in feature_views} - - updated_contexts = [] - for ctx in query_contexts: - fv = fv_by_name.get(ctx.name) - if ( - fv is not None - and isinstance(fv, BatchFeatureView) - and has_transformation(fv) - ): - udf = get_transformation_function(fv) - if udf is not None: - source_info = resolve_feature_view_source_with_fallback(fv) - source_query = source_info.data_source.get_table_query_string() - - timestamp_filter = get_timestamp_filter_sql( - start_date=ctx.min_event_timestamp, - end_date=ctx.max_event_timestamp, - timestamp_field=ctx.timestamp_field, - tz=timezone.utc, - quote_fields=False, - ) - source_df = spark_session.sql( - f"SELECT * FROM {source_query} WHERE {timestamp_filter}" - ) - - transformed_df = udf(source_df) - - tmp_view_name = "feast_bfv_" + uuid.uuid4().hex - transformed_df.createOrReplaceTempView(tmp_view_name) - - ctx = replace(ctx, table_subquery=tmp_view_name) - - updated_contexts.append(ctx) - - return updated_contexts - - def _get_entity_df_event_timestamp_range( entity_df: Union[pd.DataFrame, str], entity_df_event_timestamp_col: str, diff --git a/sdk/python/tests/component/spark/test_compute.py b/sdk/python/tests/component/spark/test_compute.py index 803cd505513..073d995b129 100644 --- a/sdk/python/tests/component/spark/test_compute.py +++ b/sdk/python/tests/component/spark/test_compute.py @@ -1,6 +1,6 @@ from datetime import timedelta from typing import cast -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest from pyspark.sql import DataFrame @@ -15,6 +15,12 @@ from feast.infra.common.retrieval_task import HistoricalRetrievalTask from feast.infra.compute_engines.spark.compute import SparkComputeEngine from feast.infra.compute_engines.spark.job import SparkDAGRetrievalJob +from feast.infra.compute_engines.spark.utils import ( + _ensure_s3a_event_log_dir, + get_or_create_new_spark_session, + map_in_pandas, + write_to_online_store, +) from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( SparkOfflineStore, ) @@ -69,7 +75,6 @@ def transform_feature(df: DataFrame) -> DataFrame: try: fs.apply([driver, driver_stats_fv]) - # ๐Ÿ›  Build retrieval task task = HistoricalRetrievalTask( project=spark_environment.project, entity_df=entity_df, @@ -78,7 +83,6 @@ def transform_feature(df: DataFrame) -> DataFrame: registry=registry, ) - # ๐Ÿงช Run SparkComputeEngine engine = SparkComputeEngine( repo_config=spark_environment.config, offline_store=SparkOfflineStore(), @@ -89,7 +93,6 @@ def transform_feature(df: DataFrame) -> DataFrame: spark_df = cast(SparkDAGRetrievalJob, spark_dag_retrieval_job).to_spark_df() df_out = spark_df.orderBy("driver_id").toPandas() - # โœ… Assert output assert df_out.driver_id.to_list() == [1001, 1002] assert abs(df_out["sum_conv_rate"].to_list()[0] - 3.1) < 1e-6 assert abs(df_out["sum_conv_rate"].to_list()[1] - 2.0) < 1e-6 @@ -102,20 +105,7 @@ def transform_feature(df: DataFrame) -> DataFrame: @pytest.mark.integration def test_spark_compute_engine_materialize(): - """ - Test the SparkComputeEngine materialize method. - For the current feature view driver_hourly_stats, The below execution plan: - 1. feature data from create_feature_dataset - 2. filter by start_time and end_time, that is, the last 2 days - for the driver_id 1001, the data left is row 0 - for the driver_id 1002, the data left is row 2 - 3. apply the transform_feature function to the data - for all features, the value is multiplied by 2 - 4. write the data to the online store and offline store - - Returns: - - """ + """Materialize with BFV transform (2x multiply), verify online + offline writes.""" spark_environment = create_spark_environment() fs = spark_environment.feature_store registry = fs.registry @@ -150,7 +140,6 @@ def tqdm_builder(length): try: fs.apply([driver, driver_stats_fv]) - # ๐Ÿ›  Build retrieval task task = MaterializationTask( project=spark_environment.project, feature_view=driver_stats_fv, @@ -159,7 +148,6 @@ def tqdm_builder(length): tqdm_builder=tqdm_builder, ) - # ๐Ÿงช Run SparkComputeEngine engine = SparkComputeEngine( repo_config=spark_environment.config, offline_store=SparkOfflineStore(), @@ -192,5 +180,130 @@ def tqdm_builder(length): spark_environment.teardown() +def _base_conf(event_log_dir: str) -> dict: + return { + "spark.eventLog.enabled": "true", + "spark.eventLog.dir": event_log_dir, + "spark.hadoop.fs.s3a.endpoint": "http://minio:9000", + } + + +@patch("feast.infra.compute_engines.spark.utils.boto3") +def test_ensure_s3a_event_log_dir_creates_placeholder_when_empty(mock_boto3): + s3 = MagicMock() + mock_boto3.client.return_value = s3 + s3.list_objects_v2.return_value = {"KeyCount": 0} + + _ensure_s3a_event_log_dir(_base_conf("s3a://my-bucket/spark-events/")) + + s3.list_objects_v2.assert_called_once_with( + Bucket="my-bucket", Prefix="spark-events/", MaxKeys=1 + ) + s3.put_object.assert_called_once_with( + Bucket="my-bucket", Key="spark-events/.keep", Body=b"" + ) + + +@patch("feast.infra.compute_engines.spark.utils.boto3") +def test_ensure_s3a_event_log_dir_skips_when_prefix_exists(mock_boto3): + s3 = MagicMock() + mock_boto3.client.return_value = s3 + s3.list_objects_v2.return_value = {"KeyCount": 3} + + _ensure_s3a_event_log_dir(_base_conf("s3a://my-bucket/spark-events/")) + + s3.put_object.assert_not_called() + + +@patch("feast.infra.compute_engines.spark.utils.boto3") +def test_ensure_s3a_event_log_dir_noop_when_event_log_disabled(mock_boto3): + _ensure_s3a_event_log_dir( + {"spark.eventLog.enabled": "false", "spark.eventLog.dir": "s3a://b/p/"} + ) + mock_boto3.client.assert_not_called() + + +@patch("feast.infra.compute_engines.spark.utils.boto3") +def test_ensure_s3a_event_log_dir_noop_for_non_s3a_path(mock_boto3): + _ensure_s3a_event_log_dir( + {"spark.eventLog.enabled": "true", "spark.eventLog.dir": "hdfs:///spark-logs"} + ) + mock_boto3.client.assert_not_called() + + +@patch("feast.infra.compute_engines.spark.utils.boto3") +def test_ensure_s3a_event_log_dir_non_fatal_on_s3_error(mock_boto3): + s3 = MagicMock() + mock_boto3.client.return_value = s3 + s3.list_objects_v2.side_effect = Exception("connection refused") + + _ensure_s3a_event_log_dir(_base_conf("s3a://my-bucket/spark-events/")) + + +def test_get_or_create_applies_sql_configs_to_reused_session(): + """SQL/Hadoop configs must be forwarded even when a SparkSession already exists.""" + mock_session = MagicMock() + spark_config = { + "spark.sql.sources.useV1SourceList": "avro", + "spark.hadoop.fs.s3a.endpoint": "http://minio:9000", + "spark.executor.instances": "2", + } + + with patch( + "feast.infra.compute_engines.spark.utils.SparkSession" + ) as mock_spark_cls: + mock_spark_cls.getActiveSession.return_value = mock_session + result = get_or_create_new_spark_session(spark_config) + + assert result is mock_session + set_calls = {call.args[0] for call in mock_session.conf.set.call_args_list} + assert "spark.sql.sources.useV1SourceList" in set_calls + assert "spark.hadoop.fs.s3a.endpoint" in set_calls + assert "spark.executor.instances" not in set_calls + + +def test_map_in_pandas_dummy_yield_has_correct_schema(): + """map_in_pandas must yield a DataFrame with column 'status', not column '0'.""" + import pandas as pd + + batches = list(map_in_pandas(iter([]), MagicMock())) + assert len(batches) == 1 + df = batches[0] + assert isinstance(df, pd.DataFrame) + assert list(df.columns) == ["status"] + assert df["status"].iloc[0] == 0 + + +def test_write_to_online_store_skips_empty_partitions(): + from pyspark.sql import SparkSession + from pyspark.sql.types import FloatType, StringType, StructField, StructType + + spark = SparkSession.builder.master("local[1]").appName("test_write").getOrCreate() + schema = StructType( + [ + StructField("review_id", StringType(), True), + StructField("val", FloatType(), True), + ] + ) + df = spark.createDataFrame([], schema=schema) + + mock_artifacts = MagicMock() + mock_online_store = MagicMock() + mock_fv = MagicMock() + mock_fv.entity_columns = [] + mock_config = MagicMock() + mock_config.materialization_config.online_write_batch_size = None + mock_artifacts.unserialize.return_value = ( + mock_fv, + mock_online_store, + MagicMock(), + mock_config, + ) + + write_to_online_store(df, mock_artifacts) + mock_online_store.online_write_batch.assert_not_called() + spark.stop() + + if __name__ == "__main__": test_spark_compute_engine_get_historical_features() From e69c62316158586c2b9ce1b5a317c68d17eae168 Mon Sep 17 00:00:00 2001 From: abhijeet-dhumal Date: Wed, 6 May 2026 14:45:31 +0530 Subject: [PATCH 2/7] fix: resolve linting and ruff check issues Signed-off-by: abhijeet-dhumal --- sdk/python/feast/infra/compute_engines/spark/utils.py | 4 ++-- sdk/python/tests/component/spark/test_compute.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/sdk/python/feast/infra/compute_engines/spark/utils.py b/sdk/python/feast/infra/compute_engines/spark/utils.py index 970747623aa..e7615ce9c7f 100644 --- a/sdk/python/feast/infra/compute_engines/spark/utils.py +++ b/sdk/python/feast/infra/compute_engines/spark/utils.py @@ -1,6 +1,6 @@ import logging import os -from typing import TYPE_CHECKING, Dict, Iterable, Literal, Optional +from typing import TYPE_CHECKING, Any, Dict, Iterable, Literal, Optional import pandas as pd import pyarrow @@ -339,7 +339,7 @@ def _write_partition(rows): # type: ignore[type-arg] spark_df.foreachPartition(_write_partition) -_FEAST_EMBED_MODEL_CACHE: Dict[tuple, object] = {} +_FEAST_EMBED_MODEL_CACHE: Dict[tuple, "Any"] = {} def spark_embed( diff --git a/sdk/python/tests/component/spark/test_compute.py b/sdk/python/tests/component/spark/test_compute.py index 073d995b129..e65c2e16383 100644 --- a/sdk/python/tests/component/spark/test_compute.py +++ b/sdk/python/tests/component/spark/test_compute.py @@ -274,6 +274,7 @@ def test_map_in_pandas_dummy_yield_has_correct_schema(): assert df["status"].iloc[0] == 0 +@pytest.mark.integration def test_write_to_online_store_skips_empty_partitions(): from pyspark.sql import SparkSession from pyspark.sql.types import FloatType, StringType, StructField, StructType From 893db923dc8587ac8e4ae04a42d8b72c8953dc13 Mon Sep 17 00:00:00 2001 From: abhijeet-dhumal Date: Thu, 7 May 2026 11:31:24 +0530 Subject: [PATCH 3/7] feat: read from offline path in get_historical_features for BFVs --- .../contrib/spark_offline_store/spark.py | 72 +++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py index bede2a6f44c..ccae3c041a3 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py @@ -78,6 +78,72 @@ class SparkFeatureViewQueryContext(offline_utils.FeatureViewQueryContext): max_date_partition: str +def _apply_bfv_transformations_for_historical( + spark_session: SparkSession, + feature_views: List[FeatureView], + query_context: List[offline_utils.FeatureViewQueryContext], +) -> List[offline_utils.FeatureViewQueryContext]: + """ + For BatchFeatureViews, redirect get_historical_features to read from the + pre-materialized offline store (batch_source.path) when available, avoiding + expensive UDF re-execution on raw data. + + Precedence: + 1. offline=True + batch_source.path set -> read pre-computed parquet + 2. Python/pandas UDF present -> execute UDF on raw source (fallback) + 3. Otherwise -> pass through unchanged + """ + from dataclasses import replace + + fv_by_name = {fv.projection.name_to_use(): fv for fv in feature_views} + new_contexts = [] + + for ctx in query_context: + fv = fv_by_name.get(ctx.name) + if fv is None or not isinstance(fv, BatchFeatureView): + new_contexts.append(ctx) + continue + + if ( + getattr(fv, "offline", False) + and isinstance(fv.batch_source, SparkSource) + and fv.batch_source.path + ): + tmp_view = f"__feast_offline_{ctx.name}_{uuid.uuid4().hex[:8]}" + file_format = fv.batch_source.file_format or "parquet" + df = spark_session.read.format(file_format).load(fv.batch_source.path) + df.createOrReplaceTempView(tmp_view) + ctx = replace(ctx, table_subquery=tmp_view) + elif ( + hasattr(fv, "feature_transformation") + and fv.feature_transformation is not None + and ( + getattr(fv.feature_transformation, "mode", None) + in ("python", "pandas") + or getattr( + getattr(fv.feature_transformation, "mode", None), "value", None + ) + in ("python", "pandas") + ) + ): + udf = getattr(fv.feature_transformation, "udf", None) or getattr( + fv, "udf", None + ) + if udf is not None: + temp_view_name = f"__feast_bfv_{ctx.name}_{uuid.uuid4().hex[:8]}" + spark_session.conf.set("spark.sql.runSQLOnFiles", "true") + raw_df = spark_session.sql( + f"SELECT * FROM {ctx.table_subquery}" + ) + transformed_df = udf(raw_df) + transformed_df.createOrReplaceTempView(temp_view_name) + ctx = replace(ctx, table_subquery=temp_view_name) + + new_contexts.append(ctx) + + return new_contexts + + class SparkOfflineStore(OfflineStore): @staticmethod def pull_latest_from_table_or_query( @@ -260,6 +326,12 @@ def get_historical_features( entity_df_event_timestamp_range, ) + query_context = _apply_bfv_transformations_for_historical( + spark_session=spark_session, + feature_views=feature_views, + query_context=query_context, + ) + spark_query_context = [ SparkFeatureViewQueryContext( **asdict(context), From 165fbe21b93b7293cc1a7a0fa69021802aaff2a9 Mon Sep 17 00:00:00 2001 From: abhijeet-dhumal Date: Thu, 7 May 2026 12:46:27 +0530 Subject: [PATCH 4/7] feat: allow query + path in SparkSource for offline materialization SparkSource previously required exactly one of table/query/path. This relaxes the constraint to allow query + path together: - query: used for reading raw data during materialization - path: used for offline write-back (offline=True) and as pre-computed read source in get_historical_features Co-authored-by: Cursor --- .../contrib/spark_offline_store/spark_source.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py index cd41921e56a..03c65e28b1f 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py @@ -287,11 +287,19 @@ def __init__( date_partition_column_format: Optional[str] = "%Y-%m-%d", table_format: Optional[TableFormat] = None, ): - # Check that only one of the ways to load a spark dataframe can be used. We have - # to treat empty string and null the same due to proto (de)serialization. - if sum([(not (not arg)) for arg in [table, query, path]]) != 1: + # query + path is allowed: query for reads during materialization, + # path for offline write-back (offline=True) and get_historical_features. + # table must be standalone (cannot combine with query or path). + has_table = bool(table) + has_query = bool(query) + has_path = bool(path) + if has_table and (has_query or has_path): raise ValueError( - "Exactly one of params(table, query, path) must be specified." + "'table' cannot be combined with 'query' or 'path'." + ) + if not (has_table or has_query or has_path): + raise ValueError( + "At least one of params(table, query, path) must be specified." ) if path: # If table_format is specified, file_format is optional (table format determines the reader) From 6ebcde17d6e21799236140047d12f4b56a962fde Mon Sep 17 00:00:00 2001 From: abhijeet-dhumal Date: Thu, 7 May 2026 13:29:59 +0530 Subject: [PATCH 5/7] fix: add missing BatchFeatureView import in spark offline store Co-authored-by: Cursor --- .../infra/offline_stores/contrib/spark_offline_store/spark.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py index ccae3c041a3..c55653871ed 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py @@ -33,6 +33,7 @@ from pyspark.sql import SparkSession from feast import FeatureView, OnDemandFeatureView +from feast.batch_feature_view import BatchFeatureView from feast.data_source import DataSource from feast.dataframe import DataFrameEngine, FeastDataFrame from feast.errors import EntitySQLEmptyResults, InvalidEntityType From c80d43661bfd8afe90459914a5a2566728760e6e Mon Sep 17 00:00:00 2001 From: abhijeet-dhumal Date: Thu, 7 May 2026 13:32:45 +0530 Subject: [PATCH 6/7] fix: graceful fallback when offline path is not readable Co-authored-by: Cursor --- .../contrib/spark_offline_store/spark.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py index c55653871ed..7dcffe317b3 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py @@ -112,10 +112,21 @@ def _apply_bfv_transformations_for_historical( ): tmp_view = f"__feast_offline_{ctx.name}_{uuid.uuid4().hex[:8]}" file_format = fv.batch_source.file_format or "parquet" - df = spark_session.read.format(file_format).load(fv.batch_source.path) - df.createOrReplaceTempView(tmp_view) - ctx = replace(ctx, table_subquery=tmp_view) - elif ( + try: + df = spark_session.read.format(file_format).load(fv.batch_source.path) + df.createOrReplaceTempView(tmp_view) + ctx = replace(ctx, table_subquery=tmp_view) + new_contexts.append(ctx) + continue + except Exception: + warnings.warn( + f"Offline path '{fv.batch_source.path}' not readable for " + f"'{ctx.name}'; falling back to source query.", + RuntimeWarning, + stacklevel=2, + ) + + if ( hasattr(fv, "feature_transformation") and fv.feature_transformation is not None and ( From 3d1c3640859061eaa62d297320973c0c38ddf0b2 Mon Sep 17 00:00:00 2001 From: abhijeet-dhumal Date: Thu, 7 May 2026 20:05:33 +0530 Subject: [PATCH 7/7] =?UTF-8?q?fix:=20chunk=20foreachPartition=20writes=20?= =?UTF-8?q?to=20bound=20Python=20memory=20=E2=80=94=20prevents=20MemoryErr?= =?UTF-8?q?or/OOMKill=20on=20large=20feature=20views?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: abhijeet-dhumal --- .../infra/compute_engines/spark/utils.py | 76 ++++++++++--------- 1 file changed, 41 insertions(+), 35 deletions(-) diff --git a/sdk/python/feast/infra/compute_engines/spark/utils.py b/sdk/python/feast/infra/compute_engines/spark/utils.py index e7615ce9c7f..36dbb2bd9ee 100644 --- a/sdk/python/feast/infra/compute_engines/spark/utils.py +++ b/sdk/python/feast/infra/compute_engines/spark/utils.py @@ -243,25 +243,24 @@ def write_to_online_store( Uses foreachPartition instead of mapInArrow to avoid a Spark 3.5 serialiser mismatch (ArrowStreamPandasUDFSerializer vs ArrowStreamUDFSerializer) when WindowGroupLimitExec precedes MapInArrowExec. + + Rows are consumed in fixed-size chunks (default 5 000) so that Python + memory stays bounded regardless of partition size. Previous behaviour + called ``list(rows)`` which materialised the entire partition and caused + ``MemoryError`` / OOMKill on large feature views. """ + from itertools import islice + from pyspark.sql.pandas.types import to_arrow_schema df_schema = spark_df.schema + _CHUNK = 5_000 def _write_partition(rows): # type: ignore[type-arg] - rows_list = list(rows) - if not rows_list: - return - import pyarrow as pa from feast.utils import _convert_arrow_to_proto - pdf = pd.DataFrame([r.asDict(recursive=True) for r in rows_list]) - table = pa.Table.from_pandas( - pdf, schema=to_arrow_schema(df_schema), preserve_index=False - ) - ( feature_view, online_store, @@ -273,24 +272,26 @@ def _write_partition(rows): # type: ignore[type-arg] entity.name: entity.dtype.to_value_type() for entity in feature_view.entity_columns } + arrow_schema = to_arrow_schema(df_schema) batch_size = getattr( repo_config.materialization_config, "online_write_batch_size", None ) - if batch_size is None: - sub_tables = [table] - else: - sub_tables = [ - table.slice(offset, min(batch_size, len(table) - offset)) - for offset in range(0, len(table), batch_size) - ] - - for sub_table in sub_tables: + write_size = batch_size or _CHUNK + + while True: + chunk = list(islice(rows, write_size)) + if not chunk: + break + pdf = pd.DataFrame([r.asDict(recursive=True) for r in chunk]) + table = pa.Table.from_pandas( + pdf, schema=arrow_schema, preserve_index=False + ) online_store.online_write_batch( config=repo_config, table=feature_view, data=_convert_arrow_to_proto( - sub_table, feature_view, join_key_to_value_type + table, feature_view, join_key_to_value_type ), progress=lambda x: None, ) @@ -304,24 +305,19 @@ def write_to_offline_store( ) -> None: """Write a Spark DataFrame to the offline store via foreachPartition. - Same Spark 3.5 serialiser workaround as ``write_to_online_store``. + Same Spark 3.5 serialiser workaround and chunked-iterator pattern as + ``write_to_online_store``. """ + from itertools import islice + from pyspark.sql.pandas.types import to_arrow_schema df_schema = spark_df.schema + _CHUNK = 5_000 def _write_partition(rows): # type: ignore[type-arg] - rows_list = list(rows) - if not rows_list: - return - import pyarrow as pa - pdf = pd.DataFrame([r.asDict(recursive=True) for r in rows_list]) - table = pa.Table.from_pandas( - pdf, schema=to_arrow_schema(df_schema), preserve_index=False - ) - ( feature_view, _, @@ -329,12 +325,22 @@ def _write_partition(rows): # type: ignore[type-arg] repo_config, ) = serialized_artifacts.unserialize() - offline_store.offline_write_batch( - config=repo_config, - feature_view=feature_view, - table=table, - progress=lambda x: None, - ) + arrow_schema = to_arrow_schema(df_schema) + + while True: + chunk = list(islice(rows, _CHUNK)) + if not chunk: + break + pdf = pd.DataFrame([r.asDict(recursive=True) for r in chunk]) + table = pa.Table.from_pandas( + pdf, schema=arrow_schema, preserve_index=False + ) + offline_store.offline_write_batch( + config=repo_config, + feature_view=feature_view, + table=table, + progress=lambda x: None, + ) spark_df.foreachPartition(_write_partition)