77 CollectionSchema ,
88 DataType ,
99 FieldSchema ,
10- connections ,
10+ MilvusClient ,
1111)
12- from pymilvus .orm .connections import Connections
1312
1413from feast import Entity
1514from feast .feature_view import FeatureView
@@ -85,14 +84,15 @@ class MilvusOnlineStoreConfig(FeastConfigBaseModel, VectorStoreConfig):
8584 """
8685
8786 type : Literal ["milvus" ] = "milvus"
88-
8987 host : Optional [StrictStr ] = "localhost"
9088 port : Optional [int ] = 19530
9189 index_type : Optional [str ] = "IVF_FLAT"
9290 metric_type : Optional [str ] = "L2"
9391 embedding_dim : Optional [int ] = 128
9492 vector_enabled : Optional [bool ] = True
9593 nlist : Optional [int ] = 128
94+ username : Optional [StrictStr ] = ""
95+ password : Optional [StrictStr ] = ""
9696
9797
9898class MilvusOnlineStore (OnlineStore ):
@@ -103,24 +103,23 @@ class MilvusOnlineStore(OnlineStore):
103103 _collections: Dictionary to cache Milvus collections.
104104 """
105105
106- _conn : Optional [Connections ] = None
107- _collections : Dict [str , Collection ] = {}
106+ client : Optional [MilvusClient ] = None
107+ _collections : Dict [str , Any ] = {}
108108
109- def _connect (self , config : RepoConfig ) -> connections :
110- if not self ._conn :
111- if not connections . has_connection ( "feast" ):
112- self . _conn = connections . connect (
113- alias = "feast" ,
114- host = config .online_store .host ,
115- port = str ( config . online_store . port ) ,
116- )
117- return self ._conn
109+ def _connect (self , config : RepoConfig ) -> MilvusClient :
110+ if not self .client :
111+ self . client = MilvusClient (
112+ url = f" { config . online_store . host } : { config . online_store . port } " ,
113+ token = f" { config . online_store . username } : { config . online_store . password } "
114+ if config . online_store . username and config .online_store .password
115+ else "" ,
116+ )
117+ return self .client
118118
119- def _get_collection (self , config : RepoConfig , table : FeatureView ) -> Collection :
119+ def _get_collection (self , config : RepoConfig , table : FeatureView ) -> Dict [str , Any ]:
120+ self .client = self ._connect (config )
120121 collection_name = _table_id (config .project , table )
121122 if collection_name not in self ._collections :
122- self ._connect (config )
123-
124123 # Create a composite key by combining entity fields
125124 composite_key_name = (
126125 "_" .join ([field .name for field in table .entity_columns ]) + "_pk"
@@ -166,23 +165,38 @@ def _get_collection(self, config: RepoConfig, table: FeatureView) -> Collection:
166165 schema = CollectionSchema (
167166 fields = fields , description = "Feast feature view data"
168167 )
169- collection = Collection (name = collection_name , schema = schema , using = "feast" )
170- if not collection .has_index ():
171- index_params = {
172- "index_type" : config .online_store .index_type ,
173- "metric_type" : config .online_store .metric_type ,
174- "params" : {"nlist" : config .online_store .nlist },
175- }
176- for vector_field in schema .fields :
177- if vector_field .dtype in [
178- DataType .FLOAT_VECTOR ,
179- DataType .BINARY_VECTOR ,
180- ]:
181- collection .create_index (
182- field_name = vector_field .name , index_params = index_params
183- )
184- collection .load ()
185- self ._collections [collection_name ] = collection
168+ collection_exists = self .client .has_collection (
169+ collection_name = collection_name
170+ )
171+ if not collection_exists :
172+ self .client .create_collection (
173+ collection_name = collection_name ,
174+ dimension = config .online_store .embedding_dim ,
175+ schema = schema ,
176+ )
177+ index_params = self .client .prepare_index_params ()
178+ for vector_field in schema .fields :
179+ if vector_field .dtype in [
180+ DataType .FLOAT_VECTOR ,
181+ DataType .BINARY_VECTOR ,
182+ ]:
183+ index_params .add_index (
184+ collection_name = collection_name ,
185+ field_name = vector_field .name ,
186+ metric_type = config .online_store .metric_type ,
187+ index_type = config .online_store .index_type ,
188+ index_name = f"vector_index_{ vector_field .name } " ,
189+ params = {"nlist" : config .online_store .nlist },
190+ )
191+ self .client .create_index (
192+ collection_name = collection_name ,
193+ index_params = index_params ,
194+ )
195+ else :
196+ self .client .load_collection (collection_name )
197+ self ._collections [collection_name ] = self .client .describe_collection (
198+ collection_name
199+ )
186200 return self ._collections [collection_name ]
187201
188202 def online_write_batch (
@@ -199,6 +213,7 @@ def online_write_batch(
199213 ],
200214 progress : Optional [Callable [[int ], Any ]],
201215 ) -> None :
216+ self .client = self ._connect (config )
202217 collection = self ._get_collection (config , table )
203218 entity_batch_to_insert = []
204219 for entity_key , values_dict , timestamp , created_ts in data :
@@ -231,8 +246,9 @@ def online_write_batch(
231246 if progress :
232247 progress (1 )
233248
234- collection .insert (entity_batch_to_insert )
235- collection .flush ()
249+ self .client .insert (
250+ collection_name = collection ["collection_name" ], data = entity_batch_to_insert
251+ )
236252
237253 def online_read (
238254 self ,
@@ -252,14 +268,14 @@ def update(
252268 entities_to_keep : Sequence [Entity ],
253269 partial : bool ,
254270 ):
255- self ._connect (config )
271+ self .client = self . _connect (config )
256272 for table in tables_to_keep :
257- self ._get_collection (config , table )
273+ self ._collections = self ._get_collection (config , table )
274+
258275 for table in tables_to_delete :
259276 collection_name = _table_id (config .project , table )
260- collection = Collection (name = collection_name )
261- if collection .exists ():
262- collection .drop ()
277+ if self ._collections .get (collection_name , None ):
278+ self .client .drop_collection (collection_name )
263279 self ._collections .pop (collection_name , None )
264280
265281 def plan (
@@ -273,12 +289,12 @@ def teardown(
273289 tables : Sequence [FeatureView ],
274290 entities : Sequence [Entity ],
275291 ):
276- self ._connect (config )
292+ self .client = self . _connect (config )
277293 for table in tables :
278- collection = self . _get_collection (config , table )
279- if collection :
280- collection . drop ( )
281- self ._collections .pop (collection . name , None )
294+ collection_name = _table_id (config . project , table )
295+ if self . _collections . get ( collection_name , None ) :
296+ self . client . drop_collection ( collection_name )
297+ self ._collections .pop (collection_name , None )
282298
283299 def retrieve_online_documents (
284300 self ,
@@ -298,6 +314,8 @@ def retrieve_online_documents(
298314 Optional [ValueProto ],
299315 ]
300316 ]:
317+ self .client = self ._connect (config )
318+ collection_name = _table_id (config .project , table )
301319 collection = self ._get_collection (config , table )
302320 if not config .online_store .vector_enabled :
303321 raise ValueError ("Vector search is not enabled in the online store config" )
@@ -321,42 +339,45 @@ def retrieve_online_documents(
321339 + ["created_ts" , "event_ts" ]
322340 )
323341 assert all (
324- field
342+ field in [ f [ "name" ] for f in collection [ "fields" ]]
325343 for field in output_fields
326- if field in [f .name for f in collection .schema .fields ]
327- ), f"field(s) [{ [field for field in output_fields if field not in [f .name for f in collection .schema .fields ]]} '] not found in collection schema"
328-
344+ ), f"field(s) [{ [field for field in output_fields if field not in [f ['name' ] for f in collection ['fields' ]]]} ] not found in collection schema"
329345 # Note we choose the first vector field as the field to search on. Not ideal but it's something.
330346 ann_search_field = None
331- for field in collection . schema . fields :
347+ for field in collection [ " fields" ] :
332348 if (
333- field . dtype in [DataType .FLOAT_VECTOR , DataType .BINARY_VECTOR ]
334- and field . name in output_fields
349+ field [ "type" ] in [DataType .FLOAT_VECTOR , DataType .BINARY_VECTOR ]
350+ and field [ " name" ] in output_fields
335351 ):
336- ann_search_field = field . name
352+ ann_search_field = field [ " name" ]
337353 break
338354
339- results = collection .search (
355+ self .client .load_collection (collection_name )
356+ results = self .client .search (
357+ collection_name = collection_name ,
340358 data = [embedding ],
341359 anns_field = ann_search_field ,
342- param = search_params ,
360+ search_params = search_params ,
343361 limit = top_k ,
344362 output_fields = output_fields ,
345- consistency_level = "Strong" ,
346363 )
347364
348365 result_list = []
349366 for hits in results :
350367 for hit in hits :
351368 single_record = {}
352369 for field in output_fields :
353- single_record [field ] = hit .entity .get (field )
370+ single_record [field ] = hit .get ( " entity" , {}) .get (field , None )
354371
355- entity_key_bytes = bytes .fromhex (hit .entity .get (composite_key_name ))
356- embedding = hit .entity .get (ann_search_field )
372+ entity_key_bytes = bytes .fromhex (
373+ hit .get ("entity" , {}).get (composite_key_name , None )
374+ )
375+ embedding = hit .get ("entity" , {}).get (ann_search_field )
357376 serialized_embedding = _serialize_vector_to_float_list (embedding )
358- distance = hit .distance
359- event_ts = datetime .fromtimestamp (hit .entity .get ("event_ts" ) / 1e6 )
377+ distance = hit .get ("distance" , None )
378+ event_ts = datetime .fromtimestamp (
379+ hit .get ("entity" , {}).get ("event_ts" ) / 1e6
380+ )
360381 prepared_result = _build_retrieve_online_document_record (
361382 entity_key_bytes ,
362383 # This may have a bug
@@ -412,7 +433,7 @@ def __init__(self, host: str, port: int, name: str):
412433 self ._connect ()
413434
414435 def _connect (self ):
415- return connections . connect ( alias = "default" , host = self . host , port = str ( self . port ))
436+ raise NotImplementedError
416437
417438 def to_infra_object_proto (self ) -> InfraObjectProto :
418439 # Implement serialization if needed
0 commit comments