Skip to content

Commit 1ef7de9

Browse files
feat: add entity df to SparkOfflineStore when get_historical_features
Signed-off-by: tanlocnguyen <tanlocnguyen296@gmail.com>
1 parent 8b7d6fb commit 1ef7de9

File tree

1 file changed

+16
-10
lines changed
  • sdk/python/feast/infra/offline_stores/contrib/spark_offline_store

1 file changed

+16
-10
lines changed

sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def get_historical_features(
125125
config: RepoConfig,
126126
feature_views: List[FeatureView],
127127
feature_refs: List[str],
128-
entity_df: Union[pandas.DataFrame, str],
128+
entity_df: Union[pandas.DataFrame, str, pyspark.sql.DataFrame],
129129
registry: Registry,
130130
project: str,
131131
full_feature_names: bool = False,
@@ -473,15 +473,16 @@ def _get_entity_df_event_timestamp_range(
473473
entity_df_event_timestamp.min().to_pydatetime(),
474474
entity_df_event_timestamp.max().to_pydatetime(),
475475
)
476-
elif isinstance(entity_df, str):
476+
elif isinstance(entity_df, str) or isinstance(entity_df, pyspark.sql.DataFrame):
477477
# If the entity_df is a string (SQL query), determine range
478478
# from table
479-
df = spark_session.sql(entity_df).select(entity_df_event_timestamp_col)
480-
481-
# Checks if executing entity sql resulted in any data
482-
if df.rdd.isEmpty():
483-
raise EntitySQLEmptyResults(entity_df)
484-
479+
if isinstance(entity_df, str):
480+
df = spark_session.sql(entity_df).select(entity_df_event_timestamp_col)
481+
# Checks if executing entity sql resulted in any data
482+
if df.rdd.isEmpty():
483+
raise EntitySQLEmptyResults(entity_df)
484+
else:
485+
df = entity_df
485486
# TODO(kzhang132): need utc conversion here.
486487

487488
entity_df_event_timestamp_range = (
@@ -499,8 +500,11 @@ def _get_entity_schema(
499500
) -> Dict[str, np.dtype]:
500501
if isinstance(entity_df, pd.DataFrame):
501502
return dict(zip(entity_df.columns, entity_df.dtypes))
502-
elif isinstance(entity_df, str):
503-
entity_spark_df = spark_session.sql(entity_df)
503+
elif isinstance(entity_df, str) or isinstance(entity_df, pyspark.sql.DataFrame):
504+
if isinstance(entity_df, str):
505+
entity_spark_df = spark_session.sql(entity_df)
506+
else:
507+
entity_spark_df = entity_df
504508
return dict(
505509
zip(
506510
entity_spark_df.columns,
@@ -525,6 +529,8 @@ def _upload_entity_df(
525529
return
526530
elif isinstance(entity_df, str):
527531
spark_session.sql(entity_df).createOrReplaceTempView(table_name)
532+
elif isinstance(entity_df, pyspark.sql.DataFrame):
533+
entity_df.createOrReplaceTempView(table_name)
528534
return
529535
else:
530536
raise InvalidEntityType(type(entity_df))

0 commit comments

Comments
 (0)