Skip to content
Prev Previous commit
Next Next commit
Restructured code, extended existing test cases
Signed-off-by: Aniket Paluskar <apaluska@redhat.com>
  • Loading branch information
aniketpalu committed Nov 23, 2025
commit 4fbf122ed3cdba4cb85aa11e9405f4c894001e7f
Original file line number Diff line number Diff line change
Expand Up @@ -181,25 +181,11 @@ def get_historical_features(
# This makes date-range retrievals possible without enumerating entities upfront; sources remain bounded by time.
non_entity_mode = entity_df is None
if non_entity_mode:
start_date: Optional[datetime] = kwargs.get("start_date")
end_date: Optional[datetime] = kwargs.get("end_date")

end_date = end_date or datetime.now(timezone.utc)
if start_date is None:
max_ttl_seconds = 0
for fv in feature_views:
if fv.ttl and isinstance(fv.ttl, timedelta):
max_ttl_seconds = max(
max_ttl_seconds, int(fv.ttl.total_seconds())
)
start_date = (
end_date - timedelta(seconds=max_ttl_seconds)
if max_ttl_seconds > 0
else end_date - timedelta(days=30)
)
# Why: derive bounded time window without requiring entities; uses max TTL fallback to constrain scans.
start_date, end_date = _compute_non_entity_dates(feature_views, kwargs)
entity_df_event_timestamp_range = (start_date, end_date)

# Build query contexts so we can reuse entity names and per-view table info consistently.
entity_df_event_timestamp_range = (start_date, end_date)
fv_query_contexts = offline_utils.get_feature_view_query_context(
feature_refs,
feature_views,
Expand All @@ -209,60 +195,25 @@ def get_historical_features(
)

# Collect the union of entity columns required across all feature views.
all_entities: List[str] = []
for ctx in fv_query_contexts:
for e in ctx.entities:
if e not in all_entities:
all_entities.append(e)
all_entities = _gather_all_entities(fv_query_contexts)

# Build a UNION DISTINCT of per-feature-view entity projections, time-bounded and partition-pruned.
start_date_str = _format_datetime(start_date)
end_date_str = _format_datetime(end_date)
per_view_selects: List[str] = []
for fv, ctx, date_format in zip(
feature_views, fv_query_contexts, date_partition_column_formats
):
from_expression = fv.batch_source.get_table_query_string()
timestamp_field = fv.batch_source.timestamp_field or "event_timestamp"
date_partition_column = fv.batch_source.date_partition_column
partition_clause = ""
if date_partition_column:
partition_clause = (
f" AND {date_partition_column} >= '{start_date.strftime(date_format)}'"
f" AND {date_partition_column} <= '{end_date.strftime(date_format)}'"
)
# Select all required entity columns, filling missing ones with NULL to keep UNION schemas aligned.
select_entities = []
ctx_entities_set = set(ctx.entities)
for col in all_entities:
if col in ctx_entities_set:
# Cast entity columns to STRING to guarantee UNION schema alignment across sources.
select_entities.append(f"CAST({col} AS STRING) AS {col}")
else:
select_entities.append(f"CAST(NULL AS STRING) AS {col}")

per_view_selects.append(
f"""
SELECT DISTINCT {", ".join(select_entities)}
FROM {from_expression}
WHERE {timestamp_field} BETWEEN TIMESTAMP('{start_date_str}') AND TIMESTAMP('{end_date_str}'){partition_clause}
"""
)

union_query = "\nUNION DISTINCT\n".join(
[s.strip() for s in per_view_selects]
)
spark_session.sql(
f"CREATE OR REPLACE TEMPORARY VIEW {tmp_entity_df_table_name} AS {union_query}"
_create_temp_entity_union_view(
spark_session=spark_session,
tmp_view_name=tmp_entity_df_table_name,
feature_views=feature_views,
fv_query_contexts=fv_query_contexts,
start_date=start_date,
end_date=end_date,
date_partition_column_formats=date_partition_column_formats,
)

# Add a stable as-of timestamp column for PIT joins.
left_table_query_string = f"(SELECT *, TIMESTAMP('{_format_datetime(end_date)}') AS entity_ts FROM {tmp_entity_df_table_name})"
event_timestamp_col = "entity_ts"
# Why: Keep type consistent with entity_df branch (dict KeysView[str]) to satisfy typing and downstream usage.
entity_schema_keys = cast(
KeysView[str],
{k: None for k in (all_entities + [event_timestamp_col])}.keys(),
left_table_query_string, event_timestamp_col = _make_left_table_query(
end_date=end_date, tmp_view_name=tmp_entity_df_table_name
)
entity_schema_keys = _entity_schema_keys_from(
all_entities=all_entities, event_timestamp_col=event_timestamp_col
)
else:
entity_schema = _get_entity_schema(
Expand Down Expand Up @@ -633,6 +584,109 @@ def get_spark_session_or_start_new_with_repoconfig(
return spark_session


def _compute_non_entity_dates(
feature_views: List[FeatureView], kwargs: Dict[str, Any]
) -> Tuple[datetime, datetime]:
# Why: bounds the scan window when no entity_df is provided using explicit dates or max TTL fallback.
start_date: Optional[datetime] = kwargs.get("start_date")
end_date: Optional[datetime] = kwargs.get("end_date") or datetime.now(timezone.utc)

if start_date is None:
max_ttl_seconds = 0
for fv in feature_views:
if fv.ttl and isinstance(fv.ttl, timedelta):
max_ttl_seconds = max(max_ttl_seconds, int(fv.ttl.total_seconds()))
start_date = (
end_date - timedelta(seconds=max_ttl_seconds)
if max_ttl_seconds > 0
else end_date - timedelta(days=30)
)
return start_date, end_date


def _gather_all_entities(
fv_query_contexts: List[offline_utils.FeatureViewQueryContext],
) -> List[str]:
# Why: ensure a unified entity set across feature views to align UNION schemas.
all_entities: List[str] = []
for ctx in fv_query_contexts:
for e in ctx.entities:
if e not in all_entities:
all_entities.append(e)
return all_entities


def _create_temp_entity_union_view(
spark_session: SparkSession,
tmp_view_name: str,
feature_views: List[FeatureView],
fv_query_contexts: List[offline_utils.FeatureViewQueryContext],
start_date: datetime,
end_date: datetime,
date_partition_column_formats: List[Optional[str]],
) -> None:
# Why: derive distinct entity keys observed in the time window without requiring an entity_df upfront.
start_date_str = _format_datetime(start_date)
end_date_str = _format_datetime(end_date)

# Compute the unified entity set to align schemas in the UNION.
all_entities = _gather_all_entities(fv_query_contexts)

per_view_selects: List[str] = []
for fv, ctx, date_format in zip(
feature_views, fv_query_contexts, date_partition_column_formats
):
assert isinstance(fv.batch_source, SparkSource)
from_expression = fv.batch_source.get_table_query_string()
timestamp_field = fv.batch_source.timestamp_field or "event_timestamp"
date_partition_column = fv.batch_source.date_partition_column
partition_clause = ""
if date_partition_column and date_format:
partition_clause = (
f" AND {date_partition_column} >= '{start_date.strftime(date_format)}'"
f" AND {date_partition_column} <= '{end_date.strftime(date_format)}'"
)

# Fill missing entity columns with NULL and cast to STRING to keep UNION schemas aligned.
select_entities: List[str] = []
ctx_entities_set = set(ctx.entities)
for col in all_entities:
if col in ctx_entities_set:
select_entities.append(f"CAST({col} AS STRING) AS {col}")
else:
select_entities.append(f"CAST(NULL AS STRING) AS {col}")

per_view_selects.append(
f"""
SELECT DISTINCT {", ".join(select_entities)}
FROM {from_expression}
WHERE {timestamp_field} BETWEEN TIMESTAMP('{start_date_str}') AND TIMESTAMP('{end_date_str}'){partition_clause}
"""
)

union_query = "\nUNION DISTINCT\n".join([s.strip() for s in per_view_selects])
spark_session.sql(
f"CREATE OR REPLACE TEMPORARY VIEW {tmp_view_name} AS {union_query}"
)


def _make_left_table_query(end_date: datetime, tmp_view_name: str) -> Tuple[str, str]:
# Why: use a stable as-of timestamp for PIT joins when no entity timestamps are provided.
event_timestamp_col = "entity_ts"
left_table_query_string = (
f"(SELECT *, TIMESTAMP('{_format_datetime(end_date)}') AS {event_timestamp_col} "
f"FROM {tmp_view_name})"
)
return left_table_query_string, event_timestamp_col


def _entity_schema_keys_from(
all_entities: List[str], event_timestamp_col: str
) -> KeysView[str]:
# Why: pass a KeysView[str] to PIT query builder to match entity_df branch typing.
return cast(KeysView[str], {k: None for k in (all_entities + [event_timestamp_col])}.keys())


def _get_entity_df_event_timestamp_range(
entity_df: Union[pd.DataFrame, str],
entity_df_event_timestamp_col: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,3 +339,168 @@ def _mock_entity():
value_type=ValueType.INT64,
)
]


@patch(
"feast.infra.offline_stores.contrib.spark_offline_store.spark.get_spark_session_or_start_new_with_repoconfig"
)
def test_get_historical_features_non_entity_with_date_range(mock_get_spark_session):
mock_spark_session = MagicMock()
# Return a DataFrame for any sql call; last call is used by RetrievalJob
final_df = MagicMock()
expected_pdf = pd.DataFrame([{"feature1": 1.0, "feature2": 2.0}])
final_df.toPandas.return_value = expected_pdf
mock_spark_session.sql.return_value = final_df
mock_get_spark_session.return_value = mock_spark_session

test_repo_config = RepoConfig(
project="test_project",
registry="test_registry",
provider="local",
offline_store=SparkOfflineStoreConfig(type="spark"),
)

test_data_source1 = SparkSource(
name="test_nested_batch_source1",
description="test_nested_batch_source",
table="offline_store_database_name.offline_store_table_name1",
timestamp_field="nested_timestamp",
field_mapping={
"event_header.event_published_datetime_utc": "nested_timestamp",
},
date_partition_column="effective_date",
date_partition_column_format="%Y%m%d",
)

test_data_source2 = SparkSource(
name="test_nested_batch_source2",
description="test_nested_batch_source",
table="offline_store_database_name.offline_store_table_name2",
timestamp_field="nested_timestamp",
field_mapping={
"event_header.event_published_datetime_utc": "nested_timestamp",
},
date_partition_column="effective_date",
)

test_feature_view1 = FeatureView(
name="test_feature_view1",
entities=_mock_entity(),
schema=[
Field(name="feature1", dtype=Float32),
],
source=test_data_source1,
)

test_feature_view2 = FeatureView(
name="test_feature_view2",
entities=_mock_entity(),
schema=[
Field(name="feature2", dtype=Float32),
],
source=test_data_source2,
)

mock_registry = MagicMock()
start_date = datetime(2021, 1, 1)
end_date = datetime(2021, 1, 2)
retrieval_job = SparkOfflineStore.get_historical_features(
config=test_repo_config,
feature_views=[test_feature_view2, test_feature_view1],
feature_refs=["test_feature_view2:feature2", "test_feature_view1:feature1"],
entity_df=None,
registry=mock_registry,
project="test_project",
start_date=start_date,
end_date=end_date,
)

# Verify query bounded by end_date correctly in both date formats from the two sources
query = retrieval_job.query
assert "effective_date <= '2021-01-02'" in query
assert "effective_date <= '20210102'" in query

# Verify data: the mocked Spark DataFrame flows through to Pandas
pdf = retrieval_job._to_df_internal()
assert pdf.equals(expected_pdf)


@patch(
"feast.infra.offline_stores.contrib.spark_offline_store.spark.get_spark_session_or_start_new_with_repoconfig"
)
def test_get_historical_features_non_entity_with_only_end_date(mock_get_spark_session):
mock_spark_session = MagicMock()
final_df = MagicMock()
expected_pdf = pd.DataFrame([{"feature1": 10.0, "feature2": 20.0}])
final_df.toPandas.return_value = expected_pdf
mock_spark_session.sql.return_value = final_df
mock_get_spark_session.return_value = mock_spark_session

test_repo_config = RepoConfig(
project="test_project",
registry="test_registry",
provider="local",
offline_store=SparkOfflineStoreConfig(type="spark"),
)

test_data_source1 = SparkSource(
name="test_nested_batch_source1",
description="test_nested_batch_source",
table="offline_store_database_name.offline_store_table_name1",
timestamp_field="nested_timestamp",
field_mapping={
"event_header.event_published_datetime_utc": "nested_timestamp",
},
date_partition_column="effective_date",
date_partition_column_format="%Y%m%d",
)

test_data_source2 = SparkSource(
name="test_nested_batch_source2",
description="test_nested_batch_source",
table="offline_store_database_name.offline_store_table_name2",
timestamp_field="nested_timestamp",
field_mapping={
"event_header.event_published_datetime_utc": "nested_timestamp",
},
date_partition_column="effective_date",
)

test_feature_view1 = FeatureView(
name="test_feature_view1",
entities=_mock_entity(),
schema=[
Field(name="feature1", dtype=Float32),
],
source=test_data_source1,
)

test_feature_view2 = FeatureView(
name="test_feature_view2",
entities=_mock_entity(),
schema=[
Field(name="feature2", dtype=Float32),
],
source=test_data_source2,
)

mock_registry = MagicMock()
end_date = datetime(2021, 1, 2)
retrieval_job = SparkOfflineStore.get_historical_features(
config=test_repo_config,
feature_views=[test_feature_view2, test_feature_view1],
feature_refs=["test_feature_view2:feature2", "test_feature_view1:feature1"],
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aniketpalu Did we validate the join works from multiple feature views in a feature service ?

@jyejare
We have tested this scenario and even added it in test cases.

entity_df=None,
registry=mock_registry,
project="test_project",
end_date=end_date,
)

# Verify query bounded by end_date correctly for both sources
query = retrieval_job.query
assert "effective_date <= '2021-01-02'" in query
assert "effective_date <= '20210102'" in query

# Verify data: mocked DataFrame flows to Pandas
pdf = retrieval_job._to_df_internal()
assert pdf.equals(expected_pdf)
Loading