Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix lint
Signed-off-by: cmuhao <sduxuhao@gmail.com>
  • Loading branch information
HaoXuAI committed Sep 10, 2024
commit 2432b7ddaccc9a2ed2493ca6578ee2c6c9c9cbac
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import abc
import contextlib
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Callable

import duckdb
from infra.key_encoding_utils import serialize_entity_key
Expand All @@ -26,49 +26,51 @@ class DuckDBOnlineStoreConfig:


class DuckDBOnlineStore(OnlineStore):
__conn: Optional[duckdb.Connection] = None
__conn: Optional[duckdb.DuckDBPyConnection] = None

@abc.abstractmethod
async def online_read_async(
self,
config: RepoConfig,
table: FeatureView,
entity_keys: List[EntityKeyProto],
requested_features: Optional[List[str]] = None,
self,
config: RepoConfig,
table: FeatureView,
entity_keys: List[EntityKeyProto],
requested_features: Optional[List[str]] = None,
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
pass

@contextlib.contextmanager
def _get_conn(self,
config: RepoConfig) -> Any:
def _get_conn(self, config: RepoConfig) -> Any:
assert config.online_store.type == "duckdb"
online_store_config = config.online_store

if self.__conn is None:
self.__conn = duckdb.connect(database=online_store_config.path, read_only=online_store_config.read_only)
self.__conn = duckdb.connect(
database=online_store_config.path,
read_only=online_store_config.read_only,
)
yield self.__conn

def create_vector_index(
self,
config: RepoConfig,
table_name: str,
vector_column: str
self, config: RepoConfig, table_name: str, vector_column: str
) -> None:
"""Create an HNSW index for vector similarity search."""
if not config.online_store.enable_vector_search:
raise ValueError("Vector search is not enabled in the configuration.")
distance_metric = config.online_store.distance_metric

with self._get_conn(None) as conn:
with self._get_conn(config) as conn:
conn.execute(
f"CREATE INDEX idx ON {table_name} USING HNSW ({vector_column}) WITH (metric = '{distance_metric}');"
)

def online_write_batch(
self,
config: RepoConfig,
table: FeatureView,
data: List[Tuple[EntityKeyProto, Dict[str, ValueProto]]],
self,
config: RepoConfig,
table: FeatureView,
data: List[
Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]
],
progress: Optional[Callable[[int], Any]],
) -> None:
insert_values = []
for entity_key, values in data:
Expand All @@ -88,11 +90,11 @@ def online_write_batch(
)

def online_read(
self,
config: RepoConfig,
table: FeatureView,
entity_keys: List[EntityKeyProto],
requested_features: Optional[List[str]] = None,
self,
config: RepoConfig,
table: FeatureView,
entity_keys: List[EntityKeyProto],
requested_features: Optional[List[str]] = None,
) -> List[Tuple[Optional[Dict[str, ValueProto]]]]:
keys = [serialize_entity_key(key).hex() for key in entity_keys]
query = f"SELECT feature_name, value FROM {table.name} WHERE entity_key IN ({','.join(['?'] * len(keys))})"
Expand All @@ -108,13 +110,13 @@ def online_read(
]

def retrieve_online_documents(
self,
config: RepoConfig,
table: FeatureView,
requested_feature: str,
embedding: List[float],
top_k: int,
distance_metric: Optional[str] = "L2",
self,
config: RepoConfig,
table: FeatureView,
requested_feature: str,
embedding: List[float],
top_k: int,
distance_metric: Optional[str] = "L2",
) -> List[
Tuple[
Optional[datetime],
Expand All @@ -124,9 +126,12 @@ def retrieve_online_documents(
Optional[ValueProto],
]
]:
online_store_config = config.online_store
"""Perform a vector similarity search using the HNSW index."""
if not self.config.enable_vector_search:

if not online_store_config.enable_vector_search:
raise ValueError("Vector search is not enabled in the configuration.")

if config.entity_key_serialization_version < 3:
raise ValueError(
"Entity key serialization version must be at least 3 for vector search."
Expand Down Expand Up @@ -156,12 +161,12 @@ def retrieve_online_documents(
"""
rows = conn.execute(query, (embedding, top_k)).fetchall()
for (
entity_key,
_,
feature_val,
vector_value,
distance_val,
event_ts,
entity_key,
_,
feature_val,
vector_value,
distance_val,
event_ts,
) in rows:
result.append(
_build_retrieve_online_document_record(
Expand All @@ -177,10 +182,10 @@ def retrieve_online_documents(
return result

def update(
self,
config: RepoConfig,
tables_to_delete: List[FeatureView],
tables_to_keep: List[FeatureView],
self,
config: RepoConfig,
tables_to_delete: List[FeatureView],
tables_to_keep: List[FeatureView],
) -> None:
with self._get_conn(config) as conn:
for table in tables_to_delete:
Expand All @@ -191,9 +196,9 @@ def update(
)

def teardown(
self,
config: RepoConfig,
tables: List[FeatureView],
self,
config: RepoConfig,
tables: List[FeatureView],
) -> None:
with self._get_conn(config) as conn:
for table in tables:
Expand Down