Skip to content

Commit c9aca2d

Browse files
authored
fix: Invalid column names in get_historical_features when there are field mappings on join keys (#4886)
* Map join key to original column name in field mapping. Signed-off-by: Aloysius Lim <aloysius.lim@cuezen.com> * Add test. Signed-off-by: Aloysius Lim <aloysius.lim@cuezen.com> * Format. Signed-off-by: Aloysius Lim <aloysius.lim@cuezen.com> --------- Signed-off-by: Aloysius Lim <aloysius.lim@cuezen.com>
1 parent 6607d3d commit c9aca2d

File tree

2 files changed

+106
-5
lines changed

2 files changed

+106
-5
lines changed

sdk/python/feast/infra/offline_stores/offline_utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,23 +118,27 @@ def get_feature_view_query_context(
118118

119119
query_context = []
120120
for feature_view, features in feature_views_to_feature_map.items():
121+
reverse_field_mapping = {
122+
v: k for k, v in feature_view.batch_source.field_mapping.items()
123+
}
124+
121125
join_keys: List[str] = []
122126
entity_selections: List[str] = []
123127
for entity_column in feature_view.entity_columns:
124128
join_key = feature_view.projection.join_key_map.get(
125129
entity_column.name, entity_column.name
126130
)
127131
join_keys.append(join_key)
128-
entity_selections.append(f"{entity_column.name} AS {join_key}")
132+
entity_selections.append(
133+
f"{reverse_field_mapping.get(entity_column.name, entity_column.name)} "
134+
f"AS {join_key}"
135+
)
129136

130137
if isinstance(feature_view.ttl, timedelta):
131138
ttl_seconds = int(feature_view.ttl.total_seconds())
132139
else:
133140
ttl_seconds = 0
134141

135-
reverse_field_mapping = {
136-
v: k for k, v in feature_view.batch_source.field_mapping.items()
137-
}
138142
features = [reverse_field_mapping.get(feature, feature) for feature in features]
139143
timestamp_field = reverse_field_mapping.get(
140144
feature_view.batch_source.timestamp_field,

sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from feast.infra.offline_stores.offline_utils import (
1515
DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL,
1616
)
17-
from feast.types import Float32, Int32
17+
from feast.types import Float32, Int32, String
1818
from feast.utils import _utc_now
1919
from tests.integration.feature_repos.repo_configuration import (
2020
construct_universal_feature_views,
@@ -639,3 +639,100 @@ def test_historical_features_containing_backfills(environment):
639639
actual_df,
640640
sort_by=["driver_id"],
641641
)
642+
643+
644+
@pytest.mark.integration
645+
@pytest.mark.universal_offline_stores
646+
@pytest.mark.parametrize("full_feature_names", [True, False], ids=lambda v: str(v))
647+
def test_historical_features_field_mapping(
648+
environment, universal_data_sources, full_feature_names
649+
):
650+
store = environment.feature_store
651+
652+
# (entities, datasets, data_sources) = universal_data_sources
653+
# feature_views = construct_universal_feature_views(data_sources)
654+
655+
now = datetime.now().replace(microsecond=0, second=0, minute=0)
656+
tomorrow = now + timedelta(days=1)
657+
day_after_tomorrow = now + timedelta(days=2)
658+
659+
entity_df = pd.DataFrame(
660+
data=[
661+
{"driver_id": 1001, "event_timestamp": day_after_tomorrow},
662+
{"driver_id": 1002, "event_timestamp": day_after_tomorrow},
663+
]
664+
)
665+
666+
driver_stats_df = pd.DataFrame(
667+
data=[
668+
{
669+
"id": 1001,
670+
"avg_daily_trips": 20,
671+
"event_timestamp": now,
672+
"created": tomorrow,
673+
},
674+
{
675+
"id": 1002,
676+
"avg_daily_trips": 40,
677+
"event_timestamp": tomorrow,
678+
"created": now,
679+
},
680+
]
681+
)
682+
683+
expected_df = pd.DataFrame(
684+
data=[
685+
{
686+
"driver_id": 1001,
687+
"event_timestamp": day_after_tomorrow,
688+
"avg_daily_trips": 20,
689+
},
690+
{
691+
"driver_id": 1002,
692+
"event_timestamp": day_after_tomorrow,
693+
"avg_daily_trips": 40,
694+
},
695+
]
696+
)
697+
698+
driver_stats_data_source = environment.data_source_creator.create_data_source(
699+
df=driver_stats_df,
700+
destination_name=f"test_driver_stats_{int(time.time_ns())}_{random.randint(1000, 9999)}",
701+
timestamp_field="event_timestamp",
702+
created_timestamp_column="created",
703+
# Map original "id" column to "driver_id" join key
704+
field_mapping={"id": "driver_id"},
705+
)
706+
707+
driver = Entity(name="driver", join_keys=["driver_id"])
708+
driver_fv = FeatureView(
709+
name="driver_stats",
710+
entities=[driver],
711+
schema=[
712+
Field(name="driver_id", dtype=String),
713+
Field(name="avg_daily_trips", dtype=Int32),
714+
],
715+
source=driver_stats_data_source,
716+
)
717+
718+
store.apply([driver, driver_fv])
719+
720+
offline_job = store.get_historical_features(
721+
entity_df=entity_df,
722+
features=["driver_stats:avg_daily_trips"],
723+
full_feature_names=False,
724+
)
725+
726+
start_time = _utc_now()
727+
actual_df = offline_job.to_df()
728+
729+
print(f"actual_df shape: {actual_df.shape}")
730+
end_time = _utc_now()
731+
print(str(f"Time to execute job_from_df.to_df() = '{(end_time - start_time)}'\n"))
732+
733+
assert sorted(expected_df.columns) == sorted(actual_df.columns)
734+
validate_dataframes(
735+
expected_df,
736+
actual_df,
737+
sort_by=["driver_id"],
738+
)

0 commit comments

Comments
 (0)