Skip to content

Commit 346cc97

Browse files
ElliotNguyen68tqtensor
authored andcommitted
feat: Add Entity df in format of a Spark Dataframe instead of just pd.DataFrame or string for SparkOfflineStore (feast-dev#3988)
* remove unused parameter when init sparksource Signed-off-by: tanlocnguyen <tanlocnguyen296@gmail.com> * feat: add entity df to SparkOfflineStore when get_historical_features Signed-off-by: tanlocnguyen <tanlocnguyen296@gmail.com> * fix: lint error Signed-off-by: tanlocnguyen <tanlocnguyen296@gmail.com> --------- Signed-off-by: tanlocnguyen <tanlocnguyen296@gmail.com> Co-authored-by: tanlocnguyen <tanlocnguyen296@gmail.com>
1 parent 1b2726d commit 346cc97

1 file changed

Lines changed: 17 additions & 10 deletions

File tree

  • sdk/python/feast/infra/offline_stores/contrib/spark_offline_store

sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def get_historical_features(
125125
config: RepoConfig,
126126
feature_views: List[FeatureView],
127127
feature_refs: List[str],
128-
entity_df: Union[pandas.DataFrame, str],
128+
entity_df: Union[pandas.DataFrame, str, pyspark.sql.DataFrame],
129129
registry: Registry,
130130
project: str,
131131
full_feature_names: bool = False,
@@ -473,15 +473,16 @@ def _get_entity_df_event_timestamp_range(
473473
entity_df_event_timestamp.min().to_pydatetime(),
474474
entity_df_event_timestamp.max().to_pydatetime(),
475475
)
476-
elif isinstance(entity_df, str):
476+
elif isinstance(entity_df, str) or isinstance(entity_df, pyspark.sql.DataFrame):
477477
# If the entity_df is a string (SQL query), determine range
478478
# from table
479-
df = spark_session.sql(entity_df).select(entity_df_event_timestamp_col)
480-
481-
# Checks if executing entity sql resulted in any data
482-
if df.rdd.isEmpty():
483-
raise EntitySQLEmptyResults(entity_df)
484-
479+
if isinstance(entity_df, str):
480+
df = spark_session.sql(entity_df).select(entity_df_event_timestamp_col)
481+
# Checks if executing entity sql resulted in any data
482+
if df.rdd.isEmpty():
483+
raise EntitySQLEmptyResults(entity_df)
484+
else:
485+
df = entity_df
485486
# TODO(kzhang132): need utc conversion here.
486487

487488
entity_df_event_timestamp_range = (
@@ -499,8 +500,11 @@ def _get_entity_schema(
499500
) -> Dict[str, np.dtype]:
500501
if isinstance(entity_df, pd.DataFrame):
501502
return dict(zip(entity_df.columns, entity_df.dtypes))
502-
elif isinstance(entity_df, str):
503-
entity_spark_df = spark_session.sql(entity_df)
503+
elif isinstance(entity_df, str) or isinstance(entity_df, pyspark.sql.DataFrame):
504+
if isinstance(entity_df, str):
505+
entity_spark_df = spark_session.sql(entity_df)
506+
else:
507+
entity_spark_df = entity_df
504508
return dict(
505509
zip(
506510
entity_spark_df.columns,
@@ -526,6 +530,9 @@ def _upload_entity_df(
526530
elif isinstance(entity_df, str):
527531
spark_session.sql(entity_df).createOrReplaceTempView(table_name)
528532
return
533+
elif isinstance(entity_df, pyspark.sql.DataFrame):
534+
entity_df.createOrReplaceTempView(table_name)
535+
return
529536
else:
530537
raise InvalidEntityType(type(entity_df))
531538

0 commit comments

Comments
 (0)