Skip to content

Commit c613b7e

Browse files
feat: Enable Arrow-based columnar data transfers when to pandas in sparksource retrieval job
Signed-off-by: tanlocnguyen <tanlocnguyen296@gmail.com>
1 parent e6fc736 commit c613b7e

File tree

1 file changed

+9
-4
lines changed
  • sdk/python/feast/infra/offline_stores/contrib/spark_offline_store

1 file changed

+9
-4
lines changed

sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,10 @@ def to_spark_df(self) -> pyspark.sql.DataFrame:
338338

339339
def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame:
340340
"""Return dataset as Pandas DataFrame synchronously"""
341+
spark_session = get_spark_session_or_start_new_with_repoconfig(self._config.offline_store)
342+
spark_session.conf.set(
343+
"spark.sql.execution.arrow.pyspark.enabled", "true"
344+
)
341345
return self.to_spark_df().toPandas()
342346

343347
def _to_arrow_internal(self, timeout: Optional[int] = None) -> pyarrow.Table:
@@ -442,7 +446,7 @@ def metadata(self) -> Optional[RetrievalMetadata]:
442446
def get_spark_session_or_start_new_with_repoconfig(
443447
store_config: SparkOfflineStoreConfig,
444448
) -> SparkSession:
445-
spark_session = SparkSession.getActiveSession()
449+
spark_session = SparkSession.builder.getOrCreate()
446450
if not spark_session:
447451
spark_builder = SparkSession.builder
448452
spark_conf = store_config.spark_conf
@@ -457,7 +461,7 @@ def get_spark_session_or_start_new_with_repoconfig(
457461

458462

459463
def _get_entity_df_event_timestamp_range(
460-
entity_df: Union[pd.DataFrame, str],
464+
entity_df: Union[pd.DataFrame, str, pyspark.sql.DataFrame],
461465
entity_df_event_timestamp_col: str,
462466
spark_session: SparkSession,
463467
) -> Tuple[datetime, datetime]:
@@ -496,7 +500,8 @@ def _get_entity_df_event_timestamp_range(
496500

497501

498502
def _get_entity_schema(
499-
spark_session: SparkSession, entity_df: Union[pandas.DataFrame, str]
503+
spark_session: SparkSession,
504+
entity_df: Union[pandas.DataFrame, str, pyspark.sql.DataFrame],
500505
) -> Dict[str, np.dtype]:
501506
if isinstance(entity_df, pd.DataFrame):
502507
return dict(zip(entity_df.columns, entity_df.dtypes))
@@ -518,7 +523,7 @@ def _get_entity_schema(
518523
def _upload_entity_df(
519524
spark_session: SparkSession,
520525
table_name: str,
521-
entity_df: Union[pandas.DataFrame, str],
526+
entity_df: Union[pandas.DataFrame, str, pyspark.sql.DataFrame],
522527
event_timestamp_col: str,
523528
) -> None:
524529
if isinstance(entity_df, pd.DataFrame):

0 commit comments

Comments
 (0)