Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: Add Entity df in format of a Spark Dataframe instead of just pd…
….DataFrame or string for SparkOfflineStore (#3988)

* remove unused parameter when init sparksource

Signed-off-by: tanlocnguyen <tanlocnguyen296@gmail.com>

* feat: add entity df to SparkOfflineStore when get_historical_features

Signed-off-by: tanlocnguyen <tanlocnguyen296@gmail.com>

* fix: lint error

Signed-off-by: tanlocnguyen <tanlocnguyen296@gmail.com>

---------

Signed-off-by: tanlocnguyen <tanlocnguyen296@gmail.com>
Co-authored-by: tanlocnguyen <tanlocnguyen296@gmail.com>
  • Loading branch information
ElliotNguyen68 and ElliotNguyen68 authored Mar 6, 2024
commit 43b2c287705c2a3e882517524229f155c9ce0a01
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def get_historical_features(
config: RepoConfig,
feature_views: List[FeatureView],
feature_refs: List[str],
entity_df: Union[pandas.DataFrame, str],
entity_df: Union[pandas.DataFrame, str, pyspark.sql.DataFrame],
registry: Registry,
project: str,
full_feature_names: bool = False,
Expand Down Expand Up @@ -473,15 +473,16 @@ def _get_entity_df_event_timestamp_range(
entity_df_event_timestamp.min().to_pydatetime(),
entity_df_event_timestamp.max().to_pydatetime(),
)
elif isinstance(entity_df, str):
elif isinstance(entity_df, str) or isinstance(entity_df, pyspark.sql.DataFrame):
# If the entity_df is a string (SQL query), determine range
# from table
df = spark_session.sql(entity_df).select(entity_df_event_timestamp_col)

# Checks if executing entity sql resulted in any data
if df.rdd.isEmpty():
raise EntitySQLEmptyResults(entity_df)

if isinstance(entity_df, str):
df = spark_session.sql(entity_df).select(entity_df_event_timestamp_col)
# Checks if executing entity sql resulted in any data
if df.rdd.isEmpty():
raise EntitySQLEmptyResults(entity_df)
else:
df = entity_df
# TODO(kzhang132): need utc conversion here.

entity_df_event_timestamp_range = (
Expand All @@ -499,8 +500,11 @@ def _get_entity_schema(
) -> Dict[str, np.dtype]:
if isinstance(entity_df, pd.DataFrame):
return dict(zip(entity_df.columns, entity_df.dtypes))
elif isinstance(entity_df, str):
entity_spark_df = spark_session.sql(entity_df)
elif isinstance(entity_df, str) or isinstance(entity_df, pyspark.sql.DataFrame):
if isinstance(entity_df, str):
entity_spark_df = spark_session.sql(entity_df)
else:
entity_spark_df = entity_df
return dict(
zip(
entity_spark_df.columns,
Expand All @@ -526,6 +530,9 @@ def _upload_entity_df(
elif isinstance(entity_df, str):
spark_session.sql(entity_df).createOrReplaceTempView(table_name)
return
elif isinstance(entity_df, pyspark.sql.DataFrame):
entity_df.createOrReplaceTempView(table_name)
return
else:
raise InvalidEntityType(type(entity_df))

Expand Down