Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions sdk/python/feast/infra/compute_engines/spark/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging
import os
from typing import Dict, Iterable, Literal, Optional

import pandas as pd
Expand All @@ -9,6 +11,102 @@
from feast.infra.common.serde import SerializedArtifacts
from feast.utils import _convert_arrow_to_proto, _run_pyarrow_field_mapping

try:
import boto3
from botocore.client import Config as BotoConfig
except ImportError:
boto3 = None # type: ignore[assignment]
BotoConfig = None # type: ignore[assignment,misc]

logger = logging.getLogger(__name__)


def _ensure_s3a_event_log_dir(spark_config: Dict[str, str]) -> None:
"""Pre-create the S3A event log prefix before SparkContext initialisation.

Spark's EventLogFileWriter.requireLogBaseDirAsDirectory() is called inside
SparkContext.__init__ and crashes if the S3A path doesn't exist yet (S3 has no
real directories, so an empty prefix returns a 404). This function writes a
zero-byte placeholder so the prefix exists before SparkContext is built.

This is only attempted when:
- spark.eventLog.enabled == "true"
- spark.eventLog.dir starts with "s3a://"
Failures are non-fatal: Spark will surface its own error if the dir is still missing.
"""
if spark_config.get("spark.eventLog.enabled", "false").lower() != "true":
return
event_dir = spark_config.get("spark.eventLog.dir", "")
if not event_dir.startswith("s3a://"):
return

path = event_dir[len("s3a://") :]
bucket, _, prefix = path.partition("/")
prefix = prefix.rstrip("/")
prefix = (prefix + "/") if prefix else prefix
placeholder_key = prefix + ".keep"

endpoint = spark_config.get(
"spark.hadoop.fs.s3a.endpoint",
os.environ.get("AWS_ENDPOINT_URL", ""),
)
access_key = spark_config.get(
"spark.hadoop.fs.s3a.access.key",
os.environ.get("AWS_ACCESS_KEY_ID", ""),
)
secret_key = spark_config.get(
"spark.hadoop.fs.s3a.secret.key",
os.environ.get("AWS_SECRET_ACCESS_KEY", ""),
)
session_token = (
spark_config.get(
"spark.hadoop.fs.s3a.session.token",
os.environ.get("AWS_SESSION_TOKEN", ""),
)
or None
)

try:
if boto3 is None:
raise ImportError("boto3 is not installed")

addressing_style = (
"path"
if spark_config.get(
"spark.hadoop.fs.s3a.path.style.access", "false"
).lower()
== "true"
else "auto"
)

s3 = boto3.client(
"s3",
endpoint_url=endpoint if endpoint else None,
aws_access_key_id=access_key or None,
aws_secret_access_key=secret_key or None,
aws_session_token=session_token,
config=BotoConfig(
signature_version="s3v4",
s3={"addressing_style": addressing_style},
),
)
resp = s3.list_objects_v2(Bucket=bucket, Prefix=prefix, MaxKeys=1)
if resp.get("KeyCount", 0) == 0:
s3.put_object(Bucket=bucket, Key=placeholder_key, Body=b"")
logger.debug(
"Created S3A event log dir placeholder: s3a://%s/%s",
bucket,
placeholder_key,
)
except Exception as exc:
logger.warning(
"Could not pre-create S3A event log dir s3a://%s/%s — "
"SparkContext may fail if the path still doesn't exist: %s",
bucket,
prefix,
exc,
)


