Skip to content

Commit c7a4664

Browse files
Add support for multiple entities in Redshift (#1850)
* Modify universal historical retrieval integration test to use multiple entities Signed-off-by: Felix Wang <wangfelix98@gmail.com> * Update Redshift historical retrieval query to allow for multiple entities Signed-off-by: Felix Wang <wangfelix98@gmail.com> * Modify universal online retrieval integration test to use multiple entities Signed-off-by: Felix Wang <wangfelix98@gmail.com> * Modify Redis online store to serialize entity keys properly Signed-off-by: Felix Wang <wangfelix98@gmail.com> * Fix bug in universal online tests Signed-off-by: Felix Wang <wangfelix98@gmail.com> * Small fixes Signed-off-by: Felix Wang <wangfelix98@gmail.com> * Small fix Signed-off-by: Felix Wang <wangfelix98@gmail.com>
1 parent bf5dc7d commit c7a4664

6 files changed

Lines changed: 151 additions & 67 deletions

File tree

sdk/python/feast/infra/offline_stores/redshift.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -377,9 +377,9 @@ def _upload_entity_df_and_get_entity_schema(
377377
{{entity_df_event_timestamp_col}} AS entity_timestamp
378378
{% for featureview in featureviews %}
379379
{% if featureview.entities %}
380-
,CONCAT(
380+
,(
381381
{% for entity in featureview.entities %}
382-
CAST({{entity}} AS VARCHAR),
382+
CAST({{entity}} as VARCHAR) ||
383383
{% endfor %}
384384
CAST({{entity_df_event_timestamp_col}} AS VARCHAR)
385385
) AS {{featureview.name}}__entity_row_unique_id

sdk/python/feast/infra/online_stores/helpers.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
import importlib
22
import struct
3-
from typing import Any
3+
from typing import Any, List
44

55
import mmh3
66

77
from feast import errors
88
from feast.infra.key_encoding_utils import serialize_entity_key
99
from feast.infra.online_stores.online_store import OnlineStore
10-
from feast.protos.feast.storage.Redis_pb2 import RedisKeyV2 as RedisKeyProto
1110
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
1211

1312

@@ -37,13 +36,9 @@ def get_online_store_from_config(online_store_config: Any,) -> OnlineStore:
3736
return online_store_class()
3837

3938

40-
def _redis_key(project: str, entity_key: EntityKeyProto):
41-
redis_key = RedisKeyProto(
42-
project=project,
43-
entity_names=entity_key.join_keys,
44-
entity_values=entity_key.entity_values,
45-
)
46-
return redis_key.SerializeToString()
39+
def _redis_key(project: str, entity_key: EntityKeyProto) -> bytes:
40+
key: List[bytes] = [serialize_entity_key(entity_key), project.encode("utf-8")]
41+
return b"".join(key)
4742

4843

4944
def _mmh3(key: str):

sdk/python/tests/integration/feature_repos/repo_configuration.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
create_customer_daily_profile_feature_view,
2929
create_driver_hourly_stats_feature_view,
3030
create_global_stats_feature_view,
31+
create_order_feature_view,
3132
)
3233

3334

@@ -94,17 +95,19 @@ def construct_universal_datasets(
9495
orders_df = driver_test_data.create_orders_df(
9596
customers=entities["customer"],
9697
drivers=entities["driver"],
97-
start_date=end_time - timedelta(days=3),
98-
end_date=end_time + timedelta(days=3),
98+
start_date=start_time,
99+
end_date=end_time,
99100
order_count=20,
100101
)
101102
global_df = driver_test_data.create_global_daily_stats_df(start_time, end_time)
103+
entity_df = orders_df[["customer_id", "driver_id", "order_id", "event_timestamp"]]
102104

103105
return {
104106
"customer": customer_df,
105107
"driver": driver_df,
106108
"orders": orders_df,
107109
"global": global_df,
110+
"entity": entity_df,
108111
}
109112

110113

@@ -127,7 +130,7 @@ def construct_universal_data_sources(
127130
datasets["orders"],
128131
destination_name="orders",
129132
event_timestamp_column="event_timestamp",
130-
created_timestamp_column="created",
133+
created_timestamp_column=None,
131134
)
132135
global_ds = data_source_creator.create_data_source(
133136
datasets["global"],
@@ -161,6 +164,7 @@ def construct_universal_feature_views(
161164
"input_request": create_conv_rate_request_data_source(),
162165
}
163166
),
167+
"order": create_order_feature_view(data_sources["orders"]),
164168
}
165169

166170

sdk/python/tests/integration/feature_repos/universal/feature_views.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,15 @@ def create_global_stats_feature_view(source, infer_features: bool = False):
117117
ttl=timedelta(days=2),
118118
)
119119
return global_stats_feature_view
120+
121+
122+
def create_order_feature_view(source, infer_features: bool = False):
123+
return FeatureView(
124+
name="order",
125+
entities=["driver", "customer_id"],
126+
features=None
127+
if infer_features
128+
else [Feature(name="order_is_success", dtype=ValueType.INT32)],
129+
batch_source=source,
130+
ttl=timedelta(days=2),
131+
)

sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py

Lines changed: 83 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from datetime import datetime
2-
from typing import Any, Dict, List
2+
from typing import Any, Dict, List, Optional
33

44
import numpy as np
55
import pandas as pd
@@ -37,14 +37,23 @@ def find_asof_record(
3737
ts_key: str,
3838
ts_start: datetime,
3939
ts_end: datetime,
40-
filter_key: str = "",
41-
filter_value: Any = None,
40+
filter_keys: Optional[List[str]] = None,
41+
filter_values: Optional[List[Any]] = None,
4242
) -> Dict[str, Any]:
43+
filter_keys = filter_keys or []
44+
filter_values = filter_values or []
45+
assert len(filter_keys) == len(filter_values)
4346
found_record = {}
4447
for record in records:
4548
if (
46-
not filter_key or record[filter_key] == filter_value
47-
) and ts_start <= record[ts_key] <= ts_end:
49+
all(
50+
[
51+
record[filter_key] == filter_value
52+
for filter_key, filter_value in zip(filter_keys, filter_values)
53+
]
54+
)
55+
and ts_start <= record[ts_key] <= ts_end
56+
):
4857
if not found_record or found_record[ts_key] < record[ts_key]:
4958
found_record = record
5059
return found_record
@@ -55,43 +64,57 @@ def get_expected_training_df(
5564
customer_fv: FeatureView,
5665
driver_df: pd.DataFrame,
5766
driver_fv: FeatureView,
67+
orders_df: pd.DataFrame,
68+
order_fv: FeatureView,
5869
global_df: pd.DataFrame,
5970
global_fv: FeatureView,
60-
orders_df: pd.DataFrame,
71+
entity_df: pd.DataFrame,
6172
event_timestamp: str,
6273
full_feature_names: bool = False,
6374
):
6475
# Convert all pandas dataframes into records with UTC timestamps
65-
order_records = convert_timestamp_records_to_utc(
66-
orders_df.to_dict("records"), event_timestamp
76+
customer_records = convert_timestamp_records_to_utc(
77+
customer_df.to_dict("records"), customer_fv.batch_source.event_timestamp_column
6778
)
6879
driver_records = convert_timestamp_records_to_utc(
6980
driver_df.to_dict("records"), driver_fv.batch_source.event_timestamp_column
7081
)
71-
customer_records = convert_timestamp_records_to_utc(
72-
customer_df.to_dict("records"), customer_fv.batch_source.event_timestamp_column
82+
order_records = convert_timestamp_records_to_utc(
83+
orders_df.to_dict("records"), event_timestamp
7384
)
7485
global_records = convert_timestamp_records_to_utc(
7586
global_df.to_dict("records"), global_fv.batch_source.event_timestamp_column
7687
)
88+
entity_rows = convert_timestamp_records_to_utc(
89+
entity_df.to_dict("records"), event_timestamp
90+
)
7791

78-
# Manually do point-in-time join of orders to drivers and customers records
79-
for order_record in order_records:
92+
# Manually do point-in-time join of driver, customer, and order records against
93+
# the entity df
94+
for entity_row in entity_rows:
95+
customer_record = find_asof_record(
96+
customer_records,
97+
ts_key=customer_fv.batch_source.event_timestamp_column,
98+
ts_start=entity_row[event_timestamp] - customer_fv.ttl,
99+
ts_end=entity_row[event_timestamp],
100+
filter_keys=["customer_id"],
101+
filter_values=[entity_row["customer_id"]],
102+
)
80103
driver_record = find_asof_record(
81104
driver_records,
82105
ts_key=driver_fv.batch_source.event_timestamp_column,
83-
ts_start=order_record[event_timestamp] - driver_fv.ttl,
84-
ts_end=order_record[event_timestamp],
85-
filter_key="driver_id",
86-
filter_value=order_record["driver_id"],
106+
ts_start=entity_row[event_timestamp] - driver_fv.ttl,
107+
ts_end=entity_row[event_timestamp],
108+
filter_keys=["driver_id"],
109+
filter_values=[entity_row["driver_id"]],
87110
)
88-
customer_record = find_asof_record(
89-
customer_records,
111+
order_record = find_asof_record(
112+
order_records,
90113
ts_key=customer_fv.batch_source.event_timestamp_column,
91-
ts_start=order_record[event_timestamp] - customer_fv.ttl,
92-
ts_end=order_record[event_timestamp],
93-
filter_key="customer_id",
94-
filter_value=order_record["customer_id"],
114+
ts_start=entity_row[event_timestamp] - order_fv.ttl,
115+
ts_end=entity_row[event_timestamp],
116+
filter_keys=["customer_id", "driver_id"],
117+
filter_values=[entity_row["customer_id"], entity_row["driver_id"]],
95118
)
96119
global_record = find_asof_record(
97120
global_records,
@@ -100,15 +123,7 @@ def get_expected_training_df(
100123
ts_end=order_record[event_timestamp],
101124
)
102125

103-
order_record.update(
104-
{
105-
(f"driver_stats__{k}" if full_feature_names else k): driver_record.get(
106-
k, None
107-
)
108-
for k in ("conv_rate", "avg_daily_trips")
109-
}
110-
)
111-
order_record.update(
126+
entity_row.update(
112127
{
113128
(
114129
f"customer_profile__{k}" if full_feature_names else k
@@ -120,7 +135,21 @@ def get_expected_training_df(
120135
)
121136
}
122137
)
123-
order_record.update(
138+
entity_row.update(
139+
{
140+
(f"driver_stats__{k}" if full_feature_names else k): driver_record.get(
141+
k, None
142+
)
143+
for k in ("conv_rate", "avg_daily_trips")
144+
}
145+
)
146+
entity_row.update(
147+
{
148+
(f"order__{k}" if full_feature_names else k): order_record.get(k, None)
149+
for k in ("order_is_success",)
150+
}
151+
)
152+
entity_row.update(
124153
{
125154
(f"global_stats__{k}" if full_feature_names else k): global_record.get(
126155
k, None
@@ -130,7 +159,7 @@ def get_expected_training_df(
130159
)
131160

132161
# Convert records back to pandas dataframe
133-
expected_df = pd.DataFrame(order_records)
162+
expected_df = pd.DataFrame(entity_rows)
134163

135164
# Move "event_timestamp" column to front
136165
current_cols = expected_df.columns.tolist()
@@ -140,7 +169,7 @@ def get_expected_training_df(
140169
# Cast some columns to expected types, since we lose information when converting pandas DFs into Python objects.
141170
if full_feature_names:
142171
expected_column_types = {
143-
"order_is_success": "int32",
172+
"order__order_is_success": "int32",
144173
"driver_stats__conv_rate": "float32",
145174
"customer_profile__current_balance": "float32",
146175
"customer_profile__avg_passenger_count": "float32",
@@ -175,20 +204,23 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
175204
(entities, datasets, data_sources) = universal_data_sources
176205
feature_views = construct_universal_feature_views(data_sources)
177206

178-
customer_df, driver_df, orders_df, global_df = (
207+
customer_df, driver_df, orders_df, global_df, entity_df = (
179208
datasets["customer"],
180209
datasets["driver"],
181210
datasets["orders"],
182211
datasets["global"],
212+
datasets["entity"],
183213
)
184-
orders_df_with_request_data = orders_df.copy(deep=True)
185-
orders_df_with_request_data["val_to_add"] = [
186-
i for i in range(len(orders_df_with_request_data))
214+
entity_df_with_request_data = entity_df.copy(deep=True)
215+
entity_df_with_request_data["val_to_add"] = [
216+
i for i in range(len(entity_df_with_request_data))
187217
]
188-
customer_fv, driver_fv, driver_odfv, global_fv = (
218+
219+
customer_fv, driver_fv, driver_odfv, order_fv, global_fv = (
189220
feature_views["customer"],
190221
feature_views["driver"],
191222
feature_views["driver_odfv"],
223+
feature_views["order"],
192224
feature_views["global"],
193225
)
194226

@@ -203,6 +235,7 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
203235
customer_fv,
204236
driver_fv,
205237
driver_odfv,
238+
order_fv,
206239
global_fv,
207240
driver(),
208241
customer(),
@@ -214,7 +247,7 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
214247
entity_df_query = None
215248
orders_table = table_name_from_data_source(data_sources["orders"])
216249
if orders_table:
217-
entity_df_query = f"SELECT * FROM {orders_table}"
250+
entity_df_query = f"SELECT customer_id, driver_id, order_id, event_timestamp FROM {orders_table}"
218251

219252
event_timestamp = (
220253
DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL
@@ -226,9 +259,11 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
226259
customer_fv,
227260
driver_df,
228261
driver_fv,
262+
orders_df,
263+
order_fv,
229264
global_df,
230265
global_fv,
231-
orders_df_with_request_data,
266+
entity_df_with_request_data,
232267
event_timestamp,
233268
full_feature_names,
234269
)
@@ -242,6 +277,7 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
242277
"customer_profile:current_balance",
243278
"customer_profile:avg_passenger_count",
244279
"customer_profile:lifetime_trip_count",
280+
"order:order_is_success",
245281
"global_stats:num_rides",
246282
"global_stats:avg_ride_length",
247283
],
@@ -297,7 +333,7 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
297333
assert_frame_equal(expected_df_query, df_from_sql_entities)
298334

299335
job_from_df = store.get_historical_features(
300-
entity_df=orders_df_with_request_data,
336+
entity_df=entity_df_with_request_data,
301337
features=[
302338
"driver_stats:conv_rate",
303339
"driver_stats:avg_daily_trips",
@@ -306,6 +342,7 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
306342
"customer_profile:lifetime_trip_count",
307343
"conv_rate_plus_100:conv_rate_plus_100",
308344
"conv_rate_plus_100:conv_rate_plus_val_to_add",
345+
"order:order_is_success",
309346
"global_stats:num_rides",
310347
"global_stats:avg_ride_length",
311348
],
@@ -341,7 +378,7 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
341378
store,
342379
feature_service,
343380
full_feature_names,
344-
orders_df_with_request_data,
381+
entity_df_with_request_data,
345382
expected_df,
346383
event_timestamp,
347384
)
@@ -361,7 +398,7 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
361398
# If request data is missing that's needed for on demand transform, throw an error
362399
with pytest.raises(RequestDataNotFoundInEntityDfException):
363400
store.get_historical_features(
364-
entity_df=orders_df,
401+
entity_df=entity_df,
365402
features=[
366403
"driver_stats:conv_rate",
367404
"driver_stats:avg_daily_trips",
@@ -388,11 +425,11 @@ def response_feature_name(feature: str, full_feature_names: bool) -> str:
388425

389426

390427
def assert_feature_service_correctness(
391-
store, feature_service, full_feature_names, orders_df, expected_df, event_timestamp
428+
store, feature_service, full_feature_names, entity_df, expected_df, event_timestamp
392429
):
393430

394431
job_from_df = store.get_historical_features(
395-
entity_df=orders_df,
432+
entity_df=entity_df,
396433
features=feature_service,
397434
full_feature_names=full_feature_names,
398435
)

0 commit comments

Comments
 (0)