Skip to content

Commit b60d47c

Browse files
fix(spark): handle bucket-root S3A paths, read credentials from spark config, add session token support
Signed-off-by: abhijeet-dhumal <abhijeetdhumal652@gmail.com>
1 parent 448212d commit b60d47c

3 files changed

Lines changed: 225 additions & 81 deletions

File tree

sdk/python/feast/infra/compute_engines/spark/utils.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,13 @@
1111
from feast.infra.common.serde import SerializedArtifacts
1212
from feast.utils import _convert_arrow_to_proto, _run_pyarrow_field_mapping
1313

14+
try:
15+
import boto3
16+
from botocore.client import Config as BotoConfig
17+
except ImportError:
18+
boto3 = None # type: ignore[assignment]
19+
BotoConfig = None # type: ignore[assignment,misc]
20+
1421
logger = logging.getLogger(__name__)
1522

1623

@@ -35,26 +42,38 @@ def _ensure_s3a_event_log_dir(spark_config: Dict[str, str]) -> None:
3542

3643
path = event_dir[len("s3a://") :]
3744
bucket, _, prefix = path.partition("/")
38-
prefix = prefix.rstrip("/") + "/"
45+
prefix = prefix.rstrip("/")
46+
prefix = (prefix + "/") if prefix else prefix
3947
placeholder_key = prefix + ".keep"
4048

4149
endpoint = spark_config.get(
4250
"spark.hadoop.fs.s3a.endpoint",
4351
os.environ.get("FEAST_S3A_ENDPOINT", ""),
4452
)
45-
access_key = os.environ.get("AWS_ACCESS_KEY_ID", "")
46-
secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "")
53+
access_key = spark_config.get(
54+
"spark.hadoop.fs.s3a.access.key",
55+
os.environ.get("AWS_ACCESS_KEY_ID", ""),
56+
)
57+
secret_key = spark_config.get(
58+
"spark.hadoop.fs.s3a.secret.key",
59+
os.environ.get("AWS_SECRET_ACCESS_KEY", ""),
60+
)
61+
session_token = spark_config.get(
62+
"spark.hadoop.fs.s3a.session.token",
63+
os.environ.get("AWS_SESSION_TOKEN", ""),
64+
) or None
4765

4866
try:
49-
import boto3
50-
from botocore.client import Config
67+
if boto3 is None:
68+
raise ImportError("boto3 is not installed")
5169

