@@ -436,52 +436,85 @@ def on_demand_feature_views(self) -> List[OnDemandFeatureView]:
436436 return self ._on_demand_feature_views
437437
438438 def _to_df_internal (self , timeout : Optional [int ] = None ) -> pd .DataFrame :
439- with self ._query_generator () as query :
440-
441- df = execute_snowflake_statement (
442- self .snowflake_conn , query
443- ).fetch_pandas_all ()
439+ df = execute_snowflake_statement (
440+ self .snowflake_conn , self .to_sql ()
441+ ).fetch_pandas_all ()
444442
445443 return df
446444
447445 def _to_arrow_internal (self , timeout : Optional [int ] = None ) -> pyarrow .Table :
448- with self ._query_generator () as query :
446+ pa_table = execute_snowflake_statement (
447+ self .snowflake_conn , self .to_sql ()
448+ ).fetch_arrow_all ()
449449
450- pa_table = execute_snowflake_statement (
451- self .snowflake_conn , query
452- ).fetch_arrow_all ()
450+ if pa_table :
451+ return pa_table
452+ else :
453+ empty_result = execute_snowflake_statement (
454+ self .snowflake_conn , self .to_sql ()
455+ )
453456
454- if pa_table :
455- return pa_table
456- else :
457- empty_result = execute_snowflake_statement (self .snowflake_conn , query )
457+ return pyarrow .Table .from_pandas (
458+ pd .DataFrame (columns = [md .name for md in empty_result .description ])
459+ )
458460
459- return pyarrow .Table .from_pandas (
460- pd .DataFrame (columns = [md .name for md in empty_result .description ])
461- )
461+ def to_sql (self ) -> str :
462+ """
463+ Returns the SQL query that will be executed in Snowflake to build the historical feature table.
464+ """
465+ with self ._query_generator () as query :
466+ return query
462467
463- def to_snowflake (self , table_name : str , temporary = False ) -> None :
468+ def to_snowflake (
469+ self , table_name : str , allow_overwrite : bool = False , temporary : bool = False
470+ ) -> None :
464471 """Save dataset as a new Snowflake table"""
465472 if self .on_demand_feature_views :
466473 transformed_df = self .to_df ()
467474
475+ if allow_overwrite :
476+ query = f'DROP TABLE IF EXISTS "{ table_name } "'
477+ execute_snowflake_statement (self .snowflake_conn , query )
478+
468479 write_pandas (
469- self .snowflake_conn , transformed_df , table_name , auto_create_table = True
480+ self .snowflake_conn ,
481+ transformed_df ,
482+ table_name ,
483+ auto_create_table = True ,
484+ create_temp_table = temporary ,
470485 )
471486
472- return None
487+ else :
488+ query = f'CREATE { "OR REPLACE" if allow_overwrite else "" } { "TEMPORARY" if temporary else "" } TABLE { "IF NOT EXISTS" if not allow_overwrite else "" } "{ table_name } " AS ({ self .to_sql ()} );\n '
489+ execute_snowflake_statement (self .snowflake_conn , query )
473490
474- with self ._query_generator () as query :
475- query = f'CREATE { "TEMPORARY" if temporary else "" } TABLE IF NOT EXISTS "{ table_name } " AS ({ query } );\n '
491+ return None
476492
477- execute_snowflake_statement (self . snowflake_conn , query )
493+ def to_arrow_batches (self ) -> Iterator [ pyarrow . Table ]:
478494
479- def to_sql (self ) -> str :
480- """
481- Returns the SQL query that will be executed in Snowflake to build the historical feature table.
482- """
483- with self ._query_generator () as query :
484- return query
495+ table_name = "temp_arrow_batches_" + uuid .uuid4 ().hex
496+
497+ self .to_snowflake (table_name = table_name , allow_overwrite = True , temporary = True )
498+
499+ query = f'SELECT * FROM "{ table_name } "'
500+ arrow_batches = execute_snowflake_statement (
501+ self .snowflake_conn , query
502+ ).fetch_arrow_batches ()
503+
504+ return arrow_batches
505+
506+ def to_pandas_batches (self ) -> Iterator [pd .DataFrame ]:
507+
508+ table_name = "temp_pandas_batches_" + uuid .uuid4 ().hex
509+
510+ self .to_snowflake (table_name = table_name , allow_overwrite = True , temporary = True )
511+
512+ query = f'SELECT * FROM "{ table_name } "'
513+ arrow_batches = execute_snowflake_statement (
514+ self .snowflake_conn , query
515+ ).fetch_pandas_batches ()
516+
517+ return arrow_batches
485518
486519 def to_spark_df (self , spark_session : "SparkSession" ) -> "DataFrame" :
487520 """
@@ -502,37 +535,33 @@ def to_spark_df(self, spark_session: "SparkSession") -> "DataFrame":
502535 raise FeastExtrasDependencyImportError ("spark" , str (e ))
503536
504537 if isinstance (spark_session , SparkSession ):
505- with self ._query_generator () as query :
506-
507- arrow_batches = execute_snowflake_statement (
508- self .snowflake_conn , query
509- ).fetch_arrow_batches ()
510-
511- if arrow_batches :
512- spark_df = reduce (
513- DataFrame .unionAll ,
514- [
515- spark_session .createDataFrame (batch .to_pandas ())
516- for batch in arrow_batches
517- ],
518- )
519-
520- return spark_df
521-
522- else :
523- raise EntitySQLEmptyResults (query )
524-
538+ arrow_batches = self .to_arrow_batches ()
539+
540+ if arrow_batches :
541+ spark_df = reduce (
542+ DataFrame .unionAll ,
543+ [
544+ spark_session .createDataFrame (batch .to_pandas ())
545+ for batch in arrow_batches
546+ ],
547+ )
548+ return spark_df
549+ else :
550+ raise EntitySQLEmptyResults (self .to_sql ())
525551 else :
526552 raise InvalidSparkSessionException (spark_session )
527553
528554 def persist (
529555 self ,
530556 storage : SavedDatasetStorage ,
531- allow_overwrite : Optional [ bool ] = False ,
557+ allow_overwrite : bool = False ,
532558 timeout : Optional [int ] = None ,
533559 ):
534560 assert isinstance (storage , SavedDatasetSnowflakeStorage )
535- self .to_snowflake (table_name = storage .snowflake_options .table )
561+
562+ self .to_snowflake (
563+ table_name = storage .snowflake_options .table , allow_overwrite = allow_overwrite
564+ )
536565
537566 @property
538567 def metadata (self ) -> Optional [RetrievalMetadata ]:
0 commit comments