Skip to content

Commit 88636de

Browse files
committed
feat: support staging for spark materialization (#5671)
Signed-off-by: Jacob Weinhold <29459386+jfw-ppi@users.noreply.github.com>
1 parent 59dbb33 commit 88636de

File tree

3 files changed

+383
-49
lines changed
  • docs/reference/offline-stores
  • sdk/python
    • feast/infra/offline_stores/contrib/spark_offline_store
    • tests/unit/infra/offline_stores/contrib/spark_offline_store

3 files changed

+383
-49
lines changed

docs/reference/offline-stores/spark.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,14 @@ offline_store:
3232
spark.sql.session.timeZone: "UTC"
3333
spark.sql.execution.arrow.fallback.enabled: "true"
3434
spark.sql.execution.arrow.pyspark.enabled: "true"
35+
# Optional: spill large materializations to the staging location instead of collecting in the driver
36+
staging_location: "s3://my-bucket/tmp/feast"
37+
staging_allow_materialize: true
3538
online_store:
3639
path: data/online_store.db
3740
```
41+
42+
> The `staging_location` can point to object storage (like S3, GCS, or Azure blobs) or a local filesystem directory (e.g., `/tmp/feast/staging`) to spill large materialization outputs before reading them back into Feast.
3843
{% endcode %}
3944

4045
The full set of configuration options is available in [SparkOfflineStoreConfig](https://rtd.feast.dev/en/master/#feast.infra.offline_stores.contrib.spark_offline_store.spark.SparkOfflineStoreConfig).
@@ -60,7 +65,7 @@ Below is a matrix indicating which functionality is supported by `SparkRetrieval
6065
| export to arrow table | yes |
6166
| export to arrow batches | no |
6267
| export to SQL | no |
63-
| export to data lake (S3, GCS, etc.) | no |
68+
| export to data lake (S3, GCS, etc.) | yes |
6469
| export to data warehouse | no |
6570
| export as Spark dataframe | yes |
6671
| local execution of Python-based on-demand transforms | no |

sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py

Lines changed: 166 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
Union,
1717
cast,
1818
)
19+
from urllib.parse import urlparse
1920

2021
if TYPE_CHECKING:
2122
from feast.saved_dataset import ValidationReference
@@ -24,9 +25,10 @@
2425
import pandas
2526
import pandas as pd
2627
import pyarrow
28+
import pyarrow.dataset as ds
2729
import pyarrow.parquet as pq
2830
import pyspark
29-
from pydantic import StrictStr
31+
from pydantic import StrictBool, StrictStr
3032
from pyspark import SparkConf
3133
from pyspark.sql import SparkSession
3234

@@ -66,6 +68,9 @@ class SparkOfflineStoreConfig(FeastConfigBaseModel):
6668
staging_location: Optional[StrictStr] = None
6769
""" Remote path for batch materialization jobs"""
6870

71+
staging_allow_materialize: StrictBool = False
72+
""" Enable use of staging_location during materialization to avoid driver OOM """
73+
6974
region: Optional[StrictStr] = None
7075
""" AWS Region if applicable for s3-based staging locations"""
7176

@@ -445,8 +450,44 @@ def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame:
445450

446451
def _to_arrow_internal(self, timeout: Optional[int] = None) -> pyarrow.Table:
447452
"""Return dataset as pyarrow Table synchronously"""
453+
if self._should_use_staging_for_arrow():
454+
return self._to_arrow_via_staging()
455+
448456
return pyarrow.Table.from_pandas(self._to_df_internal(timeout=timeout))
449457

