@@ -61,15 +61,39 @@ def list_flights(self, context, criteria):
6161 # Indexed by the unique command
6262 def do_put (self , context , descriptor , reader , writer ):
6363 key = OfflineServer .descriptor_to_key (descriptor )
64-
6564 command = json .loads (key [1 ])
6665 if "api" in command :
6766 data = reader .read_all ()
6867 logger .debug (f"do_put: command is{ command } , data is { data } " )
6968 self .flights [key ] = data
69+
70+ self ._call_api (command , key )
7071 else :
7172 logger .warning (f"No 'api' field in command: { command } " )
7273
74+ def _call_api (self , command , key ):
75+ remove_data = False
76+ try :
77+ api = command ["api" ]
78+ if api == OfflineServer .offline_write_batch .__name__ :
79+ self .offline_write_batch (command , key )
80+ remove_data = True
81+ elif api == OfflineServer .write_logged_features .__name__ :
82+ self .write_logged_features (command , key )
83+ remove_data = True
84+ elif api == OfflineServer .persist .__name__ :
85+ self .persist (command ["retrieve_func" ], command , key )
86+ remove_data = True
87+ except Exception as e :
88+ remove_data = True
89+ logger .exception (e )
90+ traceback .print_exc ()
91+ raise e
92+ finally :
93+ if remove_data :
94+ # Get service is consumed, so we clear the corresponding flight and data
95+ del self .flights [key ]
96+
7397 def get_feature_view_by_name (
7498 self , fv_name : str , name_alias : str , project : str
7599 ) -> FeatureView :
@@ -133,20 +157,18 @@ def do_get(self, context, ticket):
133157 logger .debug (f"requested api is { api } " )
134158 try :
135159 if api == OfflineServer .get_historical_features .__name__ :
136- df = self .get_historical_features (command , key ).to_df ()
160+ table = self .get_historical_features (command , key ).to_arrow ()
137161 elif api == OfflineServer .pull_all_from_table_or_query .__name__ :
138- df = self .pull_all_from_table_or_query (command ).to_df ()
162+ table = self .pull_all_from_table_or_query (command ).to_arrow ()
139163 elif api == OfflineServer .pull_latest_from_table_or_query .__name__ :
140- df = self .pull_latest_from_table_or_query (command ).to_df ()
164+ table = self .pull_latest_from_table_or_query (command ).to_arrow ()
141165 else :
142166 raise NotImplementedError
143167 except Exception as e :
144168 logger .exception (e )
145169 traceback .print_exc ()
146170 raise e
147171
148- table = pa .Table .from_pandas (df )
149-
150172 # Get service is consumed, so we clear the corresponding flight and data
151173 del self .flights [key ]
152174 return fl .RecordBatchStream (table )
@@ -252,14 +274,15 @@ def get_historical_features(self, command, key):
252274 )
253275 return retJob
254276
255- def persist (self , command , key ):
277+ def persist (self , retrieve_func , command , key ):
256278 try :
257- api = command ["api" ]
258- if api == OfflineServer .get_historical_features .__name__ :
279+ if retrieve_func == OfflineServer .get_historical_features .__name__ :
259280 ret_job = self .get_historical_features (command , key )
260- elif api == OfflineServer .pull_latest_from_table_or_query .__name__ :
281+ elif (
282+ retrieve_func == OfflineServer .pull_latest_from_table_or_query .__name__
283+ ):
261284 ret_job = self .pull_latest_from_table_or_query (command )
262- elif api == OfflineServer .pull_all_from_table_or_query .__name__ :
285+ elif retrieve_func == OfflineServer .pull_all_from_table_or_query .__name__ :
263286 ret_job = self .pull_all_from_table_or_query (command )
264287 else :
265288 raise NotImplementedError
@@ -273,25 +296,7 @@ def persist(self, command, key):
273296 raise e
274297
275298 def do_action (self , context , action ):
276- command_descriptor = fl .FlightDescriptor .deserialize (action .body .to_pybytes ())
277-
278- key = OfflineServer .descriptor_to_key (command_descriptor )
279- command = json .loads (key [1 ])
280- logger .info (f"do_action command is { command } " )
281-
282- try :
283- if action .type == OfflineServer .offline_write_batch .__name__ :
284- self .offline_write_batch (command , key )
285- elif action .type == OfflineServer .write_logged_features .__name__ :
286- self .write_logged_features (command , key )
287- elif action .type == OfflineServer .persist .__name__ :
288- self .persist (command , key )
289- else :
290- raise NotImplementedError
291- except Exception as e :
292- logger .exception (e )
293- traceback .print_exc ()
294- raise e
299+ pass
295300
296301 def do_drop_dataset (self , dataset ):
297302 pass
0 commit comments