@@ -74,22 +74,23 @@ def do_put(
7474 logger .debug (f"do_put: command is{ command } , data is { data } " )
7575 self .flights [key ] = data
7676
77- self ._call_api (command , key )
77+ self ._call_api (command [ "api" ], command , key )
7878 else :
7979 logger .warning (f"No 'api' field in command: { command } " )
8080
81- def _call_api (self , command : dict , key : str ):
81+ def _call_api (self , api : str , command : dict , key : str ):
82+ assert api is not None , "api can not be empty"
83+
8284 remove_data = False
8385 try :
84- api = command ["api" ]
8586 if api == OfflineServer .offline_write_batch .__name__ :
8687 self .offline_write_batch (command , key )
8788 remove_data = True
8889 elif api == OfflineServer .write_logged_features .__name__ :
8990 self .write_logged_features (command , key )
9091 remove_data = True
9192 elif api == OfflineServer .persist .__name__ :
92- self .persist (command [ "retrieve_func" ], command , key )
93+ self .persist (command , key )
9394 remove_data = True
9495 except Exception as e :
9596 remove_data = True
@@ -150,6 +151,9 @@ def list_feature_views_by_name(
150151 for index , fv_name in enumerate (feature_view_names )
151152 ]
152153
154+ def _validate_do_get_parameters (self , command : dict ):
155+ assert "api" in command , "api parameter is mandatory"
156+
153157 # Extracts the API parameters from the flights dictionary, delegates the execution to the FeatureStore instance
154158 # and returns the stream of data
155159 def do_get (self , context : fl .ServerCallContext , ticket : fl .Ticket ):
@@ -159,6 +163,9 @@ def do_get(self, context: fl.ServerCallContext, ticket: fl.Ticket):
159163 return None
160164
161165 command = json .loads (key [1 ])
166+
167+ self ._validate_do_get_parameters (command )
168+
162169 api = command ["api" ]
163170 logger .debug (f"get command is { command } " )
164171 logger .debug (f"requested api is { api } " )
@@ -180,33 +187,52 @@ def do_get(self, context: fl.ServerCallContext, ticket: fl.Ticket):
180187 del self .flights [key ]
181188 return fl .RecordBatchStream (table )
182189
183- def offline_write_batch (self , command : dict , key : str ):
190+ def _validate_offline_write_batch_parameters (self , command : dict ):
191+ assert (
192+ "feature_view_names" in command
193+ ), "feature_view_names is a mandatory parameter"
194+ assert "name_aliases" in command , "name_aliases is a mandatory parameter"
195+
184196 feature_view_names = command ["feature_view_names" ]
185197 assert (
186198 len (feature_view_names ) == 1
187199 ), "feature_view_names list should only have one item"
200+
188201 name_aliases = command ["name_aliases" ]
189202 assert len (name_aliases ) == 1 , "name_aliases list should only have one item"
203+
204+ def offline_write_batch (self , command : dict , key : str ):
205+ self ._validate_offline_write_batch_parameters (command )
206+
207+ feature_view_names = command ["feature_view_names" ]
208+ name_aliases = command ["name_aliases" ]
209+
190210 project = self .store .config .project
191211 feature_views = self .list_feature_views_by_name (
192212 feature_view_names = feature_view_names ,
193213 name_aliases = name_aliases ,
194214 project = project ,
195215 )
196216
197- assert len (feature_views ) == 1
217+ assert len (feature_views ) == 1 , "incorrect feature view"
198218 table = self .flights [key ]
199219 self .offline_store .offline_write_batch (
200220 self .store .config , feature_views [0 ], table , command ["progress" ]
201221 )
202222
223+ def _validate_write_logged_features_parameters (self , command : dict ):
224+ assert "feature_service_name" in command
225+
203226 def write_logged_features (self , command : dict , key : str ):
227+ self ._validate_write_logged_features_parameters (command )
204228 table = self .flights [key ]
205229 feature_service = self .store .get_feature_service (
206230 command ["feature_service_name" ]
207231 )
208232
209- assert feature_service .logging_config is not None
233+ assert (
234+ feature_service .logging_config is not None
235+ ), "feature service must have logging_config set"
210236
211237 self .offline_store .write_logged_features (
212238 config = self .store .config ,
@@ -218,7 +244,23 @@ def write_logged_features(self, command: dict, key: str):
218244 registry = self .store .registry ,
219245 )
220246
247+ def _validate_pull_all_from_table_or_query_parameters (self , command : dict ):
248+ assert (
249+ "data_source_name" in command
250+ ), "data_source_name is a mandatory parameter"
251+ assert (
252+ "join_key_columns" in command
253+ ), "join_key_columns is a mandatory parameter"
254+ assert (
255+ "feature_name_columns" in command
256+ ), "feature_name_columns is a mandatory parameter"
257+ assert "timestamp_field" in command , "timestamp_field is a mandatory parameter"
258+ assert "start_date" in command , "start_date is a mandatory parameter"
259+ assert "end_date" in command , "end_date is a mandatory parameter"
260+
221261 def pull_all_from_table_or_query (self , command : dict ):
262+ self ._validate_pull_all_from_table_or_query_parameters (command )
263+
222264 return self .offline_store .pull_all_from_table_or_query (
223265 self .store .config ,
224266 self .store .get_data_source (command ["data_source_name" ]),
@@ -229,7 +271,23 @@ def pull_all_from_table_or_query(self, command: dict):
229271 utils .make_tzaware (datetime .fromisoformat (command ["end_date" ])),
230272 )
231273
274+ def _validate_pull_latest_from_table_or_query_parameters (self , command : dict ):
275+ assert (
276+ "data_source_name" in command
277+ ), "data_source_name is a mandatory parameter"
278+ assert (
279+ "join_key_columns" in command
280+ ), "join_key_columns is a mandatory parameter"
281+ assert (
282+ "feature_name_columns" in command
283+ ), "feature_name_columns is a mandatory parameter"
284+ assert "timestamp_field" in command , "timestamp_field is a mandatory parameter"
285+ assert "start_date" in command , "start_date is a mandatory parameter"
286+ assert "end_date" in command , "end_date is a mandatory parameter"
287+
232288 def pull_latest_from_table_or_query (self , command : dict ):
289+ self ._validate_pull_latest_from_table_or_query_parameters (command )
290+
233291 return self .offline_store .pull_latest_from_table_or_query (
234292 self .store .config ,
235293 self .store .get_data_source (command ["data_source_name" ]),
@@ -258,20 +316,33 @@ def list_actions(self, context):
258316 ),
259317 ]
260318
319+ def _validate_get_historical_features_parameters (self , command : dict , key : str ):
320+ assert key in self .flights , f"missing key={ key } "
321+ assert "feature_view_names" in command , "feature_view_names is mandatory"
322+ assert "name_aliases" in command , "name_aliases is mandatory"
323+ assert "feature_refs" in command , "feature_refs is mandatory"
324+ assert "project" in command , "project is mandatory"
325+ assert "full_feature_names" in command , "full_feature_names is mandatory"
326+
261327 def get_historical_features (self , command : dict , key : str ):
328+ self ._validate_get_historical_features_parameters (command , key )
329+
262330 # Extract parameters from the internal flights dictionary
263331 entity_df_value = self .flights [key ]
264332 entity_df = pa .Table .to_pandas (entity_df_value )
333+
265334 feature_view_names = command ["feature_view_names" ]
266335 name_aliases = command ["name_aliases" ]
267336 feature_refs = command ["feature_refs" ]
268337 project = command ["project" ]
269338 full_feature_names = command ["full_feature_names" ]
339+
270340 feature_views = self .list_feature_views_by_name (
271341 feature_view_names = feature_view_names ,
272342 name_aliases = name_aliases ,
273343 project = project ,
274344 )
345+
275346 retJob = self .offline_store .get_historical_features (
276347 config = self .store .config ,
277348 feature_views = feature_views ,
@@ -281,10 +352,19 @@ def get_historical_features(self, command: dict, key: str):
281352 project = project ,
282353 full_feature_names = full_feature_names ,
283354 )
355+
284356 return retJob
285357
286- def persist (self , retrieve_func : str , command : dict , key : str ):
358+ def _validate_persist_parameters (self , command : dict ):
359+ assert "retrieve_func" in command , "retrieve_func is mandatory"
360+ assert "data_source_name" in command , "data_source_name is mandatory"
361+ assert "allow_overwrite" in command , "allow_overwrite is mandatory"
362+
363+ def persist (self , command : dict , key : str ):
364+ self ._validate_persist_parameters (command )
365+
287366 try :
367+ retrieve_func = command ["retrieve_func" ]
288368 if retrieve_func == OfflineServer .get_historical_features .__name__ :
289369 ret_job = self .get_historical_features (command , key )
290370 elif (
0 commit comments