From 28318b109378c6313a7ac256da7370df0f6fd307 Mon Sep 17 00:00:00 2001 From: abhijeet-dhumal Date: Wed, 27 May 2026 14:14:33 +0530 Subject: [PATCH 1/3] fix(spark): S3/GCS PyArrow filesystem for staging paths Signed-off-by: abhijeet-dhumal --- .../contrib/spark_offline_store/spark.py | 53 +++++++++++++++---- 1 file changed, 42 insertions(+), 11 deletions(-) diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py index 3fc675ea402..b8b04d2e2a9 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py @@ -482,21 +482,52 @@ def _to_arrow_via_staging(self) -> pyarrow.Table: if not parquet_paths: return pyarrow.table({}) - normalized_paths = self._normalize_staging_paths(parquet_paths) - dataset = ds.dataset(normalized_paths, format="parquet") + pa_fs, stripped_paths = self._resolve_staging_filesystem(parquet_paths) + dataset = ds.dataset(stripped_paths, format="parquet", filesystem=pa_fs) return dataset.to_table() - def _normalize_staging_paths(self, paths: List[str]) -> List[str]: - """Normalize staging paths for PyArrow datasets.""" + def _resolve_staging_filesystem( + self, paths: List[str] + ) -> Tuple[Optional[pyarrow.fs.FileSystem], List[str]]: + """Return (pyarrow filesystem, prefix-stripped paths) for staging URIs.""" + sample = paths[0] + + if sample.startswith("s3://") or sample.startswith("s3a://"): + import pyarrow.fs as pafs + + endpoint = os.environ.get("AWS_ENDPOINT_URL_S3") or os.environ.get( + "AWS_S3_ENDPOINT", "" + ) + region = getattr( + self._config.offline_store, "region", None + ) or os.environ.get("AWS_DEFAULT_REGION", "us-east-1") + kwargs: Dict[str, Any] = {"region": region} + if endpoint: + kwargs["endpoint_override"] = endpoint.rstrip("/").replace( + "https://", "" + ).replace("http://", "") + kwargs["scheme"] = ( + "https" if endpoint.startswith("https") else "http" + ) + fs = pafs.S3FileSystem(**kwargs) + stripped = [p.replace("s3a://", "").replace("s3://", "") for p in paths] + return fs, stripped + + if sample.startswith("gs://"): + import pyarrow.fs as pafs + + fs = pafs.GcsFileSystem() + stripped = [p[len("gs://") :] for p in paths] + return fs, stripped + + # Local paths normalized = [] - for path in paths: - if path.startswith("file://"): - normalized.append(path[len("file://") :]) - elif "://" in path: - normalized.append(path) + for p in paths: + if p.startswith("file://"): + normalized.append(p[len("file://") :]) else: - normalized.append(path) - return normalized + normalized.append(p) + return None, normalized def to_feast_df( self, From bcb7340de06e3a92717f203a1e1a41b95b243743 Mon Sep 17 00:00:00 2001 From: abhijeet-dhumal Date: Wed, 27 May 2026 17:54:47 +0530 Subject: [PATCH 2/3] style: ruff format spark.py Signed-off-by: abhijeet-dhumal --- .../contrib/spark_offline_store/spark.py | 8 +- .../test_spark_staging_filesystem.py | 134 ++++++++++++++++++ 2 files changed, 137 insertions(+), 5 deletions(-) create mode 100644 sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark_staging_filesystem.py diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py index b8b04d2e2a9..35ef211a6a1 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py @@ -503,12 +503,10 @@ def _resolve_staging_filesystem( ) or os.environ.get("AWS_DEFAULT_REGION", "us-east-1") kwargs: Dict[str, Any] = {"region": region} if endpoint: - kwargs["endpoint_override"] = endpoint.rstrip("/").replace( - "https://", "" - ).replace("http://", "") - kwargs["scheme"] = ( - "https" if endpoint.startswith("https") else "http" + kwargs["endpoint_override"] = ( + endpoint.rstrip("/").replace("https://", "").replace("http://", "") ) + kwargs["scheme"] = "https" if endpoint.startswith("https") else "http" fs = pafs.S3FileSystem(**kwargs) stripped = [p.replace("s3a://", "").replace("s3://", "") for p in paths] return fs, stripped diff --git a/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark_staging_filesystem.py b/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark_staging_filesystem.py new file mode 100644 index 00000000000..5b91646ef7c --- /dev/null +++ b/sdk/python/tests/unit/infra/offline_stores/contrib/spark_offline_store/test_spark_staging_filesystem.py @@ -0,0 +1,134 @@ +""" +Unit tests for SparkRetrievalJob._resolve_staging_filesystem. + +Verifies that the correct PyArrow filesystem and prefix-stripped paths +are returned for S3, S3A, GCS, file://, and plain local paths. +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( + SparkRetrievalJob, +) + + +@pytest.fixture() +def retrieval_job(): + """Minimal SparkRetrievalJob with a mock config that has no offline_store region.""" + job = object.__new__(SparkRetrievalJob) + config = MagicMock() + config.offline_store.region = None + job._config = config + return job + + +class TestResolveS3Filesystem: + def test_s3_scheme_returns_s3_filesystem(self, retrieval_job): + with patch("pyarrow.fs.S3FileSystem") as mock_s3: + mock_s3.return_value = MagicMock(name="s3fs") + fs, paths = retrieval_job._resolve_staging_filesystem( + ["s3://my-bucket/path/a.parquet", "s3://my-bucket/path/b.parquet"] + ) + mock_s3.assert_called_once() + assert fs is mock_s3.return_value + assert paths == ["my-bucket/path/a.parquet", "my-bucket/path/b.parquet"] + + def test_s3a_scheme_strips_prefix(self, retrieval_job): + with patch("pyarrow.fs.S3FileSystem") as mock_s3: + mock_s3.return_value = MagicMock(name="s3fs") + fs, paths = retrieval_job._resolve_staging_filesystem( + ["s3a://bucket/dir/file.parquet"] + ) + assert paths == ["bucket/dir/file.parquet"] + + def test_s3_with_minio_endpoint(self, retrieval_job, monkeypatch): + monkeypatch.setenv("AWS_ENDPOINT_URL_S3", "http://minio.local:9000") + monkeypatch.setenv("AWS_DEFAULT_REGION", "us-east-1") + with patch("pyarrow.fs.S3FileSystem") as mock_s3: + mock_s3.return_value = MagicMock(name="s3fs") + retrieval_job._resolve_staging_filesystem(["s3://bucket/file.parquet"]) + call_kwargs = mock_s3.call_args[1] + assert call_kwargs["endpoint_override"] == "minio.local:9000" + assert call_kwargs["scheme"] == "http" + + def test_s3_with_https_endpoint(self, retrieval_job, monkeypatch): + monkeypatch.setenv("AWS_ENDPOINT_URL_S3", "https://s3.custom.corp") + with patch("pyarrow.fs.S3FileSystem") as mock_s3: + mock_s3.return_value = MagicMock(name="s3fs") + retrieval_job._resolve_staging_filesystem(["s3://bucket/file.parquet"]) + call_kwargs = mock_s3.call_args[1] + assert call_kwargs["endpoint_override"] == "s3.custom.corp" + assert call_kwargs["scheme"] == "https" + + def test_s3_falls_back_to_aws_s3_endpoint_env(self, retrieval_job, monkeypatch): + monkeypatch.delenv("AWS_ENDPOINT_URL_S3", raising=False) + monkeypatch.setenv("AWS_S3_ENDPOINT", "http://legacy-minio:9000") + with patch("pyarrow.fs.S3FileSystem") as mock_s3: + mock_s3.return_value = MagicMock(name="s3fs") + retrieval_job._resolve_staging_filesystem(["s3://bucket/file.parquet"]) + call_kwargs = mock_s3.call_args[1] + assert "endpoint_override" in call_kwargs + + def test_s3_no_endpoint_no_override(self, retrieval_job, monkeypatch): + monkeypatch.delenv("AWS_ENDPOINT_URL_S3", raising=False) + monkeypatch.delenv("AWS_S3_ENDPOINT", raising=False) + with patch("pyarrow.fs.S3FileSystem") as mock_s3: + mock_s3.return_value = MagicMock(name="s3fs") + retrieval_job._resolve_staging_filesystem(["s3://bucket/file.parquet"]) + call_kwargs = mock_s3.call_args[1] + assert "endpoint_override" not in call_kwargs + assert "scheme" not in call_kwargs + + def test_s3_region_from_offline_store_config(self, retrieval_job): + retrieval_job._config.offline_store.region = "eu-west-1" + with patch("pyarrow.fs.S3FileSystem") as mock_s3: + mock_s3.return_value = MagicMock(name="s3fs") + retrieval_job._resolve_staging_filesystem(["s3://bucket/file.parquet"]) + call_kwargs = mock_s3.call_args[1] + assert call_kwargs["region"] == "eu-west-1" + + def test_s3_region_fallback_to_env(self, retrieval_job, monkeypatch): + retrieval_job._config.offline_store.region = None + monkeypatch.setenv("AWS_DEFAULT_REGION", "ap-southeast-1") + with patch("pyarrow.fs.S3FileSystem") as mock_s3: + mock_s3.return_value = MagicMock(name="s3fs") + retrieval_job._resolve_staging_filesystem(["s3://bucket/file.parquet"]) + call_kwargs = mock_s3.call_args[1] + assert call_kwargs["region"] == "ap-southeast-1" + + +class TestResolveGCSFilesystem: + def test_gs_scheme_returns_gcs_filesystem(self, retrieval_job): + with patch("pyarrow.fs.GcsFileSystem") as mock_gcs: + mock_gcs.return_value = MagicMock(name="gcsfs") + fs, paths = retrieval_job._resolve_staging_filesystem( + ["gs://my-bucket/path/a.parquet", "gs://my-bucket/path/b.parquet"] + ) + mock_gcs.assert_called_once() + assert fs is mock_gcs.return_value + assert paths == ["my-bucket/path/a.parquet", "my-bucket/path/b.parquet"] + + +class TestResolveLocalFilesystem: + def test_file_scheme_stripped(self, retrieval_job): + fs, paths = retrieval_job._resolve_staging_filesystem( + ["file:///tmp/staging/a.parquet"] + ) + assert fs is None + assert paths == ["/tmp/staging/a.parquet"] + + def test_plain_local_path_unchanged(self, retrieval_job): + fs, paths = retrieval_job._resolve_staging_filesystem( + ["/tmp/staging/a.parquet", "/tmp/staging/b.parquet"] + ) + assert fs is None + assert paths == ["/tmp/staging/a.parquet", "/tmp/staging/b.parquet"] + + def test_mixed_file_and_plain_paths(self, retrieval_job): + fs, paths = retrieval_job._resolve_staging_filesystem( + ["file:///tmp/a.parquet", "/tmp/b.parquet"] + ) + assert fs is None + assert paths == ["/tmp/a.parquet", "/tmp/b.parquet"] From 63adf09d9cd782d8cc18e6132bff3662cae9156e Mon Sep 17 00:00:00 2001 From: abhijeet-dhumal Date: Mon, 1 Jun 2026 12:16:15 +0530 Subject: [PATCH 3/3] fix(spark): address review feedback on staging filesystem PR Signed-off-by: abhijeet-dhumal --- .../contrib/spark_offline_store/spark.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py index 35ef211a6a1..532437f68a7 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py @@ -492,9 +492,9 @@ def _resolve_staging_filesystem( """Return (pyarrow filesystem, prefix-stripped paths) for staging URIs.""" sample = paths[0] - if sample.startswith("s3://") or sample.startswith("s3a://"): - import pyarrow.fs as pafs + import pyarrow.fs as pafs + if sample.startswith("s3://") or sample.startswith("s3a://"): endpoint = os.environ.get("AWS_ENDPOINT_URL_S3") or os.environ.get( "AWS_S3_ENDPOINT", "" ) @@ -504,16 +504,16 @@ def _resolve_staging_filesystem( kwargs: Dict[str, Any] = {"region": region} if endpoint: kwargs["endpoint_override"] = ( - endpoint.rstrip("/").replace("https://", "").replace("http://", "") + endpoint.rstrip("/") + .removeprefix("https://") + .removeprefix("http://") ) kwargs["scheme"] = "https" if endpoint.startswith("https") else "http" fs = pafs.S3FileSystem(**kwargs) - stripped = [p.replace("s3a://", "").replace("s3://", "") for p in paths] + stripped = [p.removeprefix("s3a://").removeprefix("s3://") for p in paths] return fs, stripped if sample.startswith("gs://"): - import pyarrow.fs as pafs - fs = pafs.GcsFileSystem() stripped = [p[len("gs://") :] for p in paths] return fs, stripped