def get_or_create_new_spark_session(
spark_config: Optional[Dict[str, str]] = None,
Expand All @@ -17,6 +115,7 @@ def get_or_create_new_spark_session(
if not spark_session:
spark_builder = SparkSession.builder
if spark_config:
_ensure_s3a_event_log_dir(spark_config)
spark_builder = spark_builder.config(
conf=SparkConf().setAll([(k, v) for k, v in spark_config.items()])
)
Expand Down
274 changes: 274 additions & 0 deletions sdk/python/tests/component/spark/test_spark_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
from unittest.mock import MagicMock, patch

from feast.infra.compute_engines.spark.utils import _ensure_s3a_event_log_dir

BOTO3_PATH = "feast.infra.compute_engines.spark.utils.boto3"
BOTOCONFIG_PATH = "feast.infra.compute_engines.spark.utils.BotoConfig"


def _base_conf(event_log_dir: str) -> dict:
return {
"spark.eventLog.enabled": "true",
"spark.eventLog.dir": event_log_dir,
"spark.hadoop.fs.s3a.endpoint": "http://minio:9000",
}


@patch(BOTOCONFIG_PATH, MagicMock())
@patch(BOTO3_PATH)
def test_ensure_s3a_event_log_dir_creates_placeholder_when_empty(mock_boto3):
"""S3A prefix doesn't exist -> placeholder object is written."""
s3 = MagicMock()
mock_boto3.client.return_value = s3
s3.list_objects_v2.return_value = {"KeyCount": 0}

_ensure_s3a_event_log_dir(_base_conf("s3a://my-bucket/spark-events/"))

s3.list_objects_v2.assert_called_once_with(
Bucket="my-bucket", Prefix="spark-events/", MaxKeys=1
)
s3.put_object.assert_called_once_with(
Bucket="my-bucket", Key="spark-events/.keep", Body=b""
)


@patch(BOTOCONFIG_PATH, MagicMock())
@patch(BOTO3_PATH)
def test_ensure_s3a_event_log_dir_skips_when_prefix_exists(mock_boto3):
"""S3A prefix already has objects -> no placeholder written."""
s3 = MagicMock()
mock_boto3.client.return_value = s3
s3.list_objects_v2.return_value = {"KeyCount": 3}

_ensure_s3a_event_log_dir(_base_conf("s3a://my-bucket/spark-events/"))

s3.put_object.assert_not_called()


@patch(BOTOCONFIG_PATH, MagicMock())
@patch(BOTO3_PATH)
def test_ensure_s3a_event_log_dir_noop_when_event_log_disabled(mock_boto3):
"""spark.eventLog.enabled != true -> boto3 never called."""
_ensure_s3a_event_log_dir(
{"spark.eventLog.enabled": "false", "spark.eventLog.dir": "s3a://b/p/"}
)
mock_boto3.client.assert_not_called()


@patch(BOTOCONFIG_PATH, MagicMock())
@patch(BOTO3_PATH)
def test_ensure_s3a_event_log_dir_noop_for_non_s3a_path(mock_boto3):
"""Non-S3A paths (hdfs://, file://, etc.) are left untouched."""
_ensure_s3a_event_log_dir(
{"spark.eventLog.enabled": "true", "spark.eventLog.dir": "hdfs:///spark-logs"}
)
mock_boto3.client.assert_not_called()


@patch(BOTOCONFIG_PATH, MagicMock())
@patch(BOTO3_PATH)
def test_ensure_s3a_event_log_dir_non_fatal_on_s3_error(mock_boto3):
"""boto3 errors are swallowed -> SparkContext will surface its own error."""
s3 = MagicMock()
mock_boto3.client.return_value = s3
s3.list_objects_v2.side_effect = Exception("connection refused")

_ensure_s3a_event_log_dir(_base_conf("s3a://my-bucket/spark-events/"))


# ---------------------------------------------------------------------------
# Bucket-root edge cases (s3a://bucket, s3a://bucket/)
# ---------------------------------------------------------------------------


@patch(BOTOCONFIG_PATH, MagicMock())
@patch(BOTO3_PATH)
def test_ensure_s3a_event_log_dir_bucket_root_no_trailing_slash(mock_boto3):
"""s3a://bucket (no path) -> .keep at bucket root, not /.keep."""
s3 = MagicMock()
mock_boto3.client.return_value = s3
s3.list_objects_v2.return_value = {"KeyCount": 0}

_ensure_s3a_event_log_dir(_base_conf("s3a://my-bucket"))

s3.list_objects_v2.assert_called_once_with(Bucket="my-bucket", Prefix="", MaxKeys=1)
s3.put_object.assert_called_once_with(Bucket="my-bucket", Key=".keep", Body=b"")


@patch(BOTOCONFIG_PATH, MagicMock())
@patch(BOTO3_PATH)
def test_ensure_s3a_event_log_dir_bucket_root_trailing_slash(mock_boto3):
"""s3a://bucket/ (trailing slash, empty prefix) -> .keep at bucket root."""
s3 = MagicMock()
mock_boto3.client.return_value = s3
s3.list_objects_v2.return_value = {"KeyCount": 0}

_ensure_s3a_event_log_dir(_base_conf("s3a://my-bucket/"))

s3.list_objects_v2.assert_called_once_with(Bucket="my-bucket", Prefix="", MaxKeys=1)
s3.put_object.assert_called_once_with(Bucket="my-bucket", Key=".keep", Body=b"")


# ---------------------------------------------------------------------------
# Credentials from spark config / env var fallback
# ---------------------------------------------------------------------------


@patch.dict(
"os.environ",
{
"AWS_ACCESS_KEY_ID": "env-ak",
"AWS_SECRET_ACCESS_KEY": "env-sk", # pragma: allowlist secret
"AWS_SESSION_TOKEN": "env-st",
},
)
@patch(BOTOCONFIG_PATH, MagicMock())
@patch(BOTO3_PATH)
def test_ensure_s3a_event_log_dir_uses_spark_config_credentials(mock_boto3):
"""Credentials in spark config take precedence over env vars."""
s3 = MagicMock()
mock_boto3.client.return_value = s3
s3.list_objects_v2.return_value = {"KeyCount": 1}

conf = {
**_base_conf("s3a://my-bucket/logs/"),
"spark.hadoop.fs.s3a.access.key": "spark-ak",
"spark.hadoop.fs.s3a.secret.key": "spark-sk", # pragma: allowlist secret
"spark.hadoop.fs.s3a.session.token": "spark-st",
}
_ensure_s3a_event_log_dir(conf)

mock_boto3.client.assert_called_once()
kw = mock_boto3.client.call_args.kwargs
assert kw["aws_access_key_id"] == "spark-ak"
assert kw["aws_secret_access_key"] == "spark-sk" # pragma: allowlist secret
assert kw["aws_session_token"] == "spark-st"


@patch.dict(
"os.environ",
{
"AWS_ACCESS_KEY_ID": "env-ak",
"AWS_SECRET_ACCESS_KEY": "env-sk", # pragma: allowlist secret
"AWS_SESSION_TOKEN": "env-st",
},
)
@patch(BOTOCONFIG_PATH, MagicMock())
@patch(BOTO3_PATH)
def test_ensure_s3a_event_log_dir_falls_back_to_env_credentials(mock_boto3):
"""Without spark config keys, env vars are used."""
s3 = MagicMock()
mock_boto3.client.return_value = s3
s3.list_objects_v2.return_value = {"KeyCount": 1}

_ensure_s3a_event_log_dir(_base_conf("s3a://my-bucket/logs/"))

mock_boto3.client.assert_called_once()
kw = mock_boto3.client.call_args.kwargs
assert kw["aws_access_key_id"] == "env-ak"
assert kw["aws_secret_access_key"] == "env-sk" # pragma: allowlist secret
assert kw["aws_session_token"] == "env-st"


@patch.dict("os.environ", {}, clear=True)
@patch(BOTOCONFIG_PATH, MagicMock())
@patch(BOTO3_PATH)
def test_ensure_s3a_event_log_dir_no_credentials_passes_none(mock_boto3):
"""No credentials anywhere -> None passed to boto3 (anonymous / instance role)."""
s3 = MagicMock()
mock_boto3.client.return_value = s3
s3.list_objects_v2.return_value = {"KeyCount": 1}

conf = {
"spark.eventLog.enabled": "true",
"spark.eventLog.dir": "s3a://my-bucket/logs/",
}
_ensure_s3a_event_log_dir(conf)

mock_boto3.client.assert_called_once()
kw = mock_boto3.client.call_args.kwargs
assert kw["aws_access_key_id"] is None
assert kw["aws_secret_access_key"] is None
assert kw["aws_session_token"] is None


# ---------------------------------------------------------------------------
# Path-style addressing (MinIO / S3-compatible)
# ---------------------------------------------------------------------------


@patch(BOTOCONFIG_PATH)
@patch(BOTO3_PATH)
def test_ensure_s3a_event_log_dir_path_style_when_enabled(mock_boto3, mock_config_cls):
"""spark.hadoop.fs.s3a.path.style.access=true -> addressing_style='path'."""
s3 = MagicMock()
mock_boto3.client.return_value = s3
s3.list_objects_v2.return_value = {"KeyCount": 1}

conf = {
**_base_conf("s3a://my-bucket/logs/"),
"spark.hadoop.fs.s3a.path.style.access": "true",
}
_ensure_s3a_event_log_dir(conf)

mock_config_cls.assert_called_once()
config_kwargs = mock_config_cls.call_args
assert config_kwargs.kwargs["s3"] == {"addressing_style": "path"}


@patch(BOTOCONFIG_PATH)
@patch(BOTO3_PATH)
def test_ensure_s3a_event_log_dir_virtual_hosted_style_by_default(
mock_boto3, mock_config_cls
):
"""No path.style.access config -> addressing_style='auto'."""
s3 = MagicMock()
mock_boto3.client.return_value = s3
s3.list_objects_v2.return_value = {"KeyCount": 1}

_ensure_s3a_event_log_dir(_base_conf("s3a://my-bucket/logs/"))

mock_config_cls.assert_called_once()
config_kwargs = mock_config_cls.call_args
assert config_kwargs.kwargs["s3"] == {"addressing_style": "auto"}


# ---------------------------------------------------------------------------
# Endpoint env var fallback (AWS_ENDPOINT_URL)
# ---------------------------------------------------------------------------


@patch.dict("os.environ", {"AWS_ENDPOINT_URL": "http://localhost:9000"}, clear=True)
@patch(BOTOCONFIG_PATH, MagicMock())
@patch(BOTO3_PATH)
def test_ensure_s3a_event_log_dir_endpoint_from_env(mock_boto3):
"""AWS_ENDPOINT_URL env var is used when spark config has no endpoint."""
s3 = MagicMock()
mock_boto3.client.return_value = s3
s3.list_objects_v2.return_value = {"KeyCount": 1}

conf = {
"spark.eventLog.enabled": "true",
"spark.eventLog.dir": "s3a://my-bucket/logs/",
}
_ensure_s3a_event_log_dir(conf)

mock_boto3.client.assert_called_once()
kw = mock_boto3.client.call_args.kwargs
assert kw["endpoint_url"] == "http://localhost:9000"


@patch.dict("os.environ", {"AWS_ENDPOINT_URL": "http://env-endpoint:9000"}, clear=True)
@patch(BOTOCONFIG_PATH, MagicMock())
@patch(BOTO3_PATH)
def test_ensure_s3a_event_log_dir_spark_endpoint_over_env(mock_boto3):
"""spark.hadoop.fs.s3a.endpoint takes precedence over AWS_ENDPOINT_URL."""
s3 = MagicMock()
mock_boto3.client.return_value = s3
s3.list_objects_v2.return_value = {"KeyCount": 1}

_ensure_s3a_event_log_dir(_base_conf("s3a://my-bucket/logs/"))

mock_boto3.client.assert_called_once()
kw = mock_boto3.client.call_args.kwargs
assert kw["endpoint_url"] == "http://minio:9000"
Loading