Skip to content

Commit 2cee78e

Browse files
fix(spark): replace mapInArrow with foreachPartition and fix session config forwarding for vector store materialization
Signed-off-by: abhijeet-dhumal <abhijeetdhumal652@gmail.com>
1 parent 50ad181 commit 2cee78e

5 files changed

Lines changed: 315 additions & 110 deletions

File tree

sdk/python/feast/infra/compute_engines/spark/compute.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ def teardown_infra(
8181
def _get_feature_view_spark_session(
8282
self, feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView]
8383
) -> SparkSession:
84-
spark_conf = self._get_feature_view_engine_config(feature_view)
84+
config = self._get_feature_view_engine_config(feature_view)
85+
spark_conf = config.get("spark_conf", config)
8586
return get_or_create_new_spark_session(spark_conf)
8687

8788
def _materialize_one(

sdk/python/feast/infra/compute_engines/spark/nodes.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@
3232
from feast.infra.compute_engines.dag.model import DAGFormat
3333
from feast.infra.compute_engines.dag.node import DAGNode
3434
from feast.infra.compute_engines.dag.value import DAGValue
35-
from feast.infra.compute_engines.spark.utils import map_in_arrow
35+
from feast.infra.compute_engines.spark.utils import (
36+
write_to_offline_store,
37+
write_to_online_store,
38+
)
3639
from feast.infra.compute_engines.utils import (
3740
create_offline_store_retrieval_job,
3841
)
@@ -572,21 +575,12 @@ def execute(self, context: ExecutionContext) -> DAGValue:
572575
feature_view=self.feature_view, repo_config=context.repo_config
573576
)
574577

575-
# ✅ 1. Write to online store if online enabled
576578
if self.feature_view.online:
577-
spark_df.mapInArrow(
578-
lambda x: map_in_arrow(x, serialized_artifacts, mode="online"),
579-
spark_df.schema,
580-
).count()
579+
write_to_online_store(spark_df, serialized_artifacts)
581580

582-
# ✅ 2. Write to offline store if offline enabled
583581
if self.feature_view.offline:
584582
if not isinstance(self.feature_view.batch_source, SparkSource):
585-
spark_df.mapInArrow(
586-
lambda x: map_in_arrow(x, serialized_artifacts, mode="offline"),
587-
spark_df.schema,
588-
).count()
589-
# Directly write spark df to spark offline store without using mapInArrow
583+
write_to_offline_store(spark_df, serialized_artifacts)
590584
else:
591585
dest_path = self.feature_view.batch_source.path
592586
file_format = self.feature_view.batch_source.file_format

sdk/python/feast/infra/compute_engines/spark/utils.py

Lines changed: 174 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
import os
3-
from typing import Dict, Iterable, Literal, Optional
3+
from typing import TYPE_CHECKING, Dict, Iterable, Literal, Optional
44

55
import pandas as pd
66
import pyarrow
@@ -18,21 +18,16 @@
1818
boto3 = None # type: ignore[assignment]
1919
BotoConfig = None # type: ignore[assignment,misc]
2020

21+
if TYPE_CHECKING:
22+
from pyspark.sql import DataFrame
23+
2124
logger = logging.getLogger(__name__)
2225

2326

