Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add faiss & in memory online store
Signed-off-by: cmuhao <sduxuhao@gmail.com>
  • Loading branch information
HaoXuAI committed Aug 29, 2024
commit 36ed17642f8a7a4c607a4b00f7cc25b2c43f9ba1
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from google.protobuf.timestamp_pb2 import Timestamp

from feast import Entity, FeatureView, RepoConfig
from feast.infra.key_encoding_utils import serialize_entity_key
from feast.infra.online_stores.online_store import OnlineStore
from feast.protos.feast.types.EntityKey_pb2 import EntityKey
from feast.protos.feast.types.Value_pb2 import Value
from feast.repo_config import FeastConfigBaseModel
from feast.infra.key_encoding_utils import serialize_entity_key, deserialize_entity_key


class FaissOnlineStoreConfig(FeastConfigBaseModel):
Expand All @@ -26,16 +26,20 @@ def __init__(self):
self.feature_names: List[str] = []
self.entity_keys: Dict[Tuple[str, ...], int] = {}

def update(self, feature_names: List[str], entity_keys: Dict[Tuple[str, ...], int]):
def update(self,
feature_names: List[str],
entity_keys: Dict[Tuple[str, ...], int]):
self.feature_names = feature_names
self.entity_keys = entity_keys

def delete(self, entity_keys: List[Tuple[str, ...]]):
def delete(self,
entity_keys: List[Tuple[str, ...]]):
for entity_key in entity_keys:
if entity_key in self.entity_keys:
del self.entity_keys[entity_key]

def read(self, entity_keys: List[Tuple[str, ...]]) -> List[Optional[int]]:
def read(self,
entity_keys: List[Tuple[str, ...]]) -> List[Optional[int]]:
return [self.entity_keys.get(entity_key) for entity_key in entity_keys]

def teardown(self):
Expand All @@ -49,19 +53,20 @@ class FaissOnlineStore(OnlineStore):
_config: Optional[FaissOnlineStoreConfig] = None
_logger: logging.Logger = logging.getLogger(__name__)

def _get_index(self, config: RepoConfig) -> faiss.IndexIVFFlat:
def _get_index(self,
config: RepoConfig) -> faiss.IndexIVFFlat:
if self._index is None or self._config is None:
raise ValueError("Index is not initialized")
return self._index

def update(
self,
config: RepoConfig,
tables_to_delete: Sequence[FeatureView],
tables_to_keep: Sequence[FeatureView],
entities_to_delete: Sequence[Entity],
entities_to_keep: Sequence[Entity],
partial: bool,
self,
config: RepoConfig,
tables_to_delete: Sequence[FeatureView],
tables_to_keep: Sequence[FeatureView],
entities_to_delete: Sequence[Entity],
entities_to_keep: Sequence[Entity],
partial: bool,
):
feature_views = tables_to_keep
if not feature_views:
Expand All @@ -82,27 +87,29 @@ def update(
self._in_memory_store.update(feature_names, {})

def teardown(
self,
config: RepoConfig,
tables: Sequence[FeatureView],
entities: Sequence[Entity],
self,
config: RepoConfig,
tables: Sequence[FeatureView],
entities: Sequence[Entity],
):
self._index = None
self._in_memory_store.teardown()

def online_read(
self,
config: RepoConfig,
table: FeatureView,
entity_keys: List[EntityKey],
requested_features: Optional[List[str]] = None,
self,
config: RepoConfig,
table: FeatureView,
entity_keys: List[EntityKey],
requested_features: Optional[List[str]] = None,
) -> List[Tuple[Optional[datetime], Optional[Dict[str, Value]]]]:
if self._index is None:
return [(None, None)] * len(entity_keys)

results = []
for entity_key in entity_keys:
serialized_key = serialize_entity_key(entity_key, entity_key_serialization_version=2)
serialized_key = serialize_entity_key(
entity_key, entity_key_serialization_version=2
)
idx = self._in_memory_store.entity_keys.get(serialized_key, -1)
if idx == -1:
results.append((None, None))
Expand All @@ -118,11 +125,11 @@ def online_read(
return results

def online_write_batch(
self,
config: RepoConfig,
table: FeatureView,
data: List[Tuple[EntityKey, Dict[str, Value], datetime, Optional[datetime]]],
progress: Optional[Callable[[int], Any]],
self,
config: RepoConfig,
table: FeatureView,
data: List[Tuple[EntityKey, Dict[str, Value], datetime, Optional[datetime]]],
progress: Optional[Callable[[int], Any]],
) -> None:
if self._index is None:
self._logger.warning("Index is not initialized. Skipping write operation.")
Expand All @@ -132,7 +139,9 @@ def online_write_batch(
serialized_keys = []

for entity_key, feature_dict, _, _ in data:
serialized_key = serialize_entity_key(entity_key, entity_key_serialization_version=2)
serialized_key = serialize_entity_key(
entity_key, entity_key_serialization_version=2
)
feature_vector = np.array(
[
feature_dict[name].double_val
Expand Down Expand Up @@ -167,13 +176,13 @@ def online_write_batch(
progress(len(data))

def retrieve_online_documents(
self,
config: RepoConfig,
table: FeatureView,
requested_feature: str,
embedding: List[float],
top_k: int,
distance_metric: Optional[str] = None,
self,
config: RepoConfig,
table: FeatureView,
requested_feature: str,
embedding: List[float],
top_k: int,
distance_metric: Optional[str] = None,
) -> List[
Tuple[
Optional[datetime],
Expand Down Expand Up @@ -222,11 +231,11 @@ def retrieve_online_documents(
return results

async def online_read_async(
self,
config: RepoConfig,
table: FeatureView,
entity_keys: List[EntityKey],
requested_features: Optional[List[str]] = None,
self,
config: RepoConfig,
table: FeatureView,
entity_keys: List[EntityKey],
requested_features: Optional[List[str]] = None,
) -> List[Tuple[Optional[datetime], Optional[Dict[str, Value]]]]:
# Implement async read if needed
raise NotImplementedError("Async read is not implemented for FaissOnlineStore")