458+
def _should_use_staging_for_arrow(self) -> bool:
459+
offline_store = getattr(self._config, "offline_store", None)
460+
return bool(
461+
isinstance(offline_store, SparkOfflineStoreConfig)
462+
and getattr(offline_store, "staging_allow_materialize", False)
463+
and getattr(offline_store, "staging_location", None)
464+
)
465+
466+
def _to_arrow_via_staging(self) -> pyarrow.Table:
467+
paths = self.to_remote_storage()
468+
if not paths:
469+
return pyarrow.table({})
470+
471+
parquet_paths = _filter_parquet_files(paths)
472+
if not parquet_paths:
473+
return pyarrow.table({})
474+
475+
normalized_paths = self._normalize_staging_paths(parquet_paths)
476+
dataset = ds.dataset(normalized_paths, format="parquet")
477+
return dataset.to_table()
478+
479+
def _normalize_staging_paths(self, paths: List[str]) -> List[str]:
480+
"""Normalize staging paths for PyArrow datasets."""
481+
normalized = []
482+
for path in paths:
483+
if path.startswith("file://"):
484+
normalized.append(path[len("file://") :])
485+
elif "://" in path:
486+
normalized.append(path)
487+
else:
488+
normalized.append(path)
489+
return normalized
490+
450491
def to_feast_df(
451492
self,
452493
validation_reference: Optional["ValidationReference"] = None,
@@ -508,55 +549,53 @@ def supports_remote_storage_export(self) -> bool:
508549

509550
def to_remote_storage(self) -> List[str]:
510551
"""Currently only works for local and s3-based staging locations"""
511-
if self.supports_remote_storage_export():
512-
sdf: pyspark.sql.DataFrame = self.to_spark_df()
513-
514-
if self._config.offline_store.staging_location.startswith("/"):
515-
local_file_staging_location = os.path.abspath(
516-
self._config.offline_store.staging_location
517-
)
518-
519-
# write to staging location
520-
output_uri = os.path.join(
521-
str(local_file_staging_location), str(uuid.uuid4())
522-
)
523-
sdf.write.parquet(output_uri)
524-
525-
return _list_files_in_folder(output_uri)
526-
elif self._config.offline_store.staging_location.startswith("s3://"):
527-
from feast.infra.utils import aws_utils
528-
529-
spark_compatible_s3_staging_location = (
530-
self._config.offline_store.staging_location.replace(
531-
"s3://", "s3a://"
532-
)
533-
)
534-
535-
# write to staging location
536-
output_uri = os.path.join(
537-
str(spark_compatible_s3_staging_location), str(uuid.uuid4())
538-
)
539-
sdf.write.parquet(output_uri)
540-
541-
return aws_utils.list_s3_files(
542-
self._config.offline_store.region, output_uri
543-
)
544-
elif self._config.offline_store.staging_location.startswith("hdfs://"):
545-
output_uri = os.path.join(
546-
self._config.offline_store.staging_location, str(uuid.uuid4())
547-
)
548-
sdf.write.parquet(output_uri)
549-
spark_session = get_spark_session_or_start_new_with_repoconfig(
550-
store_config=self._config.offline_store
551-
)
552-
return _list_hdfs_files(spark_session, output_uri)
553-
else:
554-
raise NotImplementedError(
555-
"to_remote_storage is only implemented for file://, s3:// and hdfs:// uri schemes"
556-
)
552+
if not self.supports_remote_storage_export():
553+
raise NotImplementedError()
554+
555+
sdf: pyspark.sql.DataFrame = self.to_spark_df()
556+
staging_location = self._config.offline_store.staging_location
557557

558+
if staging_location.startswith("/"):
559+
local_file_staging_location = os.path.abspath(staging_location)
560+
output_uri = os.path.join(local_file_staging_location, str(uuid.uuid4()))
561+
sdf.write.parquet(output_uri)
562+
return _list_files_in_folder(output_uri)
563+
elif staging_location.startswith("s3://"):
564+
from feast.infra.utils import aws_utils
565+
566+
spark_compatible_s3_staging_location = staging_location.replace(
567+
"s3://", "s3a://"
568+
)
569+
output_uri = os.path.join(
570+
spark_compatible_s3_staging_location, str(uuid.uuid4())
571+
)
572+
sdf.write.parquet(output_uri)
573+
s3_uri_for_listing = output_uri.replace("s3a://", "s3://", 1)
574+
return aws_utils.list_s3_files(
575+
self._config.offline_store.region, s3_uri_for_listing
576+
)
577+
elif staging_location.startswith("gs://"):
578+
output_uri = os.path.join(staging_location, str(uuid.uuid4()))
579+
sdf.write.parquet(output_uri)
580+
return _list_gcs_files(output_uri)
581+
elif staging_location.startswith(("wasbs://", "abfs://", "abfss://")) or (
582+
staging_location.startswith("https://")
583+
and ".blob.core.windows.net" in staging_location
584+
):
585+
output_uri = os.path.join(staging_location, str(uuid.uuid4()))
586+
sdf.write.parquet(output_uri)
587+
return _list_azure_files(output_uri)
588+
elif staging_location.startswith("hdfs://"):
589+
output_uri = os.path.join(staging_location, str(uuid.uuid4()))
590+
sdf.write.parquet(output_uri)
591+
spark_session = get_spark_session_or_start_new_with_repoconfig(
592+
store_config=self._config.offline_store
593+
)
594+
return _list_hdfs_files(spark_session, output_uri)
558595
else:
559-
raise NotImplementedError()
596+
raise NotImplementedError(
597+
"to_remote_storage is only implemented for file://, s3://, gs://, azure, and hdfs uri schemes"
598+
)
560599

561600
@property
562601
def metadata(self) -> Optional[RetrievalMetadata]:
@@ -789,6 +828,10 @@ def _list_files_in_folder(folder):
789828
return files
790829

791830

831+
def _filter_parquet_files(paths: List[str]) -> List[str]:
832+
return [path for path in paths if path.endswith(".parquet")]
833+
834+
792835
def _list_hdfs_files(spark_session: SparkSession, uri: str) -> List[str]:
793836
jvm = spark_session._jvm
794837
jsc = spark_session._jsc
@@ -805,6 +848,81 @@ def _list_hdfs_files(spark_session: SparkSession, uri: str) -> List[str]:
805848
return files
806849

807850

851+
def _list_gcs_files(path: str) -> List[str]:
852+
try:
853+
from google.cloud import storage
854+
except ImportError as e:
855+
from feast.errors import FeastExtrasDependencyImportError
856+
857+
raise FeastExtrasDependencyImportError("gcp", str(e))
858+
859+
assert path.startswith("gs://"), "GCS path must start with gs://"
860+
bucket_path = path[len("gs://") :]
861+
if "/" in bucket_path:
862+
bucket, prefix = bucket_path.split("/", 1)
863+
else:
864+
bucket, prefix = bucket_path, ""
865+
866+
client = storage.Client()
867+
bucket_obj = client.bucket(bucket)
868+
blobs = bucket_obj.list_blobs(prefix=prefix)
869+
870+
files = []
871+
for blob in blobs:
872+
if not blob.name.endswith("/"):
873+
files.append(f"gs://{bucket}/{blob.name}")
874+
return files
875+
876+
877+
def _list_azure_files(path: str) -> List[str]:
878+
try:
879+
from azure.identity import DefaultAzureCredential
880+
from azure.storage.blob import BlobServiceClient
881+
except ImportError as e:
882+
from feast.errors import FeastExtrasDependencyImportError
883+
884+
raise FeastExtrasDependencyImportError("azure", str(e))
885+
886+
parsed = urlparse(path)
887+
scheme = parsed.scheme
888+
889+
if scheme in ("wasbs", "abfs", "abfss"):
890+
if "@" not in parsed.netloc:
891+
raise ValueError("Azure staging URI must include container@account host")
892+
container, account_host = parsed.netloc.split("@", 1)
893+
account_url = f"https://{account_host}"
894+
base = f"{scheme}://{container}@{account_host}"
895+
prefix = parsed.path.lstrip("/")
896+
else:
897+
account_url = f"{parsed.scheme}://{parsed.netloc}"
898+
container_and_prefix = parsed.path.lstrip("/").split("/", 1)
899+
container = container_and_prefix[0]
900+
base = f"{account_url}/{container}"
901+
prefix = container_and_prefix[1] if len(container_and_prefix) > 1 else ""
902+
903+
credential = os.environ.get("AZURE_STORAGE_KEY") or os.environ.get(
904+
"AZURE_STORAGE_ACCOUNT_KEY"
905+
)
906+
if credential:
907+
client = BlobServiceClient(account_url=account_url, credential=credential)
908+
else:
909+
default_credential = DefaultAzureCredential(
910+
exclude_shared_token_cache_credential=True
911+
)
912+
client = BlobServiceClient(
913+
account_url=account_url, credential=default_credential
914+
)
915+
916+
container_client = client.get_container_client(container)
917+
blobs = container_client.list_blobs(name_starts_with=prefix if prefix else None)
918+
919+
files = []
920+
for blob in blobs:
921+
if not blob.name.endswith("/"):
922+
files.append(f"{base}/{blob.name}")
923+
return files
924+
925+
808926
def _cast_data_frame(
809927
df_new: pyspark.sql.DataFrame, df_existing: pyspark.sql.DataFrame
810928
) -> pyspark.sql.DataFrame:

0 commit comments

Comments
 (0)