11import ast
2+ import json
23import traceback
3- from typing import Dict
4+ from typing import Any , Dict
45
56import pyarrow as pa
67import pyarrow .flight as fl
@@ -12,7 +13,8 @@ class OfflineServer(fl.FlightServerBase):
1213 def __init__ (self , store : FeatureStore , location : str , ** kwargs ):
1314 super (OfflineServer , self ).__init__ (location , ** kwargs )
1415 self ._location = location
15- self .flights : Dict [str , Dict [str , str ]] = {}
16+ # A dictionary of configured flights, e.g. API calls received and not yet served
17+ self .flights : Dict [str , Any ] = {}
1618 self .store = store
1719
1820 @classmethod
@@ -23,20 +25,12 @@ def descriptor_to_key(self, descriptor):
2325 tuple (descriptor .path or tuple ()),
2426 )
2527
26- # TODO: since we cannot anticipate here the call to get_historical_features call, what data should we return?
27- # ATM it returns the metadata of the "entity_df" table
2828 def _make_flight_info (self , key , descriptor , params ):
29- table = params ["entity_df" ]
3029 endpoints = [fl .FlightEndpoint (repr (key ), [self ._location ])]
31- mock_sink = pa .MockOutputStream ()
32- stream_writer = pa .RecordBatchStreamWriter (mock_sink , table .schema )
33- stream_writer .write_table (table )
34- stream_writer .close ()
35- data_size = mock_sink .size ()
36-
37- return fl .FlightInfo (
38- table .schema , descriptor , endpoints , table .num_rows , data_size
39- )
30+ # TODO calculate actual schema from the given features
31+ schema = pa .schema ([])
32+
33+ return fl .FlightInfo (schema , descriptor , endpoints , - 1 , - 1 )
4034
4135 def get_flight_info (self , context , descriptor ):
4236 key = OfflineServer .descriptor_to_key (descriptor )
@@ -59,23 +53,12 @@ def list_flights(self, context, criteria):
5953 def do_put (self , context , descriptor , reader , writer ):
6054 key = OfflineServer .descriptor_to_key (descriptor )
6155
62- if key in self .flights :
63- params = self .flights [key ]
56+ command = json .loads (key [1 ])
57+ if "api" in command :
58+ data = reader .read_all ()
59+ self .flights [key ] = data
6460 else :
65- params = {}
66- decoded_metadata = {
67- key .decode (): value .decode ()
68- for key , value in reader .schema .metadata .items ()
69- }
70- if "command" in decoded_metadata :
71- command = decoded_metadata ["command" ]
72- api = decoded_metadata ["api" ]
73- param = decoded_metadata ["param" ]
74- value = reader .read_all ()
75- # Merge the existing dictionary for the same key, as we have multiple calls to do_put for the same key
76- params .update ({"command" : command , "api" : api , param : value })
77-
78- self .flights [key ] = params
61+ print (f"No 'api' field in command: { command } " )
7962
8063 # Extracts the API parameters from the flights dictionary, delegates the execution to the FeatureStore instance
8164 # and returns the stream of data
@@ -85,18 +68,17 @@ def do_get(self, context, ticket):
8568 print (f"Unknown key { key } " )
8669 return None
8770
88- api = self .flights [key ]["api" ]
89- # print(f"get key is {key}")
71+ command = json .loads (key [1 ])
72+ api = command ["api" ]
73+ # print(f"get command is {command}")
9074 # print(f"requested api is {api}")
9175 if api == "get_historical_features" :
92- # Extract parameters from the internal flight descriptor
93- entity_df_value = self .flights [key ][ "entity_df" ]
76+ # Extract parameters from the internal flights dictionary
77+ entity_df_value = self .flights [key ]
9478 entity_df = pa .Table .to_pandas (entity_df_value )
9579 # print(f"entity_df is {entity_df}")
9680
97- features_value = self .flights [key ]["features" ]
98- features = pa .RecordBatch .to_pylist (features_value )
99- features = [item ["features" ] for item in features ]
81+ features = command ["features" ]
10082 # print(f"features is {features}")
10183
10284 print (
@@ -113,7 +95,7 @@ def do_get(self, context, ticket):
11395 traceback .print_exc ()
11496 table = pa .Table .from_pandas (training_df )
11597
116- # Get service is consumed, so we clear the corresponding flight
98+ # Get service is consumed, so we clear the corresponding flight and data
11799 del self .flights [key ]
118100
119101 return fl .RecordBatchStream (table )
0 commit comments