Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
test: Added test for image search
Signed-off-by: ntkathole <nikhilkathole2683@gmail.com>
  • Loading branch information
ntkathole committed Sep 7, 2025
commit 8ba87f4cef45117b390c4593753f7c6689eb41e2
61 changes: 34 additions & 27 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2267,38 +2267,41 @@ def retrieve_online_documents_v2(
OnlineResponse with similar documents and metadata

Examples:
Text search only:
>>> results = store.retrieve_online_documents_v2(
... features=["documents:embedding", "documents:title"],
... query=text_embedding,
... top_k=5
... )
Text search only::

Image search only:
>>> results = store.retrieve_online_documents_v2(
... features=["images:embedding", "images:filename"],
... query_image_bytes=image_bytes,
... top_k=5
... )
results = store.retrieve_online_documents_v2(
features=["documents:embedding", "documents:title"],
query=[0.1, 0.2, 0.3], # text embedding vector
top_k=5
)

Combined text + image search:
>>> results = store.retrieve_online_documents_v2(
... features=["documents:embedding", "documents:title"],
... query=text_embedding,
... query_image_bytes=image_bytes,
... combine_with_text=True,
... text_weight=0.3,
... image_weight=0.7,
... top_k=5
... )
Image search only::

results = store.retrieve_online_documents_v2(
features=["images:embedding", "images:filename"],
query_image_bytes=b"image_data", # image bytes
top_k=5
)

Combined text + image search::

results = store.retrieve_online_documents_v2(
features=["documents:embedding", "documents:title"],
query=[0.1, 0.2, 0.3], # text embedding vector
query_image_bytes=b"image_data", # image bytes
combine_with_text=True,
text_weight=0.3,
image_weight=0.7,
top_k=5
)
"""
if not query and not query_image_bytes and not query_string:
if query is None and not query_image_bytes and not query_string:
raise ValueError(
"Must provide either query (text embedding), "
"query_image_bytes, or query_string"
)

if combine_with_text and not (query and query_image_bytes):
if combine_with_text and not (query is not None and query_image_bytes):
raise ValueError(
"combine_with_text=True requires both query (text embedding) "
"and query_image_bytes"
Expand All @@ -2323,7 +2326,11 @@ def retrieve_online_documents_v2(

text_embedding = query

if combine_with_text and text_embedding and image_embedding:
if (
combine_with_text
and text_embedding is not None
and image_embedding is not None
):
# Combine text and image embeddings
from feast.image_utils import combine_embeddings

Expand All @@ -2334,9 +2341,9 @@ def retrieve_online_documents_v2(
text_weight=text_weight,
image_weight=image_weight,
)
elif image_embedding:
elif image_embedding is not None:
final_query = image_embedding
elif text_embedding:
elif text_embedding is not None:
final_query = text_embedding
else:
final_query = None
Expand Down
48 changes: 26 additions & 22 deletions sdk/python/feast/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,18 @@ class ImageFeatureExtractor:
from images using pre-trained vision models like ResNet, ViT, etc.

Examples:
Basic usage:
>>> extractor = ImageFeatureExtractor()
>>> with open("image.jpg", "rb") as f:
... image_bytes = f.read()
>>> embedding = extractor.extract_embedding(image_bytes)

Using different models:
>>> # Use ViT model for better performance
>>> extractor = ImageFeatureExtractor("vit_base_patch16_224")
>>> embedding = extractor.extract_embedding(image_bytes)
Basic usage::

extractor = ImageFeatureExtractor()
with open("image.jpg", "rb") as f:
image_bytes = f.read()
embedding = extractor.extract_embedding(image_bytes)

Using different models::

# Use ViT model for better performance
extractor = ImageFeatureExtractor("vit_base_patch16_224")
embedding = extractor.extract_embedding(image_bytes)
"""

def __init__(self, model_name: str = "resnet34"):
Expand Down Expand Up @@ -189,18 +191,20 @@ def combine_embeddings(
ValueError: If strategy is invalid or weights don't sum to 1.0

Examples:
Weighted combination (emphasize image):
>>> combined = combine_embeddings(
... text_emb, image_emb,
... strategy="weighted_sum",
... text_weight=0.3, image_weight=0.7
... )

Concatenation for full information:
>>> combined = combine_embeddings(
... text_emb, image_emb,
... strategy="concatenate"
... )
Weighted combination (emphasize image)::

combined = combine_embeddings(
[0.1, 0.2], [0.8, 0.9], # text_emb, image_emb
strategy="weighted_sum",
text_weight=0.3, image_weight=0.7
)

Concatenation for full information::

combined = combine_embeddings(
[0.1, 0.2], [0.8, 0.9], # text_emb, image_emb
strategy="concatenate"
)
"""
if strategy == "weighted_sum":
if abs(text_weight + image_weight - 1.0) > 1e-6:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -43,6 +44,7 @@

PROTO_TO_MILVUS_TYPE_MAPPING: Dict[ValueType, DataType] = {
PROTO_VALUE_TO_VALUE_TYPE_MAP["bytes_val"]: DataType.VARCHAR,
ValueType.IMAGE_BYTES: DataType.VARCHAR, # IMAGE_BYTES serializes as bytes_val
Comment thread
ntkathole marked this conversation as resolved.
Outdated
PROTO_VALUE_TO_VALUE_TYPE_MAP["bool_val"]: DataType.BOOL,
PROTO_VALUE_TO_VALUE_TYPE_MAP["string_val"]: DataType.VARCHAR,
PROTO_VALUE_TO_VALUE_TYPE_MAP["float_val"]: DataType.FLOAT,
Expand Down Expand Up @@ -247,7 +249,7 @@ def online_write_batch(
) -> None:
self.client = self._connect(config)
collection = self._get_or_create_collection(config, table)
vector_cols = [f.name for f in table.features if f.vector_index]
vector_cols = [f.name for f in table.schema if f.vector_index]
Comment thread
ntkathole marked this conversation as resolved.
Outdated
entity_batch_to_insert = []
unique_entities: dict[str, dict[str, Any]] = {}
required_fields = {field["name"] for field in collection["fields"]}
Expand Down Expand Up @@ -503,6 +505,14 @@ def retrieve_online_documents_v2(
entity_name_feast_primitive_type_map = {
k.name: k.dtype for k in table.entity_columns
}
# Also include feature columns for proper type mapping
feature_name_feast_primitive_type_map = {
k.name: k.dtype for k in table.features
}
field_name_feast_primitive_type_map = {
**entity_name_feast_primitive_type_map,
**feature_name_feast_primitive_type_map,
}
self.client = self._connect(config)
collection_name = _table_id(config.project, table)
collection = self._get_or_create_collection(config, table)
Expand Down Expand Up @@ -662,14 +672,25 @@ def retrieve_online_documents_v2(
embedding
)
res[ann_search_field] = serialized_embedding
elif entity_name_feast_primitive_type_map.get(
field, PrimitiveFeastType.INVALID
) in [
PrimitiveFeastType.STRING,
PrimitiveFeastType.BYTES,
]:
elif (
field_name_feast_primitive_type_map.get(
field, PrimitiveFeastType.INVALID
)
== PrimitiveFeastType.STRING
):
res[field] = ValueProto(string_val=str(field_value))
elif entity_name_feast_primitive_type_map.get(
elif (
field_name_feast_primitive_type_map.get(
field, PrimitiveFeastType.INVALID
)
== PrimitiveFeastType.BYTES
):
try:
decoded_bytes = base64.b64decode(field_value)
res[field] = ValueProto(bytes_val=decoded_bytes)
except Exception:
res[field] = ValueProto(string_val=str(field_value))
elif field_name_feast_primitive_type_map.get(
field, PrimitiveFeastType.INVALID
) in [
PrimitiveFeastType.INT64,
Expand Down Expand Up @@ -732,9 +753,13 @@ def _extract_proto_values_to_dict(
else:
if (
serialize_to_string
and proto_val_type not in ["string_val"] + numeric_types
and proto_val_type
not in ["string_val", "bytes_val"] + numeric_types
):
vector_values = feature_values.SerializeToString().decode()
elif proto_val_type == "bytes_val":
byte_data = getattr(feature_values, proto_val_type)
vector_values = base64.b64encode(byte_data).decode("utf-8")
else:
if not isinstance(feature_values, str):
vector_values = str(
Expand Down
3 changes: 3 additions & 0 deletions sdk/python/feast/on_demand_feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,9 @@ def _construct_random_input(
ValueType.PDF_BYTES: [
b"%PDF-1.3\n3 0 obj\n<</Type /Page\n/Parent 1 0 R\n/Resources 2 0 R\n/Contents 4 0 R>>\nendobj\n4 0 obj\n<</Filter /FlateDecode /Length 115>>\nstream\nx\x9c\x15\xcc1\x0e\x820\x18@\xe1\x9dS\xbcM]jk$\xd5\xd5(\x83!\x86\xa1\x17\xf8\xa3\xa5`LIh+\xd7W\xc6\xf7\r\xef\xc0\xbd\xd2\xaa\xb6,\xd5\xc5\xb1o\x0c\xa6VZ\xe3znn%\xf3o\xab\xb1\xe7\xa3:Y\xdc\x8bm\xeb\xf3&1\xc8\xd7\xd3\x97\xc82\xe6\x81\x87\xe42\xcb\x87Vb(\x12<\xdd<=}Jc\x0cL\x91\xee\xda$\xb5\xc3\xbd\xd7\xe9\x0f\x8d\x97 $\nendstream\nendobj\n1 0 obj\n<</Type /Pages\n/Kids [3 0 R ]\n/Count 1\n/MediaBox [0 0 595.28 841.89]\n>>\nendobj\n5 0 obj\n<</Type /Font\n/BaseFont /Helvetica\n/Subtype /Type1\n/Encoding /WinAnsiEncoding\n>>\nendobj\n2 0 obj\n<<\n/ProcSet [/PDF /Text /ImageB /ImageC /ImageI]\n/Font <<\n/F1 5 0 R\n>>\n/XObject <<\n>>\n>>\nendobj\n6 0 obj\n<<\n/Producer (PyFPDF 1.7.2 http://pyfpdf.googlecode.com/)\n/Title (This is a sample title.)\n/Author (Francisco Javier Arceo)\n/CreationDate (D:20250312165548)\n>>\nendobj\n7 0 obj\n<<\n/Type /Catalog\n/Pages 1 0 R\n/OpenAction [3 0 R /FitH null]\n/PageLayout /OneColumn\n>>\nendobj\nxref\n0 8\n0000000000 65535 f \n0000000272 00000 n \n0000000455 00000 n \n0000000009 00000 n \n0000000087 00000 n \n0000000359 00000 n \n0000000559 00000 n \n0000000734 00000 n \ntrailer\n<<\n/Size 8\n/Root 7 0 R\n/Info 6 0 R\n>>\nstartxref\n837\n%%EOF\n"
],
ValueType.IMAGE_BYTES: [
b"\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x01\x00H\x00H\x00\x00\xff\xdb\x00C\x00\x08\x06\x06\x07\x06\x05\x08\x07\x07\x07\t\t\x08\n\x0c\x14\r\x0c\x0b\x0b\x0c\x19\x12\x13\x0f\x14\x1d\x1a\x1f\x1e\x1d\x1a\x1c\x1c $.' \",#\x1c\x1c(7),01444\x1f'9=82<.342\xff\xc0\x00\x11\x08\x00\x01\x00\x01\x01\x01\x11\x00\x02\x11\x01\x03\x11\x01\xff\xc4\x00\x14\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x08\xff\xc4\x00\x14\x10\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xda\x00\x0c\x03\x01\x00\x02\x11\x03\x11\x00\x3f\x00\xaa\xff\xd9"
],
ValueType.STRING: ["hello world"],
ValueType.INT32: [1],
ValueType.INT64: [1],
Expand Down
1 change: 1 addition & 0 deletions sdk/python/feast/type_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ def _type_err(item, dtype):
),
ValueType.STRING: ("string_val", lambda x: str(x), None),
ValueType.BYTES: ("bytes_val", lambda x: x, {bytes}),
ValueType.IMAGE_BYTES: ("bytes_val", lambda x: x, {bytes}),
ValueType.BOOL: ("bool_val", lambda x: x, {bool, np.bool_, int, np.int_}),
}

Expand Down
6 changes: 6 additions & 0 deletions sdk/python/feast/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"INVALID": "UNKNOWN",
"BYTES": "BYTES",
"PDF_BYTES": "PDF_BYTES",
"IMAGE_BYTES": "IMAGE_BYTES",
"STRING": "STRING",
"INT32": "INT32",
"INT64": "INT64",
Expand Down Expand Up @@ -81,6 +82,7 @@ class PrimitiveFeastType(Enum):
BOOL = 7
UNIX_TIMESTAMP = 8
PDF_BYTES = 9
IMAGE_BYTES = 10

def to_value_type(self) -> ValueType:
"""
Expand All @@ -105,6 +107,7 @@ def __hash__(self):
Invalid = PrimitiveFeastType.INVALID
Bytes = PrimitiveFeastType.BYTES
PdfBytes = PrimitiveFeastType.PDF_BYTES
ImageBytes = PrimitiveFeastType.IMAGE_BYTES
String = PrimitiveFeastType.STRING
Bool = PrimitiveFeastType.BOOL
Int32 = PrimitiveFeastType.INT32
Expand All @@ -118,6 +121,7 @@ def __hash__(self):
String,
Bytes,
PdfBytes,
ImageBytes,
Bool,
Int32,
Int64,
Expand All @@ -131,6 +135,7 @@ def __hash__(self):
"STRING": "String",
"BYTES": "Bytes",
"PDF_BYTES": "PdfBytes",
"IMAGE_BYTES": "ImageBytes",
"BOOL": "Bool",
"INT32": "Int32",
"INT64": "Int64",
Expand Down Expand Up @@ -174,6 +179,7 @@ def __str__(self):
ValueType.UNKNOWN: Invalid,
ValueType.BYTES: Bytes,
ValueType.PDF_BYTES: PdfBytes,
ValueType.IMAGE_BYTES: ImageBytes,
ValueType.STRING: String,
ValueType.INT32: Int32,
ValueType.INT64: Int64,
Expand Down
1 change: 1 addition & 0 deletions sdk/python/feast/value_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class ValueType(enum.Enum):
UNIX_TIMESTAMP_LIST = 18
NULL = 19
PDF_BYTES = 20
IMAGE_BYTES = 21


ListType = Union[
Expand Down
11 changes: 11 additions & 0 deletions sdk/python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from tests.data.data_creator import (
create_basic_driver_dataset, # noqa: E402
create_document_dataset,
create_image_dataset,
)
from tests.integration.feature_repos.integration_test_repo_config import ( # noqa: E402
IntegrationTestRepoConfig,
Expand Down Expand Up @@ -446,6 +447,16 @@ def fake_document_data(environment: Environment) -> Tuple[pd.DataFrame, DataSour
return df, data_source


@pytest.fixture
def fake_image_data(environment: Environment) -> Tuple[pd.DataFrame, DataSource]:
df = create_image_dataset()
data_source = environment.data_source_creator.create_data_source(
df,
environment.feature_store.project,
)
return df, data_source


@pytest.fixture
def temp_dir():
with tempfile.TemporaryDirectory() as temp_dir:
Expand Down
50 changes: 50 additions & 0 deletions sdk/python/tests/data/data_creator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import io
from datetime import datetime, timedelta, timezone
from typing import Dict, List, Optional
from zoneinfo import ZoneInfo
Expand Down Expand Up @@ -100,3 +101,52 @@ def create_document_dataset() -> pd.DataFrame:
],
}
return pd.DataFrame(data)


def create_image_dataset() -> pd.DataFrame:
"""Create a dataset with image data for testing image search functionality."""

def create_test_image_bytes(color=(255, 0, 0), size=(32, 32)):
"""Create synthetic image bytes for testing."""
try:
from PIL import Image

img = Image.new("RGB", size, color=color)
img_bytes = io.BytesIO()
img.save(img_bytes, format="JPEG")
return img_bytes.getvalue()
except ImportError:
# Return dummy bytes if PIL not available
return b"\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x01\x00H\x00H\x00\x00\xff\xdb\x00C\x00\x08\x06\x06\x07\x06\x05\x08\x07\x07\x07\t\t\x08\n\x0c\x14\r\x0c\x0b\x0b\x0c\x19\x12\x13\x0f\x14\x1d\x1a\x1f\x1e\x1d\x1a\x1c\x1c $.' \",#\x1c\x1c(7),01444\x1f'9=82<.342\xff\xc0\x00\x11\x08\x00 \x00 \x01\x01\x11\x00\x02\x11\x01\x03\x11\x01\xff\xc4\x00\x14\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x08\xff\xc4\x00\x14\x10\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xda\x00\x0c\x03\x01\x00\x02\x11\x03\x11\x00\x3f\x00\xaa\xff\xd9"

data = {
"item_id": [1, 2, 3],
"image_filename": ["red_image.jpg", "green_image.jpg", "blue_image.jpg"],
"image_bytes": [
create_test_image_bytes((255, 0, 0)), # Red
create_test_image_bytes((0, 255, 0)), # Green
create_test_image_bytes((0, 0, 255)), # Blue
],
"image_embedding": [
[0.9, 0.1], # Red-ish embedding
[0.2, 0.8], # Green-ish embedding
[0.1, 0.9], # Blue-ish embedding
],
"category": ["primary", "primary", "primary"],
"description": [
"A red colored image",
"A green colored image",
"A blue colored image",
],
"ts": [
pd.Timestamp(_utc_now()).round("ms"),
pd.Timestamp(_utc_now()).round("ms"),
pd.Timestamp(_utc_now()).round("ms"),
],
"created_ts": [
pd.Timestamp(_utc_now()).round("ms"),
pd.Timestamp(_utc_now()).round("ms"),
pd.Timestamp(_utc_now()).round("ms"),
],
}
return pd.DataFrame(data)
Loading