Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add faiss & in memory online store
Signed-off-by: cmuhao <sduxuhao@gmail.com>
  • Loading branch information
HaoXuAI committed Aug 29, 2024
commit ab231a7cf1ece2792bdda56e158841499f0478fe
209 changes: 85 additions & 124 deletions sdk/python/feast/infra/online_stores/contrib/faiss_online_store.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
import logging
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple

import faiss
import numpy as np
from google.protobuf.timestamp_pb2 import Timestamp
from protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
from protos.feast.types.Value_pb2 import Value as ValueProto

from feast import Entity, FeatureView, RepoConfig
from feast.infra.online_stores.online_store import OnlineStore
from feast import Entity, FeatureView, RepoConfig
from feast.protos.feast.types.EntityKey_pb2 import EntityKey
from feast.protos.feast.types.Value_pb2 import Value
from feast.repo_config import FeastConfigBaseModel
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
from datetime import datetime
import faiss
import numpy as np
import logging
from google.protobuf.timestamp_pb2 import Timestamp


class FaissOnlineStoreConfig(FeastConfigBaseModel):
Expand All @@ -23,146 +19,127 @@ class FaissOnlineStoreConfig(FeastConfigBaseModel):


class InMemoryStore:
feature_names: List[str]
entity_keys: Dict[Tuple[str, ...], int]

def __init__(self):
self._index = None
self.feature_names = []
self.entity_keys = {}

def update(self, feature_names: List[str], entity_keys: Dict[Tuple[str, ...], int]):
def update(self,
feature_names: List[str],
entity_keys: Dict[Tuple[str, ...], int]):
self.feature_names = feature_names
self.entity_keys = entity_keys

def delete(self, entity_keys: List[EntityKey]):
def delete(self,
entity_keys: List[EntityKey]):
for entity_key in entity_keys:
del self.entity_keys[entity_key]

def read(self, entity_keys: List[EntityKey]):
def read(self,
entity_keys: List[EntityKey]):
return [self.entity_keys.get(entity_key, None) for entity_key in entity_keys]

def teardown(self):
self._index = None
self.feature_names = []
self.entity_keys = {}


class FaissOnlineStore(OnlineStore):
def __init__(self, config: Optional[Dict[str, Any]] = None):
self._index = None
self._in_memory_store = InMemoryStore()
self._config = FaissOnlineStoreConfig(**config) if config else None
self._logger = logging.getLogger(__name__)
_index: Optional[faiss.IndexIVFFlat] = None
_in_memory_store: InMemoryStore = InMemoryStore()
_config: Optional[FaissOnlineStoreConfig] = None
_logger: logging.Logger = logging.getLogger(__name__)

def _get_index(self,
config: RepoConfig) -> faiss.IndexIVFFlat:
if self._index is None:
dimension = config.online_store.dimension
quantizer = faiss.IndexFlatL2(dimension)
self._index = faiss.IndexIVFFlat(quantizer, dimension, self._config.nlist)
self._index.train(np.random.rand(self._config.nlist * 100, dimension).astype(np.float32))
return self._index

def update(
self,
config: RepoConfig,
tables_to_delete: Sequence[FeatureView],
tables_to_keep: Sequence[FeatureView],
entities_to_delete: Sequence[Entity],
entities_to_keep: Sequence[Entity],
partial: bool,
self,
config: RepoConfig,
tables_to_delete: Sequence[FeatureView],
tables_to_keep: Sequence[FeatureView],
entities_to_delete: Sequence[Entity],
entities_to_keep: Sequence[Entity],
partial: bool,
):
feature_views = tables_to_keep
if not feature_views:
return

feature_names = [f.name for f in feature_views[0].features]
dimension = len(feature_names)

if self._index is None or not partial:
quantizer = faiss.IndexFlatL2(dimension)
self._index = faiss.IndexIVFFlat(quantizer, dimension, self._config.nlist)
self._index.train(
np.random.rand(self._config.nlist * 100, dimension).astype(np.float32)
)
self._in_memory_store = InMemoryStore()

self._in_memory_store.update(feature_names, {})
self._config = FaissOnlineStoreConfig(**config.online_store.dict())
self._get_index(config)

def teardown(
self,
config: RepoConfig,
tables: Sequence[FeatureView],
entities: Sequence[Entity],
self,
config: RepoConfig,
tables: Sequence[FeatureView],
entities: Sequence[Entity],
):
# reset index
self._index = None
self._in_memory_store.teardown()

def online_read(
self,
config: RepoConfig,
table: FeatureView,
entity_keys: List[EntityKey],
requested_features: Optional[List[str]] = None,
self,
config: RepoConfig,
table: FeatureView,
entity_keys: List[EntityKey],
requested_features: Optional[List[str]] = None,
) -> List[Tuple[Optional[datetime], Optional[Dict[str, Value]]]]:
if self._index is None:
return [(None, None)] * len(entity_keys)

