Skip to content

Commit e1f19f2

Browse files
committed
test: Added test for image search
Signed-off-by: ntkathole <nikhilkathole2683@gmail.com>
1 parent 93ee3e7 commit e1f19f2

File tree

8 files changed

+335
-60
lines changed

8 files changed

+335
-60
lines changed

sdk/python/feast/feature_store.py

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2267,38 +2267,41 @@ def retrieve_online_documents_v2(
22672267
OnlineResponse with similar documents and metadata
22682268
22692269
Examples:
2270-
Text search only:
2271-
>>> results = store.retrieve_online_documents_v2(
2272-
... features=["documents:embedding", "documents:title"],
2273-
... query=text_embedding,
2274-
... top_k=5
2275-
... )
2270+
Text search only::
22762271
2277-
Image search only:
2278-
>>> results = store.retrieve_online_documents_v2(
2279-
... features=["images:embedding", "images:filename"],
2280-
... query_image_bytes=image_bytes,
2281-
... top_k=5
2282-
... )
2272+
results = store.retrieve_online_documents_v2(
2273+
features=["documents:embedding", "documents:title"],
2274+
query=[0.1, 0.2, 0.3], # text embedding vector
2275+
top_k=5
2276+
)
22832277
2284-
Combined text + image search:
2285-
>>> results = store.retrieve_online_documents_v2(
2286-
... features=["documents:embedding", "documents:title"],
2287-
... query=text_embedding,
2288-
... query_image_bytes=image_bytes,
2289-
... combine_with_text=True,
2290-
... text_weight=0.3,
2291-
... image_weight=0.7,
2292-
... top_k=5
2293-
... )
2278+
Image search only::
2279+
2280+
results = store.retrieve_online_documents_v2(
2281+
features=["images:embedding", "images:filename"],
2282+
query_image_bytes=b"image_data", # image bytes
2283+
top_k=5
2284+
)
2285+
2286+
Combined text + image search::
2287+
2288+
results = store.retrieve_online_documents_v2(
2289+
features=["documents:embedding", "documents:title"],
2290+
query=[0.1, 0.2, 0.3], # text embedding vector
2291+
query_image_bytes=b"image_data", # image bytes
2292+
combine_with_text=True,
2293+
text_weight=0.3,
2294+
image_weight=0.7,
2295+
top_k=5
2296+
)
22942297
"""
2295-
if not query and not query_image_bytes and not query_string:
2298+
if query is None and not query_image_bytes and not query_string:
22962299
raise ValueError(
22972300
"Must provide either query (text embedding), "
22982301
"query_image_bytes, or query_string"
22992302
)
23002303

2301-
if combine_with_text and not (query and query_image_bytes):
2304+
if combine_with_text and not (query is not None and query_image_bytes):
23022305
raise ValueError(
23032306
"combine_with_text=True requires both query (text embedding) "
23042307
"and query_image_bytes"
@@ -2323,7 +2326,11 @@ def retrieve_online_documents_v2(
23232326

23242327
text_embedding = query
23252328

2326-
if combine_with_text and text_embedding and image_embedding:
2329+
if (
2330+
combine_with_text
2331+
and text_embedding is not None
2332+
and image_embedding is not None
2333+
):
23272334
# Combine text and image embeddings
23282335
from feast.image_utils import combine_embeddings
23292336

@@ -2334,9 +2341,9 @@ def retrieve_online_documents_v2(
23342341
text_weight=text_weight,
23352342
image_weight=image_weight,
23362343
)
2337-
elif image_embedding:
2344+
elif image_embedding is not None:
23382345
final_query = image_embedding
2339-
elif text_embedding:
2346+
elif text_embedding is not None:
23402347
final_query = text_embedding
23412348
else:
23422349
final_query = None

sdk/python/feast/image_utils.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,18 @@ class ImageFeatureExtractor:
4949
from images using pre-trained vision models like ResNet, ViT, etc.
5050
5151
Examples:
52-
Basic usage:
53-
>>> extractor = ImageFeatureExtractor()
54-
>>> with open("image.jpg", "rb") as f:
55-
... image_bytes = f.read()
56-
>>> embedding = extractor.extract_embedding(image_bytes)
57-
58-
Using different models:
59-
>>> # Use ViT model for better performance
60-
>>> extractor = ImageFeatureExtractor("vit_base_patch16_224")
61-
>>> embedding = extractor.extract_embedding(image_bytes)
52+
Basic usage::
53+
54+
extractor = ImageFeatureExtractor()
55+
with open("image.jpg", "rb") as f:
56+
image_bytes = f.read()
57+
embedding = extractor.extract_embedding(image_bytes)
58+
59+
Using different models::
60+
61+
# Use ViT model for better performance
62+
extractor = ImageFeatureExtractor("vit_base_patch16_224")
63+
embedding = extractor.extract_embedding(image_bytes)
6264
"""
6365

6466
def __init__(self, model_name: str = "resnet34"):
@@ -189,18 +191,20 @@ def combine_embeddings(
189191
ValueError: If strategy is invalid or weights don't sum to 1.0
190192
191193
Examples:
192-
Weighted combination (emphasize image):
193-
>>> combined = combine_embeddings(
194-
... text_emb, image_emb,
195-
... strategy="weighted_sum",
196-
... text_weight=0.3, image_weight=0.7
197-
... )
198-
199-
Concatenation for full information:
200-
>>> combined = combine_embeddings(
201-
... text_emb, image_emb,
202-
... strategy="concatenate"
203-
... )
194+
Weighted combination (emphasize image)::
195+
196+
combined = combine_embeddings(
197+
[0.1, 0.2], [0.8, 0.9], # text_emb, image_emb
198+
strategy="weighted_sum",
199+
text_weight=0.3, image_weight=0.7
200+
)
201+
202+
Concatenation for full information::
203+
204+
combined = combine_embeddings(
205+
[0.1, 0.2], [0.8, 0.9], # text_emb, image_emb
206+
strategy="concatenate"
207+
)
204208
"""
205209
if strategy == "weighted_sum":
206210
if abs(text_weight + image_weight - 1.0) > 1e-6:

sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import base64
12
from datetime import datetime
23
from pathlib import Path
34
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union
@@ -247,7 +248,7 @@ def online_write_batch(
247248
) -> None:
248249
self.client = self._connect(config)
249250
collection = self._get_or_create_collection(config, table)
250-
vector_cols = [f.name for f in table.features if f.vector_index]
251+
vector_cols = [f.name for f in table.schema if f.vector_index]
251252
entity_batch_to_insert = []
252253
unique_entities: dict[str, dict[str, Any]] = {}
253254
required_fields = {field["name"] for field in collection["fields"]}
@@ -503,6 +504,14 @@ def retrieve_online_documents_v2(
503504
entity_name_feast_primitive_type_map = {
504505
k.name: k.dtype for k in table.entity_columns
505506
}
507+
# Also include feature columns for proper type mapping
508+
feature_name_feast_primitive_type_map = {
509+
k.name: k.dtype for k in table.features
510+
}
511+
field_name_feast_primitive_type_map = {
512+
**entity_name_feast_primitive_type_map,
513+
**feature_name_feast_primitive_type_map,
514+
}
506515
self.client = self._connect(config)
507516
collection_name = _table_id(config.project, table)
508517
collection = self._get_or_create_collection(config, table)
@@ -662,14 +671,25 @@ def retrieve_online_documents_v2(
662671
embedding
663672
)
664673
res[ann_search_field] = serialized_embedding
665-
elif entity_name_feast_primitive_type_map.get(
666-
field, PrimitiveFeastType.INVALID
667-
) in [
668-
PrimitiveFeastType.STRING,
669-
PrimitiveFeastType.BYTES,
670-
]:
674+
elif (
675+
field_name_feast_primitive_type_map.get(
676+
field, PrimitiveFeastType.INVALID
677+
)
678+
== PrimitiveFeastType.STRING
679+
):
671680
res[field] = ValueProto(string_val=str(field_value))
672-
elif entity_name_feast_primitive_type_map.get(
681+
elif (
682+
field_name_feast_primitive_type_map.get(
683+
field, PrimitiveFeastType.INVALID
684+
)
685+
== PrimitiveFeastType.BYTES
686+
):
687+
try:
688+
decoded_bytes = base64.b64decode(field_value)
689+
res[field] = ValueProto(bytes_val=decoded_bytes)
690+
except Exception:
691+
res[field] = ValueProto(string_val=str(field_value))
692+
elif field_name_feast_primitive_type_map.get(
673693
field, PrimitiveFeastType.INVALID
674694
) in [
675695
PrimitiveFeastType.INT64,
@@ -732,9 +752,13 @@ def _extract_proto_values_to_dict(
732752
else:
733753
if (
734754
serialize_to_string
735-
and proto_val_type not in ["string_val"] + numeric_types
755+
and proto_val_type
756+
not in ["string_val", "bytes_val"] + numeric_types
736757
):
737758
vector_values = feature_values.SerializeToString().decode()
759+
elif proto_val_type == "bytes_val":
760+
byte_data = getattr(feature_values, proto_val_type)
761+
vector_values = base64.b64encode(byte_data).decode("utf-8")
738762
else:
739763
if not isinstance(feature_values, str):
740764
vector_values = str(

sdk/python/tests/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from tests.data.data_creator import (
3535
create_basic_driver_dataset, # noqa: E402
3636
create_document_dataset,
37+
create_image_dataset,
3738
)
3839
from tests.integration.feature_repos.integration_test_repo_config import ( # noqa: E402
3940
IntegrationTestRepoConfig,
@@ -446,6 +447,16 @@ def fake_document_data(environment: Environment) -> Tuple[pd.DataFrame, DataSour
446447
return df, data_source
447448

448449

450+
@pytest.fixture
451+
def fake_image_data(environment: Environment) -> Tuple[pd.DataFrame, DataSource]:
452+
df = create_image_dataset()
453+
data_source = environment.data_source_creator.create_data_source(
454+
df,
455+
environment.feature_store.project,
456+
)
457+
return df, data_source
458+
459+
449460
@pytest.fixture
450461
def temp_dir():
451462
with tempfile.TemporaryDirectory() as temp_dir:

sdk/python/tests/data/data_creator.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import io
12
from datetime import datetime, timedelta, timezone
23
from typing import Dict, List, Optional
34
from zoneinfo import ZoneInfo
@@ -100,3 +101,52 @@ def create_document_dataset() -> pd.DataFrame:
100101
],
101102
}
102103
return pd.DataFrame(data)
104+
105+
106+
def create_image_dataset() -> pd.DataFrame:
107+
"""Create a dataset with image data for testing image search functionality."""
108+
109+
def create_test_image_bytes(color=(255, 0, 0), size=(32, 32)):
110+
"""Create synthetic image bytes for testing."""
111+
try:
112+
from PIL import Image
113+
114+
img = Image.new("RGB", size, color=color)
115+
img_bytes = io.BytesIO()
116+
img.save(img_bytes, format="JPEG")
117+
return img_bytes.getvalue()
118+
except ImportError:
119+
# Return dummy bytes if PIL not available
120+
return b"\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x01\x00H\x00H\x00\x00\xff\xdb\x00C\x00\x08\x06\x06\x07\x06\x05\x08\x07\x07\x07\t\t\x08\n\x0c\x14\r\x0c\x0b\x0b\x0c\x19\x12\x13\x0f\x14\x1d\x1a\x1f\x1e\x1d\x1a\x1c\x1c $.' \",#\x1c\x1c(7),01444\x1f'9=82<.342\xff\xc0\x00\x11\x08\x00 \x00 \x01\x01\x11\x00\x02\x11\x01\x03\x11\x01\xff\xc4\x00\x14\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x08\xff\xc4\x00\x14\x10\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xda\x00\x0c\x03\x01\x00\x02\x11\x03\x11\x00\x3f\x00\xaa\xff\xd9"
121+
122+
data = {
123+
"item_id": [1, 2, 3],
124+
"image_filename": ["red_image.jpg", "green_image.jpg", "blue_image.jpg"],
125+
"image_bytes": [
126+
create_test_image_bytes((255, 0, 0)), # Red
127+
create_test_image_bytes((0, 255, 0)), # Green
128+
create_test_image_bytes((0, 0, 255)), # Blue
129+
],
130+
"image_embedding": [
131+
[0.9, 0.1], # Red-ish embedding
132+
[0.2, 0.8], # Green-ish embedding
133+
[0.1, 0.9], # Blue-ish embedding
134+
],
135+
"category": ["primary", "primary", "primary"],
136+
"description": [
137+
"A red colored image",
138+
"A green colored image",
139+
"A blue colored image",
140+
],
141+
"ts": [
142+
pd.Timestamp(_utc_now()).round("ms"),
143+
pd.Timestamp(_utc_now()).round("ms"),
144+
pd.Timestamp(_utc_now()).round("ms"),
145+
],
146+
"created_ts": [
147+
pd.Timestamp(_utc_now()).round("ms"),
148+
pd.Timestamp(_utc_now()).round("ms"),
149+
pd.Timestamp(_utc_now()).round("ms"),
150+
],
151+
}
152+
return pd.DataFrame(data)

sdk/python/tests/integration/feature_repos/universal/online_store/milvus.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,25 @@
88
class MilvusOnlineStoreCreator(OnlineStoreCreator):
99
def __init__(self, project_name: str, **kwargs):
1010
super().__init__(project_name)
11+
self.db_path = "online_store.db"
1112

1213
def create_online_store(self) -> Dict[str, Any]:
1314
return {
1415
"type": "milvus",
15-
"path": "online_store.db",
16+
"path": self.db_path,
1617
"index_type": "IVF_FLAT",
1718
"metric_type": "L2",
1819
"embedding_dim": 2,
1920
"vector_enabled": True,
2021
"nlist": 1,
2122
}
23+
24+
def teardown(self):
25+
"""Clean up Milvus online store resources."""
26+
import os
27+
28+
if os.path.exists(self.db_path):
29+
try:
30+
os.remove(self.db_path)
31+
except Exception:
32+
pass

0 commit comments

Comments
 (0)