Skip to content
3 changes: 2 additions & 1 deletion sdk/python/feast/infra/compute_engines/spark/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ def teardown_infra(
def _get_feature_view_spark_session(
self, feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView]
) -> SparkSession:
spark_conf = self._get_feature_view_engine_config(feature_view)
config = self._get_feature_view_engine_config(feature_view)
spark_conf = config.get("spark_conf", config)
return get_or_create_new_spark_session(spark_conf)

def _materialize_one(
Expand Down
18 changes: 6 additions & 12 deletions sdk/python/feast/infra/compute_engines/spark/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@
from feast.infra.compute_engines.dag.model import DAGFormat
from feast.infra.compute_engines.dag.node import DAGNode
from feast.infra.compute_engines.dag.value import DAGValue
from feast.infra.compute_engines.spark.utils import map_in_arrow
from feast.infra.compute_engines.spark.utils import (
write_to_offline_store,
write_to_online_store,
)
from feast.infra.compute_engines.utils import (
create_offline_store_retrieval_job,
)
Expand Down Expand Up @@ -572,21 +575,12 @@ def execute(self, context: ExecutionContext) -> DAGValue:
feature_view=self.feature_view, repo_config=context.repo_config
)

# ✅ 1. Write to online store if online enabled
if self.feature_view.online:
spark_df.mapInArrow(
lambda x: map_in_arrow(x, serialized_artifacts, mode="online"),
spark_df.schema,
).count()
write_to_online_store(spark_df, serialized_artifacts)

# ✅ 2. Write to offline store if offline enabled
if self.feature_view.offline:
if not isinstance(self.feature_view.batch_source, SparkSource):
spark_df.mapInArrow(
lambda x: map_in_arrow(x, serialized_artifacts, mode="offline"),
spark_df.schema,
).count()
# Directly write spark df to spark offline store without using mapInArrow
write_to_offline_store(spark_df, serialized_artifacts)
else:
dest_path = self.feature_view.batch_source.path
file_format = self.feature_view.batch_source.file_format
Expand Down
196 changes: 180 additions & 16 deletions sdk/python/feast/infra/compute_engines/spark/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import os
from typing import Dict, Iterable, Literal, Optional
from typing import TYPE_CHECKING, Any, Dict, Iterable, Literal, Optional

import pandas as pd
import pyarrow
Expand All @@ -18,21 +18,16 @@
boto3 = None # type: ignore[assignment]
BotoConfig = None # type: ignore[assignment,misc]

if TYPE_CHECKING:
from pyspark.sql import DataFrame

logger = logging.getLogger(__name__)


def _ensure_s3a_event_log_dir(spark_config: Dict[str, str]) -> None:
"""Pre-create the S3A event log prefix before SparkContext initialisation.

Spark's EventLogFileWriter.requireLogBaseDirAsDirectory() is called inside
SparkContext.__init__ and crashes if the S3A path doesn't exist yet (S3 has no
real directories, so an empty prefix returns a 404). This function writes a
zero-byte placeholder so the prefix exists before SparkContext is built.
"""Pre-create an S3A event-log prefix so SparkContext.__init__ doesn't 404.

This is only attempted when:
- spark.eventLog.enabled == "true"
- spark.eventLog.dir starts with "s3a://"
Failures are non-fatal: Spark will surface its own error if the dir is still missing.
Only acts when eventLog is enabled with an s3a:// path. Non-fatal on failure.
"""
if spark_config.get("spark.eventLog.enabled", "false").lower() != "true":
return
Expand Down Expand Up @@ -121,6 +116,18 @@ def get_or_create_new_spark_session(
)

spark_session = spark_builder.getOrCreate()

# getOrCreate() silently drops new configs on a reused session.
# Re-apply spark.sql.* and spark.hadoop.* which are safe to set post-creation.
if spark_config:
_RUNTIME_PREFIXES = ("spark.sql.", "spark.hadoop.")
for k, v in spark_config.items():
if any(k.startswith(p) for p in _RUNTIME_PREFIXES):
try:
spark_session.conf.set(k, v)
except Exception as e:
logger.debug("Could not set runtime config %s: %s", k, e)

spark_session.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
return spark_session

Expand All @@ -146,7 +153,9 @@ def map_in_arrow(
for entity in feature_view.entity_columns
}

batch_size = repo_config.materialization_config.online_write_batch_size
batch_size = getattr(
repo_config.materialization_config, "online_write_batch_size", None
)
# Single batch if None (backward compatible), otherwise use configured batch_size
sub_batches = (
[table]
Expand Down Expand Up @@ -202,7 +211,9 @@ def map_in_pandas(iterator, serialized_artifacts: SerializedArtifacts):
for entity in feature_view.entity_columns
}

batch_size = repo_config.materialization_config.online_write_batch_size
batch_size = getattr(
repo_config.materialization_config, "online_write_batch_size", None
)
# Single batch if None (backward compatible), otherwise use configured batch_size
sub_batches = (
[table]
Expand All @@ -220,6 +231,159 @@ def map_in_pandas(iterator, serialized_artifacts: SerializedArtifacts):
lambda x: None,
)

yield pd.DataFrame(
[pd.Series(range(1, 2))]
) # dummy result because mapInPandas needs to return something
yield pd.DataFrame({"status": [0]})


def write_to_online_store(
spark_df: "DataFrame",
serialized_artifacts: SerializedArtifacts,
) -> None:
"""Write a Spark DataFrame to the online store via foreachPartition.

