Skip to content

Commit 92fed1d

Browse files
committed
fix postgres test
1 parent 7de2016 commit 92fed1d

File tree

12 files changed

+179
-87
lines changed

12 files changed

+179
-87
lines changed

sdk/python/feast/feature_store.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1740,14 +1740,19 @@ def _retrieve_online_documents(
17401740
query,
17411741
top_k,
17421742
)
1743+
document_feature_vals = [feature[2] for feature in document_features]
1744+
document_feature_distance_vals = [feature[3] for feature in document_features]
17431745
online_features_response = GetOnlineFeaturesResponse(results=[])
1744-
self._populate_response_from_feature_data(
1745-
document_features,
1746-
[],
1747-
online_features_response,
1748-
False,
1749-
requested_feature,
1750-
requested_feature_views[0],
1746+
1747+
# TODO Refactor to better way of populating result
1748+
# TODO populate entity in the response after returning entity in document_features is supported
1749+
self._populate_result_rows_from_columnar(
1750+
online_features_response=online_features_response,
1751+
data={requested_feature: document_feature_vals}
1752+
)
1753+
self._populate_result_rows_from_columnar(
1754+
online_features_response=online_features_response,
1755+
data={"distance": document_feature_distance_vals}
17511756
)
17521757
return OnlineResponse(online_features_response)
17531758

@@ -1974,7 +1979,7 @@ def _retrieve_from_online_store(
19741979
requested_feature: str,
19751980
query: List[float],
19761981
top_k: int,
1977-
) -> List[Tuple[List[Timestamp], List["FieldStatus.ValueType"], List[Value]]]:
1982+
) -> List[Tuple[Timestamp, "FieldStatus.ValueType", Value, Value]]:
19781983
"""
19791984
Search and return document features from the online document store.
19801985
"""
@@ -1985,25 +1990,27 @@ def _retrieve_from_online_store(
19851990
query=query,
19861991
top_k=top_k,
19871992
)
1988-
# Each row is a set of features for a given entity key. We only need to convert
1989-
# the data to Protobuf once.
1993+
19901994
null_value = Value()
1995+
not_found_status = FieldStatus.NOT_FOUND
1996+
present_status = FieldStatus.PRESENT
1997+
19911998
read_row_protos = []
1999+
row_ts_proto = Timestamp()
19922000

1993-
for doc in documents:
1994-
row_ts_proto = Timestamp()
1995-
row_ts, feature_data = doc
1996-
# TODO (Ly): reuse whatever timestamp if row_ts is None?
2001+
for row_ts, feature_val, distance in documents:
2002+
# Reset timestamp to default or update if row_ts is not None
19972003
if row_ts is not None:
19982004
row_ts_proto.FromDatetime(row_ts)
1999-
event_timestamps = [row_ts_proto]
2000-
if feature_data is None:
2001-
statuses = [FieldStatus.NOT_FOUND]
2002-
values = [null_value]
2005+
2006+
if feature_val is None:
2007+
status = not_found_status
2008+
value = null_value
20032009
else:
2004-
statuses = [FieldStatus.PRESENT]
2005-
values = [feature_data]
2006-
read_row_protos.append((event_timestamps, statuses, values))
2010+
status = present_status
2011+
value = feature_val
2012+
2013+
read_row_protos.append((row_ts_proto, status, value, distance))
20072014
return read_row_protos
20082015

20092016
@staticmethod

sdk/python/feast/infra/online_stores/contrib/postgres.py

Lines changed: 51 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from psycopg2 import sql
1010
from psycopg2.extras import execute_values
1111
from psycopg2.pool import SimpleConnectionPool
12-
1312
from feast import Entity
1413
from feast.feature_view import FeatureView
1514
from feast.infra.key_encoding_utils import serialize_entity_key
@@ -21,19 +20,16 @@
2120
from feast.repo_config import RepoConfig
2221
from feast.usage import log_exceptions_and_usage
2322

