Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import base64
import logging
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -42,6 +43,8 @@
to_naive_utc,
)

logger = logging.getLogger(__name__)

PROTO_TO_MILVUS_TYPE_MAPPING: Dict[ValueType, DataType] = {
PROTO_VALUE_TO_VALUE_TYPE_MAP["bytes_val"]: DataType.VARCHAR,
ValueType.IMAGE_BYTES: DataType.VARCHAR,
Expand Down Expand Up @@ -140,11 +143,13 @@ def _connect(self, config: RepoConfig) -> MilvusClient:
if not self.client:
if config.provider == "local" and config.online_store.path:
db_path = self._get_db_path(config)
print(f"Connecting to Milvus in local mode using {db_path}")
logger.info("Connecting to Milvus in local mode using %s", db_path)
self.client = MilvusClient(db_path)
else:
print(
f"Connecting to Milvus remotely at {config.online_store.host}:{config.online_store.port}"
logger.info(
"Connecting to Milvus remotely at %s:%s",
config.online_store.host,
config.online_store.port,
)
self.client = MilvusClient(
uri=f"{config.online_store.host}:{config.online_store.port}",
Expand Down Expand Up @@ -339,7 +344,6 @@ def online_read(
table: FeatureView,
entity_keys: List[EntityKeyProto],
requested_features: Optional[List[str]] = None,
full_feature_names: bool = False,
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
self.client = self._connect(config)
collection_name = _table_id(config.project, table)
Expand Down Expand Up @@ -487,7 +491,7 @@ def update(
):
self.client = self._connect(config)
for table in tables_to_keep:
self._collections = self._get_or_create_collection(config, table)
self._get_or_create_collection(config, table)

for table in tables_to_delete:
collection_name = _table_id(config.project, table)
Expand All @@ -498,7 +502,7 @@ def update(
def plan(
self, config: RepoConfig, desired_registry_proto: RegistryProto
) -> List[InfraObject]:
raise NotImplementedError
return []

def teardown(
self,
Expand Down Expand Up @@ -686,9 +690,8 @@ def retrieve_online_documents_v2(
for hit in hits:
res = {}
res_ts = None
entity_key_bytes = bytes.fromhex(
hit.get("entity", {}).get(composite_key_name, None)
)
raw_key = hit.get("entity", {}).get(composite_key_name)
entity_key_bytes = bytes.fromhex(raw_key) if raw_key else None
entity_key_proto = (
deserialize_entity_key(entity_key_bytes)
if entity_key_bytes
Expand Down
178 changes: 178 additions & 0 deletions sdk/python/tests/unit/online_store/test_online_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1714,3 +1714,181 @@ def test_milvus_keyword_search() -> None:
assert len(result_hybrid["content"]) > 0
assert any("Feast" in content for content in result_hybrid["content"])
assert len(result_hybrid["vector"]) > 0


def test_milvus_update_preserves_collection_cache() -> None:
"""
Regression test: update() used to overwrite self._collections with the
describe_collection() dict of the last processed table, replacing the
dict-of-dicts cache with a single flat dict. After the fix, each call
to _get_or_create_collection() updates the keyed entry in-place and the
cache remains a proper mapping from collection name to collection info.
"""
from datetime import timedelta

from feast import Entity, FeatureView, Field, FileSource
from feast.types import Array, Float32, Int64, String

runner = CliRunner()
with runner.local_repo(
example_repo_py=get_example_repo("example_rag_feature_repo.py"),
offline_store="file",
online_store="milvus",
apply=False,
teardown=False,
) as store:
source = FileSource(
path="data/dummy.parquet",
timestamp_field="event_timestamp",
created_timestamp_column="created_timestamp",
)
entity_a = Entity(name="id_a", join_keys=["id_a"], value_type=ValueType.INT64)
entity_b = Entity(name="id_b", join_keys=["id_b"], value_type=ValueType.INT64)

fv_a = FeatureView(
name="fv_a",
entities=[entity_a],
schema=[
Field(name="id_a", dtype=Int64),
Field(
name="vec_a",
dtype=Array(Float32),
vector_index=True,
vector_search_metric="COSINE",
),
Field(name="text_a", dtype=String),
],
source=source,
ttl=timedelta(hours=1),
)
fv_b = FeatureView(
name="fv_b",
entities=[entity_b],
schema=[
Field(name="id_b", dtype=Int64),
Field(
name="vec_b",
dtype=Array(Float32),
vector_index=True,
vector_search_metric="COSINE",
),
Field(name="text_b", dtype=String),
],
source=source,
ttl=timedelta(hours=1),
)

store.apply([source, entity_a, entity_b, fv_a, fv_b])

online_store = store._provider._online_store
# After applying two feature views, the cache must be a proper dict
# mapping collection names to collection-info dicts, not a flat dict.
assert isinstance(online_store._collections, dict), (
"_collections should be a dict"
)
collection_name_a = f"{store.config.project}_fv_a"
collection_name_b = f"{store.config.project}_fv_b"
assert collection_name_a in online_store._collections, (
f"Cache missing entry for {collection_name_a}"
)
assert collection_name_b in online_store._collections, (
f"Cache missing entry for {collection_name_b} — "
"update() likely overwrote _collections with a single collection dict"
)
# Each cached value must be a collection-info dict (has a 'fields' key),
# not itself keyed by collection name.
for name in [collection_name_a, collection_name_b]:
assert "fields" in online_store._collections[name], (
f"Cache entry for {name} looks like a corrupted flat dict"
)


def test_milvus_plan_returns_empty_list() -> None:
"""
Regression test: plan() used to raise NotImplementedError, causing
`feast plan` to crash for any project using the Milvus online store.
It should return [] matching the OnlineStore base class default.
"""
from feast.infra.online_stores.milvus_online_store.milvus import MilvusOnlineStore

store = MilvusOnlineStore()
result = store.plan(config=None, desired_registry_proto=None) # type: ignore[arg-type]
assert result == [], f"plan() should return [] but returned {result!r}"


def test_milvus_retrieve_online_documents_v2_missing_entity_key() -> None:
"""
Regression test: retrieve_online_documents_v2() passed the raw
hit.get("entity", {}).get(composite_key_name, None) directly to
bytes.fromhex(), raising TypeError when the key was absent.
After the fix, a missing composite key produces a None entity_key_proto
instead of crashing.
"""
from datetime import timedelta
from unittest.mock import patch

from feast import Entity, FeatureView, Field, FileSource
from feast.types import Array, Float32, Int64, String

runner = CliRunner()
with runner.local_repo(
example_repo_py=get_example_repo("example_rag_feature_repo.py"),
offline_store="file",
online_store="milvus",
apply=False,
teardown=False,
) as store:
source = FileSource(
path="data/dummy.parquet",
timestamp_field="event_timestamp",
created_timestamp_column="created_timestamp",
)
entity = Entity(name="doc_id", join_keys=["doc_id"], value_type=ValueType.INT64)
fv = FeatureView(
name="docs",
entities=[entity],
schema=[
Field(name="doc_id", dtype=Int64),
Field(
name="vec",
dtype=Array(Float32),
vector_index=True,
vector_search_metric="COSINE",
),
Field(name="text", dtype=String),
],
source=source,
ttl=timedelta(hours=1),
)
store.apply([source, entity, fv])

online_store = store._provider._online_store
fv_obj = store.get_feature_view("docs")
# Simulate a search hit that is missing the composite primary key.
fake_hit = {
"entity": {
"event_ts": int(_utc_now().timestamp() * 1e6),
"created_ts": int(_utc_now().timestamp() * 1e6),
"text": "hello",
},
"distance": 0.9,
}

mock_results = [[fake_hit]]
with patch.object(online_store.client, "search", return_value=mock_results):
with patch.object(
online_store.client, "load_collection", return_value=None
):
# Before the fix this raised TypeError: fromhex argument must be str, not None
result = online_store.retrieve_online_documents_v2(
config=store.config,
table=fv_obj,
requested_features=["text"],
embedding=[0.1] * 10,
top_k=1,
)
assert len(result) == 1
_ts, entity_key_proto, _features = result[0]
assert entity_key_proto is None, (
"entity_key_proto should be None when the composite key is absent from the hit"
)
Loading