Uses foreachPartition instead of mapInArrow to avoid a Spark 3.5
serialiser mismatch (ArrowStreamPandasUDFSerializer vs ArrowStreamUDFSerializer)
when WindowGroupLimitExec precedes MapInArrowExec.

Rows are consumed in fixed-size chunks (default 5 000) so that Python
memory stays bounded regardless of partition size. Previous behaviour
called ``list(rows)`` which materialised the entire partition and caused
``MemoryError`` / OOMKill on large feature views.
"""
from itertools import islice

from pyspark.sql.pandas.types import to_arrow_schema

df_schema = spark_df.schema
_CHUNK = 5_000

def _write_partition(rows): # type: ignore[type-arg]
import pyarrow as pa

from feast.utils import _convert_arrow_to_proto

(
feature_view,
online_store,
_,
repo_config,
) = serialized_artifacts.unserialize()

join_key_to_value_type = {
entity.name: entity.dtype.to_value_type()
for entity in feature_view.entity_columns
}
arrow_schema = to_arrow_schema(df_schema)

batch_size = getattr(
repo_config.materialization_config, "online_write_batch_size", None
)
write_size = batch_size or _CHUNK

while True:
chunk = list(islice(rows, write_size))
if not chunk:
break
pdf = pd.DataFrame([r.asDict(recursive=True) for r in chunk])
table = pa.Table.from_pandas(
pdf, schema=arrow_schema, preserve_index=False
)
online_store.online_write_batch(
config=repo_config,
table=feature_view,
data=_convert_arrow_to_proto(
table, feature_view, join_key_to_value_type
),
progress=lambda x: None,
)

spark_df.foreachPartition(_write_partition)


def write_to_offline_store(
spark_df: "DataFrame",
serialized_artifacts: SerializedArtifacts,
) -> None:
"""Write a Spark DataFrame to the offline store via foreachPartition.

Same Spark 3.5 serialiser workaround and chunked-iterator pattern as
``write_to_online_store``.
"""
from itertools import islice

from pyspark.sql.pandas.types import to_arrow_schema

df_schema = spark_df.schema
_CHUNK = 5_000

def _write_partition(rows): # type: ignore[type-arg]
import pyarrow as pa

(
feature_view,
_,
offline_store,
repo_config,
) = serialized_artifacts.unserialize()

arrow_schema = to_arrow_schema(df_schema)

while True:
chunk = list(islice(rows, _CHUNK))
if not chunk:
break
pdf = pd.DataFrame([r.asDict(recursive=True) for r in chunk])
table = pa.Table.from_pandas(
pdf, schema=arrow_schema, preserve_index=False
)
offline_store.offline_write_batch(
config=repo_config,
feature_view=feature_view,
table=table,
progress=lambda x: None,
)

spark_df.foreachPartition(_write_partition)


_FEAST_EMBED_MODEL_CACHE: Dict[tuple, "Any"] = {}


def spark_embed(
df: "DataFrame",
text_col: str,
model: str = "sentence-transformers/all-MiniLM-L6-v2",
output_col: str = "embedding",
batch_size: int = 64,
) -> "DataFrame":
"""Append an embedding column to *df* using a sentence-transformer.

Intended for ``@batch_feature_view`` with ``TransformationMode.PYTHON``.
Uses ``localCheckpoint(eager=True)`` to sever Python lineage and avoid
downstream Arrow serialiser mismatches. Model is cached per executor.
"""
import pyspark.sql.functions as F
import pyspark.sql.types as T
from pyspark.sql.functions import pandas_udf

model_id = model
bs = batch_size
_cache = _FEAST_EMBED_MODEL_CACHE

@pandas_udf(T.ArrayType(T.FloatType()))
def _embed_udf(texts: pd.Series) -> pd.Series:
import torch
from sentence_transformers import SentenceTransformer

device = "cuda" if torch.cuda.is_available() else "cpu"
cache_key = (model_id, device)
if cache_key not in _cache:
_cache[cache_key] = SentenceTransformer(model_id, device=device)
sent_model = _cache[cache_key]

embeddings = sent_model.encode(
texts.tolist(), batch_size=bs, show_progress_bar=False
)
return pd.Series([e.astype("float32").tolist() for e in embeddings])

embedded = df.withColumn(output_col, _embed_udf(F.col(text_col)))
return embedded.localCheckpoint(eager=True)
Loading
Loading