Skip to content

Commit ec763de

Browse files
authored
Merge pull request feast-dev#4 from dmartinol/remote_offline
Integrating comments
2 parents 01fa2f6 + 31d1fe8 commit ec763de

2 files changed

Lines changed: 45 additions & 71 deletions

File tree

sdk/python/feast/infra/offline_stores/remote.py

Lines changed: 25 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import uuid
23
from datetime import datetime
34
from pathlib import Path
@@ -39,45 +40,24 @@ def __init__(
3940
entity_df: Union[pd.DataFrame, str],
4041
# TODO add missing parameters from the OfflineStore API
4142
):
42-
# Generate unique command identifier
43-
self.command = str(uuid.uuid4())
4443
# Initialize the client connection
4544
self.client = fl.connect(
4645
f"grpc://{config.offline_store.host}:{config.offline_store.port}"
4746
)
48-
# Put API parameters
49-
self._put_parameters(feature_refs, entity_df)
47+
self.feature_refs = feature_refs
48+
self.entity_df = entity_df
5049

51-
def _put_parameters(self, feature_refs, entity_df):
52-
historical_flight_descriptor = fl.FlightDescriptor.for_command(self.command)
50+
# TODO add one specialized implementation for each OfflineStore API
51+
# This can result in a dictionary of functions indexed by api (e.g., "get_historical_features")
52+
def _put_parameters(self, command_descriptor):
53+
entity_df_table = pa.Table.from_pandas(self.entity_df)
5354

54-
entity_df_table = pa.Table.from_pandas(entity_df)
5555
writer, _ = self.client.do_put(
56-
historical_flight_descriptor,
57-
entity_df_table.schema.with_metadata(
58-
{
59-
"command": self.command,
60-
"api": "get_historical_features",
61-
"param": "entity_df",
62-
}
63-
),
56+
command_descriptor,
57+
entity_df_table.schema,
6458
)
65-
writer.write_table(entity_df_table)
66-
writer.close()
6759

68-
features_array = pa.array(feature_refs)
69-
features_batch = pa.RecordBatch.from_arrays([features_array], ["features"])
70-
writer, _ = self.client.do_put(
71-
historical_flight_descriptor,
72-
features_batch.schema.with_metadata(
73-
{
74-
"command": self.command,
75-
"api": "get_historical_features",
76-
"param": "features",
77-
}
78-
),
79-
)
80-
writer.write_batch(features_batch)
60+
writer.write_table(entity_df_table)
8161
writer.close()
8262

8363
# Invoked to realize the Pandas DataFrame
@@ -88,8 +68,21 @@ def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame:
8868
# Invoked to synchronously execute the underlying query and return the result as an arrow table
8969
# This is where do_get service is invoked
9070
def _to_arrow_internal(self, timeout: Optional[int] = None) -> pa.Table:
91-
upload_descriptor = fl.FlightDescriptor.for_command(self.command)
92-
flight = self.client.get_flight_info(upload_descriptor)
71+
# Generate unique command identifier
72+
command_id = str(uuid.uuid4())
73+
command = {
74+
"command_id": command_id,
75+
"api": "get_historical_features",
76+
"features": self.feature_refs,
77+
}
78+
command_descriptor = fl.FlightDescriptor.for_command(
79+
json.dumps(
80+
command,
81+
)
82+
)
83+
84+
self._put_parameters(command_descriptor)
85+
flight = self.client.get_flight_info(command_descriptor)
9386
ticket = flight.endpoints[0].ticket
9487

9588
reader = self.client.do_get(ticket)
@@ -112,7 +105,6 @@ def get_historical_features(
112105
project: str,
113106
full_feature_names: bool = False,
114107
) -> RemoteRetrievalJob:
115-
print(f"config.offline_store is {type(config.offline_store)}")
116108
assert isinstance(config.offline_store, RemoteOfflineStoreConfig)
117109

118110
# TODO: extend RemoteRetrievalJob API with all method parameters

sdk/python/feast/offline_server.py

Lines changed: 20 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import ast
2+
import json
23
import traceback
3-
from typing import Dict
4+
from typing import Any, Dict
45