24-
# Search query template to find the top k items that are closest to the given embedding
25-
# SELECT * FROM items ORDER BY embedding <-> '[3,1,2]' LIMIT 5;
26-
SEARCH_QUERY_TEMPLATE = """
27-
SELECT feature_name, value, event_ts FROM {table_name}
28-
WHERE feature_name = '{feature_name}'
29-
ORDER BY value <-> %s
30-
LIMIT %s;
31-
"""
32-
3323

3424
class PostgreSQLOnlineStoreConfig(PostgreSQLConfig):
3525
type: Literal["postgres"] = "postgres"
3626

27+
# Whether to enable the pgvector extension for vector similarity search
28+
pgvector_enabled: Optional[bool] = False
29+
30+
# If pgvector is enabled, the length of the vector field
31+
vector_len: Optional[int] = 512
32+
3733

3834
class PostgreSQLOnlineStore(OnlineStore):
3935
_conn: Optional[psycopg2._psycopg.connection] = None
@@ -77,11 +73,15 @@ def online_write_batch(
7773
created_ts = _to_naive_utc(created_ts)
7874

7975
for feature_name, val in values.items():
76+
if config.online_config["pgvector_enabled"]:
77+
val = str(val.float_list_val.val)
78+
else:
79+
val = val.SerializeToString()
8080
insert_values.append(
8181
(
8282
entity_key_bin,
8383
feature_name,
84-
val.SerializeToString(),
84+
val,
8585
timestamp,
8686
created_ts,
8787
)
@@ -221,14 +221,17 @@ def update(
221221

222222
for table in tables_to_keep:
223223
table_name = _table_id(project, table)
224+
value_type = "BYTEA"
225+
if config.online_config["pgvector_enabled"]:
226+
value_type = f'vector({config.online_config["vector_len"]})'
224227
cur.execute(
225228
sql.SQL(
226229
"""
227230
CREATE TABLE IF NOT EXISTS {}
228231
(
229232
entity_key BYTEA,
230233
feature_name TEXT,
231-
value BYTEA,
234+
value {},
232235
event_ts TIMESTAMPTZ,
233236
created_ts TIMESTAMPTZ,
234237
PRIMARY KEY(entity_key, feature_name)
@@ -237,6 +240,7 @@ def update(
237240
"""
238241
).format(
239242
sql.Identifier(table_name),
243+
sql.SQL(value_type),
240244
sql.Identifier(f"{table_name}_ek"),
241245
sql.Identifier(table_name),
242246
)
@@ -267,7 +271,7 @@ def retrieve_online_documents(
267271
requested_feature: str,
268272
embedding: List[float],
269273
top_k: int,
270-
) -> List[Tuple[Optional[datetime], Optional[ValueProto]]]:
274+
) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]]:
271275
"""
272276
273277
Args:
@@ -280,25 +284,50 @@ def retrieve_online_documents(
280284
List of tuples containing the event timestamp and the document feature
281285
282286
"""
287+
project = config.project
283288

284289
# Convert the embedding to a string to be used in postgres vector search
285-
query_embedding_str = f"'[{','.join(str(el) for el in embedding)}]'"
290+
query_embedding_str = f"[{','.join(str(el) for el in embedding)}]"
286291

287-
result: List[Tuple[Optional[datetime], Optional[ValueProto]]] = []
292+
result: List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]] = []
288293
with self._get_conn(config) as conn, conn.cursor() as cur:
294+
table_name = _table_id(project, table)
295+
296+
# Search query template to find the top k items that are closest to the given embedding
297+
# SELECT * FROM items ORDER BY embedding <-> '[3,1,2]' LIMIT 5;
289298
cur.execute(
290-
SEARCH_QUERY_TEMPLATE.format(
291-
table_name=table, feature_name=requested_feature
299+
sql.SQL(
300+
"""
301+
SELECT
302+
entity_key,
303+
feature_name,
304+
value,
305+
value <-> %s as distance,
306+
event_ts FROM {table_name}
307+
WHERE feature_name = {feature_name}
308+
ORDER BY distance
309+
LIMIT {top_k};
310+
"""
311+
).format(
312+
table_name=sql.Identifier(table_name),
313+
feature_name=sql.Literal(requested_feature),
314+
top_k=sql.Literal(top_k)
292315
),
293-
(query_embedding_str, top_k),
316+
(query_embedding_str,),
294317
)
295318
rows = cur.fetchall()
296319

297-
for feature_name, value, event_ts in rows:
298-
val = ValueProto()
299-
val.ParseFromString(value)
320+
for entity_key, feature_name, value, distance, event_ts in rows:
321+
322+
# TODO Deserialize entity_key to return the entity in response
323+
entity_key_proto = EntityKeyProto()
324+
entity_key_proto_bin = bytes(entity_key)
325+
326+
# TODO Convert to List[float] for value type proto
327+
feature_value_proto = ValueProto(string_val=value)
300328

301-
result.append((event_ts, val))
329+
distance_value_proto = ValueProto(float_val=distance)
330+
result.append((event_ts, feature_value_proto, distance_value_proto))
302331

303332
return result
304333

sdk/python/feast/infra/online_stores/contrib/postgres_repo_configuration.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,17 @@
22
IntegrationTestRepoConfig,
33
)
44
from tests.integration.feature_repos.universal.online_store.postgres import (
5-
PostgresOnlieStoreCreator,
5+
PostgresOnlineStoreCreator,
6+
PGVectorOnlineStoreCreator
67
)
78

