Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
update postgres online store to make it work with write and read APIs.
  • Loading branch information
HaoXuAI committed Apr 16, 2024
commit e274a2311e1d222ee900c776850633fb29445d64
19 changes: 19 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,25 @@ test-python-universal-postgres-online:
not test_snowflake" \
sdk/python/tests

test-python-universal-pgvector-online:
PYTHONPATH='.' \
FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.online_stores.contrib.pgvector_repo_configuration \
PYTEST_PLUGINS=sdk.python.tests.integration.feature_repos.universal.online_store.postgres \
python -m pytest -n 8 --integration \
-k "not test_universal_cli and \
not test_go_feature_server and \
not test_feature_logging and \
not test_reorder_columns and \
not test_logged_features_validation and \
not test_lambda_materialization_consistency and \
not test_offline_write and \
not test_push_features_to_offline_store and \
not gcs_registry and \
not s3_registry and \
not test_universal_types and \
not test_snowflake" \
sdk/python/tests

test-python-universal-mysql-online:
PYTHONPATH='.' \
FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.online_stores.contrib.mysql_repo_configuration \
Expand Down
4 changes: 2 additions & 2 deletions sdk/python/feast/infra/key_encoding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def serialize_entity_key(
return b"".join(output)


def get_val_str(val):
accept_value_types = ["float_list_val", "double_list_val", "int_list_val"]
def get_list_val_str(val: ValueProto):
accept_value_types = ["float_list_val", "double_list_val", "int32_list_val", "int64_list_val"]
for accept_type in accept_value_types:
if val.HasField(accept_type):
return str(getattr(val, accept_type).val)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from tests.integration.feature_repos.integration_test_repo_config import (
IntegrationTestRepoConfig,
)
from tests.integration.feature_repos.universal.online_store.postgres import (
PGVectorOnlineStoreCreator,
)

FULL_REPO_CONFIGS = [
IntegrationTestRepoConfig(
online_store="pgvector", online_store_creator=PGVectorOnlineStoreCreator
),
]
48 changes: 24 additions & 24 deletions sdk/python/feast/infra/online_stores/contrib/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from collections import defaultdict
from datetime import datetime
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple

import psycopg2
import pytz
Expand All @@ -12,7 +12,7 @@

from feast import Entity
from feast.feature_view import FeatureView
from feast.infra.key_encoding_utils import get_val_str, serialize_entity_key
from feast.infra.key_encoding_utils import get_list_val_str, serialize_entity_key
from feast.infra.online_stores.online_store import OnlineStore
from feast.infra.utils.postgres.connection_utils import _get_conn, _get_connection_pool
from feast.infra.utils.postgres.postgres_config import ConnectionType, PostgreSQLConfig
Expand Down Expand Up @@ -74,19 +74,15 @@ def online_write_batch(
created_ts = _to_naive_utc(created_ts)

for feature_name, val in values.items():
val_str: Union[str, bytes]
if (
"pgvector_enabled" in config.online_config
and config.online_config["pgvector_enabled"]
):
val_str = get_val_str(val)
else:
val_str = val.SerializeToString()
vector_val = None
if "pgvector_enabled" in config.online_store and config.online_store.pgvector_enabled:
vector_val = get_list_val_str(val)
insert_values.append(
(
entity_key_bin,
feature_name,
val_str,
val.SerializeToString(),
vector_val,
timestamp,
created_ts,
)
Expand All @@ -100,11 +96,12 @@ def online_write_batch(
sql.SQL(
"""
INSERT INTO {}
(entity_key, feature_name, value, event_ts, created_ts)
(entity_key, feature_name, value, vector_value, event_ts, created_ts)
VALUES %s
ON CONFLICT (entity_key, feature_name) DO
UPDATE SET
value = EXCLUDED.value,
vector_value = EXCLUDED.vector_value,
event_ts = EXCLUDED.event_ts,
created_ts = EXCLUDED.created_ts;
""",
Expand Down Expand Up @@ -226,20 +223,20 @@ def update(

for table in tables_to_keep:
table_name = _table_id(project, table)
value_type = "BYTEA"
if (
"pgvector_enabled" in config.online_config
and config.online_config["pgvector_enabled"]
):
value_type = f'vector({config.online_config["vector_len"]})'
if "pgvector_enabled" in config.online_store and config.online_store.pgvector_enabled:
vector_value_type = f'vector({config.online_store.vector_len})'
else:
# keep the vector_value_type as BYTEA if pgvector is not enabled, to maintain compatibility
vector_value_type = 'BYTEA'
cur.execute(
sql.SQL(
"""
CREATE TABLE IF NOT EXISTS {}
(
entity_key BYTEA,
feature_name TEXT,
value {},
value BYTEA,
vector_value {} NULL,
event_ts TIMESTAMPTZ,
created_ts TIMESTAMPTZ,
PRIMARY KEY(entity_key, feature_name)
Expand All @@ -248,7 +245,7 @@ def update(
"""
).format(
sql.Identifier(table_name),
sql.SQL(value_type),
sql.SQL(vector_value_type),
sql.Identifier(f"{table_name}_ek"),
sql.Identifier(table_name),
)
Expand Down Expand Up @@ -294,6 +291,9 @@ def retrieve_online_documents(
"""
project = config.project

if "pgvector_enabled" not in config.online_store or not config.online_store.pgvector_enabled:
raise ValueError("pgvector is not enabled in the online store configuration")

# Convert the embedding to a string to be used in postgres vector search
query_embedding_str = f"[{','.join(str(el) for el in embedding)}]"

Expand All @@ -311,8 +311,8 @@ def retrieve_online_documents(
SELECT
entity_key,
feature_name,
value,
value <-> %s as distance,
vector_value,
vector_value <-> %s as distance,
event_ts FROM {table_name}
WHERE feature_name = {feature_name}
ORDER BY distance
Expand All @@ -327,13 +327,13 @@ def retrieve_online_documents(
)
rows = cur.fetchall()

for entity_key, feature_name, value, distance, event_ts in rows:
for entity_key, feature_name, vector_value, distance, event_ts in rows:
# TODO Deserialize entity_key to return the entity in response
# entity_key_proto = EntityKeyProto()
# entity_key_proto_bin = bytes(entity_key)

# TODO Convert to List[float] for value type proto
feature_value_proto = ValueProto(string_val=value)
feature_value_proto = ValueProto(string_val=vector_value)

distance_value_proto = ValueProto(float_val=distance)
result.append((event_ts, feature_value_proto, distance_value_proto))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,11 @@
IntegrationTestRepoConfig,
)
from tests.integration.feature_repos.universal.online_store.postgres import (
PGVectorOnlineStoreCreator,
PostgresOnlineStoreCreator,
)

FULL_REPO_CONFIGS = [
IntegrationTestRepoConfig(
online_store="postgres", online_store_creator=PostgresOnlineStoreCreator
),
IntegrationTestRepoConfig(
online_store="pgvector", online_store_creator=PGVectorOnlineStoreCreator
),
]

AVAILABLE_ONLINE_STORES = {"pgvector": PGVectorOnlineStoreCreator}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE EXTENSION IF NOT EXISTS vector;
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from testcontainers.core.container import DockerContainer
from testcontainers.core.waiting_utils import wait_for_logs
from testcontainers.postgres import PostgresContainer

import os
from tests.integration.feature_repos.universal.online_store_creator import (
OnlineStoreCreator,
)
Expand Down Expand Up @@ -37,12 +37,17 @@ def teardown(self):
class PGVectorOnlineStoreCreator(OnlineStoreCreator):
def __init__(self, project_name: str, **kwargs):
super().__init__(project_name)
script_directory = os.path.dirname(os.path.abspath(__file__))
self.container = (
DockerContainer("pgvector/pgvector:pg16")
.with_env("POSTGRES_USER", "root")
.with_env("POSTGRES_PASSWORD", "test")
.with_env("POSTGRES_DB", "test")
.with_exposed_ports(5432)
.with_volume_mapping(
os.path.join(script_directory, 'init.sql'),
"/docker-entrypoint-initdb.d/init.sql",
)
)

def create_online_store(self) -> Dict[str, str]:
Expand All @@ -51,8 +56,10 @@ def create_online_store(self) -> Dict[str, str]:
wait_for_logs(
container=self.container, predicate=log_string_to_wait_for, timeout=10
)
command = "psql -h localhost -p 5432 -U root -d test -c 'CREATE EXTENSION IF NOT EXISTS vector;'"
self.container.exec(command)
init_log_string_to_wait_for = "PostgreSQL init process complete"
wait_for_logs(
container=self.container, predicate=init_log_string_to_wait_for, timeout=10
)
return {
"host": "localhost",
"type": "postgres",
Expand Down