@@ -135,7 +135,7 @@ def _get_collection(self, config: RepoConfig, table: FeatureView) -> Collection:
135135 FieldSchema (name = "event_ts" , dtype = DataType .INT64 ),
136136 FieldSchema (name = "created_ts" , dtype = DataType .INT64 ),
137137 ]
138- fields_to_exclude = [field . name for field in table . entity_columns ] + [
138+ fields_to_exclude = [
139139 "event_ts" ,
140140 "created_ts" ,
141141 ]
@@ -199,11 +199,6 @@ def online_write_batch(
199199 progress : Optional [Callable [[int ], Any ]],
200200 ) -> None :
201201 collection = self ._get_collection (config , table )
202- numeric_vector_list_types = [
203- k
204- for k in PROTO_VALUE_TO_VALUE_TYPE_MAP .keys ()
205- if k is not None and "list" in k and "string" not in k
206- ]
207202 entity_batch_to_insert = []
208203 for entity_key , values_dict , timestamp , created_ts in data :
209204 # need to construct the composite primary key also need to handle the fact that entities are a list
@@ -218,15 +213,11 @@ def online_write_batch(
218213 created_ts_int = (
219214 int (to_naive_utc (created_ts ).timestamp () * 1e6 ) if created_ts else 0
220215 )
221- for feature_name in values_dict :
222- feature_values = values_dict [feature_name ]
223- for proto_val_type in PROTO_VALUE_TO_VALUE_TYPE_MAP :
224- if feature_values .HasField (proto_val_type ):
225- if proto_val_type in numeric_vector_list_types :
226- vector_values = getattr (feature_values , proto_val_type ).val
227- else :
228- vector_values = getattr (feature_values , proto_val_type )
229- values_dict [feature_name ] = vector_values
216+ values_dict = _extract_proto_values_to_dict (values_dict )
217+ entity_dict = _extract_proto_values_to_dict (
218+ dict (zip (entity_key .join_keys , entity_key .entity_values ))
219+ )
220+ values_dict .update (entity_dict )
230221
231222 single_entity_record = {
232223 composite_key_name : entity_key_str ,
@@ -317,28 +308,51 @@ def retrieve_online_documents(
317308 expr = f"feature_name == '{ requested_feature } '"
318309
319310 composite_key_name = (
320- "_" .join ([str (value ) for value in table .entity_columns ]) + "_pk"
311+ "_" .join ([str (field . name ) for field in table .entity_columns ]) + "_pk"
321312 )
322313 if requested_features :
323314 features_str = ", " .join ([f"'{ f } '" for f in requested_features ])
324315 expr += f" && feature_name in [{ features_str } ]"
325316
317+ output_fields = (
318+ [composite_key_name ] + requested_features + ["created_ts" , "event_ts" ]
319+ )
320+ assert all (field for field in output_fields if field in [f .name for f in collection .schema .fields ]), \
321+ f"field(s) [{ [field for field in output_fields if field not in [f .name for f in collection .schema .fields ]]} '] not found in collection schema"
322+
323+ # Note we choose the first vector field as the field to search on. Not ideal but it's something.
324+ ann_search_field = None
325+ for field in collection .schema .fields :
326+ if field .dtype in [DataType .FLOAT_VECTOR , DataType .BINARY_VECTOR ]:
327+ ann_search_field = field .name
328+ break
329+
326330 results = collection .search (
327331 data = [embedding ],
328- anns_field = "vector_value" ,
332+ anns_field = ann_search_field ,
329333 param = search_params ,
330334 limit = top_k ,
331- expr = expr ,
332- output_fields = [composite_key_name ]
333- + requested_features
334- + ["created_ts" , "event_ts" ],
335+ # expr=expr,
336+ output_fields = output_fields ,
335337 consistency_level = "Strong" ,
336338 )
337339
338340 result_list = []
339341 for hits in results :
340342 for hit in hits :
341- entity_key_str = hit .entity .get ("entity_key" )
343+ single_record = {}
344+ for field in output_fields :
345+ val = hit .entity .get (field )
346+ if field == composite_key_name :
347+ val = deserialize_entity_key (
348+ bytes .fromhex (val ),
349+ config .entity_key_serialization_version ,
350+ )
351+ entity_key_proto = val
352+ single_record [field ] = val
353+
354+
355+ entity_key_str = hit .entity .get (composite_key_name )
342356 val_bin = hit .entity .get ("value" )
343357 val = ValueProto ()
344358 val .ParseFromString (val_bin )
@@ -350,7 +364,7 @@ def retrieve_online_documents(
350364 )
351365 result_list .append (
352366 _build_retrieve_online_document_record (
353- entity_key ,
367+ entity_key_proto ,
354368 val .SerializeToString (),
355369 embedding ,
356370 distance ,
@@ -365,6 +379,24 @@ def _table_id(project: str, table: FeatureView) -> str:
365379 return f"{ project } _{ table .name } "
366380
367381
382+ def _extract_proto_values_to_dict (input_dict : Dict [str , Any ]) -> Dict [str , Any ]:
383+ numeric_vector_list_types = [
384+ k
385+ for k in PROTO_VALUE_TO_VALUE_TYPE_MAP .keys ()
386+ if k is not None and "list" in k and "string" not in k
387+ ]
388+ output_dict = {}
389+ for feature_name , feature_values in input_dict .items ():
390+ for proto_val_type in PROTO_VALUE_TO_VALUE_TYPE_MAP :
391+ if feature_values .HasField (proto_val_type ):
392+ if proto_val_type in numeric_vector_list_types :
393+ vector_values = getattr (feature_values , proto_val_type ).val
394+ else :
395+ vector_values = getattr (feature_values , proto_val_type )
396+ output_dict [feature_name ] = vector_values
397+ return output_dict
398+
399+
368400class MilvusTable (InfraObject ):
369401 """
370402 A Milvus collection managed by Feast.
0 commit comments