Skip to content
Next Next commit
feat: Offline Store historical features retrieval based on datetime r…
…ange for spark

Signed-off-by: Aniket Paluskar <apaluska@redhat.com>
  • Loading branch information
aniketpalu committed Nov 13, 2025
commit 6ae5adaf417086518eb74b42ee6ba56491df15ea
3 changes: 3 additions & 0 deletions sdk/python/feast/arrow_error_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ def wrapper(*args, **kwargs):
except Exception as e:
if isinstance(e, FeastError):
raise fl.FlightError(e.to_error_detail())
# Re-raise non-Feast exceptions so Arrow Flight returns a proper error
# instead of allowing the server method to return None.
raise e

return wrapper

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import uuid
import warnings
from dataclasses import asdict, dataclass
from datetime import datetime, timezone
from datetime import datetime, timedelta, timezone
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
KeysView,
List,
Optional,
Tuple,
Expand Down Expand Up @@ -151,10 +152,11 @@ def get_historical_features(
config: RepoConfig,
feature_views: List[FeatureView],
feature_refs: List[str],
entity_df: Union[pandas.DataFrame, str, pyspark.sql.DataFrame],
entity_df: Optional[Union[pandas.DataFrame, str, pyspark.sql.DataFrame]],
registry: BaseRegistry,
project: str,
full_feature_names: bool = False,
**kwargs,
) -> RetrievalJob:
assert isinstance(config.offline_store, SparkOfflineStoreConfig)
date_partition_column_formats = []
Expand All @@ -175,33 +177,124 @@ def get_historical_features(
)
tmp_entity_df_table_name = offline_utils.get_temp_entity_table_name()

entity_schema = _get_entity_schema(
spark_session=spark_session,
entity_df=entity_df,
)
event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df(
entity_schema=entity_schema,
)
entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range(
entity_df,
event_timestamp_col,
spark_session,
)
_upload_entity_df(
spark_session=spark_session,
table_name=tmp_entity_df_table_name,
entity_df=entity_df,
event_timestamp_col=event_timestamp_col,
)
# Non-entity mode: synthesize a left table and timestamp range from start/end dates to avoid requiring entity_df.
# This makes date-range retrievals possible without enumerating entities upfront; sources remain bounded by time.
non_entity_mode = entity_df is None
if non_entity_mode:
start_date: Optional[datetime] = kwargs.get("start_date")
end_date: Optional[datetime] = kwargs.get("end_date")

end_date = end_date or datetime.now(timezone.utc)
if start_date is None:
max_ttl_seconds = 0
for fv in feature_views:
if fv.ttl and isinstance(fv.ttl, timedelta):
max_ttl_seconds = max(
max_ttl_seconds, int(fv.ttl.total_seconds())
)
start_date = (
end_date - timedelta(seconds=max_ttl_seconds)
if max_ttl_seconds > 0
else end_date - timedelta(days=30)
)

expected_join_keys = offline_utils.get_expected_join_keys(
project=project, feature_views=feature_views, registry=registry
)
offline_utils.assert_expected_columns_in_entity_df(
entity_schema=entity_schema,
join_keys=expected_join_keys,
entity_df_event_timestamp_col=event_timestamp_col,
)
# Build query contexts so we can reuse entity names and per-view table info consistently.
entity_df_event_timestamp_range = (start_date, end_date)
fv_query_contexts = offline_utils.get_feature_view_query_context(
feature_refs,
feature_views,
registry,
project,
entity_df_event_timestamp_range,
)

# Collect the union of entity columns required across all feature views.
all_entities: List[str] = []
for ctx in fv_query_contexts:
for e in ctx.entities:
if e not in all_entities:
all_entities.append(e)

# Build a UNION DISTINCT of per-feature-view entity projections, time-bounded and partition-pruned.
start_date_str = _format_datetime(start_date)
end_date_str = _format_datetime(end_date)
per_view_selects: List[str] = []
for fv, ctx, date_format in zip(
feature_views, fv_query_contexts, date_partition_column_formats
):
from_expression = fv.batch_source.get_table_query_string()
timestamp_field = fv.batch_source.timestamp_field or "event_timestamp"
date_partition_column = fv.batch_source.date_partition_column
partition_clause = ""
if date_partition_column:
partition_clause = (
f" AND {date_partition_column} >= '{start_date.strftime(date_format)}'"
f" AND {date_partition_column} <= '{end_date.strftime(date_format)}'"
)
# Select all required entity columns, filling missing ones with NULL to keep UNION schemas aligned.
select_entities = []
ctx_entities_set = set(ctx.entities)
for col in all_entities:
if col in ctx_entities_set:
# Cast entity columns to STRING to guarantee UNION schema alignment across sources.
select_entities.append(f"CAST({col} AS STRING) AS {col}")
else:
select_entities.append(f"CAST(NULL AS STRING) AS {col}")

per_view_selects.append(
f"""
SELECT DISTINCT {", ".join(select_entities)}
FROM {from_expression}
WHERE {timestamp_field} BETWEEN TIMESTAMP('{start_date_str}') AND TIMESTAMP('{end_date_str}'){partition_clause}
"""
)

