Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
713768e
feat: add document store
HaoXuAI Mar 31, 2024
58d5d94
feat: add document store
HaoXuAI Mar 31, 2024
2cd73d1
feat: add document store
HaoXuAI Mar 31, 2024
d2e0a59
feat: add document store
HaoXuAI Mar 31, 2024
7079e7f
remove DocumentStore
HaoXuAI Apr 9, 2024
8c9ee97
format
HaoXuAI Apr 9, 2024
513dd39
Merge branch 'master' into feat-documentstore
HaoXuAI Apr 9, 2024
29d98cd
format
HaoXuAI Apr 9, 2024
11eb97f
format
HaoXuAI Apr 9, 2024
865baf2
format
HaoXuAI Apr 9, 2024
47cd117
format
HaoXuAI Apr 9, 2024
3f9f59f
format
HaoXuAI Apr 9, 2024
7935071
remove unused vars
HaoXuAI Apr 9, 2024
ba39f93
add test
HaoXuAI Apr 11, 2024
cf53c71
add test
HaoXuAI Apr 11, 2024
92046af
format
HaoXuAI Apr 11, 2024
d0acd2d
format
HaoXuAI Apr 11, 2024
cc45f73
format
HaoXuAI Apr 11, 2024
006b5c6
format
HaoXuAI Apr 11, 2024
6e0ba03
format
HaoXuAI Apr 11, 2024
a2302be
fix not implemented issue
HaoXuAI Apr 11, 2024
2e6fc55
fix not implemented issue
HaoXuAI Apr 11, 2024
3cbbf21
fix test
HaoXuAI Apr 11, 2024
ec32764
format
HaoXuAI Apr 11, 2024
e2d8008
format
HaoXuAI Apr 12, 2024
523d20f
format
HaoXuAI Apr 12, 2024
5cd085d
format
HaoXuAI Apr 12, 2024
795699e
format
HaoXuAI Apr 12, 2024
67b007f
format
HaoXuAI Apr 12, 2024
33b46bd
update testcontainer
HaoXuAI Apr 12, 2024
82fe5f1
format
HaoXuAI Apr 12, 2024
0618378
fix postgres integration test
HaoXuAI Apr 12, 2024
7de2016
format
HaoXuAI Apr 12, 2024
92fed1d
fix postgres test
HaoXuAI Apr 14, 2024
d4f2639
fix postgres test
HaoXuAI Apr 14, 2024
396d7de
fix postgres test
HaoXuAI Apr 14, 2024
6c38b92
fix postgres test
HaoXuAI Apr 14, 2024
f763dc9
fix postgres test
HaoXuAI Apr 14, 2024
818c055
format
HaoXuAI Apr 14, 2024
a51b555
format
HaoXuAI Apr 15, 2024
2624b22
format
HaoXuAI Apr 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix postgres test
  • Loading branch information
HaoXuAI committed Apr 14, 2024
commit 92fed1d08cee90c3069f1948ef5b42f2eb4a5f34
49 changes: 28 additions & 21 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1740,14 +1740,19 @@ def _retrieve_online_documents(
query,
top_k,
)
document_feature_vals = [feature[2] for feature in document_features]
document_feature_distance_vals = [feature[3] for feature in document_features]
online_features_response = GetOnlineFeaturesResponse(results=[])
self._populate_response_from_feature_data(
document_features,
[],
online_features_response,
False,
requested_feature,
requested_feature_views[0],

# TODO Refactor to better way of populating result
# TODO populate entity in the response after returning entity in document_features is supported
self._populate_result_rows_from_columnar(
online_features_response=online_features_response,
data={requested_feature: document_feature_vals}
)
self._populate_result_rows_from_columnar(
online_features_response=online_features_response,
data={"distance": document_feature_distance_vals}
)
return OnlineResponse(online_features_response)

Expand Down Expand Up @@ -1974,7 +1979,7 @@ def _retrieve_from_online_store(
requested_feature: str,
query: List[float],
top_k: int,
) -> List[Tuple[List[Timestamp], List["FieldStatus.ValueType"], List[Value]]]:
) -> List[Tuple[Timestamp, "FieldStatus.ValueType", Value, Value]]:
"""
Search and return document features from the online document store.
"""
Expand All @@ -1985,25 +1990,27 @@ def _retrieve_from_online_store(
query=query,
top_k=top_k,
)
# Each row is a set of features for a given entity key. We only need to convert
# the data to Protobuf once.

