Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add faiss & in memory online store
Signed-off-by: cmuhao <sduxuhao@gmail.com>
  • Loading branch information
HaoXuAI committed Aug 29, 2024
commit 1f10a728160bcc8ae02afd033f54d9c8b8593db5
142 changes: 80 additions & 62 deletions sdk/python/feast/infra/online_stores/contrib/faiss_online_store.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import logging
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple

import faiss
import numpy as np
from typing import Sequence, Tuple, List, Optional, Dict, Any, Callable, Union
from google.protobuf.timestamp_pb2 import Timestamp
from protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
from protos.feast.types.Value_pb2 import Value as ValueProto

from feast import Entity, FeatureView, RepoConfig
from feast.infra.online_stores.online_store import OnlineStore
from feast.repo_config import FeastConfigBaseModel
from feast.protos.feast.types.EntityKey_pb2 import EntityKey
from feast.protos.feast.types.Value_pb2 import Value
from datetime import datetime
from google.protobuf.timestamp_pb2 import Timestamp
import logging

from protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
from protos.feast.types.Value_pb2 import Value as ValueProto
from feast.repo_config import FeastConfigBaseModel


class FaissOnlineStoreConfig(FeastConfigBaseModel):
Expand All @@ -27,19 +28,15 @@ def __init__(self):
self.feature_names = []
self.entity_keys = {}

def update(self,
feature_names: List[str],
entity_keys: Dict[Tuple[str, ...], int]):
def update(self, feature_names: List[str], entity_keys: Dict[Tuple[str, ...], int]):
self.feature_names = feature_names
self.entity_keys = entity_keys

def delete(self,
entity_keys: List[EntityKey]):
def delete(self, entity_keys: List[EntityKey]):
for entity_key in entity_keys:
del self.entity_keys[entity_key]

def read(self,
entity_keys: List[EntityKey]):
def read(self, entity_keys: List[EntityKey]):
return [self.entity_keys.get(entity_key, None) for entity_key in entity_keys]

def teardown(self):
Expand All @@ -49,22 +46,20 @@ def teardown(self):


class FaissOnlineStore(OnlineStore):

def __init__(self,
config: Optional[Dict[str, Any]] = None):
def __init__(self, config: Optional[Dict[str, Any]] = None):
self._index = None
self._in_memory_store = InMemoryStore()
self._config = FaissOnlineStoreConfig(**config) if config else None
self._logger = logging.getLogger(__name__)

