11import logging
22import os
3- from typing import Dict , Iterable , Literal , Optional
3+ from typing import TYPE_CHECKING , Dict , Iterable , Literal , Optional
44
55import pandas as pd
66import pyarrow
1818 boto3 = None # type: ignore[assignment]
1919 BotoConfig = None # type: ignore[assignment,misc]
2020
21+ if TYPE_CHECKING :
22+ from pyspark .sql import DataFrame
23+
2124logger = logging .getLogger (__name__ )
2225
2326
2427def _ensure_s3a_event_log_dir (spark_config : Dict [str , str ]) -> None :
25- """Pre-create the S3A event log prefix before SparkContext initialisation.
26-
27- Spark's EventLogFileWriter.requireLogBaseDirAsDirectory() is called inside
28- SparkContext.__init__ and crashes if the S3A path doesn't exist yet (S3 has no
29- real directories, so an empty prefix returns a 404). This function writes a
30- zero-byte placeholder so the prefix exists before SparkContext is built.
28+ """Pre-create an S3A event-log prefix so SparkContext.__init__ doesn't 404.
3129
32- This is only attempted when:
33- - spark.eventLog.enabled == "true"
34- - spark.eventLog.dir starts with "s3a://"
35- Failures are non-fatal: Spark will surface its own error if the dir is still missing.
30+ Only acts when eventLog is enabled with an s3a:// path. Non-fatal on failure.
3631 """
3732 if spark_config .get ("spark.eventLog.enabled" , "false" ).lower () != "true" :
3833 return
@@ -121,6 +116,18 @@ def get_or_create_new_spark_session(
121116 )
122117
123118 spark_session = spark_builder .getOrCreate ()
119+
120+ # getOrCreate() silently drops new configs on a reused session.
121+ # Re-apply spark.sql.* and spark.hadoop.* which are safe to set post-creation.
122+ if spark_config :
123+ _RUNTIME_PREFIXES = ("spark.sql." , "spark.hadoop." )
124+ for k , v in spark_config .items ():
125+ if any (k .startswith (p ) for p in _RUNTIME_PREFIXES ):
126+ try :
127+ spark_session .conf .set (k , v )
128+ except Exception as e :
129+ logger .debug ("Could not set runtime config %s: %s" , k , e )
130+
124131 spark_session .conf .set ("spark.sql.execution.arrow.pyspark.enabled" , "true" )
125132 return spark_session
126133
@@ -146,7 +153,9 @@ def map_in_arrow(
146153 for entity in feature_view .entity_columns
147154 }
148155
149- batch_size = repo_config .materialization_config .online_write_batch_size
156+ batch_size = getattr (
157+ repo_config .materialization_config , "online_write_batch_size" , None
158+ )
150159 # Single batch if None (backward compatible), otherwise use configured batch_size
151160 sub_batches = (
152161 [table ]
@@ -202,7 +211,9 @@ def map_in_pandas(iterator, serialized_artifacts: SerializedArtifacts):
202211 for entity in feature_view .entity_columns
203212 }
204213
205- batch_size = repo_config .materialization_config .online_write_batch_size
214+ batch_size = getattr (
215+ repo_config .materialization_config , "online_write_batch_size" , None
216+ )
206217 # Single batch if None (backward compatible), otherwise use configured batch_size
207218 sub_batches = (
208219 [table ]
@@ -220,6 +231,153 @@ def map_in_pandas(iterator, serialized_artifacts: SerializedArtifacts):
220231 lambda x : None ,
221232 )
222233
223- yield pd .DataFrame (
224- [pd .Series (range (1 , 2 ))]
225- ) # dummy result because mapInPandas needs to return something
234+ yield pd .DataFrame ({"status" : [0 ]})
235+
236+
237+ def write_to_online_store (
238+ spark_df : "DataFrame" ,
239+ serialized_artifacts : SerializedArtifacts ,
240+ ) -> None :
241+ """Write a Spark DataFrame to the online store via foreachPartition.
242+
243+ Uses foreachPartition instead of mapInArrow to avoid a Spark 3.5
244+ serialiser mismatch (ArrowStreamPandasUDFSerializer vs ArrowStreamUDFSerializer)
245+ when WindowGroupLimitExec precedes MapInArrowExec.
246+ """
247+ from pyspark .sql .pandas .types import to_arrow_schema
248+
249+ df_schema = spark_df .schema
250+
251+ def _write_partition (rows ): # type: ignore[type-arg]
252+ rows_list = list (rows )
253+ if not rows_list :
254+ return
255+
256+ import pyarrow as pa
257+
258+ from feast .utils import _convert_arrow_to_proto
259+
260+ pdf = pd .DataFrame ([r .asDict (recursive = True ) for r in rows_list ])
261+ table = pa .Table .from_pandas (
262+ pdf , schema = to_arrow_schema (df_schema ), preserve_index = False
263+ )
264+
265+ (
266+ feature_view ,
267+ online_store ,
268+ _ ,
269+ repo_config ,
270+ ) = serialized_artifacts .unserialize ()
271+
272+ join_key_to_value_type = {
273+ entity .name : entity .dtype .to_value_type ()
274+ for entity in feature_view .entity_columns
275+ }
276+
277+ batch_size = getattr (
278+ repo_config .materialization_config , "online_write_batch_size" , None
279+ )
280+ if batch_size is None :
281+ sub_tables = [table ]
282+ else :
283+ sub_tables = [
284+ table .slice (offset , min (batch_size , len (table ) - offset ))
285+ for offset in range (0 , len (table ), batch_size )
286+ ]
287+
288+ for sub_table in sub_tables :
289+ online_store .online_write_batch (
290+ config = repo_config ,
291+ table = feature_view ,
292+ data = _convert_arrow_to_proto (
293+ sub_table , feature_view , join_key_to_value_type
294+ ),
295+ progress = lambda x : None ,
296+ )
297+
298+ spark_df .foreachPartition (_write_partition )
299+
300+
301+ def write_to_offline_store (
302+ spark_df : "DataFrame" ,
303+ serialized_artifacts : SerializedArtifacts ,
304+ ) -> None :
305+ """Write a Spark DataFrame to the offline store via foreachPartition.
306+
307+ Same Spark 3.5 serialiser workaround as ``write_to_online_store``.
308+ """
309+ from pyspark .sql .pandas .types import to_arrow_schema
310+
311+ df_schema = spark_df .schema
312+
313+ def _write_partition (rows ): # type: ignore[type-arg]
314+ rows_list = list (rows )
315+ if not rows_list :
316+ return
317+
318+ import pyarrow as pa
319+
320+ pdf = pd .DataFrame ([r .asDict (recursive = True ) for r in rows_list ])
321+ table = pa .Table .from_pandas (
322+ pdf , schema = to_arrow_schema (df_schema ), preserve_index = False
323+ )
324+
325+ (
326+ feature_view ,
327+ _ ,
328+ offline_store ,
329+ repo_config ,
330+ ) = serialized_artifacts .unserialize ()
331+
332+ offline_store .offline_write_batch (
333+ config = repo_config ,
334+ feature_view = feature_view ,
335+ table = table ,
336+ progress = lambda x : None ,
337+ )
338+
339+ spark_df .foreachPartition (_write_partition )
340+
341+
342+ _FEAST_EMBED_MODEL_CACHE : Dict [tuple , object ] = {}
343+
344+
345+ def spark_embed (
346+ df : "DataFrame" ,
347+ text_col : str ,
348+ model : str = "sentence-transformers/all-MiniLM-L6-v2" ,
349+ output_col : str = "embedding" ,
350+ batch_size : int = 64 ,
351+ ) -> "DataFrame" :
352+ """Append an embedding column to *df* using a sentence-transformer.
353+
354+ Intended for ``@batch_feature_view`` with ``TransformationMode.PYTHON``.
355+ Uses ``localCheckpoint(eager=True)`` to sever Python lineage and avoid
356+ downstream Arrow serialiser mismatches. Model is cached per executor.
357+ """
358+ import pyspark .sql .functions as F
359+ import pyspark .sql .types as T
360+ from pyspark .sql .functions import pandas_udf
361+
362+ model_id = model
363+ bs = batch_size
364+ _cache = _FEAST_EMBED_MODEL_CACHE
365+
366+ @pandas_udf (T .ArrayType (T .FloatType ()))
367+ def _embed_udf (texts : pd .Series ) -> pd .Series :
368+ import torch
369+ from sentence_transformers import SentenceTransformer
370+
371+ device = "cuda" if torch .cuda .is_available () else "cpu"
372+ cache_key = (model_id , device )
373+ if cache_key not in _cache :
374+ _cache [cache_key ] = SentenceTransformer (model_id , device = device )
375+ sent_model = _cache [cache_key ]
376+
377+ embeddings = sent_model .encode (
378+ texts .tolist (), batch_size = bs , show_progress_bar = False
379+ )
380+ return pd .Series ([e .astype ("float32" ).tolist () for e in embeddings ])
381+
382+ embedded = df .withColumn (output_col , _embed_udf (F .col (text_col )))
383+ return embedded .localCheckpoint (eager = True )
0 commit comments