1616 Union ,
1717 cast ,
1818)
19+ from urllib .parse import urlparse
1920
2021if TYPE_CHECKING :
2122 from feast .saved_dataset import ValidationReference
2425import pandas
2526import pandas as pd
2627import pyarrow
28+ import pyarrow .dataset as ds
2729import pyarrow .parquet as pq
2830import pyspark
29- from pydantic import StrictStr
31+ from pydantic import StrictBool , StrictStr
3032from pyspark import SparkConf
3133from pyspark .sql import SparkSession
3234
@@ -66,6 +68,9 @@ class SparkOfflineStoreConfig(FeastConfigBaseModel):
6668 staging_location : Optional [StrictStr ] = None
6769 """ Remote path for batch materialization jobs"""
6870
71+ staging_allow_materialize : StrictBool = False
72+ """ Enable use of staging_location during materialization to avoid driver OOM """
73+
6974 region : Optional [StrictStr ] = None
7075 """ AWS Region if applicable for s3-based staging locations"""
7176
@@ -445,8 +450,44 @@ def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame:
445450
446451 def _to_arrow_internal (self , timeout : Optional [int ] = None ) -> pyarrow .Table :
447452 """Return dataset as pyarrow Table synchronously"""
453+ if self ._should_use_staging_for_arrow ():
454+ return self ._to_arrow_via_staging ()
455+
448456 return pyarrow .Table .from_pandas (self ._to_df_internal (timeout = timeout ))
449457
458+ def _should_use_staging_for_arrow (self ) -> bool :
459+ offline_store = getattr (self ._config , "offline_store" , None )
460+ return bool (
461+ isinstance (offline_store , SparkOfflineStoreConfig )
462+ and getattr (offline_store , "staging_allow_materialize" , False )
463+ and getattr (offline_store , "staging_location" , None )
464+ )
465+
466+ def _to_arrow_via_staging (self ) -> pyarrow .Table :
467+ paths = self .to_remote_storage ()
468+ if not paths :
469+ return pyarrow .table ({})
470+
471+ parquet_paths = _filter_parquet_files (paths )
472+ if not parquet_paths :
473+ return pyarrow .table ({})
474+
475+ normalized_paths = self ._normalize_staging_paths (parquet_paths )
476+ dataset = ds .dataset (normalized_paths , format = "parquet" )
477+ return dataset .to_table ()
478+
479+ def _normalize_staging_paths (self , paths : List [str ]) -> List [str ]:
480+ """Normalize staging paths for PyArrow datasets."""
481+ normalized = []
482+ for path in paths :
483+ if path .startswith ("file://" ):
484+ normalized .append (path [len ("file://" ) :])
485+ elif "://" in path :
486+ normalized .append (path )
487+ else :
488+ normalized .append (path )
489+ return normalized
490+
450491 def to_feast_df (
451492 self ,
452493 validation_reference : Optional ["ValidationReference" ] = None ,
@@ -508,55 +549,53 @@ def supports_remote_storage_export(self) -> bool:
508549
509550 def to_remote_storage (self ) -> List [str ]:
510551 """Currently only works for local and s3-based staging locations"""
511- if self .supports_remote_storage_export ():
512- sdf : pyspark .sql .DataFrame = self .to_spark_df ()
513-
514- if self ._config .offline_store .staging_location .startswith ("/" ):
515- local_file_staging_location = os .path .abspath (
516- self ._config .offline_store .staging_location
517- )
518-
519- # write to staging location
520- output_uri = os .path .join (
521- str (local_file_staging_location ), str (uuid .uuid4 ())
522- )
523- sdf .write .parquet (output_uri )
524-
525- return _list_files_in_folder (output_uri )
526- elif self ._config .offline_store .staging_location .startswith ("s3://" ):
527- from feast .infra .utils import aws_utils
528-
529- spark_compatible_s3_staging_location = (
530- self ._config .offline_store .staging_location .replace (
531- "s3://" , "s3a://"
532- )
533- )
534-
535- # write to staging location
536- output_uri = os .path .join (
537- str (spark_compatible_s3_staging_location ), str (uuid .uuid4 ())
538- )
539- sdf .write .parquet (output_uri )
540-
541- return aws_utils .list_s3_files (
542- self ._config .offline_store .region , output_uri
543- )
544- elif self ._config .offline_store .staging_location .startswith ("hdfs://" ):
545- output_uri = os .path .join (
546- self ._config .offline_store .staging_location , str (uuid .uuid4 ())
547- )
548- sdf .write .parquet (output_uri )
549- spark_session = get_spark_session_or_start_new_with_repoconfig (
550- store_config = self ._config .offline_store
551- )
552- return _list_hdfs_files (spark_session , output_uri )
553- else :
554- raise NotImplementedError (
555- "to_remote_storage is only implemented for file://, s3:// and hdfs:// uri schemes"
556- )
552+ if not self .supports_remote_storage_export ():
553+ raise NotImplementedError ()
554+
555+ sdf : pyspark .sql .DataFrame = self .to_spark_df ()
556+ staging_location = self ._config .offline_store .staging_location
557557
558+ if staging_location .startswith ("/" ):
559+ local_file_staging_location = os .path .abspath (staging_location )
560+ output_uri = os .path .join (local_file_staging_location , str (uuid .uuid4 ()))
561+ sdf .write .parquet (output_uri )
562+ return _list_files_in_folder (output_uri )
563+ elif staging_location .startswith ("s3://" ):
564+ from feast .infra .utils import aws_utils
565+
566+ spark_compatible_s3_staging_location = staging_location .replace (
567+ "s3://" , "s3a://"
568+ )
569+ output_uri = os .path .join (
570+ spark_compatible_s3_staging_location , str (uuid .uuid4 ())
571+ )
572+ sdf .write .parquet (output_uri )
573+ s3_uri_for_listing = output_uri .replace ("s3a://" , "s3://" , 1 )
574+ return aws_utils .list_s3_files (
575+ self ._config .offline_store .region , s3_uri_for_listing
576+ )
577+ elif staging_location .startswith ("gs://" ):
578+ output_uri = os .path .join (staging_location , str (uuid .uuid4 ()))
579+ sdf .write .parquet (output_uri )
580+ return _list_gcs_files (output_uri )
581+ elif staging_location .startswith (("wasbs://" , "abfs://" , "abfss://" )) or (
582+ staging_location .startswith ("https://" )
583+ and ".blob.core.windows.net" in staging_location
584+ ):
585+ output_uri = os .path .join (staging_location , str (uuid .uuid4 ()))
586+ sdf .write .parquet (output_uri )
587+ return _list_azure_files (output_uri )
588+ elif staging_location .startswith ("hdfs://" ):
589+ output_uri = os .path .join (staging_location , str (uuid .uuid4 ()))
590+ sdf .write .parquet (output_uri )
591+ spark_session = get_spark_session_or_start_new_with_repoconfig (
592+ store_config = self ._config .offline_store
593+ )
594+ return _list_hdfs_files (spark_session , output_uri )
558595 else :
559- raise NotImplementedError ()
596+ raise NotImplementedError (
597+ "to_remote_storage is only implemented for file://, s3://, gs://, azure, and hdfs uri schemes"
598+ )
560599
561600 @property
562601 def metadata (self ) -> Optional [RetrievalMetadata ]:
@@ -789,6 +828,10 @@ def _list_files_in_folder(folder):
789828 return files
790829
791830
831+ def _filter_parquet_files (paths : List [str ]) -> List [str ]:
832+ return [path for path in paths if path .endswith (".parquet" )]
833+
834+
792835def _list_hdfs_files (spark_session : SparkSession , uri : str ) -> List [str ]:
793836 jvm = spark_session ._jvm
794837 jsc = spark_session ._jsc
@@ -805,6 +848,81 @@ def _list_hdfs_files(spark_session: SparkSession, uri: str) -> List[str]:
805848 return files
806849
807850
851+ def _list_gcs_files (path : str ) -> List [str ]:
852+ try :
853+ from google .cloud import storage
854+ except ImportError as e :
855+ from feast .errors import FeastExtrasDependencyImportError
856+
857+ raise FeastExtrasDependencyImportError ("gcp" , str (e ))
858+
859+ assert path .startswith ("gs://" ), "GCS path must start with gs://"
860+ bucket_path = path [len ("gs://" ) :]
861+ if "/" in bucket_path :
862+ bucket , prefix = bucket_path .split ("/" , 1 )
863+ else :
864+ bucket , prefix = bucket_path , ""
865+
866+ client = storage .Client ()
867+ bucket_obj = client .bucket (bucket )
868+ blobs = bucket_obj .list_blobs (prefix = prefix )
869+
870+ files = []
871+ for blob in blobs :
872+ if not blob .name .endswith ("/" ):
873+ files .append (f"gs://{ bucket } /{ blob .name } " )
874+ return files
875+
876+
877+ def _list_azure_files (path : str ) -> List [str ]:
878+ try :
879+ from azure .identity import DefaultAzureCredential
880+ from azure .storage .blob import BlobServiceClient
881+ except ImportError as e :
882+ from feast .errors import FeastExtrasDependencyImportError
883+
884+ raise FeastExtrasDependencyImportError ("azure" , str (e ))
885+
886+ parsed = urlparse (path )
887+ scheme = parsed .scheme
888+
889+ if scheme in ("wasbs" , "abfs" , "abfss" ):
890+ if "@" not in parsed .netloc :
891+ raise ValueError ("Azure staging URI must include container@account host" )
892+ container , account_host = parsed .netloc .split ("@" , 1 )
893+ account_url = f"https://{ account_host } "
894+ base = f"{ scheme } ://{ container } @{ account_host } "
895+ prefix = parsed .path .lstrip ("/" )
896+ else :
897+ account_url = f"{ parsed .scheme } ://{ parsed .netloc } "
898+ container_and_prefix = parsed .path .lstrip ("/" ).split ("/" , 1 )
899+ container = container_and_prefix [0 ]
900+ base = f"{ account_url } /{ container } "
901+ prefix = container_and_prefix [1 ] if len (container_and_prefix ) > 1 else ""
902+
903+ credential = os .environ .get ("AZURE_STORAGE_KEY" ) or os .environ .get (
904+ "AZURE_STORAGE_ACCOUNT_KEY"
905+ )
906+ if credential :
907+ client = BlobServiceClient (account_url = account_url , credential = credential )
908+ else :
909+ default_credential = DefaultAzureCredential (
910+ exclude_shared_token_cache_credential = True
911+ )
912+ client = BlobServiceClient (
913+ account_url = account_url , credential = default_credential
914+ )
915+
916+ container_client = client .get_container_client (container )
917+ blobs = container_client .list_blobs (name_starts_with = prefix if prefix else None )
918+
919+ files = []
920+ for blob in blobs :
921+ if not blob .name .endswith ("/" ):
922+ files .append (f"{ base } /{ blob .name } " )
923+ return files
924+
925+
808926def _cast_data_frame (
809927 df_new : pyspark .sql .DataFrame , df_existing : pyspark .sql .DataFrame
810928) -> pyspark .sql .DataFrame :
0 commit comments