union_query = "\nUNION DISTINCT\n".join(
[s.strip() for s in per_view_selects]
)
spark_session.sql(
f"CREATE OR REPLACE TEMPORARY VIEW {tmp_entity_df_table_name} AS {union_query}"
)

# Add a stable as-of timestamp column for PIT joins.
left_table_query_string = f"(SELECT *, TIMESTAMP('{_format_datetime(end_date)}') AS entity_ts FROM {tmp_entity_df_table_name})"
event_timestamp_col = "entity_ts"
# Why: Keep type consistent with entity_df branch (dict KeysView[str]) to satisfy typing and downstream usage.
entity_schema_keys = cast(
KeysView[str],
{k: None for k in (all_entities + [event_timestamp_col])}.keys(),
)
else:
entity_schema = _get_entity_schema(
spark_session=spark_session,
entity_df=entity_df,
)
event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df(
entity_schema=entity_schema,
)
entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range(
entity_df,
event_timestamp_col,
spark_session,
)
_upload_entity_df(
spark_session=spark_session,
table_name=tmp_entity_df_table_name,
entity_df=entity_df,
event_timestamp_col=event_timestamp_col,
)
left_table_query_string = tmp_entity_df_table_name
entity_schema_keys = cast(KeysView[str], entity_schema.keys())

if not non_entity_mode:
expected_join_keys = offline_utils.get_expected_join_keys(
project=project, feature_views=feature_views, registry=registry
)
offline_utils.assert_expected_columns_in_entity_df(
entity_schema=entity_schema,
join_keys=expected_join_keys,
entity_df_event_timestamp_col=event_timestamp_col,
)

query_context = offline_utils.get_feature_view_query_context(
feature_refs,
Expand Down Expand Up @@ -232,9 +325,9 @@ def get_historical_features(
feature_view_query_contexts=cast(
List[offline_utils.FeatureViewQueryContext], spark_query_context
),
left_table_query_string=tmp_entity_df_table_name,
left_table_query_string=left_table_query_string,
entity_df_event_timestamp_col=event_timestamp_col,
entity_df_columns=entity_schema.keys(),
entity_df_columns=entity_schema_keys,
query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN,
full_feature_names=full_feature_names,
)
Expand All @@ -248,7 +341,7 @@ def get_historical_features(
),
metadata=RetrievalMetadata(
features=feature_refs,
keys=list(set(entity_schema.keys()) - {event_timestamp_col}),
keys=list(set(entity_schema_keys) - {event_timestamp_col}),
min_event_timestamp=entity_df_event_timestamp_range[0],
max_event_timestamp=entity_df_event_timestamp_range[1],
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,8 @@ def get_table_query_string(self) -> str:
# If both the table query string and the actual query are null, we can load from file.
spark_session = SparkSession.getActiveSession()
if spark_session is None:
raise AssertionError("Could not find an active spark session.")
# Remote mode may not have an active session bound to the thread; create one on demand.
spark_session = SparkSession.builder.getOrCreate()
try:
df = self._load_dataframe_from_path(spark_session)
except Exception:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch

from feast.entity import Entity
from feast.feature_view import FeatureView, Field
from feast.infra.offline_stores.contrib.spark_offline_store.spark import (
SparkOfflineStore,
SparkOfflineStoreConfig,
)
from feast.infra.offline_stores.contrib.spark_offline_store.spark_source import (
SparkSource,
)
from feast.repo_config import RepoConfig
from feast.types import Float32, ValueType


def _mock_spark_offline_store_config():
return SparkOfflineStoreConfig(type="spark")


def _mock_entity():
return [
Entity(
name="user_id",
join_keys=["user_id"],
description="User ID",
value_type=ValueType.INT64,
)
]


def _mock_feature_view():
return FeatureView(
name="user_stats",
entities=_mock_entity(),
schema=[
Field(name="metric", dtype=Float32),
],
source=SparkSource(
name="user_stats_source",
table="default.user_stats",
timestamp_field="event_timestamp",
date_partition_column="ds",
date_partition_column_format="%Y-%m-%d",
),
)


@patch(
"feast.infra.offline_stores.contrib.spark_offline_store.spark.get_spark_session_or_start_new_with_repoconfig"
)
def test_spark_non_entity_historical_retrieval_accepts_dates(mock_get_spark_session):
# Why: Avoid executing real Spark SQL against non-existent tables during unit tests.
mock_spark_session = MagicMock()
mock_get_spark_session.return_value = mock_spark_session
repo_config = RepoConfig(
project="test_project",
registry="test_registry",
provider="local",
offline_store=_mock_spark_offline_store_config(),
)

fv = _mock_feature_view()

retrieval_job = SparkOfflineStore.get_historical_features(
config=repo_config,
feature_views=[fv],
feature_refs=["user_stats:metric"],
entity_df=None, # start/end-only mode
registry=MagicMock(),
project="test_project",
full_feature_names=False,
start_date=datetime(2023, 1, 1, tzinfo=timezone.utc),
end_date=datetime(2023, 1, 2, tzinfo=timezone.utc),
)

from feast.infra.offline_stores.offline_store import RetrievalJob

assert isinstance(retrieval_job, RetrievalJob)
Loading