diff --git a/sdk/python/feast/errors.py b/sdk/python/feast/errors.py index 15ba86781df..042a3622a98 100644 --- a/sdk/python/feast/errors.py +++ b/sdk/python/feast/errors.py @@ -56,6 +56,14 @@ def __init__(self, name, project=None): super().__init__(f"Feature view {name} does not exist") +class InvalidSparkSessionException(Exception): + def __init__(self, spark_arg): + super().__init__( + f" Need Spark Session to convert results to spark data frame\ + recieved {type(spark_arg)} instead. " + ) + + class OnDemandFeatureViewNotFoundException(FeastObjectNotFoundException): def __init__(self, name, project=None): if project: diff --git a/sdk/python/feast/infra/offline_stores/snowflake.py b/sdk/python/feast/infra/offline_stores/snowflake.py index 2d621de50ff..330c2ffae54 100644 --- a/sdk/python/feast/infra/offline_stores/snowflake.py +++ b/sdk/python/feast/infra/offline_stores/snowflake.py @@ -1,7 +1,9 @@ import contextlib import os import uuid +import warnings from datetime import datetime +from functools import reduce from pathlib import Path from typing import ( Any, @@ -21,11 +23,16 @@ import pyarrow from pydantic import Field, StrictStr from pydantic.typing import Literal +from pyspark.sql import DataFrame, SparkSession from pytz import utc from feast import OnDemandFeatureView from feast.data_source import DataSource -from feast.errors import EntitySQLEmptyResults, InvalidEntityType +from feast.errors import ( + EntitySQLEmptyResults, + InvalidEntityType, + InvalidSparkSessionException, +) from feast.feature_logging import LoggingConfig, LoggingSource from feast.feature_view import DUMMY_ENTITY_ID, DUMMY_ENTITY_VAL, FeatureView from feast.infra.offline_stores import offline_utils @@ -57,6 +64,8 @@ raise FeastExtrasDependencyImportError("snowflake", str(e)) +warnings.filterwarnings("ignore", category=DeprecationWarning) + class SnowflakeOfflineStoreConfig(FeastConfigBaseModel): """Offline store config for Snowflake""" @@ -447,6 +456,41 @@ def to_sql(self) -> str: with self._query_generator() as query: return query + def to_spark_df(self, spark_session: SparkSession) -> DataFrame: + """ + Method to convert snowflake query results to pyspark data frame. + + Args: + spark_session: spark Session variable of current environment. + + Returns: + spark_df: A pyspark dataframe. + """ + + if isinstance(spark_session, SparkSession): + with self._query_generator() as query: + + arrow_batches = execute_snowflake_statement( + self.snowflake_conn, query + ).fetch_arrow_batches() + + if arrow_batches: + spark_df = reduce( + DataFrame.unionAll, + [ + spark_session.createDataFrame(batch.to_pandas()) + for batch in arrow_batches + ], + ) + + return spark_df + + else: + raise EntitySQLEmptyResults(query) + + else: + raise InvalidSparkSessionException(spark_session) + def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False): assert isinstance(storage, SavedDatasetSnowflakeStorage) self.to_snowflake(table_name=storage.snowflake_options.table)