Skip to content

Commit 37d93f1

Browse files
almost have deserialization from the search results done
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
1 parent dbc11e2 commit 37d93f1

File tree

1 file changed

+55
-23
lines changed
  • sdk/python/feast/infra/online_stores/milvus_online_store

1 file changed

+55
-23
lines changed

sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py

Lines changed: 55 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def _get_collection(self, config: RepoConfig, table: FeatureView) -> Collection:
135135
FieldSchema(name="event_ts", dtype=DataType.INT64),
136136
FieldSchema(name="created_ts", dtype=DataType.INT64),
137137
]
138-
fields_to_exclude = [field.name for field in table.entity_columns] + [
138+
fields_to_exclude = [
139139
"event_ts",
140140
"created_ts",
141141
]
@@ -199,11 +199,6 @@ def online_write_batch(
199199
progress: Optional[Callable[[int], Any]],
200200
) -> None:
201201
collection = self._get_collection(config, table)
202-
numeric_vector_list_types = [
203-
k
204-
for k in PROTO_VALUE_TO_VALUE_TYPE_MAP.keys()
205-
if k is not None and "list" in k and "string" not in k
206-
]
207202
entity_batch_to_insert = []
208203
for entity_key, values_dict, timestamp, created_ts in data:
209204
# need to construct the composite primary key also need to handle the fact that entities are a list
@@ -218,15 +213,11 @@ def online_write_batch(
218213
created_ts_int = (
219214
int(to_naive_utc(created_ts).timestamp() * 1e6) if created_ts else 0
220215
)
221-
for feature_name in values_dict:
222-
feature_values = values_dict[feature_name]
223-
for proto_val_type in PROTO_VALUE_TO_VALUE_TYPE_MAP:
224-
if feature_values.HasField(proto_val_type):
225-
if proto_val_type in numeric_vector_list_types:
226-
vector_values = getattr(feature_values, proto_val_type).val
227-
else:
228-
vector_values = getattr(feature_values, proto_val_type)
229-
values_dict[feature_name] = vector_values
216+
values_dict = _extract_proto_values_to_dict(values_dict)
217+
entity_dict = _extract_proto_values_to_dict(
218+
dict(zip(entity_key.join_keys, entity_key.entity_values))
219+
)
220+
values_dict.update(entity_dict)
230221

231222
single_entity_record = {
232223
composite_key_name: entity_key_str,
@@ -317,28 +308,51 @@ def retrieve_online_documents(
317308
expr = f"feature_name == '{requested_feature}'"
318309

319310
composite_key_name = (
320-
"_".join([str(value) for value in table.entity_columns]) + "_pk"
311+
"_".join([str(field.name) for field in table.entity_columns]) + "_pk"
321312
)
322313
if requested_features:
323314
features_str = ", ".join([f"'{f}'" for f in requested_features])
324315
expr += f" && feature_name in [{features_str}]"
325316

317+
output_fields = (
318+
[composite_key_name] + requested_features + ["created_ts", "event_ts"]
319+
)
320+
assert all(field for field in output_fields if field in [f.name for f in collection.schema.fields]), \
321+
f"field(s) [{[field for field in output_fields if field not in [f.name for f in collection.schema.fields]]}'] not found in collection schema"
322+
323+
# Note we choose the first vector field as the field to search on. Not ideal but it's something.
324+
ann_search_field = None
325+
for field in collection.schema.fields:
326+
if field.dtype in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]:
327+
ann_search_field = field.name
328+
break
329+
326330
results = collection.search(
327331
data=[embedding],
328-
anns_field="vector_value",
332+
anns_field=ann_search_field,
329333
param=search_params,
330334
limit=top_k,
331-
expr=expr,
332-
output_fields=[composite_key_name]
333-
+ requested_features
334-
+ ["created_ts", "event_ts"],
335+
# expr=expr,
336+
output_fields=output_fields,
335337
consistency_level="Strong",
336338
)
337339

338340
result_list = []
339341
for hits in results:
340342
for hit in hits:
341-
entity_key_str = hit.entity.get("entity_key")
343+
single_record = {}
344+
for field in output_fields:
345+
val = hit.entity.get(field)
346+
if field == composite_key_name:
347+
val = deserialize_entity_key(
348+
bytes.fromhex(val),
349+
config.entity_key_serialization_version,
350+
)
351+
entity_key_proto = val
352+
single_record[field] = val
353+
354+
355+
entity_key_str = hit.entity.get(composite_key_name)
342356
val_bin = hit.entity.get("value")
343357
val = ValueProto()
344358
val.ParseFromString(val_bin)
@@ -350,7 +364,7 @@ def retrieve_online_documents(
350364
)
351365
result_list.append(
352366
_build_retrieve_online_document_record(
353-
entity_key,
367+
entity_key_proto,
354368
val.SerializeToString(),
355369
embedding,
356370
distance,
@@ -365,6 +379,24 @@ def _table_id(project: str, table: FeatureView) -> str:
365379
return f"{project}_{table.name}"
366380

367381

382+
def _extract_proto_values_to_dict(input_dict: Dict[str, Any]) -> Dict[str, Any]:
383+
numeric_vector_list_types = [
384+
k
385+
for k in PROTO_VALUE_TO_VALUE_TYPE_MAP.keys()
386+
if k is not None and "list" in k and "string" not in k
387+
]
388+
output_dict = {}
389+
for feature_name, feature_values in input_dict.items():
390+
for proto_val_type in PROTO_VALUE_TO_VALUE_TYPE_MAP:
391+
if feature_values.HasField(proto_val_type):
392+
if proto_val_type in numeric_vector_list_types:
393+
vector_values = getattr(feature_values, proto_val_type).val
394+
else:
395+
vector_values = getattr(feature_values, proto_val_type)
396+
output_dict[feature_name] = vector_values
397+
return output_dict
398+
399+
368400
class MilvusTable(InfraObject):
369401
"""
370402
A Milvus collection managed by Feast.

0 commit comments

Comments
 (0)