Skip to content

Commit ad45bb4

Browse files
authored
fix: Pgvector patch (#4108)
1 parent 0fb2351 commit ad45bb4

File tree

5 files changed

+73
-28
lines changed

5 files changed

+73
-28
lines changed

sdk/python/feast/feature_store.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1740,12 +1740,14 @@ def _retrieve_online_documents(
17401740
query,
17411741
top_k,
17421742
)
1743-
document_feature_vals = [feature[2] for feature in document_features]
1744-
document_feature_distance_vals = [feature[3] for feature in document_features]
1745-
online_features_response = GetOnlineFeaturesResponse(results=[])
17461743

17471744
# TODO Refactor to better way of populating result
17481745
# TODO populate entity in the response after returning entity in document_features is supported
1746+
# TODO currently not return the vector value since it is same as feature value, if embedding is supported,
1747+
# the feature value can be raw text before embedded
1748+
document_feature_vals = [feature[2] for feature in document_features]
1749+
document_feature_distance_vals = [feature[4] for feature in document_features]
1750+
online_features_response = GetOnlineFeaturesResponse(results=[])
17491751
self._populate_result_rows_from_columnar(
17501752
online_features_response=online_features_response,
17511753
data={requested_feature: document_feature_vals},
@@ -1979,7 +1981,7 @@ def _retrieve_from_online_store(
19791981
requested_feature: str,
19801982
query: List[float],
19811983
top_k: int,
1982-
) -> List[Tuple[Timestamp, "FieldStatus.ValueType", Value, Value]]:
1984+
) -> List[Tuple[Timestamp, "FieldStatus.ValueType", Value, Value, Value]]:
19831985
"""
19841986
Search and return document features from the online document store.
19851987
"""
@@ -1994,19 +1996,22 @@ def _retrieve_from_online_store(
19941996
read_row_protos = []
19951997
row_ts_proto = Timestamp()
19961998

1997-
for row_ts, feature_val, distance_val in documents:
1999+
for row_ts, feature_val, vector_value, distance_val in documents:
19982000
# Reset timestamp to default or update if row_ts is not None
19992001
if row_ts is not None:
20002002
row_ts_proto.FromDatetime(row_ts)
20012003

2002-
if feature_val is None or distance_val is None:
2004+
if feature_val is None or vector_value is None or distance_val is None:
20032005
feature_val = Value()
2006+
vector_value = Value()
20042007
distance_val = Value()
20052008
status = FieldStatus.NOT_FOUND
20062009
else:
20072010
status = FieldStatus.PRESENT
20082011

2009-
read_row_protos.append((row_ts_proto, status, feature_val, distance_val))
2012+
read_row_protos.append(
2013+
(row_ts_proto, status, feature_val, vector_value, distance_val)
2014+
)
20102015
return read_row_protos
20112016

20122017
@staticmethod

sdk/python/feast/infra/online_stores/contrib/postgres.py

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,7 @@ def online_write_batch(
7575

7676
for feature_name, val in values.items():
7777
vector_val = None
78-
if (
79-
"pgvector_enabled" in config.online_store
80-
and config.online_store.pgvector_enabled
81-
):
78+
if config.online_store.pgvector_enabled:
8279
vector_val = get_list_val_str(val)
8380
insert_values.append(
8481
(
@@ -226,10 +223,7 @@ def update(
226223

227224
for table in tables_to_keep:
228225
table_name = _table_id(project, table)
229-
if (
230-
"pgvector_enabled" in config.online_store
231-
and config.online_store.pgvector_enabled
232-
):
226+
if config.online_store.pgvector_enabled:
233227
vector_value_type = f"vector({config.online_store.vector_len})"
234228
else:
235229
# keep the vector_value_type as BYTEA if pgvector is not enabled, to maintain compatibility
@@ -282,7 +276,14 @@ def retrieve_online_documents(
282276
requested_feature: str,
283277
embedding: List[float],
284278
top_k: int,
285-
) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]]:
279+
) -> List[
280+
Tuple[
281+
Optional[datetime],
282+
Optional[ValueProto],
283+
Optional[ValueProto],
284+
Optional[ValueProto],
285+
]
286+
]:
286287
"""
287288
288289
Args:
@@ -297,10 +298,7 @@ def retrieve_online_documents(
297298
"""
298299
project = config.project
299300

300-
if (
301-
"pgvector_enabled" not in config.online_store
302-
or not config.online_store.pgvector_enabled
303-
):
301+
if not config.online_store.pgvector_enabled:
304302
raise ValueError(
305303
"pgvector is not enabled in the online store configuration"
306304
)
@@ -309,7 +307,12 @@ def retrieve_online_documents(
309307
query_embedding_str = f"[{','.join(str(el) for el in embedding)}]"
310308

311309
result: List[
312-
Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]
310+
Tuple[
311+
Optional[datetime],
312+
Optional[ValueProto],
313+
Optional[ValueProto],
314+
Optional[ValueProto],
315+
]
313316
] = []
314317
with self._get_conn(config) as conn, conn.cursor() as cur:
315318
table_name = _table_id(project, table)
@@ -322,6 +325,7 @@ def retrieve_online_documents(
322325
SELECT
323326
entity_key,
324327
feature_name,
328+
value,
325329
vector_value,
326330
vector_value <-> %s as distance,
327331
event_ts FROM {table_name}
@@ -338,16 +342,31 @@ def retrieve_online_documents(
338342
)
339343
rows = cur.fetchall()
340344

341-
for entity_key, feature_name, vector_value, distance, event_ts in rows:
345+
for (
346+
entity_key,
347+
feature_name,
348+
value,
349+
vector_value,
350+
distance,
351+
event_ts,
352+
) in rows:
342353
# TODO Deserialize entity_key to return the entity in response
343354
# entity_key_proto = EntityKeyProto()
344355
# entity_key_proto_bin = bytes(entity_key)
345356

346-
# TODO Convert to List[float] for value type proto
347-
feature_value_proto = ValueProto(string_val=vector_value)
357+
feature_value_proto = ValueProto()
358+
feature_value_proto.ParseFromString(bytes(value))
348359

360+
vector_value_proto = ValueProto(string_val=vector_value)
349361
distance_value_proto = ValueProto(float_val=distance)
350-
result.append((event_ts, feature_value_proto, distance_value_proto))
362+
result.append(
363+
(
364+
event_ts,
365+
feature_value_proto,
366+
vector_value_proto,
367+
distance_value_proto,
368+
)
369+
)
351370

352371
return result
353372

sdk/python/feast/infra/online_stores/online_store.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,14 @@ def retrieve_online_documents(
142142
requested_feature: str,
143143
embedding: List[float],
144144
top_k: int,
145-
) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]]:
145+
) -> List[
146+
Tuple[
147+
Optional[datetime],
148+
Optional[ValueProto],
149+
Optional[ValueProto],
150+
Optional[ValueProto],
151+
]
152+
]:
146153
"""
147154
Retrieves online feature values for the specified embeddings.
148155

sdk/python/feast/infra/provider.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,14 @@ def retrieve_online_documents(
303303
requested_feature: str,
304304
query: List[float],
305305
top_k: int,
306-
) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]]:
306+
) -> List[
307+
Tuple[
308+
Optional[datetime],
309+
Optional[ValueProto],
310+
Optional[ValueProto],
311+
Optional[ValueProto],
312+
]
313+
]:
307314
"""
308315
Searches for the top-k nearest neighbors of the given document in the online document store.
309316

sdk/python/tests/foo_provider.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,5 +111,12 @@ def retrieve_online_documents(
111111
requested_feature: str,
112112
query: List[float],
113113
top_k: int,
114-
) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]]:
114+
) -> List[
115+
Tuple[
116+
Optional[datetime],
117+
Optional[ValueProto],
118+
Optional[ValueProto],
119+
Optional[ValueProto],
120+
]
121+
]:
115122
return []

0 commit comments

Comments
 (0)