56
import pyarrow as pa
67
import pyarrow.flight as fl
@@ -12,7 +13,8 @@ class OfflineServer(fl.FlightServerBase):
1213
def __init__(self, store: FeatureStore, location: str, **kwargs):
1314
super(OfflineServer, self).__init__(location, **kwargs)
1415
self._location = location
15-
self.flights: Dict[str, Dict[str, str]] = {}
16+
# A dictionary of configured flights, e.g. API calls received and not yet served
17+
self.flights: Dict[str, Any] = {}
1618
self.store = store
1719

1820
@classmethod
@@ -23,20 +25,12 @@ def descriptor_to_key(self, descriptor):
2325
tuple(descriptor.path or tuple()),
2426
)
2527

26-
# TODO: since we cannot anticipate here the call to get_historical_features call, what data should we return?
27-
# ATM it returns the metadata of the "entity_df" table
2828
def _make_flight_info(self, key, descriptor, params):
29-
table = params["entity_df"]
3029
endpoints = [fl.FlightEndpoint(repr(key), [self._location])]
31-
mock_sink = pa.MockOutputStream()
32-
stream_writer = pa.RecordBatchStreamWriter(mock_sink, table.schema)
33-
stream_writer.write_table(table)
34-
stream_writer.close()
35-
data_size = mock_sink.size()
36-
37-
return fl.FlightInfo(
38-
table.schema, descriptor, endpoints, table.num_rows, data_size
39-
)
30+
# TODO calculate actual schema from the given features
31+
schema = pa.schema([])
32+
33+
return fl.FlightInfo(schema, descriptor, endpoints, -1, -1)
4034

4135
def get_flight_info(self, context, descriptor):
4236
key = OfflineServer.descriptor_to_key(descriptor)
@@ -59,23 +53,12 @@ def list_flights(self, context, criteria):
5953
def do_put(self, context, descriptor, reader, writer):
6054
key = OfflineServer.descriptor_to_key(descriptor)
6155

62-
if key in self.flights:
63-
params = self.flights[key]
56+
command = json.loads(key[1])
57+
if "api" in command:
58+
data = reader.read_all()
59+
self.flights[key] = data
6460
else:
65-
params = {}
66-
decoded_metadata = {
67-
key.decode(): value.decode()
68-
for key, value in reader.schema.metadata.items()
69-
}
70-
if "command" in decoded_metadata:
71-
command = decoded_metadata["command"]
72-
api = decoded_metadata["api"]
73-
param = decoded_metadata["param"]
74-
value = reader.read_all()
75-
# Merge the existing dictionary for the same key, as we have multiple calls to do_put for the same key
76-
params.update({"command": command, "api": api, param: value})
77-
78-
self.flights[key] = params
61+
print(f"No 'api' field in command: {command}")
7962

8063
# Extracts the API parameters from the flights dictionary, delegates the execution to the FeatureStore instance
8164
# and returns the stream of data
@@ -85,18 +68,17 @@ def do_get(self, context, ticket):
8568
print(f"Unknown key {key}")
8669
return None
8770

88-
api = self.flights[key]["api"]
89-
# print(f"get key is {key}")
71+
command = json.loads(key[1])
72+
api = command["api"]
73+
# print(f"get command is {command}")
9074
# print(f"requested api is {api}")
9175
if api == "get_historical_features":
92-
# Extract parameters from the internal flight descriptor
93-
entity_df_value = self.flights[key]["entity_df"]
76+
# Extract parameters from the internal flights dictionary
77+
entity_df_value = self.flights[key]
9478
entity_df = pa.Table.to_pandas(entity_df_value)
9579
# print(f"entity_df is {entity_df}")
9680

97-
features_value = self.flights[key]["features"]
98-
features = pa.RecordBatch.to_pylist(features_value)
99-
features = [item["features"] for item in features]
81+
features = command["features"]
10082
# print(f"features is {features}")
10183

10284
print(
@@ -113,7 +95,7 @@ def do_get(self, context, ticket):
11395
traceback.print_exc()
11496
table = pa.Table.from_pandas(training_df)
11597

116-
# Get service is consumed, so we clear the corresponding flight
98+
# Get service is consumed, so we clear the corresponding flight and data
11799
del self.flights[key]
118100

119101
return fl.RecordBatchStream(table)

0 commit comments

Comments
 (0)