11from datetime import datetime
2- from typing import Any , Dict , List
2+ from typing import Any , Dict , List , Optional
33
44import numpy as np
55import pandas as pd
@@ -37,14 +37,23 @@ def find_asof_record(
3737 ts_key : str ,
3838 ts_start : datetime ,
3939 ts_end : datetime ,
40- filter_key : str = "" ,
41- filter_value : Any = None ,
40+ filter_keys : Optional [ List [ str ]] = None ,
41+ filter_values : Optional [ List [ Any ]] = None ,
4242) -> Dict [str , Any ]:
43+ filter_keys = filter_keys or []
44+ filter_values = filter_values or []
45+ assert len (filter_keys ) == len (filter_values )
4346 found_record = {}
4447 for record in records :
4548 if (
46- not filter_key or record [filter_key ] == filter_value
47- ) and ts_start <= record [ts_key ] <= ts_end :
49+ all (
50+ [
51+ record [filter_key ] == filter_value
52+ for filter_key , filter_value in zip (filter_keys , filter_values )
53+ ]
54+ )
55+ and ts_start <= record [ts_key ] <= ts_end
56+ ):
4857 if not found_record or found_record [ts_key ] < record [ts_key ]:
4958 found_record = record
5059 return found_record
@@ -55,43 +64,57 @@ def get_expected_training_df(
5564 customer_fv : FeatureView ,
5665 driver_df : pd .DataFrame ,
5766 driver_fv : FeatureView ,
67+ orders_df : pd .DataFrame ,
68+ order_fv : FeatureView ,
5869 global_df : pd .DataFrame ,
5970 global_fv : FeatureView ,
60- orders_df : pd .DataFrame ,
71+ entity_df : pd .DataFrame ,
6172 event_timestamp : str ,
6273 full_feature_names : bool = False ,
6374):
6475 # Convert all pandas dataframes into records with UTC timestamps
65- order_records = convert_timestamp_records_to_utc (
66- orders_df .to_dict ("records" ), event_timestamp
76+ customer_records = convert_timestamp_records_to_utc (
77+ customer_df .to_dict ("records" ), customer_fv . batch_source . event_timestamp_column
6778 )
6879 driver_records = convert_timestamp_records_to_utc (
6980 driver_df .to_dict ("records" ), driver_fv .batch_source .event_timestamp_column
7081 )
71- customer_records = convert_timestamp_records_to_utc (
72- customer_df .to_dict ("records" ), customer_fv . batch_source . event_timestamp_column
82+ order_records = convert_timestamp_records_to_utc (
83+ orders_df .to_dict ("records" ), event_timestamp
7384 )
7485 global_records = convert_timestamp_records_to_utc (
7586 global_df .to_dict ("records" ), global_fv .batch_source .event_timestamp_column
7687 )
88+ entity_rows = convert_timestamp_records_to_utc (
89+ entity_df .to_dict ("records" ), event_timestamp
90+ )
7791
78- # Manually do point-in-time join of orders to drivers and customers records
79- for order_record in order_records :
92+ # Manually do point-in-time join of driver, customer, and order records against
93+ # the entity df
94+ for entity_row in entity_rows :
95+ customer_record = find_asof_record (
96+ customer_records ,
97+ ts_key = customer_fv .batch_source .event_timestamp_column ,
98+ ts_start = entity_row [event_timestamp ] - customer_fv .ttl ,
99+ ts_end = entity_row [event_timestamp ],
100+ filter_keys = ["customer_id" ],
101+ filter_values = [entity_row ["customer_id" ]],
102+ )
80103 driver_record = find_asof_record (
81104 driver_records ,
82105 ts_key = driver_fv .batch_source .event_timestamp_column ,
83- ts_start = order_record [event_timestamp ] - driver_fv .ttl ,
84- ts_end = order_record [event_timestamp ],
85- filter_key = "driver_id" ,
86- filter_value = order_record [ "driver_id" ],
106+ ts_start = entity_row [event_timestamp ] - driver_fv .ttl ,
107+ ts_end = entity_row [event_timestamp ],
108+ filter_keys = [ "driver_id" ] ,
109+ filter_values = [ entity_row [ "driver_id" ] ],
87110 )
88- customer_record = find_asof_record (
89- customer_records ,
111+ order_record = find_asof_record (
112+ order_records ,
90113 ts_key = customer_fv .batch_source .event_timestamp_column ,
91- ts_start = order_record [event_timestamp ] - customer_fv .ttl ,
92- ts_end = order_record [event_timestamp ],
93- filter_key = "customer_id" ,
94- filter_value = order_record [ "customer_id" ],
114+ ts_start = entity_row [event_timestamp ] - order_fv .ttl ,
115+ ts_end = entity_row [event_timestamp ],
116+ filter_keys = [ "customer_id" , "driver_id" ] ,
117+ filter_values = [ entity_row [ "customer_id" ], entity_row [ "driver_id" ] ],
95118 )
96119 global_record = find_asof_record (
97120 global_records ,
@@ -100,15 +123,7 @@ def get_expected_training_df(
100123 ts_end = order_record [event_timestamp ],
101124 )
102125
103- order_record .update (
104- {
105- (f"driver_stats__{ k } " if full_feature_names else k ): driver_record .get (
106- k , None
107- )
108- for k in ("conv_rate" , "avg_daily_trips" )
109- }
110- )
111- order_record .update (
126+ entity_row .update (
112127 {
113128 (
114129 f"customer_profile__{ k } " if full_feature_names else k
@@ -120,7 +135,21 @@ def get_expected_training_df(
120135 )
121136 }
122137 )
123- order_record .update (
138+ entity_row .update (
139+ {
140+ (f"driver_stats__{ k } " if full_feature_names else k ): driver_record .get (
141+ k , None
142+ )
143+ for k in ("conv_rate" , "avg_daily_trips" )
144+ }
145+ )
146+ entity_row .update (
147+ {
148+ (f"order__{ k } " if full_feature_names else k ): order_record .get (k , None )
149+ for k in ("order_is_success" ,)
150+ }
151+ )
152+ entity_row .update (
124153 {
125154 (f"global_stats__{ k } " if full_feature_names else k ): global_record .get (
126155 k , None
@@ -130,7 +159,7 @@ def get_expected_training_df(
130159 )
131160
132161 # Convert records back to pandas dataframe
133- expected_df = pd .DataFrame (order_records )
162+ expected_df = pd .DataFrame (entity_rows )
134163
135164 # Move "event_timestamp" column to front
136165 current_cols = expected_df .columns .tolist ()
@@ -140,7 +169,7 @@ def get_expected_training_df(
140169 # Cast some columns to expected types, since we lose information when converting pandas DFs into Python objects.
141170 if full_feature_names :
142171 expected_column_types = {
143- "order_is_success " : "int32" ,
172+ "order__order_is_success " : "int32" ,
144173 "driver_stats__conv_rate" : "float32" ,
145174 "customer_profile__current_balance" : "float32" ,
146175 "customer_profile__avg_passenger_count" : "float32" ,
@@ -175,20 +204,23 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
175204 (entities , datasets , data_sources ) = universal_data_sources
176205 feature_views = construct_universal_feature_views (data_sources )
177206
178- customer_df , driver_df , orders_df , global_df = (
207+ customer_df , driver_df , orders_df , global_df , entity_df = (
179208 datasets ["customer" ],
180209 datasets ["driver" ],
181210 datasets ["orders" ],
182211 datasets ["global" ],
212+ datasets ["entity" ],
183213 )
184- orders_df_with_request_data = orders_df .copy (deep = True )
185- orders_df_with_request_data ["val_to_add" ] = [
186- i for i in range (len (orders_df_with_request_data ))
214+ entity_df_with_request_data = entity_df .copy (deep = True )
215+ entity_df_with_request_data ["val_to_add" ] = [
216+ i for i in range (len (entity_df_with_request_data ))
187217 ]
188- customer_fv , driver_fv , driver_odfv , global_fv = (
218+
219+ customer_fv , driver_fv , driver_odfv , order_fv , global_fv = (
189220 feature_views ["customer" ],
190221 feature_views ["driver" ],
191222 feature_views ["driver_odfv" ],
223+ feature_views ["order" ],
192224 feature_views ["global" ],
193225 )
194226
@@ -203,6 +235,7 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
203235 customer_fv ,
204236 driver_fv ,
205237 driver_odfv ,
238+ order_fv ,
206239 global_fv ,
207240 driver (),
208241 customer (),
@@ -214,7 +247,7 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
214247 entity_df_query = None
215248 orders_table = table_name_from_data_source (data_sources ["orders" ])
216249 if orders_table :
217- entity_df_query = f"SELECT * FROM { orders_table } "
250+ entity_df_query = f"SELECT customer_id, driver_id, order_id, event_timestamp FROM { orders_table } "
218251
219252 event_timestamp = (
220253 DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL
@@ -226,9 +259,11 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
226259 customer_fv ,
227260 driver_df ,
228261 driver_fv ,
262+ orders_df ,
263+ order_fv ,
229264 global_df ,
230265 global_fv ,
231- orders_df_with_request_data ,
266+ entity_df_with_request_data ,
232267 event_timestamp ,
233268 full_feature_names ,
234269 )
@@ -242,6 +277,7 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
242277 "customer_profile:current_balance" ,
243278 "customer_profile:avg_passenger_count" ,
244279 "customer_profile:lifetime_trip_count" ,
280+ "order:order_is_success" ,
245281 "global_stats:num_rides" ,
246282 "global_stats:avg_ride_length" ,
247283 ],
@@ -297,7 +333,7 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
297333 assert_frame_equal (expected_df_query , df_from_sql_entities )
298334
299335 job_from_df = store .get_historical_features (
300- entity_df = orders_df_with_request_data ,
336+ entity_df = entity_df_with_request_data ,
301337 features = [
302338 "driver_stats:conv_rate" ,
303339 "driver_stats:avg_daily_trips" ,
@@ -306,6 +342,7 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
306342 "customer_profile:lifetime_trip_count" ,
307343 "conv_rate_plus_100:conv_rate_plus_100" ,
308344 "conv_rate_plus_100:conv_rate_plus_val_to_add" ,
345+ "order:order_is_success" ,
309346 "global_stats:num_rides" ,
310347 "global_stats:avg_ride_length" ,
311348 ],
@@ -341,7 +378,7 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
341378 store ,
342379 feature_service ,
343380 full_feature_names ,
344- orders_df_with_request_data ,
381+ entity_df_with_request_data ,
345382 expected_df ,
346383 event_timestamp ,
347384 )
@@ -361,7 +398,7 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
361398 # If request data is missing that's needed for on demand transform, throw an error
362399 with pytest .raises (RequestDataNotFoundInEntityDfException ):
363400 store .get_historical_features (
364- entity_df = orders_df ,
401+ entity_df = entity_df ,
365402 features = [
366403 "driver_stats:conv_rate" ,
367404 "driver_stats:avg_daily_trips" ,
@@ -388,11 +425,11 @@ def response_feature_name(feature: str, full_feature_names: bool) -> str:
388425
389426
390427def assert_feature_service_correctness (
391- store , feature_service , full_feature_names , orders_df , expected_df , event_timestamp
428+ store , feature_service , full_feature_names , entity_df , expected_df , event_timestamp
392429):
393430
394431 job_from_df = store .get_historical_features (
395- entity_df = orders_df ,
432+ entity_df = entity_df ,
396433 features = feature_service ,
397434 full_feature_names = full_feature_names ,
398435 )
0 commit comments