5270
s3 = boto3.client(
5371
"s3",
5472
endpoint_url=endpoint if endpoint else None,
5573
aws_access_key_id=access_key or None,
5674
aws_secret_access_key=secret_key or None,
57-
config=Config(signature_version="s3v4"),
75+
aws_session_token=session_token,
76+
config=BotoConfig(signature_version="s3v4"),
5877
)
5978
resp = s3.list_objects_v2(Bucket=bucket, Prefix=prefix, MaxKeys=1)
6079
if resp.get("KeyCount", 0) == 0:
@@ -81,9 +100,6 @@ def get_or_create_new_spark_session(
81100
if not spark_session:
82101
spark_builder = SparkSession.builder
83102
if spark_config:
84-
# Spark's EventLogFileWriter.requireLogBaseDirAsDirectory() is called
85-
# during SparkContext.__init__ and will crash if the S3A event log
86-
# prefix doesn't exist yet. Ensure the prefix exists first.
87103
_ensure_s3a_event_log_dir(spark_config)
88104
spark_builder = spark_builder.config(
89105
conf=SparkConf().setAll([(k, v) for k, v in spark_config.items()])

sdk/python/tests/component/spark/test_compute.py

Lines changed: 0 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from feast.infra.common.retrieval_task import HistoricalRetrievalTask
1616
from feast.infra.compute_engines.spark.compute import SparkComputeEngine
1717
from feast.infra.compute_engines.spark.job import SparkDAGRetrievalJob
18-
from feast.infra.compute_engines.spark.utils import _ensure_s3a_event_log_dir
1918
from feast.infra.offline_stores.contrib.spark_offline_store.spark import (
2019
SparkOfflineStore,
2120
)
@@ -193,76 +192,5 @@ def tqdm_builder(length):
193192
spark_environment.teardown()
194193

195194

196-
# ---------------------------------------------------------------------------
197-
# Unit tests for _ensure_s3a_event_log_dir — no Spark dependency needed
198-
# ---------------------------------------------------------------------------
199-
200-
201-
def _base_conf(event_log_dir: str) -> dict:
202-
return {
203-
"spark.eventLog.enabled": "true",
204-
"spark.eventLog.dir": event_log_dir,
205-
"spark.hadoop.fs.s3a.endpoint": "http://minio:9000",
206-
}
207-
208-
209-
@patch("feast.infra.compute_engines.spark.utils.boto3")
210-
def test_ensure_s3a_event_log_dir_creates_placeholder_when_empty(mock_boto3):
211-
"""S3A prefix doesn't exist → placeholder object is written."""
212-
s3 = MagicMock()
213-
mock_boto3.client.return_value = s3
214-
s3.list_objects_v2.return_value = {"KeyCount": 0}
215-
216-
_ensure_s3a_event_log_dir(_base_conf("s3a://my-bucket/spark-events/"))
217-
218-
s3.list_objects_v2.assert_called_once_with(
219-
Bucket="my-bucket", Prefix="spark-events/", MaxKeys=1
220-
)
221-
s3.put_object.assert_called_once_with(
222-
Bucket="my-bucket", Key="spark-events/.keep", Body=b""
223-
)
224-
225-
226-
@patch("feast.infra.compute_engines.spark.utils.boto3")
227-
def test_ensure_s3a_event_log_dir_skips_when_prefix_exists(mock_boto3):
228-
"""S3A prefix already has objects → no placeholder written."""
229-
s3 = MagicMock()
230-
mock_boto3.client.return_value = s3
231-
s3.list_objects_v2.return_value = {"KeyCount": 3}
232-
233-
_ensure_s3a_event_log_dir(_base_conf("s3a://my-bucket/spark-events/"))
234-
235-
s3.put_object.assert_not_called()
236-
237-
238-
@patch("feast.infra.compute_engines.spark.utils.boto3")
239-
def test_ensure_s3a_event_log_dir_noop_when_event_log_disabled(mock_boto3):
240-
"""spark.eventLog.enabled != true → boto3 never called."""
241-
_ensure_s3a_event_log_dir(
242-
{"spark.eventLog.enabled": "false", "spark.eventLog.dir": "s3a://b/p/"}
243-
)
244-
mock_boto3.client.assert_not_called()
245-
246-
247-
@patch("feast.infra.compute_engines.spark.utils.boto3")
248-
def test_ensure_s3a_event_log_dir_noop_for_non_s3a_path(mock_boto3):
249-
"""Non-S3A paths (hdfs://, file://, etc.) are left untouched."""
250-
_ensure_s3a_event_log_dir(
251-
{"spark.eventLog.enabled": "true", "spark.eventLog.dir": "hdfs:///spark-logs"}
252-
)
253-
mock_boto3.client.assert_not_called()
254-
255-
256-
@patch("feast.infra.compute_engines.spark.utils.boto3")
257-
def test_ensure_s3a_event_log_dir_non_fatal_on_s3_error(mock_boto3):
258-
"""boto3 errors are swallowed — SparkContext will surface its own error."""
259-
s3 = MagicMock()
260-
mock_boto3.client.return_value = s3
261-
s3.list_objects_v2.side_effect = Exception("connection refused")
262-
263-
# Must not raise
264-
_ensure_s3a_event_log_dir(_base_conf("s3a://my-bucket/spark-events/"))
265-
266-
267195
if __name__ == "__main__":
268196
test_spark_compute_engine_get_historical_features()
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
from unittest.mock import MagicMock, patch
2+
3+
from feast.infra.compute_engines.spark.utils import _ensure_s3a_event_log_dir
4+
5+
BOTO3_PATH = "feast.infra.compute_engines.spark.utils.boto3"
6+
BOTOCONFIG_PATH = "feast.infra.compute_engines.spark.utils.BotoConfig"
7+
8+
9+
def _base_conf(event_log_dir: str) -> dict:
10+
return {
11+
"spark.eventLog.enabled": "true",
12+
"spark.eventLog.dir": event_log_dir,
13+
"spark.hadoop.fs.s3a.endpoint": "http://minio:9000",
14+
}
15+
16+
17+
@patch(BOTOCONFIG_PATH, MagicMock())
18+
@patch(BOTO3_PATH)
19+
def test_ensure_s3a_event_log_dir_creates_placeholder_when_empty(mock_boto3):
20+
"""S3A prefix doesn't exist -> placeholder object is written."""
21+
s3 = MagicMock()
22+
mock_boto3.client.return_value = s3
23+
s3.list_objects_v2.return_value = {"KeyCount": 0}
24+
25+
_ensure_s3a_event_log_dir(_base_conf("s3a://my-bucket/spark-events/"))
26+
27+
s3.list_objects_v2.assert_called_once_with(
28+
Bucket="my-bucket", Prefix="spark-events/", MaxKeys=1
29+
)
30+
s3.put_object.assert_called_once_with(
31+
Bucket="my-bucket", Key="spark-events/.keep", Body=b""
32+
)
33+
34+
35+
@patch(BOTOCONFIG_PATH, MagicMock())
36+
@patch(BOTO3_PATH)
37+
def test_ensure_s3a_event_log_dir_skips_when_prefix_exists(mock_boto3):
38+
"""S3A prefix already has objects -> no placeholder written."""
39+
s3 = MagicMock()
40+
mock_boto3.client.return_value = s3
41+
s3.list_objects_v2.return_value = {"KeyCount": 3}
42+
43+
_ensure_s3a_event_log_dir(_base_conf("s3a://my-bucket/spark-events/"))
44+
45+
s3.put_object.assert_not_called()
46+
47+
48+
@patch(BOTOCONFIG_PATH, MagicMock())
49+
@patch(BOTO3_PATH)
50+
def test_ensure_s3a_event_log_dir_noop_when_event_log_disabled(mock_boto3):
51+
"""spark.eventLog.enabled != true -> boto3 never called."""
52+
_ensure_s3a_event_log_dir(
53+
{"spark.eventLog.enabled": "false", "spark.eventLog.dir": "s3a://b/p/"}
54+
)
55+
mock_boto3.client.assert_not_called()
56+
57+
58+
@patch(BOTOCONFIG_PATH, MagicMock())
59+
@patch(BOTO3_PATH)
60+
def test_ensure_s3a_event_log_dir_noop_for_non_s3a_path(mock_boto3):
61+
"""Non-S3A paths (hdfs://, file://, etc.) are left untouched."""
62+
_ensure_s3a_event_log_dir(
63+
{"spark.eventLog.enabled": "true", "spark.eventLog.dir": "hdfs:///spark-logs"}
64+
)
65+
mock_boto3.client.assert_not_called()
66+
67+
68+
@patch(BOTOCONFIG_PATH, MagicMock())
69+
@patch(BOTO3_PATH)
70+
def test_ensure_s3a_event_log_dir_non_fatal_on_s3_error(mock_boto3):
71+
"""boto3 errors are swallowed -> SparkContext will surface its own error."""
72+
s3 = MagicMock()
73+
mock_boto3.client.return_value = s3
74+
s3.list_objects_v2.side_effect = Exception("connection refused")
75+
76+
_ensure_s3a_event_log_dir(_base_conf("s3a://my-bucket/spark-events/"))
77+
78+
79+
# ---------------------------------------------------------------------------
80+
# Bucket-root edge cases (s3a://bucket, s3a://bucket/)
81+
# ---------------------------------------------------------------------------
82+
83+
84+
@patch(BOTOCONFIG_PATH, MagicMock())
85+
@patch(BOTO3_PATH)
86+
def test_ensure_s3a_event_log_dir_bucket_root_no_trailing_slash(mock_boto3):
87+
"""s3a://bucket (no path) -> .keep at bucket root, not /.keep."""
88+
s3 = MagicMock()
89+
mock_boto3.client.return_value = s3
90+
s3.list_objects_v2.return_value = {"KeyCount": 0}
91+
92+
_ensure_s3a_event_log_dir(_base_conf("s3a://my-bucket"))
93+
94+
s3.list_objects_v2.assert_called_once_with(
95+
Bucket="my-bucket", Prefix="", MaxKeys=1
96+
)
97+
s3.put_object.assert_called_once_with(
98+
Bucket="my-bucket", Key=".keep", Body=b""
99+
)
100+
101+
102+
@patch(BOTOCONFIG_PATH, MagicMock())
103+
@patch(BOTO3_PATH)
104+
def test_ensure_s3a_event_log_dir_bucket_root_trailing_slash(mock_boto3):
105+
"""s3a://bucket/ (trailing slash, empty prefix) -> .keep at bucket root."""
106+
s3 = MagicMock()
107+
mock_boto3.client.return_value = s3
108+
s3.list_objects_v2.return_value = {"KeyCount": 0}
109+
110+
_ensure_s3a_event_log_dir(_base_conf("s3a://my-bucket/"))
111+
112+
s3.list_objects_v2.assert_called_once_with(
113+
Bucket="my-bucket", Prefix="", MaxKeys=1
114+
)
115+
s3.put_object.assert_called_once_with(
116+
Bucket="my-bucket", Key=".keep", Body=b""
117+
)
118+
119+
120+
# ---------------------------------------------------------------------------
121+
# Credentials from spark config / env var fallback
122+
# ---------------------------------------------------------------------------
123+
124+
125+
@patch.dict(
126+
"os.environ",
127+
{
128+
"AWS_ACCESS_KEY_ID": "env-ak",
129+
"AWS_SECRET_ACCESS_KEY": "env-sk",
130+
"AWS_SESSION_TOKEN": "env-st",
131+
},
132+
)
133+
@patch(BOTOCONFIG_PATH, MagicMock())
134+
@patch(BOTO3_PATH)
135+
def test_ensure_s3a_event_log_dir_uses_spark_config_credentials(mock_boto3):
136+
"""Credentials in spark config take precedence over env vars."""
137+
s3 = MagicMock()
138+
mock_boto3.client.return_value = s3
139+
s3.list_objects_v2.return_value = {"KeyCount": 1}
140+
141+
conf = {
142+
**_base_conf("s3a://my-bucket/logs/"),
143+
"spark.hadoop.fs.s3a.access.key": "spark-ak",
144+
"spark.hadoop.fs.s3a.secret.key": "spark-sk",
145+
"spark.hadoop.fs.s3a.session.token": "spark-st",
146+
}
147+
_ensure_s3a_event_log_dir(conf)
148+
149+
mock_boto3.client.assert_called_once()
150+
kwargs = mock_boto3.client.call_args
151+
assert kwargs.kwargs["aws_access_key_id"] == "spark-ak"
152+
assert kwargs.kwargs["aws_secret_access_key"] == "spark-sk"
153+
assert kwargs.kwargs["aws_session_token"] == "spark-st"
154+
155+
156+
@patch.dict(
157+
"os.environ",
158+
{
159+
"AWS_ACCESS_KEY_ID": "env-ak",
160+
"AWS_SECRET_ACCESS_KEY": "env-sk",
161+
"AWS_SESSION_TOKEN": "env-st",
162+
},
163+
)
164+
@patch(BOTOCONFIG_PATH, MagicMock())
165+
@patch(BOTO3_PATH)
166+
def test_ensure_s3a_event_log_dir_falls_back_to_env_credentials(mock_boto3):
167+
"""Without spark config keys, env vars are used."""
168+
s3 = MagicMock()
169+
mock_boto3.client.return_value = s3
170+
s3.list_objects_v2.return_value = {"KeyCount": 1}
171+
172+
_ensure_s3a_event_log_dir(_base_conf("s3a://my-bucket/logs/"))
173+
174+
mock_boto3.client.assert_called_once()
175+
kwargs = mock_boto3.client.call_args
176+
assert kwargs.kwargs["aws_access_key_id"] == "env-ak"
177+
assert kwargs.kwargs["aws_secret_access_key"] == "env-sk"
178+
assert kwargs.kwargs["aws_session_token"] == "env-st"
179+
180+
181+
@patch.dict("os.environ", {}, clear=True)
182+
@patch(BOTOCONFIG_PATH, MagicMock())
183+
@patch(BOTO3_PATH)
184+
def test_ensure_s3a_event_log_dir_no_credentials_passes_none(mock_boto3):
185+
"""No credentials anywhere -> None passed to boto3 (anonymous / instance role)."""
186+
s3 = MagicMock()
187+
mock_boto3.client.return_value = s3
188+
s3.list_objects_v2.return_value = {"KeyCount": 1}
189+
190+
conf = {
191+
"spark.eventLog.enabled": "true",
192+
"spark.eventLog.dir": "s3a://my-bucket/logs/",
193+
}
194+
_ensure_s3a_event_log_dir(conf)
195+
196+
mock_boto3.client.assert_called_once()
197+
kwargs = mock_boto3.client.call_args
198+
assert kwargs.kwargs["aws_access_key_id"] is None
199+
assert kwargs.kwargs["aws_secret_access_key"] is None
200+
assert kwargs.kwargs["aws_session_token"] is None

0 commit comments

Comments
 (0)