def update(
self,
config: RepoConfig,
tables_to_delete: Sequence[FeatureView],
tables_to_keep: Sequence[FeatureView],
entities_to_delete: Sequence[Entity],
entities_to_keep: Sequence[Entity],
partial: bool,
self,
config: RepoConfig,
tables_to_delete: Sequence[FeatureView],
tables_to_keep: Sequence[FeatureView],
entities_to_delete: Sequence[Entity],
entities_to_keep: Sequence[Entity],
partial: bool,
):
feature_views = tables_to_keep
if not feature_views:
Expand All @@ -76,50 +71,59 @@ def update(
if self._index is None or not partial:
quantizer = faiss.IndexFlatL2(dimension)
self._index = faiss.IndexIVFFlat(quantizer, dimension, self._config.nlist)
self._index.train(np.random.rand(self._config.nlist * 100, dimension).astype(np.float32))
self._index.train(
np.random.rand(self._config.nlist * 100, dimension).astype(np.float32)
)
self._in_memory_store = InMemoryStore()

self._in_memory_store.update(feature_names, {})

def teardown(self,
config: RepoConfig,
tables: Sequence[FeatureView],
entities: Sequence[Entity]):
def teardown(
self,
config: RepoConfig,
tables: Sequence[FeatureView],
entities: Sequence[Entity],
):
# reset index
self._index = None
self._in_memory_store.teardown()

def online_read(
self,
config: RepoConfig,
table: FeatureView,
entity_keys: List[EntityKey],
requested_features: Optional[List[str]] = None,
self,
config: RepoConfig,
table: FeatureView,
entity_keys: List[EntityKey],
requested_features: Optional[List[str]] = None,
) -> List[Tuple[Optional[datetime], Optional[Dict[str, Value]]]]:
if self._index is None:
return [(None, None)] * len(entity_keys)

results = []
for entity_key in entity_keys:
entity_key_tuple = tuple(f"{field.name}:{field.value.string_val}" for field in entity_key.join_keys)
entity_key_tuple = tuple(
f"{field.name}:{field.value.string_val}"
for field in entity_key.join_keys
)
idx = self._in_memory_store.entity_keys.get(entity_key_tuple, -1)
if idx == -1:
results.append((None, None))
else:
feature_vector = self._index.reconstruct(int(idx))
feature_dict = {
name: Value(double_val=value)
for name, value in zip(self._in_memory_store.feature_names, feature_vector)
for name, value in zip(
self._in_memory_store.feature_names, feature_vector
)
}
results.append((None, feature_dict))
return results

def online_write_batch(
self,
config: RepoConfig,
table: FeatureView,
data: List[Tuple[EntityKey, Dict[str, Value], datetime, Optional[datetime]]],
progress: Optional[Callable[[int], Any]],
self,
config: RepoConfig,
table: FeatureView,
data: List[Tuple[EntityKey, Dict[str, Value], datetime, Optional[datetime]]],
progress: Optional[Callable[[int], Any]],
) -> None:
if self._index is None:
self._logger.warning("Index is not initialized. Skipping write operation.")
Expand All @@ -129,22 +133,35 @@ def online_write_batch(
entity_key_tuples = []

for entity_key, feature_dict, _, _ in data:
entity_key_tuple = tuple(f"{field.name}:{field.value.string_val}" for field in entity_key.join_keys)
feature_vector = np.array([
feature_dict[name].double_val for name in self._in_memory_store.feature_names
], dtype=np.float32)
entity_key_tuple = tuple(
f"{field.name}:{field.value.string_val}"
for field in entity_key.join_keys
)
feature_vector = np.array(
[
feature_dict[name].double_val
for name in self._in_memory_store.feature_names
],
dtype=np.float32,
)

feature_vectors.append(feature_vector)
entity_key_tuples.append(entity_key_tuple)

feature_vectors = np.array(feature_vectors)

existing_indices = [self._in_memory_store.entity_keys.get(ekt, -1) for ekt in entity_key_tuples]
existing_indices = [
self._in_memory_store.entity_keys.get(ekt, -1) for ekt in entity_key_tuples
]
mask = np.array(existing_indices) != -1
if np.any(mask):
self._index.remove_ids(np.array([idx for idx in existing_indices if idx != -1]))
self._index.remove_ids(
np.array([idx for idx in existing_indices if idx != -1])
)

new_indices = np.arange(self._index.ntotal, self._index.ntotal + len(feature_vectors))
new_indices = np.arange(
self._index.ntotal, self._index.ntotal + len(feature_vectors)
)
self._index.add(feature_vectors)

for ekt, idx in zip(entity_key_tuples, new_indices):
Expand All @@ -154,13 +171,13 @@ def online_write_batch(
progress(len(data))

def retrieve_online_documents(
self,
config: RepoConfig,
table: FeatureView,
requested_feature: str,
embedding: List[float],
top_k: int,
distance_metric: Optional[str] = None,
self,
config: RepoConfig,
table: FeatureView,
requested_feature: str,
embedding: List[float],
top_k: int,
distance_metric: Optional[str] = None,
) -> List[
Tuple[
Optional[datetime],
Expand Down Expand Up @@ -201,10 +218,11 @@ def retrieve_online_documents(

return results

async def online_read_async(self,
config: RepoConfig,
table: FeatureView,
entity_keys: List[EntityKeyProto],
requested_features: Optional[List[str]] = None) -> List[
Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
async def online_read_async(
self,
config: RepoConfig,
table: FeatureView,
entity_keys: List[EntityKeyProto],
requested_features: Optional[List[str]] = None,
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
pass