2427
def _ensure_s3a_event_log_dir(spark_config: Dict[str, str]) -> None:
25-
"""Pre-create the S3A event log prefix before SparkContext initialisation.
26-
27-
Spark's EventLogFileWriter.requireLogBaseDirAsDirectory() is called inside
28-
SparkContext.__init__ and crashes if the S3A path doesn't exist yet (S3 has no
29-
real directories, so an empty prefix returns a 404). This function writes a
30-
zero-byte placeholder so the prefix exists before SparkContext is built.
28+
"""Pre-create an S3A event-log prefix so SparkContext.__init__ doesn't 404.
3129
32-
This is only attempted when:
33-
- spark.eventLog.enabled == "true"
34-
- spark.eventLog.dir starts with "s3a://"
35-
Failures are non-fatal: Spark will surface its own error if the dir is still missing.
30+
Only acts when eventLog is enabled with an s3a:// path. Non-fatal on failure.
3631
"""
3732
if spark_config.get("spark.eventLog.enabled", "false").lower() != "true":
3833
return
@@ -121,6 +116,18 @@ def get_or_create_new_spark_session(
121116
)
122117

123118
spark_session = spark_builder.getOrCreate()
119+
120+
# getOrCreate() silently drops new configs on a reused session.
121+
# Re-apply spark.sql.* and spark.hadoop.* which are safe to set post-creation.
122+
if spark_config:
123+
_RUNTIME_PREFIXES = ("spark.sql.", "spark.hadoop.")
124+
for k, v in spark_config.items():
125+
if any(k.startswith(p) for p in _RUNTIME_PREFIXES):
126+
try:
127+
spark_session.conf.set(k, v)
128+
except Exception as e:
129+
logger.debug("Could not set runtime config %s: %s", k, e)
130+
124131
spark_session.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
125132
return spark_session
126133

@@ -146,7 +153,9 @@ def map_in_arrow(
146153
for entity in feature_view.entity_columns
147154
}
148155

149-
batch_size = repo_config.materialization_config.online_write_batch_size
156+
batch_size = getattr(
157+
repo_config.materialization_config, "online_write_batch_size", None
158+
)
150159
# Single batch if None (backward compatible), otherwise use configured batch_size
151160
sub_batches = (
152161
[table]
@@ -202,7 +211,9 @@ def map_in_pandas(iterator, serialized_artifacts: SerializedArtifacts):
202211
for entity in feature_view.entity_columns
203212
}
204213

205-
batch_size = repo_config.materialization_config.online_write_batch_size
214+
batch_size = getattr(
215+
repo_config.materialization_config, "online_write_batch_size", None
216+
)
206217
# Single batch if None (backward compatible), otherwise use configured batch_size
207218
sub_batches = (
208219
[table]
@@ -220,6 +231,153 @@ def map_in_pandas(iterator, serialized_artifacts: SerializedArtifacts):
220231
lambda x: None,
221232
)
222233

223-
yield pd.DataFrame(
224-
[pd.Series(range(1, 2))]
225-
) # dummy result because mapInPandas needs to return something
234+
yield pd.DataFrame({"status": [0]})
235+
236+
237+
def write_to_online_store(
238+
spark_df: "DataFrame",
239+
serialized_artifacts: SerializedArtifacts,
240+
) -> None:
241+
"""Write a Spark DataFrame to the online store via foreachPartition.
242+
243+
Uses foreachPartition instead of mapInArrow to avoid a Spark 3.5
244+
serialiser mismatch (ArrowStreamPandasUDFSerializer vs ArrowStreamUDFSerializer)
245+
when WindowGroupLimitExec precedes MapInArrowExec.
246+
"""
247+
from pyspark.sql.pandas.types import to_arrow_schema
248+
249+
df_schema = spark_df.schema
250+
251+
def _write_partition(rows): # type: ignore[type-arg]
252+
rows_list = list(rows)
253+
if not rows_list:
254+
return
255+
256+
import pyarrow as pa
257+
258+
from feast.utils import _convert_arrow_to_proto
259+
260+
pdf = pd.DataFrame([r.asDict(recursive=True) for r in rows_list])
261+
table = pa.Table.from_pandas(
262+
pdf, schema=to_arrow_schema(df_schema), preserve_index=False
263+
)
264+
265+
(
266+
feature_view,
267+
online_store,
268+
_,
269+
repo_config,
270+
) = serialized_artifacts.unserialize()
271+
272+
join_key_to_value_type = {
273+
entity.name: entity.dtype.to_value_type()
274+
for entity in feature_view.entity_columns
275+
}
276+
277+
batch_size = getattr(
278+
repo_config.materialization_config, "online_write_batch_size", None
279+
)
280+
if batch_size is None:
281+
sub_tables = [table]
282+
else:
283+
sub_tables = [
284+
table.slice(offset, min(batch_size, len(table) - offset))
285+
for offset in range(0, len(table), batch_size)
286+
]
287+
288+
for sub_table in sub_tables:
289+
online_store.online_write_batch(
290+
config=repo_config,
291+
table=feature_view,
292+
data=_convert_arrow_to_proto(
293+
sub_table, feature_view, join_key_to_value_type
294+
),
295+
progress=lambda x: None,
296+
)
297+
298+
spark_df.foreachPartition(_write_partition)
299+
300+
301+
def write_to_offline_store(
302+
spark_df: "DataFrame",
303+
serialized_artifacts: SerializedArtifacts,
304+
) -> None:
305+
"""Write a Spark DataFrame to the offline store via foreachPartition.
306+
307+
Same Spark 3.5 serialiser workaround as ``write_to_online_store``.
308+
"""
309+
from pyspark.sql.pandas.types import to_arrow_schema
310+
311+
df_schema = spark_df.schema
312+
313+
def _write_partition(rows): # type: ignore[type-arg]
314+
rows_list = list(rows)
315+
if not rows_list:
316+
return
317+
318+
import pyarrow as pa
319+
320+
pdf = pd.DataFrame([r.asDict(recursive=True) for r in rows_list])
321+
table = pa.Table.from_pandas(
322+
pdf, schema=to_arrow_schema(df_schema), preserve_index=False
323+
)
324+
325+
(
326+
feature_view,
327+
_,
328+
offline_store,
329+
repo_config,
330+
) = serialized_artifacts.unserialize()
331+
332+
offline_store.offline_write_batch(
333+
config=repo_config,
334+
feature_view=feature_view,
335+
table=table,
336+
progress=lambda x: None,
337+
)
338+
339+
spark_df.foreachPartition(_write_partition)
340+
341+
342+
_FEAST_EMBED_MODEL_CACHE: Dict[tuple, object] = {}
343+
344+
345+
def spark_embed(
346+
df: "DataFrame",
347+
text_col: str,
348+
model: str = "sentence-transformers/all-MiniLM-L6-v2",
349+
output_col: str = "embedding",
350+
batch_size: int = 64,
351+
) -> "DataFrame":
352+
"""Append an embedding column to *df* using a sentence-transformer.
353+
354+
Intended for ``@batch_feature_view`` with ``TransformationMode.PYTHON``.
355+
Uses ``localCheckpoint(eager=True)`` to sever Python lineage and avoid
356+
downstream Arrow serialiser mismatches. Model is cached per executor.
357+
"""
358+
import pyspark.sql.functions as F
359+
import pyspark.sql.types as T
360+
from pyspark.sql.functions import pandas_udf
361+
362+
model_id = model
363+
bs = batch_size
364+
_cache = _FEAST_EMBED_MODEL_CACHE
365+
366+
@pandas_udf(T.ArrayType(T.FloatType()))
367+
def _embed_udf(texts: pd.Series) -> pd.Series:
368+
import torch
369+
from sentence_transformers import SentenceTransformer
370+
371+
device = "cuda" if torch.cuda.is_available() else "cpu"
372+
cache_key = (model_id, device)
373+
if cache_key not in _cache:
374+
_cache[cache_key] = SentenceTransformer(model_id, device=device)
375+
sent_model = _cache[cache_key]
376+
377+
embeddings = sent_model.encode(
378+
texts.tolist(), batch_size=bs, show_progress_bar=False
379+
)
380+
return pd.Series([e.astype("float32").tolist() for e in embeddings])
381+
382+
embedded = df.withColumn(output_col, _embed_udf(F.col(text_col)))
383+
return embedded.localCheckpoint(eager=True)

sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py

Lines changed: 0 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from pyspark.sql import SparkSession
3434

3535
from feast import FeatureView, OnDemandFeatureView
36-
from feast.batch_feature_view import BatchFeatureView
3736
from feast.data_source import DataSource
3837
from feast.dataframe import DataFrameEngine, FeastDataFrame
3938
from feast.errors import EntitySQLEmptyResults, InvalidEntityType
@@ -261,10 +260,6 @@ def get_historical_features(
261260
entity_df_event_timestamp_range,
262261
)
263262

264-
query_context = _apply_bfv_transformations(
265-
spark_session, feature_views, query_context
266-
)
267-
268263
spark_query_context = [
269264
SparkFeatureViewQueryContext(
270265
**asdict(context),
@@ -718,62 +713,6 @@ def _entity_schema_keys_from(
718713
)
719714

720715

721-
def _apply_bfv_transformations(
722-
spark_session: SparkSession,
723-
feature_views: List[FeatureView],
724-
query_contexts: List[offline_utils.FeatureViewQueryContext],
725-
) -> List[offline_utils.FeatureViewQueryContext]:
726-
"""
727-
For BatchFeatureViews with a UDF, read the raw source into a Spark DataFrame,
728-
invoke the transformation, register the result as a temp view, and replace the
729-
table_subquery in the query context so the PIT join reads transformed data.
730-
"""
731-
from dataclasses import replace
732-
733-
from feast.feature_view_utils import (
734-
get_transformation_function,
735-
has_transformation,
736-
resolve_feature_view_source_with_fallback,
737-
)
738-
739-
fv_by_name = {fv.projection.name_to_use(): fv for fv in feature_views}
740-
741-
updated_contexts = []
742-
for ctx in query_contexts:
743-
fv = fv_by_name.get(ctx.name)
744-
if (
745-
fv is not None
746-
and isinstance(fv, BatchFeatureView)
747-
and has_transformation(fv)
748-
):
749-
udf = get_transformation_function(fv)
750-
if udf is not None:
751-
source_info = resolve_feature_view_source_with_fallback(fv)
752-
source_query = source_info.data_source.get_table_query_string()
753-
754-
timestamp_filter = get_timestamp_filter_sql(
755-
start_date=ctx.min_event_timestamp,
756-
end_date=ctx.max_event_timestamp,
757-
timestamp_field=ctx.timestamp_field,
758-
tz=timezone.utc,
759-
quote_fields=False,
760-
)
761-
source_df = spark_session.sql(
762-
f"SELECT * FROM {source_query} WHERE {timestamp_filter}"
763-
)
764-
765-
transformed_df = udf(source_df)
766-
767-
tmp_view_name = "feast_bfv_" + uuid.uuid4().hex
768-
transformed_df.createOrReplaceTempView(tmp_view_name)
769-
770-
ctx = replace(ctx, table_subquery=tmp_view_name)
771-
772-
updated_contexts.append(ctx)
773-
774-
return updated_contexts
775-
776-
777716
def _get_entity_df_event_timestamp_range(
778717
entity_df: Union[pd.DataFrame, str],
779718
entity_df_event_timestamp_col: str,

0 commit comments

Comments
 (0)