11from datetime import datetime
2+ < << << << HEAD
23from pathlib import Path
4+ == == == =
5+ >> >> >> > c239766f2 (feat : Add Milvus Vector Database Implementation (#4751))
36from typing import Any , Callable , Dict , List , Literal , Optional , Sequence , Tuple , Union
47
58from pydantic import StrictStr
811 CollectionSchema ,
912 DataType ,
1013 FieldSchema ,
14+ < << << << HEAD
1115 MilvusClient ,
1216)
17+ == == == =
18+ connections ,
19+ )
20+ from pymilvus .orm .connections import Connections
21+ > >> >> >> c239766f2 (feat : Add Milvus Vector Database Implementation (#4751))
1322
1423from feast import Entity
1524from feast .feature_view import FeatureView
@@ -85,16 +94,26 @@ class MilvusOnlineStoreConfig(FeastConfigBaseModel, VectorStoreConfig):
8594 """
8695
8796 type : Literal ["milvus" ] = "milvus"
97+ < << << << HEAD
8898 path : Optional [StrictStr ] = "data/online_store.db"
8999 host : Optional [StrictStr ] = "localhost"
90100 port : Optional [int ] = 19530
91101 index_type : Optional [str ] = "FLAT"
102+ == == == =
103+
104+ host : Optional [StrictStr ] = "localhost"
105+ port : Optional [int ] = 19530
106+ index_type : Optional [str ] = "IVF_FLAT"
107+ >> >> >> > c239766f2 (feat : Add Milvus Vector Database Implementation (#4751))
92108 metric_type : Optional [str ] = "L2"
93109 embedding_dim : Optional [int ] = 128
94110 vector_enabled : Optional [bool ] = True
95111 nlist : Optional [int ] = 128
112+ << < << << HEAD
96113 username : Optional [StrictStr ] = ""
97114 password : Optional [StrictStr ] = ""
115+ == == == =
116+ >> > >> > > c239766f2 (feat : Add Milvus Vector Database Implementation (#4751))
98117
99118
100119class MilvusOnlineStore (OnlineStore ):
@@ -105,6 +124,7 @@ class MilvusOnlineStore(OnlineStore):
105124 _collections: Dictionary to cache Milvus collections.
106125 """
107126
127+ << < << << HEAD
108128 client : Optional [MilvusClient ] = None
109129 _collections : Dict [str , Any ] = {}
110130
@@ -139,6 +159,26 @@ def _get_collection(self, config: RepoConfig, table: FeatureView) -> Dict[str, A
139159 self .client = self ._connect (config )
140160 collection_name = _table_id (config .project , table )
141161 if collection_name not in self ._collections :
162+ == == == =
163+ _conn : Optional [Connections ] = None
164+ _collections : Dict [str , Collection ] = {}
165+
166+ def _connect (self , config : RepoConfig ) - > connections :
167+ if not self ._conn :
168+ if not connections .has_connection ("feast" ):
169+ self ._conn = connections .connect (
170+ alias = "feast" ,
171+ host = config .online_store .host ,
172+ port = str (config .online_store .port ),
173+ )
174+ return self ._conn
175+
176+ def _get_collection (self , config : RepoConfig , table : FeatureView ) - > Collection :
177+ collection_name = _table_id (config .project , table )
178+ if collection_name not in self ._collections :
179+ self ._connect (config )
180+
181+ >> >> > >> c239766f2 (feat : Add Milvus Vector Database Implementation (#4751))
142182 # Create a composite key by combining entity fields
143183 composite_key_name = (
144184 "_" .join ([field .name for field in table .entity_columns ]) + "_pk"
@@ -184,6 +224,7 @@ def _get_collection(self, config: RepoConfig, table: FeatureView) -> Dict[str, A
184224 schema = CollectionSchema (
185225 fields = fields , description = "Feast feature view data"
186226 )
227+ << << < << HEAD
187228 collection_exists = self .client .has_collection (
188229 collection_name = collection_name
189230 )
@@ -216,6 +257,25 @@ def _get_collection(self, config: RepoConfig, table: FeatureView) -> Dict[str, A
216257 self ._collections [collection_name ] = self .client .describe_collection (
217258 collection_name
218259 )
260+ == == == =
261+ collection = Collection (name = collection_name , schema = schema , using = "feast" )
262+ if not collection .has_index ():
263+ index_params = {
264+ "index_type" : config .online_store .index_type ,
265+ "metric_type" : config .online_store .metric_type ,
266+ "params" : {"nlist" : config .online_store .nlist },
267+ }
268+ for vector_field in schema .fields :
269+ if vector_field .dtype in [
270+ DataType .FLOAT_VECTOR ,
271+ DataType .BINARY_VECTOR ,
272+ ]:
273+ collection .create_index (
274+ field_name = vector_field .name , index_params = index_params
275+ )
276+ collection .load ()
277+ self ._collections [collection_name ] = collection
278+ >> >> > >> c239766f2 (feat : Add Milvus Vector Database Implementation (#4751))
219279 return self ._collections [collection_name ]
220280
221281 def online_write_batch (
@@ -232,7 +292,10 @@ def online_write_batch(
232292 ],
233293 progress : Optional [Callable [[int ], Any ]],
234294 ) - > None :
295+ << < << < < HEAD
235296 self .client = self ._connect (config )
297+ == == == =
298+ >> > >> > > c239766f2 (feat : Add Milvus Vector Database Implementation (#4751))
236299 collection = self ._get_collection (config , table )
237300 entity_batch_to_insert = []
238301 for entity_key , values_dict , timestamp , created_ts in data :
@@ -265,10 +328,15 @@ def online_write_batch(
265328 if progress :
266329 progress (1 )
267330
331+ << < << << HEAD
268332 self .client .insert (
269333 collection_name = collection ["collection_name" ],
270334 data = entity_batch_to_insert ,
271335 )
336+ == == == =
337+ collection .insert (entity_batch_to_insert )
338+ collection .flush ()
339+ >> > >> >> c239766f2 (feat : Add Milvus Vector Database Implementation (#4751))
272340
273341 def online_read (
274342 self ,
@@ -288,6 +356,7 @@ def update(
288356 entities_to_keep : Sequence [Entity ],
289357 partial : bool ,
290358 ):
359+ << < << < < HEAD
291360 self .client = self ._connect (config )
292361 for table in tables_to_keep :
293362 self ._collections = self ._get_collection (config , table )
@@ -296,6 +365,16 @@ def update(
296365 collection_name = _table_id (config .project , table )
297366 if self ._collections .get (collection_name , None ):
298367 self .client .drop_collection (collection_name )
368+ == == == =
369+ self ._connect (config )
370+ for table in tables_to_keep :
371+ self ._get_collection (config , table )
372+ for table in tables_to_delete :
373+ collection_name = _table_id (config .project , table )
374+ collection = Collection (name = collection_name )
375+ if collection .exists ():
376+ collection .drop ()
377+ >> > >> >> c239766f2 (feat : Add Milvus Vector Database Implementation (#4751))
299378 self ._collections .pop (collection_name , None )
300379
301380 def plan (
@@ -309,12 +388,21 @@ def teardown(
309388 tables : Sequence [FeatureView ],
310389 entities : Sequence [Entity ],
311390 ):
391+ << < << < < HEAD
312392 self .client = self ._connect (config )
313393 for table in tables :
314394 collection_name = _table_id (config .project , table )
315395 if self ._collections .get (collection_name , None ):
316396 self .client .drop_collection (collection_name )
317397 self ._collections .pop (collection_name , None )
398+ == == == =
399+ self ._connect (config )
400+ for table in tables :
401+ collection = self ._get_collection (config , table )
402+ if collection :
403+ collection .drop ()
404+ self ._collections .pop (collection .name , None )
405+ >> > >> >> c239766f2 (feat : Add Milvus Vector Database Implementation (#4751))
318406
319407 def retrieve_online_documents (
320408 self ,
@@ -334,8 +422,11 @@ def retrieve_online_documents(
334422 Optional [ValueProto ],
335423 ]
336424 ]:
425+ << < << < < HEAD
337426 self .client = self ._connect (config )
338427 collection_name = _table_id (config .project , table )
428+ == == == =
429+ >> > >> > > c239766f2 (feat : Add Milvus Vector Database Implementation (#4751))
339430 collection = self ._get_collection (config , table )
340431 if not config .online_store .vector_enabled :
341432 raise ValueError ("Vector search is not enabled in the online store config" )
@@ -359,6 +450,7 @@ def retrieve_online_documents(
359450 + ["created_ts" , "event_ts" ]
360451 )
361452 assert all (
453+ << << << < HEAD
362454 field in [f ["name" ] for f in collection ["fields" ]]
363455 for field in output_fields
364456 ), f"field(s) [{ [field for field in output_fields if field not in [f ['name' ] for f in collection ['fields' ]]]} ] not found in collection schema"
@@ -380,13 +472,38 @@ def retrieve_online_documents(
380472 search_params = search_params ,
381473 limit = top_k ,
382474 output_fields = output_fields ,
475+ == == == =
476+ field
477+ for field in output_fields
478+ if field in [f .name for f in collection .schema .fields ]
479+ ), f"field(s) [{ [field for field in output_fields if field not in [f .name for f in collection .schema .fields ]]} '] not found in collection schema"
480+
481+ # Note we choose the first vector field as the field to search on. Not ideal but it's something.
482+ ann_search_field = None
483+ for field in collection .schema .fields :
484+ if (
485+ field .dtype in [DataType .FLOAT_VECTOR , DataType .BINARY_VECTOR ]
486+ and field .name in output_fields
487+ ):
488+ ann_search_field = field .name
489+ break
490+
491+ results = collection .search (
492+ data = [embedding ],
493+ anns_field = ann_search_field ,
494+ param = search_params ,
495+ limit = top_k ,
496+ output_fields = output_fields ,
497+ consistency_level = "Strong" ,
498+ >> >> >> > c239766f2 (feat : Add Milvus Vector Database Implementation (#4751))
383499 )
384500
385501 result_list = []
386502 for hits in results :
387503 for hit in hits :
388504 single_record = {}
389505 for field in output_fields :
506+ << < << < < HEAD
390507 single_record [field ] = hit .get ("entity" , {}).get (field , None )
391508
392509 entity_key_bytes = bytes .fromhex (
@@ -398,6 +515,15 @@ def retrieve_online_documents(
398515 event_ts = datetime .fromtimestamp (
399516 hit .get ("entity" , {}).get ("event_ts" ) / 1e6
400517 )
518+ == == == =
519+ single_record [field ] = hit .entity .get (field )
520+
521+ entity_key_bytes = bytes .fromhex (hit .entity .get (composite_key_name ))
522+ embedding = hit .entity .get (ann_search_field )
523+ serialized_embedding = _serialize_vector_to_float_list (embedding )
524+ distance = hit .distance
525+ event_ts = datetime .fromtimestamp (hit .entity .get ("event_ts" ) / 1e6 )
526+ >> > >> >> c239766f2 (feat : Add Milvus Vector Database Implementation (#4751))
401527 prepared_result = _build_retrieve_online_document_record (
402528 entity_key_bytes ,
403529 # This may have a bug
@@ -453,7 +579,11 @@ def __init__(self, host: str, port: int, name: str):
453579 self ._connect ()
454580
455581 def _connect (self ):
582+ << << << < HEAD
456583 raise NotImplementedError
584+ == == == =
585+ return connections .connect (alias = "default" , host = self .host , port = str (self .port ))
586+ >> > >> >> c239766f2 (feat : Add Milvus Vector Database Implementation (#4751))
457587
458588 def to_infra_object_proto (self ) - > InfraObjectProto :
459589 # Implement serialization if needed
0 commit comments