index = self._get_index(config)
results = []
for entity_key in entity_keys:
entity_key_tuple = tuple(
f"{field.name}:{field.value.string_val}"
for field in entity_key.join_keys
)
entity_key_tuple = tuple(f"{field.name}:{field.value.string_val}" for field in entity_key.join_keys)
idx = self._in_memory_store.entity_keys.get(entity_key_tuple, -1)
if idx == -1:
results.append((None, None))
else:
feature_vector = self._index.reconstruct(int(idx))
feature_vector = index.reconstruct(int(idx))
feature_dict = {
name: Value(double_val=value)
for name, value in zip(
self._in_memory_store.feature_names, feature_vector
)
for name, value in zip(self._in_memory_store.feature_names, feature_vector)
}
results.append((None, feature_dict))
return results

def online_write_batch(
self,
config: RepoConfig,
table: FeatureView,
data: List[Tuple[EntityKey, Dict[str, Value], datetime, Optional[datetime]]],
progress: Optional[Callable[[int], Any]],
self,
config: RepoConfig,
table: FeatureView,
data: List[Tuple[EntityKey, Dict[str, Value], datetime, Optional[datetime]]],
progress: Optional[Callable[[int], Any]],
) -> None:
if self._index is None:
self._logger.warning("Index is not initialized. Skipping write operation.")
return

index = self._get_index(config)
feature_vectors = []
entity_key_tuples = []

for entity_key, feature_dict, _, _ in data:
entity_key_tuple = tuple(
f"{field.name}:{field.value.string_val}"
for field in entity_key.join_keys
)
feature_vector = np.array(
[
feature_dict[name].double_val
for name in self._in_memory_store.feature_names
],
dtype=np.float32,
)
entity_key_tuple = tuple(f"{field.name}:{field.value.string_val}" for field in entity_key.join_keys)
feature_vector = np.array([
feature_dict[name].double_val for name in self._in_memory_store.feature_names
], dtype=np.float32)

feature_vectors.append(feature_vector)
entity_key_tuples.append(entity_key_tuple)

feature_vectors = np.array(feature_vectors)

existing_indices = [
self._in_memory_store.entity_keys.get(ekt, -1) for ekt in entity_key_tuples
]
existing_indices = [self._in_memory_store.entity_keys.get(ekt, -1) for ekt in entity_key_tuples]
mask = np.array(existing_indices) != -1
if np.any(mask):
self._index.remove_ids(
np.array([idx for idx in existing_indices if idx != -1])
)
index.remove_ids(np.array([idx for idx in existing_indices if idx != -1]))

new_indices = np.arange(
self._index.ntotal, self._index.ntotal + len(feature_vectors)
)
self._index.add(feature_vectors)
new_indices = np.arange(index.ntotal, index.ntotal + len(feature_vectors))
index.add(feature_vectors)

for ekt, idx in zip(entity_key_tuples, new_indices):
self._in_memory_store.entity_keys[ekt] = idx
Expand All @@ -171,34 +148,24 @@ def online_write_batch(
progress(len(data))

def retrieve_online_documents(
self,
config: RepoConfig,
table: FeatureView,
requested_feature: str,
embedding: List[float],
top_k: int,
distance_metric: Optional[str] = None,
) -> List[
Tuple[
Optional[datetime],
Optional[Value],
Optional[Value],
Optional[Value],
]
]:
if self._index is None:
self._logger.warning("Index is not initialized. Returning empty result.")
return []

self,
config: RepoConfig,
table: FeatureView,
requested_feature: str,
embedding: List[float],
top_k: int,
distance_metric: Optional[str] = None,
) -> List[Tuple[Optional[datetime], Optional[Value], Optional[Value], Optional[Value]]]:
index = self._get_index(config)
query_vector = np.array(embedding, dtype=np.float32).reshape(1, -1)
distances, indices = self._index.search(query_vector, top_k)
distances, indices = index.search(query_vector, top_k)

results = []
for i, idx in enumerate(indices[0]):
if idx == -1:
continue

feature_vector = self._index.reconstruct(int(idx))
feature_vector = index.reconstruct(int(idx))

timestamp = Timestamp()
timestamp.GetCurrentTime()
Expand All @@ -207,22 +174,16 @@ def retrieve_online_documents(
vector_value = Value(string_val=",".join(map(str, feature_vector)))
distance_value = Value(float_val=distances[0][i])

results.append(
(
timestamp.ToDatetime(),
feature_value,
vector_value,
distance_value,
)
)
results.append((timestamp.ToDatetime(), feature_value, vector_value, distance_value))

return results

async def online_read_async(
self,
config: RepoConfig,
table: FeatureView,
entity_keys: List[EntityKeyProto],
requested_features: Optional[List[str]] = None,
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
self,
config: RepoConfig,
table: FeatureView,
entity_keys: List[EntityKey],
requested_features: Optional[List[str]] = None,
) -> List[Tuple[Optional[datetime], Optional[Dict[str, Value]]]]:
# Implement async read if needed
pass