Skip to content

Commit 3731e17

Browse files
updating workflow
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
1 parent e482ba6 commit 3731e17

File tree

1 file changed

+38
-6
lines changed

1 file changed

+38
-6
lines changed

examples/rag/feature_repo/test_workflow.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33
import torch.nn.functional as F
44
from feast import FeatureStore
5-
from pymilvus import MilvusClient
5+
from pymilvus import MilvusClient, DataType, FieldSchema
66
from transformers import AutoTokenizer, AutoModel
77
from example_repo import city_embeddings_feature_view, item
88
TOKENIZER = "sentence-transformers/all-MiniLM-L6-v2"
@@ -36,12 +36,43 @@ def run_model(sentences, tokenizer, model):
3636
def run_demo():
3737
store = FeatureStore(repo_path=".")
3838
df = pd.read_parquet("./data/city_wikipedia_summaries_with_embeddings.parquet")
39-
store.apply([city_embeddings_feature_view, item])
40-
store.write_to_online_store_async("city_embeddings", df)
39+
embedding_length = len(df['vector'][0])
40+
print(f'embedding length = {embedding_length}')
41+
42+
print('\ndata=')
43+
print(df.head().T)
4144

42-
client = MilvusClient(alias="feast", host="localhost", port="19530", token="username:password")
43-
print(client.list_collections())
45+
store.apply([city_embeddings_feature_view, item])
46+
store.write_to_online_store("city_embeddings", df)
4447

48+
client = MilvusClient(uir="http://localhost:19530", token="username:password")
49+
fields = [
50+
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True),
51+
FieldSchema(name='state', dtype=DataType.STRING, description="State"),
52+
FieldSchema(name='wiki_summary', dtype=DataType.STRING, description="State"),
53+
FieldSchema(name='sentence_chunks', dtype=DataType.STRING, description="Sentence Chunks"),
54+
FieldSchema(name="item_id", dtype=DataType.INT64, default_value=0, description="Item"),
55+
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=embedding_length, description="vector")
56+
]
57+
cols = [f.name for f in fields]
58+
client.insert(
59+
collection_name="demo_collection",
60+
data=df[cols].to_dict(orient="records"),
61+
schema=fields,
62+
)
63+
print('\n')
64+
print('collections', client.list_collections())
65+
print('query results =', client.query(
66+
collection_name="rag_city_embeddings",
67+
filter="item_id == 0",
68+
# output_fields=['city_embeddings', 'item_id', 'city_name'],
69+
))
70+
print('query results2 =', client.query(
71+
collection_name="rag_city_embeddings",
72+
filter="item_id >= 0",
73+
output_fields=["count(*)"]
74+
# output_fields=['city_embeddings', 'item_id', 'city_name'],
75+
))
4576
question = "the most populous city in the U.S. state of Texas?"
4677
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
4778
model = AutoModel.from_pretrained(MODEL)
@@ -50,7 +81,8 @@ def run_demo():
5081

5182
# Retrieve top k documents
5283
features = store.retrieve_online_documents(
53-
feature="city_embeddings:Embeddings",
84+
feature=None,
85+
features=["city_embeddings:vector", "city_embeddings:item_id", "city_embeddings:state"],
5486
query=query,
5587
top_k=3
5688
)

0 commit comments

Comments
 (0)