11import tempfile
22import uuid
33from contextlib import contextmanager
4+ from dataclasses import dataclass , replace
5+ from datetime import datetime , timedelta
46from pathlib import Path
5- from typing import Dict , List , Union
7+ from typing import Dict , List , Optional , Union
68
79import pytest
8- from attr import dataclass
910
10- from feast import FeatureStore , RepoConfig , importer
11+ from feast import FeatureStore , FeatureView , RepoConfig , driver_test_data , importer
12+ from feast .data_source import DataSource
1113from tests .data .data_creator import create_dataset
1214from tests .integration .feature_repos .universal .data_source_creator import (
1315 DataSourceCreator ,
1416)
15- from tests .integration .feature_repos .universal .entities import driver
17+ from tests .integration .feature_repos .universal .entities import customer , driver
1618from tests .integration .feature_repos .universal .feature_views import (
17- correctness_feature_view ,
19+ create_customer_daily_profile_feature_view ,
20+ create_driver_hourly_stats_feature_view ,
1821)
1922
2023
21- @dataclass
24+ @dataclass ( frozen = True , repr = True )
2225class TestRepoConfig :
2326 """
2427 This class should hold all possible parameters that may need to be varied by individual tests.
@@ -30,20 +33,21 @@ class TestRepoConfig:
3033 offline_store_creator : str = "tests.integration.feature_repos.universal.data_sources.file.FileDataSourceCreator"
3134
3235 full_feature_names : bool = True
36+ infer_event_timestamp_col : bool = True
3337
3438
3539FULL_REPO_CONFIGS : List [TestRepoConfig ] = [
3640 TestRepoConfig (), # Local
37- TestRepoConfig (
38- provider = "aws" ,
39- offline_store_creator = "tests.integration.feature_repos.universal.data_sources.redshift.RedshiftDataSourceCreator" ,
40- online_store = {"type" : "dynamodb" , "region" : "us-west-2" },
41- ),
4241 TestRepoConfig (
4342 provider = "gcp" ,
4443 offline_store_creator = "tests.integration.feature_repos.universal.data_sources.bigquery.BigQueryDataSourceCreator" ,
4544 online_store = "datastore" ,
4645 ),
46+ TestRepoConfig (
47+ provider = "aws" ,
48+ offline_store_creator = "tests.integration.feature_repos.universal.data_sources.redshift.RedshiftDataSourceCreator" ,
49+ online_store = {"type" : "dynamodb" , "region" : "us-west-2" },
50+ ),
4751]
4852
4953
@@ -52,8 +56,128 @@ class TestRepoConfig:
5256PROVIDERS : List [str ] = []
5357
5458
59+ @dataclass
60+ class Environment :
61+ name : str
62+ test_repo_config : TestRepoConfig
63+ feature_store : FeatureStore
64+ data_source : DataSource
65+ data_source_creator : DataSourceCreator
66+
67+ end_date = datetime .now ().replace (microsecond = 0 , second = 0 , minute = 0 )
68+ start_date = end_date - timedelta (days = 7 )
69+ before_start_date = end_date - timedelta (days = 365 )
70+ after_end_date = end_date + timedelta (days = 365 )
71+
72+ customer_entities = list (range (1001 , 1110 ))
73+ customer_df = driver_test_data .create_customer_daily_profile_df (
74+ customer_entities , start_date , end_date
75+ )
76+ _customer_feature_view : Optional [FeatureView ] = None
77+
78+ driver_entities = list (range (5001 , 5110 ))
79+ driver_df = driver_test_data .create_driver_hourly_stats_df (
80+ driver_entities , start_date , end_date
81+ )
82+ _driver_stats_feature_view : Optional [FeatureView ] = None
83+
84+ orders_df = driver_test_data .create_orders_df (
85+ customers = customer_entities ,
86+ drivers = driver_entities ,
87+ start_date = before_start_date ,
88+ end_date = after_end_date ,
89+ order_count = 1000 ,
90+ )
91+ _orders_table : Optional [str ] = None
92+
93+ def customer_feature_view (self ) -> FeatureView :
94+ if self ._customer_feature_view is None :
95+ customer_table_id = self .data_source_creator .get_prefixed_table_name (
96+ self .name , "customer_profile"
97+ )
98+ ds = self .data_source_creator .create_data_sources (
99+ customer_table_id ,
100+ self .customer_df ,
101+ event_timestamp_column = "event_timestamp" ,
102+ created_timestamp_column = "created" ,
103+ )
104+ self ._customer_feature_view = create_customer_daily_profile_feature_view (ds )
105+ return self ._customer_feature_view
106+
107+ def driver_stats_feature_view (self ) -> FeatureView :
108+ if self ._driver_stats_feature_view is None :
109+ driver_table_id = self .data_source_creator .get_prefixed_table_name (
110+ self .name , "driver_hourly"
111+ )
112+ ds = self .data_source_creator .create_data_sources (
113+ driver_table_id ,
114+ self .driver_df ,
115+ event_timestamp_column = "event_timestamp" ,
116+ created_timestamp_column = "created" ,
117+ )
118+ self ._driver_stats_feature_view = create_driver_hourly_stats_feature_view (
119+ ds
120+ )
121+ return self ._driver_stats_feature_view
122+
123+ def orders_table (self ) -> Optional [str ]:
124+ if self ._orders_table is None :
125+ orders_table_id = self .data_source_creator .get_prefixed_table_name (
126+ self .name , "orders"
127+ )
128+ ds = self .data_source_creator .create_data_sources (
129+ orders_table_id ,
130+ self .orders_df ,
131+ event_timestamp_column = "event_timestamp" ,
132+ created_timestamp_column = "created" ,
133+ )
134+ if hasattr (ds , "table_ref" ):
135+ self ._orders_table = ds .table_ref
136+ elif hasattr (ds , "table" ):
137+ self ._orders_table = ds .table
138+ return self ._orders_table
139+
140+
141+ def vary_full_feature_names (configs : List [TestRepoConfig ]) -> List [TestRepoConfig ]:
142+ new_configs = []
143+ for c in configs :
144+ true_c = replace (c , full_feature_names = True )
145+ false_c = replace (c , full_feature_names = False )
146+ new_configs .extend ([true_c , false_c ])
147+ return new_configs
148+
149+
150+ def vary_infer_event_timestamp_col (
151+ configs : List [TestRepoConfig ],
152+ ) -> List [TestRepoConfig ]:
153+ new_configs = []
154+ for c in configs :
155+ true_c = replace (c , infer_event_timestamp_col = True )
156+ false_c = replace (c , infer_event_timestamp_col = False )
157+ new_configs .extend ([true_c , false_c ])
158+ return new_configs
159+
160+
161+ def vary_providers_for_offline_stores (
162+ configs : List [TestRepoConfig ],
163+ ) -> List [TestRepoConfig ]:
164+ new_configs = []
165+ for c in configs :
166+ if "FileDataSourceCreator" in c .offline_store_creator :
167+ new_configs .append (c )
168+ elif "RedshiftDataSourceCreator" in c .offline_store_creator :
169+ for p in ["local" , "aws" ]:
170+ new_configs .append (replace (c , provider = p ))
171+ elif "BigQueryDataSourceCreator" in c .offline_store_creator :
172+ for p in ["local" , "gcp" ]:
173+ new_configs .append (replace (c , provider = p ))
174+ return new_configs
175+
176+
55177@contextmanager
56- def construct_feature_store (test_repo_config : TestRepoConfig ) -> FeatureStore :
178+ def construct_test_environment (
179+ test_repo_config : TestRepoConfig , create_and_apply : bool = False
180+ ) -> Environment :
57181 """
58182 This method should take in the parameters from the test repo config and created a feature repo, apply it,
59183 and return the constructed feature store object to callers.
@@ -74,8 +198,10 @@ def construct_feature_store(test_repo_config: TestRepoConfig) -> FeatureStore:
74198
75199 offline_creator : DataSourceCreator = importer .get_class_from_type (
76200 module_name , config_class_name , "DataSourceCreator"
77- )()
78- ds = offline_creator .create_data_source (project , df )
201+ )(project )
202+ ds = offline_creator .create_data_sources (
203+ project , df , field_mapping = {"ts_1" : "ts" , "id" : "driver_id" }
204+ )
79205 offline_store = offline_creator .create_offline_store_config ()
80206 online_store = test_repo_config .online_store
81207
@@ -89,21 +215,76 @@ def construct_feature_store(test_repo_config: TestRepoConfig) -> FeatureStore:
89215 repo_path = repo_dir_name ,
90216 )
91217 fs = FeatureStore (config = config )
92- fv = correctness_feature_view (ds )
93- entity = driver ()
94- fs .apply ([fv , entity ])
218+ environment = Environment (
219+ name = project ,
220+ test_repo_config = test_repo_config ,
221+ feature_store = fs ,
222+ data_source = ds ,
223+ data_source_creator = offline_creator ,
224+ )
95225
96- yield fs
226+ fvs = []
227+ entities = []
228+ try :
229+ if create_and_apply :
230+ entities .extend ([driver (), customer ()])
231+ fvs .extend (
232+ [
233+ environment .driver_stats_feature_view (),
234+ environment .customer_feature_view (),
235+ ]
236+ )
237+ fs .apply (fvs + entities )
97238
98- fs .teardown ()
99- offline_creator .teardown (project )
239+ yield environment
240+ finally :
241+ offline_creator .teardown ()
242+ fs .teardown ()
100243
101244
102245def parametrize_e2e_test (e2e_test ):
246+ """
247+ This decorator should be used for end-to-end tests. These tests are expected to be parameterized,
248+ and receive an empty feature repo created for all supported configurations.
249+
250+ The decorator also ensures that sample data needed for the test is available in the relevant offline store.
251+
252+ Decorated tests should create and apply the objects needed by the tests, and perform any operations needed
253+ (such as materialization and looking up feature values).
254+
255+ The decorator takes care of tearing down the feature store, as well as the sample data.
256+ """
257+
258+ @pytest .mark .integration
259+ @pytest .mark .parametrize ("config" , FULL_REPO_CONFIGS , ids = lambda v : str (v ))
260+ def inner_test (config ):
261+ with construct_test_environment (config ) as environment :
262+ e2e_test (environment )
263+
264+ return inner_test
265+
266+
267+ def parametrize_offline_retrieval_test (offline_retrieval_test ):
268+ """
269+ This decorator should be used for end-to-end tests. These tests are expected to be parameterized,
270+ and receive an empty feature repo created for all supported configurations.
271+
272+ The decorator also ensures that sample data needed for the test is available in the relevant offline store.
273+
274+ Decorated tests should create and apply the objects needed by the tests, and perform any operations needed
275+ (such as materialization and looking up feature values).
276+
277+ The decorator takes care of tearing down the feature store, as well as the sample data.
278+ """
279+
280+ configs = vary_providers_for_offline_stores (FULL_REPO_CONFIGS )
281+ configs = vary_full_feature_names (configs )
282+ configs = vary_infer_event_timestamp_col (configs )
283+
103284 @pytest .mark .integration
104- @pytest .mark .parametrize ("config" , FULL_REPO_CONFIGS , ids = lambda v : v . provider )
285+ @pytest .mark .parametrize ("config" , configs , ids = lambda v : str ( v ) )
105286 def inner_test (config ):
106- with construct_feature_store (config ) as fs :
107- e2e_test ( fs )
287+ with construct_test_environment (config , create_and_apply = True ) as environment :
288+ offline_retrieval_test ( environment )
108289
109290 return inner_test
0 commit comments