null_value = Value()
not_found_status = FieldStatus.NOT_FOUND
present_status = FieldStatus.PRESENT

read_row_protos = []
row_ts_proto = Timestamp()

for doc in documents:
row_ts_proto = Timestamp()
row_ts, feature_data = doc
# TODO (Ly): reuse whatever timestamp if row_ts is None?
for row_ts, feature_val, distance in documents:
# Reset timestamp to default or update if row_ts is not None
if row_ts is not None:
row_ts_proto.FromDatetime(row_ts)
event_timestamps = [row_ts_proto]
if feature_data is None:
statuses = [FieldStatus.NOT_FOUND]
values = [null_value]

if feature_val is None:
status = not_found_status
value = null_value
else:
statuses = [FieldStatus.PRESENT]
values = [feature_data]
read_row_protos.append((event_timestamps, statuses, values))
status = present_status
value = feature_val

read_row_protos.append((row_ts_proto, status, value, distance))
return read_row_protos

@staticmethod
Expand Down
73 changes: 51 additions & 22 deletions sdk/python/feast/infra/online_stores/contrib/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from psycopg2 import sql
from psycopg2.extras import execute_values
from psycopg2.pool import SimpleConnectionPool

from feast import Entity
from feast.feature_view import FeatureView
from feast.infra.key_encoding_utils import serialize_entity_key
Expand All @@ -21,19 +20,16 @@
from feast.repo_config import RepoConfig
from feast.usage import log_exceptions_and_usage

# Search query template to find the top k items that are closest to the given embedding
# SELECT * FROM items ORDER BY embedding <-> '[3,1,2]' LIMIT 5;
SEARCH_QUERY_TEMPLATE = """
SELECT feature_name, value, event_ts FROM {table_name}
WHERE feature_name = '{feature_name}'
ORDER BY value <-> %s
LIMIT %s;
"""


class PostgreSQLOnlineStoreConfig(PostgreSQLConfig):
type: Literal["postgres"] = "postgres"

# Whether to enable the pgvector extension for vector similarity search
pgvector_enabled: Optional[bool] = False

# If pgvector is enabled, the length of the vector field
vector_len: Optional[int] = 512


class PostgreSQLOnlineStore(OnlineStore):
_conn: Optional[psycopg2._psycopg.connection] = None
Expand Down Expand Up @@ -77,11 +73,15 @@ def online_write_batch(
created_ts = _to_naive_utc(created_ts)

for feature_name, val in values.items():
if config.online_config["pgvector_enabled"]:
val = str(val.float_list_val.val)
else:
val = val.SerializeToString()
insert_values.append(
(
entity_key_bin,
feature_name,
val.SerializeToString(),
val,
timestamp,
created_ts,
)
Expand Down Expand Up @@ -221,14 +221,17 @@ def update(

for table in tables_to_keep:
table_name = _table_id(project, table)
value_type = "BYTEA"
if config.online_config["pgvector_enabled"]:
value_type = f'vector({config.online_config["vector_len"]})'
cur.execute(
sql.SQL(
"""
CREATE TABLE IF NOT EXISTS {}
(
entity_key BYTEA,
feature_name TEXT,
value BYTEA,
value {},
event_ts TIMESTAMPTZ,
created_ts TIMESTAMPTZ,
PRIMARY KEY(entity_key, feature_name)
Expand All @@ -237,6 +240,7 @@ def update(
"""
).format(
sql.Identifier(table_name),
sql.SQL(value_type),
sql.Identifier(f"{table_name}_ek"),
sql.Identifier(table_name),
)
Expand Down Expand Up @@ -267,7 +271,7 @@ def retrieve_online_documents(
requested_feature: str,
embedding: List[float],
top_k: int,
) -> List[Tuple[Optional[datetime], Optional[ValueProto]]]:
) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]]:
"""

Args:
Expand All @@ -280,25 +284,50 @@ def retrieve_online_documents(
List of tuples containing the event timestamp and the document feature

"""
project = config.project

# Convert the embedding to a string to be used in postgres vector search
query_embedding_str = f"'[{','.join(str(el) for el in embedding)}]'"
query_embedding_str = f"[{','.join(str(el) for el in embedding)}]"

result: List[Tuple[Optional[datetime], Optional[ValueProto]]] = []
result: List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]] = []
with self._get_conn(config) as conn, conn.cursor() as cur:
table_name = _table_id(project, table)

