From 5123bed4762406e6bdc2ab1810d11fff4c2b6896 Mon Sep 17 00:00:00 2001 From: yassinnouh21 Date: Fri, 10 Apr 2026 12:24:24 +0200 Subject: [PATCH] feat: Add feature view versioning support to FAISS online store When enable_online_feature_view_versioning is enabled, FAISS indices are namespaced by versioned table keys (e.g. project_driver_stats_v2) so multiple feature view versions can coexist in memory. Reuses the shared compute_table_id() from helpers.py for consistency with PostgreSQL and MySQL stores. Signed-off-by: yassinnouh21 --- sdk/python/feast/errors.py | 2 +- .../infra/online_stores/faiss_online_store.py | 119 +++++---- .../feast/infra/online_stores/online_store.py | 6 + .../online_store/test_faiss_versioning.py | 249 ++++++++++++++++++ 4 files changed, 331 insertions(+), 45 deletions(-) create mode 100644 sdk/python/tests/unit/infra/online_store/test_faiss_versioning.py diff --git a/sdk/python/feast/errors.py b/sdk/python/feast/errors.py index fb35ff79de..08d4082743 100644 --- a/sdk/python/feast/errors.py +++ b/sdk/python/feast/errors.py @@ -142,7 +142,7 @@ class VersionedOnlineReadNotSupported(FeastError): def __init__(self, store_name: str, version: int): super().__init__( f"Versioned feature reads (@v{version}) are not yet supported by {store_name}. " - f"Currently only SQLite, PostgreSQL, and MySQL support version-qualified feature references. " + f"Currently only SQLite, PostgreSQL, MySQL, and FAISS support version-qualified feature references. " ) diff --git a/sdk/python/feast/infra/online_stores/faiss_online_store.py b/sdk/python/feast/infra/online_stores/faiss_online_store.py index 3e3d92cde6..dfa7d6c376 100644 --- a/sdk/python/feast/infra/online_stores/faiss_online_store.py +++ b/sdk/python/feast/infra/online_stores/faiss_online_store.py @@ -8,6 +8,7 @@ from feast import Entity, FeatureView, RepoConfig from feast.infra.key_encoding_utils import serialize_entity_key +from feast.infra.online_stores.helpers import compute_table_id from feast.infra.online_stores.online_store import OnlineStore from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto @@ -43,16 +44,21 @@ def teardown(self): self.entity_keys = {} +def _table_id(project: str, table: FeatureView, enable_versioning: bool = False) -> str: + return compute_table_id(project, table, enable_versioning) + + class FaissOnlineStore(OnlineStore): - _index: Optional[faiss.IndexIVFFlat] = None - _in_memory_store: InMemoryStore = InMemoryStore() - _config: Optional[FaissOnlineStoreConfig] = None _logger: logging.Logger = logging.getLogger(__name__) - def _get_index(self, config: RepoConfig) -> faiss.IndexIVFFlat: - if self._index is None or self._config is None: - raise ValueError("Index is not initialized") - return self._index + def __init__(self): + super().__init__() + self._indices: Dict[str, faiss.IndexIVFFlat] = {} + self._in_memory_stores: Dict[str, InMemoryStore] = {} + self._config: Optional[FaissOnlineStoreConfig] = None + + def _get_index(self, table_key: str) -> Optional[faiss.IndexIVFFlat]: + return self._indices.get(table_key) def update( self, @@ -63,23 +69,31 @@ def update( entities_to_keep: Sequence[Entity], partial: bool, ): - feature_views = tables_to_keep - if not feature_views: - return - - feature_names = [f.name for f in feature_views[0].features] - dimension = len(feature_names) - self._config = FaissOnlineStoreConfig(**config.online_store.dict()) - if self._index is None or not partial: - quantizer = faiss.IndexFlatL2(dimension) - self._index = faiss.IndexIVFFlat(quantizer, dimension, self._config.nlist) - self._index.train( - np.random.rand(self._config.nlist * 100, dimension).astype(np.float32) - ) - self._in_memory_store = InMemoryStore() + versioning = config.registry.enable_online_feature_view_versioning + + for table in tables_to_delete: + table_key = _table_id(config.project, table, versioning) + self._indices.pop(table_key, None) + self._in_memory_stores.pop(table_key, None) + + for table in tables_to_keep: + table_key = _table_id(config.project, table, versioning) + feature_names = [f.name for f in table.features] + dimension = len(feature_names) + + if table_key not in self._indices or not partial: + quantizer = faiss.IndexFlatL2(dimension) + index = faiss.IndexIVFFlat(quantizer, dimension, self._config.nlist) + index.train( + np.random.rand(self._config.nlist * 100, dimension).astype( + np.float32 + ) + ) + self._indices[table_key] = index + self._in_memory_stores[table_key] = InMemoryStore() - self._in_memory_store.update(feature_names, {}) + self._in_memory_stores[table_key].update(feature_names, {}) def teardown( self, @@ -87,8 +101,13 @@ def teardown( tables: Sequence[FeatureView], entities: Sequence[Entity], ): - self._index = None - self._in_memory_store.teardown() + versioning = config.registry.enable_online_feature_view_versioning + for table in tables: + table_key = _table_id(config.project, table, versioning) + self._indices.pop(table_key, None) + store = self._in_memory_stores.pop(table_key, None) + if store is not None: + store.teardown() def online_read( self, @@ -97,7 +116,12 @@ def online_read( entity_keys: List[EntityKeyProto], requested_features: Optional[List[str]] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: - if self._index is None: + versioning = config.registry.enable_online_feature_view_versioning + table_key = _table_id(config.project, table, versioning) + index = self._get_index(table_key) + in_memory_store = self._in_memory_stores.get(table_key) + + if index is None or in_memory_store is None: return [(None, None)] * len(entity_keys) results: List[Tuple[Optional[datetime], Optional[Dict[str, Any]]]] = [] @@ -105,15 +129,15 @@ def online_read( serialized_key = serialize_entity_key( entity_key, config.entity_key_serialization_version ).hex() - idx = self._in_memory_store.entity_keys.get(serialized_key, -1) + idx = in_memory_store.entity_keys.get(serialized_key, -1) if idx == -1: results.append((None, None)) else: - feature_vector = self._index.reconstruct(int(idx)) + feature_vector = index.reconstruct(int(idx)) feature_dict = { name: ValueProto(double_val=value) for name, value in zip( - self._in_memory_store.feature_names, feature_vector + in_memory_store.feature_names, feature_vector ) } results.append((None, feature_dict)) @@ -128,8 +152,16 @@ def online_write_batch( ], progress: Optional[Callable[[int], Any]], ) -> None: - if self._index is None: - self._logger.warning("Index is not initialized. Skipping write operation.") + versioning = config.registry.enable_online_feature_view_versioning + table_key = _table_id(config.project, table, versioning) + index = self._get_index(table_key) + in_memory_store = self._in_memory_stores.get(table_key) + + if index is None or in_memory_store is None: + self._logger.warning( + "Index for table '%s' is not initialized. Skipping write operation.", + table_key, + ) return feature_vectors = [] @@ -142,7 +174,7 @@ def online_write_batch( feature_vector = np.array( [ feature_dict[name].double_val - for name in self._in_memory_store.feature_names + for name in in_memory_store.feature_names ], dtype=np.float32, ) @@ -153,21 +185,17 @@ def online_write_batch( feature_vectors_array = np.array(feature_vectors) existing_indices = [ - self._in_memory_store.entity_keys.get(sk, -1) for sk in serialized_keys + in_memory_store.entity_keys.get(sk, -1) for sk in serialized_keys ] mask = np.array(existing_indices) != -1 if np.any(mask): - self._index.remove_ids( - np.array([idx for idx in existing_indices if idx != -1]) - ) + index.remove_ids(np.array([idx for idx in existing_indices if idx != -1])) - new_indices = np.arange( - self._index.ntotal, self._index.ntotal + len(feature_vectors_array) - ) - self._index.add(feature_vectors_array) + new_indices = np.arange(index.ntotal, index.ntotal + len(feature_vectors_array)) + index.add(feature_vectors_array) for sk, idx in zip(serialized_keys, new_indices): - self._in_memory_store.entity_keys[sk] = idx + in_memory_store.entity_keys[sk] = idx if progress: progress(len(data)) @@ -189,12 +217,16 @@ def retrieve_online_documents( Optional[ValueProto], ] ]: - if self._index is None: + versioning = config.registry.enable_online_feature_view_versioning + table_key = _table_id(config.project, table, versioning) + index = self._get_index(table_key) + + if index is None: self._logger.warning("Index is not initialized. Returning empty result.") return [] query_vector = np.array(embedding, dtype=np.float32).reshape(1, -1) - distances, indices = self._index.search(query_vector, top_k) + distances, indices = index.search(query_vector, top_k) results: List[ Tuple[ @@ -209,7 +241,7 @@ def retrieve_online_documents( if idx == -1: continue - feature_vector = self._index.reconstruct(int(idx)) + feature_vector = index.reconstruct(int(idx)) timestamp = Timestamp() timestamp.GetCurrentTime() @@ -237,5 +269,4 @@ async def online_read_async( entity_keys: List[EntityKeyProto], requested_features: Optional[List[str]] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: - # Implement async read if needed raise NotImplementedError("Async read is not implemented for FaissOnlineStore") diff --git a/sdk/python/feast/infra/online_stores/online_store.py b/sdk/python/feast/infra/online_stores/online_store.py index 41555ccfb2..cc77abf39b 100644 --- a/sdk/python/feast/infra/online_stores/online_store.py +++ b/sdk/python/feast/infra/online_stores/online_store.py @@ -274,6 +274,12 @@ def _check_versioned_read_support(self, grouped_refs): supported_types.append(PostgreSQLOnlineStore) except ImportError: pass + try: + from feast.infra.online_stores.faiss_online_store import FaissOnlineStore + + supported_types.append(FaissOnlineStore) + except ImportError: + pass if isinstance(self, tuple(supported_types)): return diff --git a/sdk/python/tests/unit/infra/online_store/test_faiss_versioning.py b/sdk/python/tests/unit/infra/online_store/test_faiss_versioning.py new file mode 100644 index 0000000000..84b0aa24e9 --- /dev/null +++ b/sdk/python/tests/unit/infra/online_store/test_faiss_versioning.py @@ -0,0 +1,249 @@ +"""Unit tests for FAISS online store feature view versioning.""" + +import sys +from datetime import timedelta +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + +from feast import Entity, FeatureView +from feast.field import Field +from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto +from feast.protos.feast.types.Value_pb2 import Value as ValueProto +from feast.types import Float32 +from feast.value_type import ValueType + + +def _make_feature_view(name="driver_stats", version_number=None, version_tag=None): + entity = Entity( + name="driver_id", + join_keys=["driver_id"], + value_type=ValueType.INT64, + ) + fv = FeatureView( + name=name, + entities=[entity], + ttl=timedelta(days=1), + schema=[Field(name="feature_a", dtype=Float32)], + ) + if version_number is not None: + fv.current_version_number = version_number + if version_tag is not None: + fv.projection.version_tag = version_tag + return fv + + +@pytest.fixture(autouse=True) +def _mock_faiss(): + """Inject a minimal faiss mock so faiss_online_store can be imported.""" + faiss_mock = MagicMock() + with patch.dict(sys.modules, {"faiss": faiss_mock}): + sys.modules.pop("feast.infra.online_stores.faiss_online_store", None) + yield faiss_mock + sys.modules.pop("feast.infra.online_stores.faiss_online_store", None) + + +class TestFaissTableId: + """Test _table_id generates correct versioned table names.""" + + def test_default_no_versioning(self): + from feast.infra.online_stores.faiss_online_store import _table_id + + fv = _make_feature_view() + assert _table_id("proj", fv) == "proj_driver_stats" + + def test_versioning_explicitly_disabled(self): + from feast.infra.online_stores.faiss_online_store import _table_id + + fv = _make_feature_view(version_number=3) + assert _table_id("proj", fv, enable_versioning=False) == "proj_driver_stats" + + def test_versioning_enabled_no_version_set(self): + from feast.infra.online_stores.faiss_online_store import _table_id + + fv = _make_feature_view() + assert _table_id("proj", fv, enable_versioning=True) == "proj_driver_stats" + + def test_versioning_enabled_with_current_version_number(self): + from feast.infra.online_stores.faiss_online_store import _table_id + + fv = _make_feature_view(version_number=2) + assert _table_id("proj", fv, enable_versioning=True) == "proj_driver_stats_v2" + + def test_version_zero_no_suffix(self): + from feast.infra.online_stores.faiss_online_store import _table_id + + fv = _make_feature_view(version_number=0) + assert _table_id("proj", fv, enable_versioning=True) == "proj_driver_stats" + + def test_projection_version_tag_takes_priority(self): + from feast.infra.online_stores.faiss_online_store import _table_id + + fv = _make_feature_view(version_number=1, version_tag=3) + assert _table_id("proj", fv, enable_versioning=True) == "proj_driver_stats_v3" + + def test_projection_version_tag_zero_no_suffix(self): + from feast.infra.online_stores.faiss_online_store import _table_id + + fv = _make_feature_view(version_tag=0, version_number=3) + assert _table_id("proj", fv, enable_versioning=True) == "proj_driver_stats" + + def test_different_project_names(self): + from feast.infra.online_stores.faiss_online_store import _table_id + + fv = _make_feature_view(version_number=1) + assert _table_id("prod", fv, enable_versioning=True) == "prod_driver_stats_v1" + assert ( + _table_id("staging", fv, enable_versioning=True) + == "staging_driver_stats_v1" + ) + + def test_different_feature_view_names(self): + from feast.infra.online_stores.faiss_online_store import _table_id + + fv = _make_feature_view(name="user_stats", version_number=2) + assert _table_id("proj", fv, enable_versioning=True) == "proj_user_stats_v2" + + +class TestFaissVersionedReadSupport: + """Test that FaissOnlineStore passes _check_versioned_read_support.""" + + def test_allowed_with_version_tag(self): + from feast.infra.online_stores.faiss_online_store import FaissOnlineStore + + store = FaissOnlineStore() + fv = _make_feature_view() + fv.projection.version_tag = 2 + store._check_versioned_read_support([(fv, ["feature_a"])]) + + def test_allowed_without_version_tag(self): + from feast.infra.online_stores.faiss_online_store import FaissOnlineStore + + store = FaissOnlineStore() + fv = _make_feature_view() + store._check_versioned_read_support([(fv, ["feature_a"])]) + + +def _make_config(project="test_project", versioning=False): + """Build a minimal RepoConfig-like mock.""" + config = MagicMock() + config.project = project + config.entity_key_serialization_version = 2 + config.online_store.dict.return_value = { + "dimension": 1, + "index_path": "/tmp/test.index", + "index_type": "IVFFlat", + "nlist": 10, + } + config.registry.enable_online_feature_view_versioning = versioning + return config + + +def _make_entity_key(driver_id=1): + return EntityKeyProto( + join_keys=["driver_id"], + entity_values=[ValueProto(int64_val=driver_id)], + ) + + +class TestFaissOnlineStoreVersionedReadWrite: + def _make_store(self, faiss_mock, nlist=10): + """Create a FaissOnlineStore with a real-enough faiss mock.""" + index_mock = MagicMock() + index_mock.ntotal = 0 + + def add_side_effect(vectors): + index_mock.ntotal += len(vectors) + + index_mock.add.side_effect = add_side_effect + + def reconstruct_side_effect(idx): + return np.array([float(idx)], dtype=np.float32) + + index_mock.reconstruct.side_effect = reconstruct_side_effect + + faiss_mock.IndexFlatL2.return_value = MagicMock() + faiss_mock.IndexIVFFlat.return_value = index_mock + + from feast.infra.online_stores.faiss_online_store import FaissOnlineStore + + store = FaissOnlineStore() + return store, index_mock + + def test_write_and_read_without_versioning(self, _mock_faiss): + store, _ = self._make_store(_mock_faiss) + config = _make_config(versioning=False) + fv = _make_feature_view() + + store.update(config, [], [fv], [], [], partial=False) + + entity_key = _make_entity_key(driver_id=42) + data = [(entity_key, {"feature_a": ValueProto(double_val=1.5)}, None, None)] + store.online_write_batch(config, fv, data, None) + + results = store.online_read(config, fv, [entity_key]) + assert len(results) == 1 + _, feature_dict = results[0] + assert feature_dict is not None + assert "feature_a" in feature_dict + + def test_write_and_read_with_versioning(self, _mock_faiss): + store, _ = self._make_store(_mock_faiss) + config = _make_config(versioning=True) + fv_v2 = _make_feature_view(version_number=2) + + store.update(config, [], [fv_v2], [], [], partial=False) + + entity_key = _make_entity_key(driver_id=7) + data = [(entity_key, {"feature_a": ValueProto(double_val=2.0)}, None, None)] + store.online_write_batch(config, fv_v2, data, None) + + results = store.online_read(config, fv_v2, [entity_key]) + assert len(results) == 1 + _, feature_dict = results[0] + assert feature_dict is not None + + def test_versioned_namespaces_are_isolated(self, _mock_faiss): + """Data written under v1 must not be visible when reading under v2.""" + store, _ = self._make_store(_mock_faiss) + config = _make_config(versioning=True) + + fv_v1 = _make_feature_view(version_number=1) + fv_v2 = _make_feature_view(version_number=2) + + store.update(config, [], [fv_v1, fv_v2], [], [], partial=False) + + entity_key = _make_entity_key(driver_id=99) + data = [(entity_key, {"feature_a": ValueProto(double_val=9.9)}, None, None)] + store.online_write_batch(config, fv_v1, data, None) + + results_v2 = store.online_read(config, fv_v2, [entity_key]) + assert results_v2 == [(None, None)] + + results_v1 = store.online_read(config, fv_v1, [entity_key]) + assert results_v1[0][1] is not None + + def test_missing_index_returns_none(self, _mock_faiss): + store, _ = self._make_store(_mock_faiss) + config = _make_config(versioning=True) + fv = _make_feature_view(version_number=5) + entity_key = _make_entity_key(driver_id=1) + results = store.online_read(config, fv, [entity_key]) + assert results == [(None, None)] + + def test_teardown_removes_versioned_index(self, _mock_faiss): + store, _ = self._make_store(_mock_faiss) + config = _make_config(versioning=True) + fv = _make_feature_view(version_number=3) + + store.update(config, [], [fv], [], [], partial=False) + + entity_key = _make_entity_key(driver_id=1) + data = [(entity_key, {"feature_a": ValueProto(double_val=3.0)}, None, None)] + store.online_write_batch(config, fv, data, None) + + store.teardown(config, [fv], []) + + results = store.online_read(config, fv, [entity_key]) + assert results == [(None, None)]