3333from tests .integration .feature_repos .universal .online_store .redis import (
3434 RedisOnlineStoreCreator ,
3535)
36- from tests .utils .e2e_test_validation import _check_offline_and_online_features
3736
3837now = datetime .now ()
3938today = datetime .today ()
@@ -189,6 +188,20 @@ def transform_feature(df: DataFrame) -> DataFrame:
189188
190189@pytest .mark .integration
191190def test_spark_compute_engine_materialize ():
191+ """
192+ Test the SparkComputeEngine materialize method.
193+ For the current feature view driver_hourly_stats, The below execution plan:
194+ 1. feature data from create_feature_dataset
195+ 2. filter by start_time and end_time, that is, the last 2 days
196+ for the driver_id 1001, the data left is row 0
197+ for the driver_id 1002, the data left is row 2
198+ 3. apply the transform_feature function to the data
199+ for all features, the value is multiplied by 2
200+ 4. write the data to the online store and offline store
201+
202+ Returns:
203+
204+ """
192205 spark_environment = create_spark_environment ()
193206 fs = spark_environment .feature_store
194207 registry = fs .registry
@@ -213,7 +226,7 @@ def transform_feature(df: DataFrame) -> DataFrame:
213226 Field (name = "driver_id" , dtype = Int32 ),
214227 ],
215228 online = True ,
216- offline = False ,
229+ offline = True ,
217230 source = data_source ,
218231 )
219232
@@ -244,18 +257,62 @@ def tqdm_builder(length):
244257
245258 assert spark_materialize_job .status () == MaterializationJobStatus .SUCCEEDED
246259
247- _check_offline_and_online_features (
260+ _check_online_features (
248261 fs = fs ,
249- fv = driver_stats_fv ,
250- driver_id = 1 ,
251- event_timestamp = now ,
252- expected_value = 0.3 ,
262+ driver_id = 1001 ,
263+ feature = "driver_hourly_stats:conv_rate" ,
264+ expected_value = 1.6 ,
253265 full_feature_names = True ,
254- check_offline_store = True ,
266+ )
267+
268+ entity_df = create_entity_df ()
269+
270+ _check_offline_features (
271+ fs = fs ,
272+ feature = "driver_hourly_stats:conv_rate" ,
273+ entity_df = entity_df ,
274+ expected_value = 1.6 ,
255275 )
256276 finally :
257277 spark_environment .teardown ()
258278
259279
280+ def _check_online_features (
281+ fs ,
282+ driver_id ,
283+ feature ,
284+ expected_value ,
285+ full_feature_names : bool = True ,
286+ ):
287+ online_response = fs .get_online_features (
288+ features = [feature ],
289+ entity_rows = [{"driver_id" : driver_id }],
290+ full_feature_names = full_feature_names ,
291+ ).to_dict ()
292+
293+ feature_ref = "__" .join (feature .split (":" ))
294+
295+ assert len (online_response ["driver_id" ]) == 1
296+ assert online_response ["driver_id" ][0 ] == driver_id
297+ assert abs (online_response [feature_ref ][0 ] - expected_value < 1e-6 ), (
298+ "Transformed result"
299+ )
300+
301+
302+ def _check_offline_features (
303+ fs ,
304+ feature ,
305+ entity_df ,
306+ expected_value ,
307+ ):
308+ offline_df = fs .get_historical_features (
309+ entity_df = entity_df ,
310+ features = [feature ],
311+ ).to_df ()
312+
313+ assert len (offline_df ) == 2
314+ assert offline_df ["driver_id" ].to_list () == [1001 , 1002 ]
315+
316+
260317if __name__ == "__main__" :
261318 test_spark_compute_engine_get_historical_features ()
0 commit comments