# Search query template to find the top k items that are closest to the given embedding
# SELECT * FROM items ORDER BY embedding <-> '[3,1,2]' LIMIT 5;
cur.execute(
SEARCH_QUERY_TEMPLATE.format(
table_name=table, feature_name=requested_feature
sql.SQL(
"""
SELECT
entity_key,
feature_name,
value,
value <-> %s as distance,
event_ts FROM {table_name}
WHERE feature_name = {feature_name}
ORDER BY distance
LIMIT {top_k};
"""
).format(
table_name=sql.Identifier(table_name),
feature_name=sql.Literal(requested_feature),
top_k=sql.Literal(top_k)
),
(query_embedding_str, top_k),
(query_embedding_str,),
)
rows = cur.fetchall()

for feature_name, value, event_ts in rows:
val = ValueProto()
val.ParseFromString(value)
for entity_key, feature_name, value, distance, event_ts in rows:

# TODO Deserialize entity_key to return the entity in response
entity_key_proto = EntityKeyProto()
entity_key_proto_bin = bytes(entity_key)

# TODO Convert to List[float] for value type proto
feature_value_proto = ValueProto(string_val=value)

result.append((event_ts, val))
distance_value_proto = ValueProto(float_val=distance)
result.append((event_ts, feature_value_proto, distance_value_proto))

return result

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,17 @@
IntegrationTestRepoConfig,
)
from tests.integration.feature_repos.universal.online_store.postgres import (
PostgresOnlieStoreCreator,
PostgresOnlineStoreCreator,
PGVectorOnlineStoreCreator
)

FULL_REPO_CONFIGS = [
IntegrationTestRepoConfig(
online_store="postgres", online_store_creator=PostgresOnlieStoreCreator
online_store="postgres",
online_store_creator=PostgresOnlineStoreCreator
),
IntegrationTestRepoConfig(
online_store="pgvector",
online_store_creator=PGVectorOnlineStoreCreator
),
]
2 changes: 1 addition & 1 deletion sdk/python/feast/infra/online_stores/online_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def retrieve_online_documents(
requested_feature: str,
embedding: List[float],
top_k: int,
) -> List[Tuple[Optional[datetime], Optional[ValueProto]]]:
) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]]:
"""
Retrieves online feature values for the specified embeddings.

Expand Down
4 changes: 2 additions & 2 deletions sdk/python/feast/infra/passthrough_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,14 +196,14 @@ def retrieve_online_documents(
config: RepoConfig,
table: FeatureView,
requested_feature: str,
embedding: List[float],
query: List[float],
top_k: int,
) -> List:
set_usage_attribute("provider", self.__class__.__name__)
result = []
if self.online_store:
result = self.online_store.retrieve_online_documents(
config, table, requested_feature, embedding, top_k
config, table, requested_feature, query, top_k
)
return result

Expand Down
2 changes: 1 addition & 1 deletion sdk/python/feast/infra/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def retrieve_online_documents(
requested_feature: str,
query: List[float],
top_k: int,
) -> List[Tuple[Optional[datetime], Optional[ValueProto]]]:
) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]]:
"""
Searches for the top-k nearest neighbors of the given document in the online document store.

Expand Down
37 changes: 18 additions & 19 deletions sdk/python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@
import pytest
from _pytest.nodes import Item