89
FULL_REPO_CONFIGS = [
910
IntegrationTestRepoConfig(
10-
online_store="postgres", online_store_creator=PostgresOnlieStoreCreator
11+
online_store="postgres",
12+
online_store_creator=PostgresOnlineStoreCreator
13+
),
14+
IntegrationTestRepoConfig(
15+
online_store="pgvector",
16+
online_store_creator=PGVectorOnlineStoreCreator
1117
),
1218
]

sdk/python/feast/infra/online_stores/online_store.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def retrieve_online_documents(
142142
requested_feature: str,
143143
embedding: List[float],
144144
top_k: int,
145-
) -> List[Tuple[Optional[datetime], Optional[ValueProto]]]:
145+
) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]]:
146146
"""
147147
Retrieves online feature values for the specified embeddings.
148148

sdk/python/feast/infra/passthrough_provider.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,14 +196,14 @@ def retrieve_online_documents(
196196
config: RepoConfig,
197197
table: FeatureView,
198198
requested_feature: str,
199-
embedding: List[float],
199+
query: List[float],
200200
top_k: int,
201201
) -> List:
202202
set_usage_attribute("provider", self.__class__.__name__)
203203
result = []
204204
if self.online_store:
205205
result = self.online_store.retrieve_online_documents(
206-
config, table, requested_feature, embedding, top_k
206+
config, table, requested_feature, query, top_k
207207
)
208208
return result
209209

sdk/python/feast/infra/provider.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ def retrieve_online_documents(
303303
requested_feature: str,
304304
query: List[float],
305305
top_k: int,
306-
) -> List[Tuple[Optional[datetime], Optional[ValueProto]]]:
306+
) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]]:
307307
"""
308308
Searches for the top-k nearest neighbors of the given document in the online document store.
309309

sdk/python/tests/conftest.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@
2323
import pytest
2424
from _pytest.nodes import Item
2525

26+
from feast.data_source import DataSource
2627
from feast.feature_store import FeatureStore # noqa: E402
2728
from feast.wait import wait_retry_backoff # noqa: E402
28-
from tests.data.data_creator import create_basic_driver_dataset # noqa: E402
29+
from tests.data.data_creator import create_basic_driver_dataset, create_document_dataset # noqa: E402
2930
from tests.integration.feature_repos.integration_test_repo_config import ( # noqa: E402
3031
IntegrationTestRepoConfig,
3132
)
@@ -270,12 +271,12 @@ def pytest_generate_tests(metafunc: pytest.Metafunc):
270271

