@@ -34,7 +34,7 @@ def test_online() -> None:
3434 provider = store ._get_provider ()
3535
3636 driver_key = EntityKeyProto (
37- join_keys = ["driver " ], entity_values = [ValueProto (int64_val = 1 )]
37+ join_keys = ["driver_id " ], entity_values = [ValueProto (int64_val = 1 )]
3838 )
3939 provider .online_write_batch (
4040 config = store .config ,
@@ -54,7 +54,7 @@ def test_online() -> None:
5454 )
5555
5656 customer_key = EntityKeyProto (
57- join_keys = ["customer " ], entity_values = [ValueProto (string_val = "5" )]
57+ join_keys = ["customer_id " ], entity_values = [ValueProto (string_val = "5" )]
5858 )
5959 provider .online_write_batch (
6060 config = store .config ,
@@ -75,7 +75,7 @@ def test_online() -> None:
7575 )
7676
7777 customer_key = EntityKeyProto (
78- join_keys = ["customer " , "driver " ],
78+ join_keys = ["customer_id " , "driver_id " ],
7979 entity_values = [ValueProto (string_val = "5" ), ValueProto (int64_val = 1 )],
8080 )
8181 provider .online_write_batch (
@@ -100,15 +100,18 @@ def test_online() -> None:
100100 "customer_profile:name" ,
101101 "customer_driver_combined:trips" ,
102102 ],
103- entity_rows = [{"driver" : 1 , "customer" : "5" }, {"driver" : 1 , "customer" : 5 }],
103+ entity_rows = [
104+ {"driver_id" : 1 , "customer_id" : "5" },
105+ {"driver_id" : 1 , "customer_id" : 5 },
106+ ],
104107 full_feature_names = False ,
105108 ).to_dict ()
106109
107110 assert "lon" in result
108111 assert "avg_orders_day" in result
109112 assert "name" in result
110- assert result ["driver " ] == [1 , 1 ]
111- assert result ["customer " ] == ["5" , "5" ]
113+ assert result ["driver_id " ] == [1 , 1 ]
114+ assert result ["customer_id " ] == ["5" , "5" ]
112115 assert result ["lon" ] == ["1.0" , "1.0" ]
113116 assert result ["avg_orders_day" ] == [1.0 , 1.0 ]
114117 assert result ["name" ] == ["John" , "John" ]
@@ -117,7 +120,7 @@ def test_online() -> None:
117120 # Ensure features are still in result when keys not found
118121 result = store .get_online_features (
119122 features = ["customer_driver_combined:trips" ],
120- entity_rows = [{"driver " : 0 , "customer " : 0 }],
123+ entity_rows = [{"driver_id " : 0 , "customer_id " : 0 }],
121124 full_feature_names = False ,
122125 ).to_dict ()
123126
@@ -127,7 +130,7 @@ def test_online() -> None:
127130 with pytest .raises (FeatureViewNotFoundException ):
128131 store .get_online_features (
129132 features = ["driver_locations_bad:lon" ],
130- entity_rows = [{"driver " : 1 }],
133+ entity_rows = [{"driver_id " : 1 }],
131134 full_feature_names = False ,
132135 )
133136
@@ -152,7 +155,7 @@ def test_online() -> None:
152155 "customer_profile:name" ,
153156 "customer_driver_combined:trips" ,
154157 ],
155- entity_rows = [{"driver " : 1 , "customer " : 5 }],
158+ entity_rows = [{"driver_id " : 1 , "customer_id " : 5 }],
156159 full_feature_names = False ,
157160 ).to_dict ()
158161 assert result ["lon" ] == ["1.0" ]
@@ -173,7 +176,7 @@ def test_online() -> None:
173176 "customer_profile:name" ,
174177 "customer_driver_combined:trips" ,
175178 ],
176- entity_rows = [{"driver " : 1 , "customer " : 5 }],
179+ entity_rows = [{"driver_id " : 1 , "customer_id " : 5 }],
177180 full_feature_names = False ,
178181 ).to_dict ()
179182
@@ -188,7 +191,7 @@ def test_online() -> None:
188191 "customer_profile:name" ,
189192 "customer_driver_combined:trips" ,
190193 ],
191- entity_rows = [{"driver " : 1 , "customer " : 5 }],
194+ entity_rows = [{"driver_id " : 1 , "customer_id " : 5 }],
192195 full_feature_names = False ,
193196 ).to_dict ()
194197 assert result ["lon" ] == ["1.0" ]
@@ -214,7 +217,7 @@ def test_online() -> None:
214217 "customer_profile:name" ,
215218 "customer_driver_combined:trips" ,
216219 ],
217- entity_rows = [{"driver " : 1 , "customer " : 5 }],
220+ entity_rows = [{"driver_id " : 1 , "customer_id " : 5 }],
218221 full_feature_names = False ,
219222 ).to_dict ()
220223 assert result ["lon" ] == ["1.0" ]
@@ -234,7 +237,7 @@ def test_online() -> None:
234237 "customer_profile:name" ,
235238 "customer_driver_combined:trips" ,
236239 ],
237- entity_rows = [{"driver " : 1 , "customer " : 5 }],
240+ entity_rows = [{"driver_id " : 1 , "customer_id " : 5 }],
238241 full_feature_names = False ,
239242 ).to_dict ()
240243 assert result ["lon" ] == ["1.0" ]
@@ -284,7 +287,7 @@ def test_online_to_df():
284287 3 3.0 0.3
285288 """
286289 driver_key = EntityKeyProto (
287- join_keys = ["driver " ], entity_values = [ValueProto (int64_val = d )]
290+ join_keys = ["driver_id " ], entity_values = [ValueProto (int64_val = d )]
288291 )
289292 provider .online_write_batch (
290293 config = store .config ,
@@ -311,7 +314,7 @@ def test_online_to_df():
311314 6 6.0 foo6 60
312315 """
313316 customer_key = EntityKeyProto (
314- join_keys = ["customer " ], entity_values = [ValueProto (string_val = str (c ))]
317+ join_keys = ["customer_id " ], entity_values = [ValueProto (string_val = str (c ))]
315318 )
316319 provider .online_write_batch (
317320 config = store .config ,
@@ -340,7 +343,7 @@ def test_online_to_df():
340343 6 3 18
341344 """
342345 combo_keys = EntityKeyProto (
343- join_keys = ["customer " , "driver " ],
346+ join_keys = ["customer_id " , "driver_id " ],
344347 entity_values = [ValueProto (string_val = str (c )), ValueProto (int64_val = d )],
345348 )
346349 provider .online_write_batch (
@@ -369,7 +372,7 @@ def test_online_to_df():
369372 ],
370373 # Reverse the row order
371374 entity_rows = [
372- {"driver " : d , "customer " : c }
375+ {"driver_id " : d , "customer_id " : c }
373376 for (d , c ) in zip (reversed (driver_ids ), reversed (customer_ids ))
374377 ],
375378 ).to_df ()
@@ -381,8 +384,8 @@ def test_online_to_df():
381384 1 4 1.0 0.1 4.0 foo4 40 4
382385 """
383386 df_dict = {
384- "driver " : driver_ids ,
385- "customer " : [str (c ) for c in customer_ids ],
387+ "driver_id " : driver_ids ,
388+ "customer_id " : [str (c ) for c in customer_ids ],
386389 "lon" : [str (d * lon_multiply ) for d in driver_ids ],
387390 "lat" : [d * lat_multiply for d in driver_ids ],
388391 "avg_orders_day" : [c * avg_order_day_multiply for c in customer_ids ],
@@ -392,8 +395,8 @@ def test_online_to_df():
392395 }
393396 # Requested column order
394397 ordered_column = [
395- "driver " ,
396- "customer " ,
398+ "driver_id " ,
399+ "customer_id " ,
397400 "lon" ,
398401 "lat" ,
399402 "avg_orders_day" ,
0 commit comments