from feast.data_source import DataSource
from feast.feature_store import FeatureStore # noqa: E402
from feast.wait import wait_retry_backoff # noqa: E402
from tests.data.data_creator import create_basic_driver_dataset # noqa: E402
from tests.data.data_creator import create_basic_driver_dataset, create_document_dataset # noqa: E402
from tests.integration.feature_repos.integration_test_repo_config import ( # noqa: E402
IntegrationTestRepoConfig,
)
Expand Down Expand Up @@ -270,12 +271,12 @@ def pytest_generate_tests(metafunc: pytest.Metafunc):

# aws lambda works only with dynamo
if (
config.get("python_feature_server")
and config.get("provider") == "aws"
and (
config.get("python_feature_server")
and config.get("provider") == "aws"
and (
not isinstance(online_store, dict)
or online_store["type"] != "dynamodb"
)
)
):
continue

Expand All @@ -297,8 +298,8 @@ def pytest_generate_tests(metafunc: pytest.Metafunc):
@pytest.fixture
def feature_server_endpoint(environment):
if (
not environment.python_feature_server
or environment.test_repo_config.provider != "local"
not environment.python_feature_server
or environment.test_repo_config.provider != "local"
):
yield environment.feature_store.get_feature_server_endpoint()
return
Expand All @@ -310,8 +311,8 @@ def feature_server_endpoint(environment):
args=(environment.feature_store.repo_path, port),
)
if (
environment.python_feature_server
and environment.test_repo_config.provider == "local"
environment.python_feature_server
and environment.test_repo_config.provider == "local"
):
proc.start()
# Wait for server to start
Expand Down Expand Up @@ -354,7 +355,7 @@ def e2e_data_sources(environment: Environment):

@pytest.fixture
def feature_store_for_online_retrieval(
environment, universal_data_sources
environment, universal_data_sources
) -> Tuple[FeatureStore, List[str], List[Dict[str, int]]]:
"""
Returns a feature store that is ready for online retrieval, along with entity rows and feature
Expand Down Expand Up @@ -408,12 +409,10 @@ def fake_ingest_data():


@pytest.fixture
def fake_ingest_document_data():
"""Fake document data to ingest into the feature store"""
data = {
"driver_id": [1],
"doc": [4, 5],
"event_timestamp": [pd.Timestamp(datetime.utcnow()).round("ms")],
"created": [pd.Timestamp(datetime.utcnow()).round("ms")],
}
return pd.DataFrame(data)
def fake_document_data(environment: Environment) -> Tuple[pd.DataFrame, DataSource]:
df = create_document_dataset()
data_source = environment.data_source_creator.create_data_source(
df,
environment.feature_store.project,
)
return df, data_source
18 changes: 18 additions & 0 deletions sdk/python/tests/data/data_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,21 @@ def get_feature_values_for_dtype(
return [[n, n] if n is not None else None for n in non_list_val]
else:
return non_list_val


def create_document_dataset() -> pd.DataFrame:
data = {
"item_id": [1, 2, 3],
"embedding_float": [[4.0, 5.0], [1.0, 2.0], [3.0, 4.0]],
"ts": [
pd.Timestamp(datetime.utcnow()).round("ms"),
pd.Timestamp(datetime.utcnow()).round("ms"),
pd.Timestamp(datetime.utcnow()).round("ms"),
],
"created_ts": [
pd.Timestamp(datetime.utcnow()).round("ms"),
pd.Timestamp(datetime.utcnow()).round("ms"),
pd.Timestamp(datetime.utcnow()).round("ms"),
],
}
return pd.DataFrame(data)
2 changes: 1 addition & 1 deletion sdk/python/tests/foo_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,5 +111,5 @@ def retrieve_online_documents(
requested_feature: str,
query: List[float],
top_k: int,
) -> List[Tuple[Optional[datetime], Optional[ValueProto]]]:
) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]]:
return []
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ def create_item_embeddings_feature_view(source, infer_features: bool = False):
schema=None
if infer_features
else [
Field(name="embedding_double", dtype=Array(Float64)),
Field(name="embedding_float", dtype=Array(Float32)),
],
source=source,
Expand Down
Loading