|
11 | 11 | from pandas.testing import assert_frame_equal |
12 | 12 |
|
13 | 13 | from feast import FeatureStore, RepoConfig |
| 14 | +from feast.infra.online_stores.contrib.milvus import MilvusOnlineStoreConfig |
14 | 15 | from feast.errors import FeatureViewNotFoundException |
15 | 16 | from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto |
16 | 17 | from feast.protos.feast.types.Value_pb2 import FloatList as FloatListProto |
17 | 18 | from feast.protos.feast.types.Value_pb2 import Value as ValueProto |
18 | 19 | from feast.repo_config import RegistryConfig |
19 | 20 | from feast.utils import _utc_now |
| 21 | +from feast.infra.provider import Provider |
20 | 22 | from tests.integration.feature_repos.universal.feature_views import TAGS |
21 | 23 | from tests.utils.cli_repo_creator import CliRunner, get_example_repo |
22 | 24 |
|
@@ -561,3 +563,104 @@ def test_sqlite_vec_import() -> None: |
561 | 563 | """).fetchall() |
562 | 564 | result = [(rowid, round(distance, 2)) for rowid, distance in result] |
563 | 565 | assert result == [(2, 2.39), (1, 2.39)] |
| 566 | + |
| 567 | +def test_milvus_get_online_documents() -> None: |
| 568 | + """ |
| 569 | + Test retrieving documents from the online store in local mode. |
| 570 | + """ |
| 571 | +def test_milvus_get_online_documents() -> None: |
| 572 | + """ |
| 573 | + Test retrieving documents from the online store in local mode using Milvus. |
| 574 | + """ |
| 575 | + n = 10 # number of samples - note: we'll actually double it |
| 576 | + vector_length = 8 |
| 577 | + runner = CliRunner() |
| 578 | + with runner.local_repo( |
| 579 | + get_example_repo("example_feature_repo_1.py"), "file" |
| 580 | + ) as store: |
| 581 | + # Configure the online store to use Milvus |
| 582 | + new_config = RepoConfig( |
| 583 | + project=store.config.project, |
| 584 | + registry=store.config.registry, |
| 585 | + provider=store.config.provider, |
| 586 | + online_store=MilvusOnlineStoreConfig( |
| 587 | + type="milvus", |
| 588 | + host="localhost", |
| 589 | + port=19530, |
| 590 | + index_type="IVF_FLAT", |
| 591 | + metric_type="L2", |
| 592 | + embedding_dim=vector_length, |
| 593 | + vector_enabled=True, |
| 594 | + ), |
| 595 | + entity_key_serialization_version=store.config.entity_key_serialization_version, |
| 596 | + ) |
| 597 | + store = FeatureStore(config=new_config, repo_path=store.repo_path) |
| 598 | + # Apply the new configuration |
| 599 | + store.apply([]) |
| 600 | + |
| 601 | + # Write some data to the feature view |
| 602 | + document_embeddings_fv = store.get_feature_view(name="document_embeddings") |
| 603 | + |
| 604 | + provider: Provider = store._get_provider() |
| 605 | + |
| 606 | + item_keys = [ |
| 607 | + EntityKeyProto( |
| 608 | + join_keys=["item_id"], entity_values=[ValueProto.Value(int64_val=i)] |
| 609 | + ) |
| 610 | + for i in range(n) |
| 611 | + ] |
| 612 | + data = [] |
| 613 | + for item_key in item_keys: |
| 614 | + embedding_vector = np.random.random(vector_length).tolist() |
| 615 | + data.append( |
| 616 | + ( |
| 617 | + item_key, |
| 618 | + { |
| 619 | + "Embeddings": ValueProto.Value( |
| 620 | + float_list_val=FloatListProto(val=embedding_vector) |
| 621 | + ) |
| 622 | + }, |
| 623 | + _utc_now(), |
| 624 | + _utc_now(), |
| 625 | + ) |
| 626 | + ) |
| 627 | + |
| 628 | + provider.online_write_batch( |
| 629 | + config=store.config, |
| 630 | + table=document_embeddings_fv, |
| 631 | + data=data, |
| 632 | + progress=None, |
| 633 | + ) |
| 634 | + |
| 635 | + documents_df = pd.DataFrame( |
| 636 | + { |
| 637 | + "item_id": [i for i in range(n)], |
| 638 | + "Embeddings": [ |
| 639 | + np.random.random(vector_length).tolist() for _ in range(n) |
| 640 | + ], |
| 641 | + "event_timestamp": [_utc_now() for _ in range(n)], |
| 642 | + } |
| 643 | + ) |
| 644 | + |
| 645 | + store.write_to_online_store( |
| 646 | + feature_view_name="document_embeddings", |
| 647 | + df=documents_df, |
| 648 | + ) |
| 649 | + |
| 650 | + # For Milvus, get the collection and check the number of entities |
| 651 | + collection = provider._online_store._get_collection( |
| 652 | + store.config, document_embeddings_fv |
| 653 | + ) |
| 654 | + record_count = collection.num_entities |
| 655 | + assert record_count == len(data) + documents_df.shape[0] |
| 656 | + |
| 657 | + query_embedding = np.random.random(vector_length).tolist() |
| 658 | + |
| 659 | + # Retrieve online documents using Milvus |
| 660 | + result = store.retrieve_online_documents( |
| 661 | + feature="document_embeddings:Embeddings", query=query_embedding, top_k=3 |
| 662 | + ).to_dict() |
| 663 | + |
| 664 | + assert "Embeddings" in result |
| 665 | + assert "distance" in result |
| 666 | + assert len(result["distance"]) == 3 |
0 commit comments