|
13 | 13 | import pyspark |
14 | 14 | from pydantic import StrictStr |
15 | 15 | from pyspark import SparkConf |
16 | | -from pyspark.sql import SparkSession, DataFrame as SparkDataFrame |
| 16 | +from pyspark.sql import SparkSession |
17 | 17 | from pytz import utc |
18 | 18 |
|
19 | 19 | from feast import FeatureView, OnDemandFeatureView |
@@ -125,7 +125,7 @@ def get_historical_features( |
125 | 125 | config: RepoConfig, |
126 | 126 | feature_views: List[FeatureView], |
127 | 127 | feature_refs: List[str], |
128 | | - entity_df: Union[pandas.DataFrame, str, SparkDataFrame], |
| 128 | + entity_df: Union[pandas.DataFrame, str, pyspark.sql.DataFrame], |
129 | 129 | registry: Registry, |
130 | 130 | project: str, |
131 | 131 | full_feature_names: bool = False, |
@@ -473,7 +473,7 @@ def _get_entity_df_event_timestamp_range( |
473 | 473 | entity_df_event_timestamp.min().to_pydatetime(), |
474 | 474 | entity_df_event_timestamp.max().to_pydatetime(), |
475 | 475 | ) |
476 | | - elif isinstance(entity_df, str) or isinstance(entity_df, SparkDataFrame): |
| 476 | + elif isinstance(entity_df, str) or isinstance(entity_df, pyspark.sql.DataFrame): |
477 | 477 | # If the entity_df is a string (SQL query), determine range |
478 | 478 | # from table |
479 | 479 | if isinstance(entity_df, str): |
@@ -501,7 +501,7 @@ def _get_entity_schema( |
501 | 501 | ) -> Dict[str, np.dtype]: |
502 | 502 | if isinstance(entity_df, pd.DataFrame): |
503 | 503 | return dict(zip(entity_df.columns, entity_df.dtypes)) |
504 | | - elif isinstance(entity_df, str) or isinstance(entity_df,SparkDataFrame): |
| 504 | + elif isinstance(entity_df, str) or isinstance(entity_df,pyspark.sql.DataFrame): |
505 | 505 | if isinstance(entity_df, str): |
506 | 506 | entity_spark_df = spark_session.sql(entity_df) |
507 | 507 | else: |
@@ -530,7 +530,7 @@ def _upload_entity_df( |
530 | 530 | return |
531 | 531 | elif isinstance(entity_df, str): |
532 | 532 | spark_session.sql(entity_df).createOrReplaceTempView(table_name) |
533 | | - elif isinstance(entity_df, SparkDataFrame): |
| 533 | + elif isinstance(entity_df, pyspark.sql.DataFrame): |
534 | 534 | entity_df.createOrReplaceTempView(table_name) |
535 | 535 | return |
536 | 536 | else: |
|
0 commit comments