271272
# aws lambda works only with dynamo
272273
if (
273-
config.get("python_feature_server")
274-
and config.get("provider") == "aws"
275-
and (
274+
config.get("python_feature_server")
275+
and config.get("provider") == "aws"
276+
and (
276277
not isinstance(online_store, dict)
277278
or online_store["type"] != "dynamodb"
278-
)
279+
)
279280
):
280281
continue
281282

@@ -297,8 +298,8 @@ def pytest_generate_tests(metafunc: pytest.Metafunc):
297298
@pytest.fixture
298299
def feature_server_endpoint(environment):
299300
if (
300-
not environment.python_feature_server
301-
or environment.test_repo_config.provider != "local"
301+
not environment.python_feature_server
302+
or environment.test_repo_config.provider != "local"
302303
):
303304
yield environment.feature_store.get_feature_server_endpoint()
304305
return
@@ -310,8 +311,8 @@ def feature_server_endpoint(environment):
310311
args=(environment.feature_store.repo_path, port),
311312
)
312313
if (
313-
environment.python_feature_server
314-
and environment.test_repo_config.provider == "local"
314+
environment.python_feature_server
315+
and environment.test_repo_config.provider == "local"
315316
):
316317
proc.start()
317318
# Wait for server to start
@@ -354,7 +355,7 @@ def e2e_data_sources(environment: Environment):
354355

355356
@pytest.fixture
356357
def feature_store_for_online_retrieval(
357-
environment, universal_data_sources
358+
environment, universal_data_sources
358359
) -> Tuple[FeatureStore, List[str], List[Dict[str, int]]]:
359360
"""
360361
Returns a feature store that is ready for online retrieval, along with entity rows and feature
@@ -408,12 +409,10 @@ def fake_ingest_data():
408409

409410

410411
@pytest.fixture
411-
def fake_ingest_document_data():
412-
"""Fake document data to ingest into the feature store"""
413-
data = {
414-
"driver_id": [1],
415-
"doc": [4, 5],
416-
"event_timestamp": [pd.Timestamp(datetime.utcnow()).round("ms")],
417-
"created": [pd.Timestamp(datetime.utcnow()).round("ms")],
418-
}
419-
return pd.DataFrame(data)
412+
def fake_document_data(environment: Environment) -> Tuple[pd.DataFrame, DataSource]:
413+
df = create_document_dataset()
414+
data_source = environment.data_source_creator.create_data_source(
415+
df,
416+
environment.feature_store.project,
417+
)
418+
return df, data_source

sdk/python/tests/data/data_creator.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,21 @@ def get_feature_values_for_dtype(
7878
return [[n, n] if n is not None else None for n in non_list_val]
7979
else:
8080
return non_list_val
81+
82+
83+
def create_document_dataset() -> pd.DataFrame:
84+
data = {
85+
"item_id": [1, 2, 3],
86+
"embedding_float": [[4.0, 5.0], [1.0, 2.0], [3.0, 4.0]],
87+
"ts": [
88+
pd.Timestamp(datetime.utcnow()).round("ms"),
89+
pd.Timestamp(datetime.utcnow()).round("ms"),
90+
pd.Timestamp(datetime.utcnow()).round("ms"),
91+
],
92+
"created_ts": [
93+
pd.Timestamp(datetime.utcnow()).round("ms"),
94+
pd.Timestamp(datetime.utcnow()).round("ms"),
95+
pd.Timestamp(datetime.utcnow()).round("ms"),
96+
],
97+
}
98+
return pd.DataFrame(data)

sdk/python/tests/foo_provider.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,5 +111,5 @@ def retrieve_online_documents(
111111
requested_feature: str,
112112
query: List[float],
113113
top_k: int,
114-
) -> List[Tuple[Optional[datetime], Optional[ValueProto]]]:
114+
) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]]:
115115
return []

sdk/python/tests/integration/feature_repos/universal/feature_views.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,6 @@ def create_item_embeddings_feature_view(source, infer_features: bool = False):
140140
schema=None
141141
if infer_features
142142
else [
143-
Field(name="embedding_double", dtype=Array(Float64)),
144143
Field(name="embedding_float", dtype=Array(Float32)),
145144
],
146145
source=source,

0 commit comments

Comments
 (0)