Skip to content

Commit 05f4e8f

Browse files
authored
Python FeatureServer optimization (feast-dev#2202)
* Optimize Python FeatureServer Signed-off-by: Judah Rand <17158624+judahrand@users.noreply.github.com> * Handle `RepeatedValue` proto in `_get_online_features` Signed-off-by: Judah Rand <17158624+judahrand@users.noreply.github.com> * Only initialize `Timestamp` once Signed-off-by: Judah Rand <17158624+judahrand@users.noreply.github.com> * Don't use `defaultdict` Signed-off-by: Judah Rand <17158624+judahrand@users.noreply.github.com>
1 parent f32b4f4 commit 05f4e8f

File tree

2 files changed

+163
-110
lines changed

2 files changed

+163
-110
lines changed

sdk/python/feast/feature_server.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import feast
1111
from feast import proto_json
1212
from feast.protos.feast.serving.ServingService_pb2 import GetOnlineFeaturesRequest
13-
from feast.type_map import feast_value_type_to_python_type
1413

1514

1615
def get_app(store: "feast.FeatureStore"):
@@ -43,16 +42,11 @@ def get_online_features(body=Depends(get_body)):
4342
if any(batch_size != num_entities for batch_size in batch_sizes):
4443
raise HTTPException(status_code=500, detail="Uneven number of columns")
4544

46-
entity_rows = [
47-
{
48-
k: feast_value_type_to_python_type(v.val[idx])
49-
for k, v in request_proto.entities.items()
50-
}
51-
for idx in range(num_entities)
52-
]
53-
54-
response_proto = store.get_online_features(
55-
features, entity_rows, full_feature_names=full_feature_names
45+
response_proto = store._get_online_features(
46+
features,
47+
request_proto.entities,
48+
full_feature_names=full_feature_names,
49+
native_entity_values=False,
5650
).proto
5751

5852
# Convert the Protobuf object to JSON and return it

sdk/python/feast/feature_store.py

Lines changed: 158 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@
2323
Dict,
2424
Iterable,
2525
List,
26+
Mapping,
2627
NamedTuple,
2728
Optional,
29+
Sequence,
2830
Set,
2931
Tuple,
3032
Union,
@@ -72,7 +74,7 @@
7274
GetOnlineFeaturesResponse,
7375
)
7476
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
75-
from feast.protos.feast.types.Value_pb2 import Value
77+
from feast.protos.feast.types.Value_pb2 import RepeatedValue, Value
7678
from feast.registry import Registry
7779
from feast.repo_config import RepoConfig, load_repo_config
7880
from feast.request_feature_view import RequestFeatureView
@@ -267,14 +269,18 @@ def _list_feature_views(
267269
return feature_views
268270

269271
@log_exceptions_and_usage
270-
def list_on_demand_feature_views(self) -> List[OnDemandFeatureView]:
272+
def list_on_demand_feature_views(
273+
self, allow_cache: bool = False
274+
) -> List[OnDemandFeatureView]:
271275
"""
272276
Retrieves the list of on demand feature views from the registry.
273277
274278
Returns:
275279
A list of on demand feature views.
276280
"""
277-
return self._registry.list_on_demand_feature_views(self.project)
281+
return self._registry.list_on_demand_feature_views(
282+
self.project, allow_cache=allow_cache
283+
)
278284

279285
@log_exceptions_and_usage
280286
def get_entity(self, name: str) -> Entity:
@@ -1067,6 +1073,30 @@ def get_online_features(
10671073
... )
10681074
>>> online_response_dict = online_response.to_dict()
10691075
"""
1076+
columnar: Dict[str, List[Any]] = {k: [] for k in entity_rows[0].keys()}
1077+
for entity_row in entity_rows:
1078+
for key, value in entity_row.items():
1079+
try:
1080+
columnar[key].append(value)
1081+
except KeyError as e:
1082+
raise ValueError("All entity_rows must have the same keys.") from e
1083+
1084+
return self._get_online_features(
1085+
features=features,
1086+
entity_values=columnar,
1087+
full_feature_names=full_feature_names,
1088+
native_entity_values=True,
1089+
)
1090+
1091+
def _get_online_features(
1092+
self,
1093+
features: Union[List[str], FeatureService],
1094+
entity_values: Mapping[
1095+
str, Union[Sequence[Any], Sequence[Value], RepeatedValue]
1096+
],
1097+
full_feature_names: bool = False,
1098+
native_entity_values: bool = True,
1099+
):
10701100
_feature_refs = self._get_features(features, allow_cache=True)
10711101
(
10721102
requested_feature_views,
@@ -1076,6 +1106,29 @@ def get_online_features(
10761106
features=features, allow_cache=True, hide_dummy_entity=False
10771107
)
10781108

1109+
entity_name_to_join_key_map, entity_type_map = self._get_entity_maps(
1110+
requested_feature_views
1111+
)
1112+
1113+
# Extract Sequence from RepeatedValue Protobuf.
1114+
entity_value_lists: Dict[str, Union[List[Any], List[Value]]] = {
1115+
k: list(v) if isinstance(v, Sequence) else list(v.val)
1116+
for k, v in entity_values.items()
1117+
}
1118+
1119+
entity_proto_values: Dict[str, List[Value]]
1120+
if native_entity_values:
1121+
# Convert values to Protobuf once.
1122+
entity_proto_values = {
1123+
k: python_values_to_proto_values(
1124+
v, entity_type_map.get(k, ValueType.UNKNOWN)
1125+
)
1126+
for k, v in entity_value_lists.items()
1127+
}
1128+
else:
1129+
entity_proto_values = entity_value_lists
1130+
1131+
num_rows = _validate_entity_values(entity_proto_values)
10791132
_validate_feature_refs(_feature_refs, full_feature_names)
10801133
(
10811134
grouped_refs,
@@ -1101,111 +1154,72 @@ def get_online_features(
11011154
}
11021155

11031156
feature_views = list(view for view, _ in grouped_refs)
1104-
entityless_case = DUMMY_ENTITY_NAME in [
1105-
entity_name
1106-
for feature_view in feature_views
1107-
for entity_name in feature_view.entities
1108-
]
1109-
1110-
provider = self._get_provider()
1111-
entities = self._list_entities(allow_cache=True, hide_dummy_entity=False)
1112-
entity_name_to_join_key_map: Dict[str, str] = {}
1113-
join_key_to_entity_type_map: Dict[str, ValueType] = {}
1114-
for entity in entities:
1115-
entity_name_to_join_key_map[entity.name] = entity.join_key
1116-
join_key_to_entity_type_map[entity.join_key] = entity.value_type
1117-
for feature_view in requested_feature_views:
1118-
for entity_name in feature_view.entities:
1119-
entity = self._registry.get_entity(
1120-
entity_name, self.project, allow_cache=True
1121-
)
1122-
# User directly uses join_key as the entity reference in the entity_rows for the
1123-
# entity mapping case.
1124-
entity_name = feature_view.projection.join_key_map.get(
1125-
entity.join_key, entity.name
1126-
)
1127-
join_key = feature_view.projection.join_key_map.get(
1128-
entity.join_key, entity.join_key
1129-
)
1130-
entity_name_to_join_key_map[entity_name] = join_key
1131-
join_key_to_entity_type_map[join_key] = entity.value_type
11321157

11331158
needed_request_data, needed_request_fv_features = self.get_needed_request_data(
11341159
grouped_odfv_refs, grouped_request_fv_refs
11351160
)
11361161

1137-
join_key_rows = []
1138-
request_data_features: Dict[str, List[Any]] = defaultdict(list)
1162+
join_key_values: Dict[str, List[Value]] = {}
1163+
request_data_features: Dict[str, List[Value]] = {}
11391164
# Entity rows may be either entities or request data.
1140-
for row in entity_rows:
1141-
join_key_row = {}
1142-
for entity_name, entity_value in row.items():
1143-
# Found request data
1144-
if (
1145-
entity_name in needed_request_data
1146-
or entity_name in needed_request_fv_features
1147-
):
1148-
if entity_name in needed_request_fv_features:
1149-
# If the data was requested as a feature then
1150-
# make sure it appears in the result.
1151-
requested_result_row_names.add(entity_name)
1152-
request_data_features[entity_name].append(entity_value)
1153-
else:
1154-
try:
1155-
join_key = entity_name_to_join_key_map[entity_name]
1156-
except KeyError:
1157-
raise EntityNotFoundException(entity_name, self.project)
1158-
# All join keys should be returned in the result.
1159-
requested_result_row_names.add(join_key)
1160-
join_key_row[join_key] = entity_value
1161-
if entityless_case:
1162-
join_key_row[DUMMY_ENTITY_ID] = DUMMY_ENTITY_VAL
1163-
if len(join_key_row) > 0:
1164-
# May be empty if this entity row was request data
1165-
join_key_rows.append(join_key_row)
1165+
for entity_name, values in entity_proto_values.items():
1166+
# Found request data
1167+
if (
1168+
entity_name in needed_request_data
1169+
or entity_name in needed_request_fv_features
1170+
):
1171+
if entity_name in needed_request_fv_features:
1172+
# If the data was requested as a feature then
1173+
# make sure it appears in the result.
1174+
requested_result_row_names.add(entity_name)
1175+
request_data_features[entity_name] = values
1176+
else:
1177+
try:
1178+
join_key = entity_name_to_join_key_map[entity_name]
1179+
except KeyError:
1180+
raise EntityNotFoundException(entity_name, self.project)
1181+
# All join keys should be returned in the result.
1182+
requested_result_row_names.add(join_key)
1183+
join_key_values[join_key] = values
11661184

11671185
self.ensure_request_data_values_exist(
11681186
needed_request_data, needed_request_fv_features, request_data_features
11691187
)
11701188

1171-
# Convert join_key_rows from rowise to columnar.
1172-
join_key_python_values: Dict[str, List[Value]] = defaultdict(list)
1173-
for join_key_row in join_key_rows:
1174-
for join_key, value in join_key_row.items():
1175-
join_key_python_values[join_key].append(value)
1176-
1177-
# Convert all join key values to Protobuf Values
1178-
join_key_proto_values = {
1179-
k: python_values_to_proto_values(v, join_key_to_entity_type_map[k])
1180-
for k, v in join_key_python_values.items()
1181-
}
1182-
1183-
# Populate online features response proto with join keys
1189+
# Populate online features response proto with join keys and request data features
11841190
online_features_response = GetOnlineFeaturesResponse(
1185-
results=[
1186-
GetOnlineFeaturesResponse.FeatureVector()
1187-
for _ in range(len(entity_rows))
1188-
]
1191+
results=[GetOnlineFeaturesResponse.FeatureVector() for _ in range(num_rows)]
11891192
)
1190-
for key, values in join_key_proto_values.items():
1191-
online_features_response.metadata.feature_names.val.append(key)
1192-
for row_idx, result_row in enumerate(online_features_response.results):
1193-
result_row.values.append(values[row_idx])
1194-
result_row.statuses.append(FieldStatus.PRESENT)
1195-
result_row.event_timestamps.append(Timestamp())
1193+
self._populate_result_rows_from_columnar(
1194+
online_features_response=online_features_response,
1195+
data=dict(**join_key_values, **request_data_features),
1196+
)
1197+
1198+
# Add the Entityless case after populating result rows to avoid having to remove
1199+
# it later.
1200+
entityless_case = DUMMY_ENTITY_NAME in [
1201+
entity_name
1202+
for feature_view in feature_views
1203+
for entity_name in feature_view.entities
1204+
]
1205+
if entityless_case:
1206+
join_key_values[DUMMY_ENTITY_ID] = python_values_to_proto_values(
1207+
[DUMMY_ENTITY_VAL] * num_rows, DUMMY_ENTITY.value_type
1208+
)
11961209

11971210
# Initialize the set of EntityKeyProtos once and reuse them for each FeatureView
11981211
# to avoid initialization overhead.
1199-
entity_keys = [EntityKeyProto() for _ in range(len(join_key_rows))]
1212+
entity_keys = [EntityKeyProto() for _ in range(num_rows)]
1213+
provider = self._get_provider()
12001214
for table, requested_features in grouped_refs:
12011215
# Get the correct set of entity values with the correct join keys.
1202-
entity_values = self._get_table_entity_values(
1203-
table, entity_name_to_join_key_map, join_key_proto_values,
1216+
table_entity_values = self._get_table_entity_values(
1217+
table, entity_name_to_join_key_map, join_key_values,
12041218
)
12051219

12061220
# Set the EntityKeyProtos inplace.
12071221
self._set_table_entity_keys(
1208-
entity_values, entity_keys,
1222+
table_entity_values, entity_keys,
12091223
)
12101224

12111225
# Populate the result_rows with the Features from the OnlineStore inplace.
@@ -1218,10 +1232,6 @@ def get_online_features(
12181232
table,
12191233
)
12201234

1221-
self._populate_request_data_features(
1222-
online_features_response, request_data_features
1223-
)
1224-
12251235
if grouped_odfv_refs:
12261236
self._augment_response_with_on_demand_transforms(
12271237
online_features_response,
@@ -1235,6 +1245,50 @@ def get_online_features(
12351245
)
12361246
return OnlineResponse(online_features_response)
12371247

1248+
@staticmethod
1249+
def _get_columnar_entity_values(
1250+
rowise: Optional[List[Dict[str, Any]]], columnar: Optional[Dict[str, List[Any]]]
1251+
) -> Dict[str, List[Any]]:
1252+
if (rowise is None and columnar is None) or (
1253+
rowise is not None and columnar is not None
1254+
):
1255+
raise ValueError(
1256+
"Exactly one of `columnar_entity_values` and `rowise_entity_values` must be set."
1257+
)
1258+
1259+
if rowise is not None:
1260+
# Convert entity_rows from rowise to columnar.
1261+
res = defaultdict(list)
1262+
for entity_row in rowise:
1263+
for key, value in entity_row.items():
1264+
res[key].append(value)
1265+
return res
1266+
return cast(Dict[str, List[Any]], columnar)
1267+
1268+
def _get_entity_maps(self, feature_views):
1269+
entities = self._list_entities(allow_cache=True, hide_dummy_entity=False)
1270+
entity_name_to_join_key_map: Dict[str, str] = {}
1271+
entity_type_map: Dict[str, ValueType] = {}
1272+
for entity in entities:
1273+
entity_name_to_join_key_map[entity.name] = entity.join_key
1274+
entity_type_map[entity.name] = entity.value_type
1275+
for feature_view in feature_views:
1276+
for entity_name in feature_view.entities:
1277+
entity = self._registry.get_entity(
1278+
entity_name, self.project, allow_cache=True
1279+
)
1280+
# User directly uses join_key as the entity reference in the entity_rows for the
1281+
# entity mapping case.
1282+
entity_name = feature_view.projection.join_key_map.get(
1283+
entity.join_key, entity.name
1284+
)
1285+
join_key = feature_view.projection.join_key_map.get(
1286+
entity.join_key, entity.join_key
1287+
)
1288+
entity_name_to_join_key_map[entity_name] = join_key
1289+
entity_type_map[join_key] = entity.value_type
1290+
return entity_name_to_join_key_map, entity_type_map
1291+
12381292
@staticmethod
12391293
def _get_table_entity_values(
12401294
table: FeatureView,
@@ -1275,23 +1329,21 @@ def _set_table_entity_keys(
12751329
entity_key.entity_values.extend(next(rowise_values))
12761330

12771331
@staticmethod
1278-
def _populate_request_data_features(
1332+
def _populate_result_rows_from_columnar(
12791333
online_features_response: GetOnlineFeaturesResponse,
1280-
request_data_features: Dict[str, List[Any]],
1334+
data: Dict[str, List[Value]],
12811335
):
1282-
# Add more feature values to the existing result rows for the request data features
1283-
for feature_name, feature_values in request_data_features.items():
1284-
proto_values = python_values_to_proto_values(
1285-
feature_values, ValueType.UNKNOWN
1286-
)
1336+
timestamp = Timestamp() # Only initialize this timestamp once.
1337+
# Add more values to the existing result rows
1338+
for feature_name, feature_values in data.items():
12871339

12881340
online_features_response.metadata.feature_names.val.append(feature_name)
12891341

1290-
for row_idx, proto_value in enumerate(proto_values):
1342+
for row_idx, proto_value in enumerate(feature_values):
12911343
result_row = online_features_response.results[row_idx]
12921344
result_row.values.append(proto_value)
12931345
result_row.statuses.append(FieldStatus.PRESENT)
1294-
result_row.event_timestamps.append(Timestamp())
1346+
result_row.event_timestamps.append(timestamp)
12951347

12961348
@staticmethod
12971349
def get_needed_request_data(
@@ -1567,6 +1619,13 @@ def serve_transformations(self, port: int) -> None:
15671619
transformation_server.start_server(self, port)
15681620

15691621

1622+
def _validate_entity_values(join_key_values: Dict[str, List[Value]]):
1623+
set_of_row_lengths = {len(v) for v in join_key_values.values()}
1624+
if len(set_of_row_lengths) > 1:
1625+
raise ValueError("All entity rows must have the same columns.")
1626+
return set_of_row_lengths.pop()
1627+
1628+
15701629
def _validate_feature_refs(feature_refs: List[str], full_feature_names: bool = False):
15711630
collided_feature_refs = []
15721631

0 commit comments

Comments
 (0)