22import torch
33import torch .nn .functional as F
44from feast import FeatureStore
5- from pymilvus import MilvusClient
5+ from pymilvus import MilvusClient , DataType , FieldSchema
66from transformers import AutoTokenizer , AutoModel
77from example_repo import city_embeddings_feature_view , item
88TOKENIZER = "sentence-transformers/all-MiniLM-L6-v2"
@@ -36,12 +36,43 @@ def run_model(sentences, tokenizer, model):
3636def run_demo ():
3737 store = FeatureStore (repo_path = "." )
3838 df = pd .read_parquet ("./data/city_wikipedia_summaries_with_embeddings.parquet" )
39- store .apply ([city_embeddings_feature_view , item ])
40- store .write_to_online_store_async ("city_embeddings" , df )
39+ embedding_length = len (df ['vector' ][0 ])
40+ print (f'embedding length = { embedding_length } ' )
41+
42+ print ('\n data=' )
43+ print (df .head ().T )
4144
42- client = MilvusClient ( alias = "feast" , host = "localhost" , port = "19530" , token = "username:password" )
43- print ( client . list_collections () )
45+ store . apply ([ city_embeddings_feature_view , item ] )
46+ store . write_to_online_store ( "city_embeddings" , df )
4447
48+ client = MilvusClient (uir = "http://localhost:19530" , token = "username:password" )
49+ fields = [
50+ FieldSchema (name = "id" , dtype = DataType .INT64 , is_primary = True ),
51+ FieldSchema (name = 'state' , dtype = DataType .STRING , description = "State" ),
52+ FieldSchema (name = 'wiki_summary' , dtype = DataType .STRING , description = "State" ),
53+ FieldSchema (name = 'sentence_chunks' , dtype = DataType .STRING , description = "Sentence Chunks" ),
54+ FieldSchema (name = "item_id" , dtype = DataType .INT64 , default_value = 0 , description = "Item" ),
55+ FieldSchema (name = "vector" , dtype = DataType .FLOAT_VECTOR , dim = embedding_length , description = "vector" )
56+ ]
57+ cols = [f .name for f in fields ]
58+ client .insert (
59+ collection_name = "demo_collection" ,
60+ data = df [cols ].to_dict (orient = "records" ),
61+ schema = fields ,
62+ )
63+ print ('\n ' )
64+ print ('collections' , client .list_collections ())
65+ print ('query results =' , client .query (
66+ collection_name = "rag_city_embeddings" ,
67+ filter = "item_id == 0" ,
68+ # output_fields=['city_embeddings', 'item_id', 'city_name'],
69+ ))
70+ print ('query results2 =' , client .query (
71+ collection_name = "rag_city_embeddings" ,
72+ filter = "item_id >= 0" ,
73+ output_fields = ["count(*)" ]
74+ # output_fields=['city_embeddings', 'item_id', 'city_name'],
75+ ))
4576 question = "the most populous city in the U.S. state of Texas?"
4677 tokenizer = AutoTokenizer .from_pretrained (TOKENIZER )
4778 model = AutoModel .from_pretrained (MODEL )
@@ -50,7 +81,8 @@ def run_demo():
5081
5182 # Retrieve top k documents
5283 features = store .retrieve_online_documents (
53- feature = "city_embeddings:Embeddings" ,
84+ feature = None ,
85+ features = ["city_embeddings:vector" , "city_embeddings:item_id" , "city_embeddings:state" ],
5486 query = query ,
5587 top_k = 3
5688 )
0 commit comments