Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
fix: Fixed comments
Signed-off-by: ntkathole <nikhilkathole2683@gmail.com>
  • Loading branch information
ntkathole committed Sep 7, 2025
commit d02913b69416b7d57e66bba39651ce27f1e82f28
29 changes: 11 additions & 18 deletions sdk/python/feast/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
_image_dependencies_available = False


COMBINATION_STRATEGIES = ["weighted_sum", "concatenate", "average"]


def _check_image_dependencies():
"""Check if image processing dependencies are available."""
if not _image_dependencies_available:
Expand All @@ -58,7 +61,10 @@ class ImageFeatureExtractor:

Using different models::

# Use ViT model for better performance
# ResNet-50
extractor = ImageFeatureExtractor("resnet50")
embedding = extractor.extract_embedding(image_bytes)
# ViT model
extractor = ImageFeatureExtractor("vit_base_patch16_224")
embedding = extractor.extract_embedding(image_bytes)
"""
Expand Down Expand Up @@ -88,6 +94,7 @@ def __init__(self, model_name: str = "resnet34"):

config = resolve_data_config({}, model=model_name)
self.preprocess = create_transform(**config)

except Exception as e:
raise RuntimeError(f"Failed to load model '{model_name}': {e}")

Expand Down Expand Up @@ -115,17 +122,6 @@ def extract_embedding(self, image_bytes: bytes) -> List[float]:
except Exception as e:
raise ValueError(f"Failed to extract embedding from image: {e}")

def get_embedding_dimension(self) -> int:
"""
Get the dimension of embeddings produced by this model.
Returns:
Integer dimension of the embedding vector
"""
dummy_input = torch.randn(1, 3, 224, 224) # Standard input size
with torch.no_grad():
output = self.model(dummy_input)
return output.shape[1]

def batch_extract_embeddings(
self, image_bytes_list: List[bytes]
) -> List[List[float]]:
Expand Down Expand Up @@ -177,18 +173,15 @@ def combine_embeddings(
Args:
text_embedding: Text embedding vector
image_embedding: Image embedding vector
strategy: Combination strategy. Options:
- "weighted_sum": Weighted sum of aligned embeddings (default)
- "concatenate": Concatenate embeddings end-to-end
- "average": Simple average of aligned embeddings
strategy: Combination strategy (default: "weighted_sum")
text_weight: Weight for text embedding (for weighted strategies)
image_weight: Weight for image embedding (for weighted strategies)

Returns:
Combined embedding vector as list of floats

Raises:
ValueError: If strategy is invalid or weights don't sum to 1.0
ValueError: If weights don't sum to 1.0 for weighted_sum strategy

Examples:
Weighted combination (emphasize image)::
Expand Down Expand Up @@ -236,7 +229,7 @@ def combine_embeddings(
else:
raise ValueError(
f"Unknown combination strategy: {strategy}. "
f"Supported strategies: weighted_sum, concatenate, average"
f"Supported strategies: {', '.join(COMBINATION_STRATEGIES)}"
)


Expand Down
8 changes: 0 additions & 8 deletions sdk/python/tests/unit/test_image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,6 @@ def test_extract_embedding_invalid_image(self):
with pytest.raises(ValueError, match="Failed to extract embedding"):
extractor.extract_embedding(b"invalid image data")

def test_get_embedding_dimension(self):
"""Test getting embedding dimension."""
extractor = ImageFeatureExtractor()
dimension = extractor.get_embedding_dimension()

assert isinstance(dimension, int)
assert dimension > 0

def test_batch_extract_embeddings(self):
"""Test batch embedding extraction."""
extractor = ImageFeatureExtractor()
Expand Down
Loading