@@ -125,7 +125,7 @@ def get_historical_features(
125125 config : RepoConfig ,
126126 feature_views : List [FeatureView ],
127127 feature_refs : List [str ],
128- entity_df : Union [pandas .DataFrame , str ],
128+ entity_df : Union [pandas .DataFrame , str , pyspark . sql . DataFrame ],
129129 registry : Registry ,
130130 project : str ,
131131 full_feature_names : bool = False ,
@@ -473,15 +473,16 @@ def _get_entity_df_event_timestamp_range(
473473 entity_df_event_timestamp .min ().to_pydatetime (),
474474 entity_df_event_timestamp .max ().to_pydatetime (),
475475 )
476- elif isinstance (entity_df , str ):
476+ elif isinstance (entity_df , str ) or isinstance ( entity_df , pyspark . sql . DataFrame ) :
477477 # If the entity_df is a string (SQL query), determine range
478478 # from table
479- df = spark_session .sql (entity_df ).select (entity_df_event_timestamp_col )
480-
481- # Checks if executing entity sql resulted in any data
482- if df .rdd .isEmpty ():
483- raise EntitySQLEmptyResults (entity_df )
484-
479+ if isinstance (entity_df , str ):
480+ df = spark_session .sql (entity_df ).select (entity_df_event_timestamp_col )
481+ # Checks if executing entity sql resulted in any data
482+ if df .rdd .isEmpty ():
483+ raise EntitySQLEmptyResults (entity_df )
484+ else :
485+ df = entity_df
485486 # TODO(kzhang132): need utc conversion here.
486487
487488 entity_df_event_timestamp_range = (
@@ -499,8 +500,11 @@ def _get_entity_schema(
499500) -> Dict [str , np .dtype ]:
500501 if isinstance (entity_df , pd .DataFrame ):
501502 return dict (zip (entity_df .columns , entity_df .dtypes ))
502- elif isinstance (entity_df , str ):
503- entity_spark_df = spark_session .sql (entity_df )
503+ elif isinstance (entity_df , str ) or isinstance (entity_df , pyspark .sql .DataFrame ):
504+ if isinstance (entity_df , str ):
505+ entity_spark_df = spark_session .sql (entity_df )
506+ else :
507+ entity_spark_df = entity_df
504508 return dict (
505509 zip (
506510 entity_spark_df .columns ,
@@ -525,6 +529,8 @@ def _upload_entity_df(
525529 return
526530 elif isinstance (entity_df , str ):
527531 spark_session .sql (entity_df ).createOrReplaceTempView (table_name )
532+ elif isinstance (entity_df , pyspark .sql .DataFrame ):
533+ entity_df .createOrReplaceTempView (table_name )
528534 return
529535 else :
530536 raise InvalidEntityType (type (entity_df ))
0 commit comments