Skip to content

Commit b60f6e4

Browse files
jklegarwoop
andauthored
Support join keys in historical feature retrieval (feast-dev#1440)
* Support join keys in historical feature retrieval Signed-off-by: Willem Pienaar <git@willem.co> * Rebase join key support Signed-off-by: Willem Pienaar <git@willem.co> * Remove unused methods Signed-off-by: Willem Pienaar <git@willem.co> Co-authored-by: Willem Pienaar <git@willem.co>
1 parent 22bf06c commit b60f6e4

File tree

10 files changed

+49
-13
lines changed

10 files changed

+49
-13
lines changed

sdk/python/feast/feature_store.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,12 @@ def get_historical_features(
274274
feature_views = _get_requested_feature_views(feature_refs, all_feature_views)
275275
provider = self._get_provider()
276276
job = provider.get_historical_features(
277-
self.config, feature_views, feature_refs, entity_df
277+
self.config,
278+
feature_views,
279+
feature_refs,
280+
entity_df,
281+
self._registry,
282+
self.project,
278283
)
279284
return job
280285

sdk/python/feast/infra/gcp.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,8 @@ def get_historical_features(
185185
feature_views: List[FeatureView],
186186
feature_refs: List[str],
187187
entity_df: Union[pandas.DataFrame, str],
188+
registry: Registry,
189+
project: str,
188190
) -> RetrievalJob:
189191
offline_store = get_offline_store_from_sources(
190192
[feature_view.input for feature_view in feature_views]
@@ -194,6 +196,8 @@ def get_historical_features(
194196
feature_views=feature_views,
195197
feature_refs=feature_refs,
196198
entity_df=entity_df,
199+
registry=registry,
200+
project=project,
197201
)
198202
return job
199203

sdk/python/feast/infra/local.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,8 @@ def get_historical_features(
196196
feature_views: List[FeatureView],
197197
feature_refs: List[str],
198198
entity_df: Union[pd.DataFrame, str],
199+
registry: Registry,
200+
project: str,
199201
) -> RetrievalJob:
200202
offline_store = get_offline_store_from_sources(
201203
[feature_view.input for feature_view in feature_views]
@@ -205,6 +207,8 @@ def get_historical_features(
205207
feature_views=feature_views,
206208
feature_refs=feature_refs,
207209
entity_df=entity_df,
210+
registry=registry,
211+
project=project,
208212
)
209213

210214

sdk/python/feast/infra/offline_stores/bigquery.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
RetrievalJob,
1616
_get_requested_feature_views_to_features_dict,
1717
)
18+
from feast.registry import Registry
1819
from feast.repo_config import RepoConfig
1920

2021

@@ -70,6 +71,8 @@ def get_historical_features(
7071
feature_views: List[FeatureView],
7172
feature_refs: List[str],
7273
entity_df: Union[pandas.DataFrame, str],
74+
registry: Registry,
75+
project: str,
7376
) -> RetrievalJob:
7477
# TODO: Add entity_df validation in order to fail before interacting with BigQuery
7578

@@ -85,7 +88,9 @@ def get_historical_features(
8588
)
8689

8790
# Build a query context containing all information required to template the BigQuery SQL query
88-
query_context = get_feature_view_query_context(feature_refs, feature_views)
91+
query_context = get_feature_view_query_context(
92+
feature_refs, feature_views, registry, project
93+
)
8994

9095
# TODO: Infer min_timestamp and max_timestamp from entity_df
9196
# Generate the BigQuery SQL query from the query context
@@ -155,7 +160,10 @@ def _upload_entity_df_into_bigquery(project, entity_df) -> str:
155160

156161

157162
def get_feature_view_query_context(
158-
feature_refs: List[str], feature_views: List[FeatureView]
163+
feature_refs: List[str],
164+
feature_views: List[FeatureView],
165+
registry: Registry,
166+
project: str,
159167
) -> List[FeatureViewQueryContext]:
160168
"""Build a query context containing all information required to template a BigQuery point-in-time SQL query"""
161169

@@ -165,7 +173,10 @@ def get_feature_view_query_context(
165173

166174
query_context = []
167175
for feature_view, features in feature_views_to_feature_map.items():
168-
entity_names = [entity for entity in feature_view.entities]
176+
join_keys = []
177+
for entity_name in feature_view.entities:
178+
entity = registry.get_entity(entity_name, project)
179+
join_keys.append(entity.join_key)
169180

170181
if isinstance(feature_view.ttl, timedelta):
171182
ttl_seconds = int(feature_view.ttl.total_seconds())
@@ -177,7 +188,7 @@ def get_feature_view_query_context(
177188
context = FeatureViewQueryContext(
178189
name=feature_view.name,
179190
ttl=ttl_seconds,
180-
entities=entity_names,
191+
entities=join_keys,
181192
features=features,
182193
table_ref=feature_view.input.table_ref,
183194
event_timestamp_column=feature_view.input.event_timestamp_column,

sdk/python/feast/infra/offline_stores/file.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
ENTITY_DF_EVENT_TIMESTAMP_COL,
1313
_get_requested_feature_views_to_features_dict,
1414
)
15+
from feast.registry import Registry
1516
from feast.repo_config import RepoConfig
1617

1718

@@ -35,6 +36,8 @@ def get_historical_features(
3536
feature_views: List[FeatureView],
3637
feature_refs: List[str],
3738
entity_df: Union[pd.DataFrame, str],
39+
registry: Registry,
40+
project: str,
3841
) -> FileRetrievalJob:
3942
if not isinstance(entity_df, pd.DataFrame):
4043
raise ValueError(
@@ -80,7 +83,11 @@ def evaluate_historical_retrieval():
8083
)
8184

8285
# Build a list of entity columns to join on (from the right table)
83-
right_entity_columns = [entity for entity in feature_view.entities]
86+
join_keys = []
87+
for entity_name in feature_view.entities:
88+
entity = registry.get_entity(entity_name, project)
89+
join_keys.append(entity.join_key)
90+
right_entity_columns = join_keys
8491
right_entity_key_columns = [
8592
event_timestamp_column
8693
] + right_entity_columns

sdk/python/feast/infra/offline_stores/offline_store.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from feast.data_source import DataSource
2222
from feast.feature_view import FeatureView
23+
from feast.registry import Registry
2324
from feast.repo_config import RepoConfig
2425

2526

@@ -63,5 +64,7 @@ def get_historical_features(
6364
feature_views: List[FeatureView],
6465
feature_refs: List[str],
6566
entity_df: Union[pd.DataFrame, str],
67+
registry: Registry,
68+
project: str,
6669
) -> RetrievalJob:
6770
pass

sdk/python/feast/infra/provider.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ def get_historical_features(
9999
feature_views: List[FeatureView],
100100
feature_refs: List[str],
101101
entity_df: Union[pandas.DataFrame, str],
102+
registry: Registry,
103+
project: str,
102104
) -> RetrievalJob:
103105
pass
104106

sdk/python/feast/registry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def list_entities(self, project: str, allow_cache: bool = False) -> List[Entity]
105105
entities.append(Entity.from_proto(entity_proto))
106106
return entities
107107

108-
def get_entity(self, name: str, project: str) -> Entity:
108+
def get_entity(self, name: str, project: str, allow_cache: bool = False) -> Entity:
109109
"""
110110
Retrieves an entity.
111111
@@ -117,7 +117,7 @@ def get_entity(self, name: str, project: str) -> Entity:
117117
Returns either the specified entity, or raises an exception if
118118
none is found
119119
"""
120-
registry_proto = self._get_registry_proto()
120+
registry_proto = self._get_registry_proto(allow_cache=allow_cache)
121121
for entity_proto in registry_proto.entities:
122122
if entity_proto.spec.name == name and entity_proto.spec.project == project:
123123
return Entity.from_proto(entity_proto)

sdk/python/tests/test_historical_retrieval.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def stage_driver_hourly_stats_bigquery_source(df, table_id):
5757
def create_driver_hourly_stats_feature_view(source):
5858
driver_stats_feature_view = FeatureView(
5959
name="driver_stats",
60-
entities=["driver_id"],
60+
entities=["driver"],
6161
features=[
6262
Feature(name="conv_rate", dtype=ValueType.FLOAT),
6363
Feature(name="acc_rate", dtype=ValueType.FLOAT),
@@ -226,8 +226,8 @@ def test_historical_features_from_parquet_sources():
226226
temp_dir, customer_df
227227
)
228228
customer_fv = create_customer_daily_profile_feature_view(customer_source)
229-
driver = Entity(name="driver", value_type=ValueType.INT64)
230-
customer = Entity(name="customer", value_type=ValueType.INT64)
229+
driver = Entity(name="driver", join_key="driver_id", value_type=ValueType.INT64)
230+
customer = Entity(name="customer_id", value_type=ValueType.INT64)
231231

232232
store = FeatureStore(
233233
config=RepoConfig(
@@ -331,8 +331,8 @@ def test_historical_features_from_bigquery_sources(provider_type):
331331
)
332332
customer_fv = create_customer_daily_profile_feature_view(customer_source)
333333

334-
driver = Entity(name="driver", value_type=ValueType.INT64)
335-
customer = Entity(name="customer", value_type=ValueType.INT64)
334+
driver = Entity(name="driver", join_key="driver_id", value_type=ValueType.INT64)
335+
customer = Entity(name="customer_id", value_type=ValueType.INT64)
336336

337337
if provider_type == "local":
338338
store = FeatureStore(

sdk/python/tests/test_materialize_from_bigquery_to_datastore.py

Whitespace-only changes.

0 commit comments

Comments
 (0)