@@ -82,6 +82,8 @@ def get_expected_training_df(
8282 location_fv : FeatureView ,
8383 global_df : pd .DataFrame ,
8484 global_fv : FeatureView ,
85+ field_mapping_df : pd .DataFrame ,
86+ field_mapping_fv : FeatureView ,
8587 entity_df : pd .DataFrame ,
8688 event_timestamp : str ,
8789 full_feature_names : bool = False ,
@@ -102,6 +104,10 @@ def get_expected_training_df(
102104 global_records = convert_timestamp_records_to_utc (
103105 global_df .to_dict ("records" ), global_fv .batch_source .event_timestamp_column
104106 )
107+ field_mapping_records = convert_timestamp_records_to_utc (
108+ field_mapping_df .to_dict ("records" ),
109+ field_mapping_fv .batch_source .event_timestamp_column ,
110+ )
105111 entity_rows = convert_timestamp_records_to_utc (
106112 entity_df .to_dict ("records" ), event_timestamp
107113 )
@@ -156,6 +162,13 @@ def get_expected_training_df(
156162 ts_end = order_record [event_timestamp ],
157163 )
158164
165+ field_mapping_record = find_asof_record (
166+ field_mapping_records ,
167+ ts_key = field_mapping_fv .batch_source .event_timestamp_column ,
168+ ts_start = order_record [event_timestamp ] - field_mapping_fv .ttl ,
169+ ts_end = order_record [event_timestamp ],
170+ )
171+
159172 entity_row .update (
160173 {
161174 (
@@ -197,6 +210,16 @@ def get_expected_training_df(
197210 }
198211 )
199212
213+ # get field_mapping_record by column name, but label by feature name
214+ entity_row .update (
215+ {
216+ (
217+ f"field_mapping__{ feature } " if full_feature_names else feature
218+ ): field_mapping_record .get (column , None )
219+ for (column , feature ) in field_mapping_fv .input .field_mapping .items ()
220+ }
221+ )
222+
200223 # Convert records back to pandas dataframe
201224 expected_df = pd .DataFrame (entity_rows )
202225
@@ -213,6 +236,7 @@ def get_expected_training_df(
213236 "customer_profile__current_balance" : "float32" ,
214237 "customer_profile__avg_passenger_count" : "float32" ,
215238 "global_stats__avg_ride_length" : "float32" ,
239+ "field_mapping__feature_name" : "int32" ,
216240 }
217241 else :
218242 expected_column_types = {
@@ -221,6 +245,7 @@ def get_expected_training_df(
221245 "current_balance" : "float32" ,
222246 "avg_passenger_count" : "float32" ,
223247 "avg_ride_length" : "float32" ,
248+ "feature_name" : "int32" ,
224249 }
225250
226251 for col , typ in expected_column_types .items ():
@@ -311,6 +336,8 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
311336 feature_views ["location" ],
312337 datasets ["global" ],
313338 feature_views ["global" ],
339+ datasets ["field_mapping" ],
340+ feature_views ["field_mapping" ],
314341 entity_df_with_request_data ,
315342 event_timestamp ,
316343 full_feature_names ,
@@ -336,6 +363,7 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
336363 "global_stats:num_rides" ,
337364 "global_stats:avg_ride_length" ,
338365 "driver_age:driver_age" ,
366+ "field_mapping:feature_name" ,
339367 ],
340368 full_feature_names = full_feature_names ,
341369 )
@@ -404,6 +432,7 @@ def test_historical_features_with_missing_request_data(
404432 "conv_rate_plus_100:conv_rate_plus_val_to_add" ,
405433 "global_stats:num_rides" ,
406434 "global_stats:avg_ride_length" ,
435+ "field_mapping:feature_name" ,
407436 ],
408437 full_feature_names = full_feature_names ,
409438 )
@@ -419,6 +448,7 @@ def test_historical_features_with_missing_request_data(
419448 "driver_age:driver_age" ,
420449 "global_stats:num_rides" ,
421450 "global_stats:avg_ride_length" ,
451+ "field_mapping:feature_name" ,
422452 ],
423453 full_feature_names = full_feature_names ,
424454 )
@@ -452,6 +482,7 @@ def test_historical_features_with_entities_from_query(
452482 "order:order_is_success" ,
453483 "global_stats:num_rides" ,
454484 "global_stats:avg_ride_length" ,
485+ "field_mapping:feature_name" ,
455486 ],
456487 full_feature_names = full_feature_names ,
457488 )
@@ -477,6 +508,8 @@ def test_historical_features_with_entities_from_query(
477508 feature_views ["location" ],
478509 datasets ["global" ],
479510 feature_views ["global" ],
511+ datasets ["field_mapping" ],
512+ feature_views ["field_mapping" ],
480513 datasets ["entity" ],
481514 event_timestamp ,
482515 full_feature_names ,
@@ -538,6 +571,7 @@ def test_historical_features_persisting(
538571 "order:order_is_success" ,
539572 "global_stats:num_rides" ,
540573 "global_stats:avg_ride_length" ,
574+ "field_mapping:feature_name" ,
541575 ],
542576 full_feature_names = full_feature_names ,
543577 )
@@ -561,6 +595,8 @@ def test_historical_features_persisting(
561595 feature_views ["location" ],
562596 datasets ["global" ],
563597 feature_views ["global" ],
598+ datasets ["field_mapping" ],
599+ feature_views ["field_mapping" ],
564600 entity_df ,
565601 event_timestamp ,
566602 full_feature_names ,
0 commit comments