|
| 1 | +import random |
| 2 | + |
| 3 | +import pytest |
| 4 | + |
| 5 | +from feast import FeatureService |
| 6 | +from tests.integration.feature_repos.repo_configuration import ( |
| 7 | + construct_universal_feature_views, |
| 8 | +) |
| 9 | +from tests.integration.feature_repos.universal.entities import customer, driver |
| 10 | + |
| 11 | + |
| 12 | +@pytest.mark.benchmark |
| 13 | +@pytest.mark.integration |
| 14 | +def test_online_retrieval(environment, universal_data_sources, benchmark): |
| 15 | + |
| 16 | + fs = environment.feature_store |
| 17 | + entities, datasets, data_sources = universal_data_sources |
| 18 | + feature_views = construct_universal_feature_views(data_sources) |
| 19 | + |
| 20 | + feature_service = FeatureService( |
| 21 | + "convrate_plus100", |
| 22 | + features=[feature_views["driver"][["conv_rate"]], feature_views["driver_odfv"]], |
| 23 | + ) |
| 24 | + |
| 25 | + feast_objects = [] |
| 26 | + feast_objects.extend(feature_views.values()) |
| 27 | + feast_objects.extend([driver(), customer(), feature_service]) |
| 28 | + fs.apply(feast_objects) |
| 29 | + fs.materialize(environment.start_date, environment.end_date) |
| 30 | + |
| 31 | + sample_drivers = random.sample(entities["driver"], 10) |
| 32 | + |
| 33 | + sample_customers = random.sample(entities["customer"], 10) |
| 34 | + |
| 35 | + entity_rows = [ |
| 36 | + {"driver": d, "customer_id": c, "val_to_add": 50} |
| 37 | + for (d, c) in zip(sample_drivers, sample_customers) |
| 38 | + ] |
| 39 | + |
| 40 | + feature_refs = [ |
| 41 | + "driver_stats:conv_rate", |
| 42 | + "driver_stats:avg_daily_trips", |
| 43 | + "customer_profile:current_balance", |
| 44 | + "customer_profile:avg_passenger_count", |
| 45 | + "customer_profile:lifetime_trip_count", |
| 46 | + "conv_rate_plus_100:conv_rate_plus_100", |
| 47 | + "conv_rate_plus_100:conv_rate_plus_val_to_add", |
| 48 | + "global_stats:num_rides", |
| 49 | + "global_stats:avg_ride_length", |
| 50 | + ] |
| 51 | + unprefixed_feature_refs = [f.rsplit(":", 1)[-1] for f in feature_refs if ":" in f] |
| 52 | + # Remove the on demand feature view output features, since they're not present in the source dataframe |
| 53 | + unprefixed_feature_refs.remove("conv_rate_plus_100") |
| 54 | + unprefixed_feature_refs.remove("conv_rate_plus_val_to_add") |
| 55 | + |
| 56 | + benchmark( |
| 57 | + fs.get_online_features, features=feature_refs, entity_rows=entity_rows, |
| 58 | + ) |
0 commit comments