Skip to content

Commit ec19036

Browse files
authored
feat: Enable Vector database and retrieve_online_documents API (feast-dev#4061)
* feat: add document store * feat: add document store * feat: add document store * feat: add document store * remove DocumentStore * format * format * format * format * format * format * remove unused vars * add test * add test * format * format * format * format * format * fix not implemented issue * fix not implemented issue * fix test * format * format * format * format * format * format * update testcontainer * format * fix postgres integration test * format * fix postgres test * fix postgres test * fix postgres test * fix postgres test * fix postgres test * format * format * format
1 parent 3c6ce86 commit ec19036

File tree

15 files changed

+419
-12
lines changed

15 files changed

+419
-12
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ test-python-universal-postgres-offline:
200200
test-python-universal-postgres-online:
201201
PYTHONPATH='.' \
202202
FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.online_stores.contrib.postgres_repo_configuration \
203-
PYTEST_PLUGINS=sdk.python.feast.infra.offline_stores.contrib.postgres_offline_store.tests \
203+
PYTEST_PLUGINS=sdk.python.tests.integration.feature_repos.universal.online_store.postgres \
204204
python -m pytest -n 8 --integration \
205205
-k "not test_universal_cli and \
206206
not test_go_feature_server and \

sdk/python/feast/feature_store.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1690,6 +1690,72 @@ def _get_online_features(
16901690
)
16911691
return OnlineResponse(online_features_response)
16921692

1693+
@log_exceptions_and_usage
1694+
def retrieve_online_documents(
1695+
self,
1696+
feature: str,
1697+
query: Union[str, List[float]],
1698+
top_k: int,
1699+
) -> OnlineResponse:
1700+
"""
1701+
Retrieves the top k closest document features. Note, embeddings are a subset of features.
1702+
1703+
Args:
1704+
feature: The list of document features that should be retrieved from the online document store. These features can be
1705+
specified either as a list of string document feature references or as a feature service. String feature
1706+
references must have format "feature_view:feature", e.g, "document_fv:document_embeddings".
1707+
query: The query to retrieve the closest document features for.
1708+
top_k: The number of closest document features to retrieve.
1709+
"""
1710+
return self._retrieve_online_documents(
1711+
feature=feature,
1712+
query=query,
1713+
top_k=top_k,
1714+
)
1715+
1716+
def _retrieve_online_documents(
1717+
self,
1718+
feature: str,
1719+
query: Union[str, List[float]],
1720+
top_k: int,
1721+
):
1722+
if isinstance(query, str):
1723+
raise ValueError(
1724+
"Using embedding functionality is not supported for document retrieval. Please embed the query before calling retrieve_online_documents."
1725+
)
1726+
(
1727+
requested_feature_views,
1728+
_,
1729+
) = self._get_feature_views_to_use(
1730+
features=[feature], allow_cache=True, hide_dummy_entity=False
1731+
)
1732+
requested_feature = (
1733+
feature.split(":")[1] if isinstance(feature, str) else feature
1734+
)
1735+
provider = self._get_provider()
1736+
document_features = self._retrieve_from_online_store(
1737+
provider,
1738+
requested_feature_views[0],
1739+
requested_feature,
1740+
query,
1741+
top_k,
1742+
)
1743+
document_feature_vals = [feature[2] for feature in document_features]
1744+
document_feature_distance_vals = [feature[3] for feature in document_features]
1745+
online_features_response = GetOnlineFeaturesResponse(results=[])
1746+
1747+
# TODO Refactor to better way of populating result
1748+
# TODO populate entity in the response after returning entity in document_features is supported
1749+
self._populate_result_rows_from_columnar(
1750+
online_features_response=online_features_response,
1751+
data={requested_feature: document_feature_vals},
1752+
)
1753+
self._populate_result_rows_from_columnar(
1754+
online_features_response=online_features_response,
1755+
data={"distance": document_feature_distance_vals},
1756+
)
1757+
return OnlineResponse(online_features_response)
1758+
16931759
@staticmethod
16941760
def _get_columnar_entity_values(
16951761
rowise: Optional[List[Dict[str, Any]]], columnar: Optional[Dict[str, List[Any]]]
@@ -1906,6 +1972,43 @@ def _read_from_online_store(
19061972
read_row_protos.append((event_timestamps, statuses, values))
19071973
return read_row_protos
19081974

1975+
def _retrieve_from_online_store(
1976+
self,
1977+
provider: Provider,
1978+
table: FeatureView,
1979+
requested_feature: str,
1980+
query: List[float],
1981+
top_k: int,
1982+
) -> List[Tuple[Timestamp, "FieldStatus.ValueType", Value, Value]]:
1983+
"""
1984+
Search and return document features from the online document store.
1985+
"""
1986+
documents = provider.retrieve_online_documents(
1987+
config=self.config,
1988+
table=table,
1989+
requested_feature=requested_feature,
1990+
query=query,
1991+
top_k=top_k,
1992+
)
1993+
1994+
read_row_protos = []
1995+
row_ts_proto = Timestamp()
1996+
1997+
for row_ts, feature_val, distance_val in documents:
1998+
# Reset timestamp to default or update if row_ts is not None
1999+
if row_ts is not None:
2000+
row_ts_proto.FromDatetime(row_ts)
2001+
2002+
if feature_val is None or distance_val is None:
2003+
feature_val = Value()
2004+
distance_val = Value()
2005+
status = FieldStatus.NOT_FOUND
2006+
else:
2007+
status = FieldStatus.PRESENT
2008+
2009+
read_row_protos.append((row_ts_proto, status, feature_val, distance_val))
2010+
return read_row_protos
2011+
19092012
@staticmethod
19102013
def _populate_response_from_feature_data(
19112014
feature_data: Iterable[

sdk/python/feast/infra/key_encoding_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,11 @@ def serialize_entity_key(
7272
output.append(val_bytes)
7373

7474
return b"".join(output)
75+
76+
77+
def get_val_str(val):
78+
accept_value_types = ["float_list_val", "double_list_val", "int_list_val"]
79+
for accept_type in accept_value_types:
80+
if val.HasField(accept_type):
81+
return str(getattr(val, accept_type).val)
82+
return None

sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/tests/data_source.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from testcontainers.core.waiting_utils import wait_for_logs
88

99
from feast.data_source import DataSource
10+
from feast.feature_logging import LoggingDestination
1011
from feast.infra.offline_stores.contrib.postgres_offline_store.postgres import (
1112
PostgreSQLOfflineStoreConfig,
1213
PostgreSQLSource,
@@ -57,6 +58,9 @@ def postgres_container():
5758

5859

5960
class PostgreSQLDataSourceCreator(DataSourceCreator, OnlineStoreCreator):
61+
def create_logged_features_destination(self) -> LoggingDestination:
62+
return None # type: ignore
63+
6064
def __init__(
6165
self, project_name: str, fixture_request: pytest.FixtureRequest, **kwargs
6266
):

sdk/python/feast/infra/online_stores/contrib/postgres.py

Lines changed: 93 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
from collections import defaultdict
44
from datetime import datetime
5-
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple
5+
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union
66

77
import psycopg2
88
import pytz
@@ -12,7 +12,7 @@
1212

1313
from feast import Entity
1414
from feast.feature_view import FeatureView
15-
from feast.infra.key_encoding_utils import serialize_entity_key
15+
from feast.infra.key_encoding_utils import get_val_str, serialize_entity_key
1616
from feast.infra.online_stores.online_store import OnlineStore
1717
from feast.infra.utils.postgres.connection_utils import _get_conn, _get_connection_pool
1818
from feast.infra.utils.postgres.postgres_config import ConnectionType, PostgreSQLConfig
@@ -25,6 +25,12 @@
2525
class PostgreSQLOnlineStoreConfig(PostgreSQLConfig):
2626
type: Literal["postgres"] = "postgres"
2727

28+
# Whether to enable the pgvector extension for vector similarity search
29+
pgvector_enabled: Optional[bool] = False
30+
31+
# If pgvector is enabled, the length of the vector field
32+
vector_len: Optional[int] = 512
33+
2834

2935
class PostgreSQLOnlineStore(OnlineStore):
3036
_conn: Optional[psycopg2._psycopg.connection] = None
@@ -68,11 +74,19 @@ def online_write_batch(
6874
created_ts = _to_naive_utc(created_ts)
6975

7076
for feature_name, val in values.items():
77+
val_str: Union[str, bytes]
78+
if (
79+
"pgvector_enabled" in config.online_config
80+
and config.online_config["pgvector_enabled"]
81+
):
82+
val_str = get_val_str(val)
83+
else:
84+
val_str = val.SerializeToString()
7185
insert_values.append(
7286
(
7387
entity_key_bin,
7488
feature_name,
75-
val.SerializeToString(),
89+
val_str,
7690
timestamp,
7791
created_ts,
7892
)
@@ -212,14 +226,20 @@ def update(
212226

213227
for table in tables_to_keep:
214228
table_name = _table_id(project, table)
229+
value_type = "BYTEA"
230+
if (
231+
"pgvector_enabled" in config.online_config
232+
and config.online_config["pgvector_enabled"]
233+
):
234+
value_type = f'vector({config.online_config["vector_len"]})'
215235
cur.execute(
216236
sql.SQL(
217237
"""
218238
CREATE TABLE IF NOT EXISTS {}
219239
(
220240
entity_key BYTEA,
221241
feature_name TEXT,
222-
value BYTEA,
242+
value {},
223243
event_ts TIMESTAMPTZ,
224244
created_ts TIMESTAMPTZ,
225245
PRIMARY KEY(entity_key, feature_name)
@@ -228,6 +248,7 @@ def update(
228248
"""
229249
).format(
230250
sql.Identifier(table_name),
251+
sql.SQL(value_type),
231252
sql.Identifier(f"{table_name}_ek"),
232253
sql.Identifier(table_name),
233254
)
@@ -251,6 +272,74 @@ def teardown(
251272
logging.exception("Teardown failed")
252273
raise
253274

275+
def retrieve_online_documents(
276+
self,
277+
config: RepoConfig,
278+
table: FeatureView,
279+
requested_feature: str,
280+
embedding: List[float],
281+
top_k: int,
282+
) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]]:
283+
"""
284+
285+
Args:
286+
config: Feast configuration object
287+
table: FeatureView object as the table to search
288+
requested_feature: The requested feature as the column to search
289+
embedding: The query embedding to search for
290+
top_k: The number of items to return
291+
Returns:
292+
List of tuples containing the event timestamp and the document feature
293+
294+
"""
295+
project = config.project
296+
297+
# Convert the embedding to a string to be used in postgres vector search
298+
query_embedding_str = f"[{','.join(str(el) for el in embedding)}]"
299+
300+
result: List[
301+
Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]
302+
] = []
303+
with self._get_conn(config) as conn, conn.cursor() as cur:
304+
table_name = _table_id(project, table)
305+
306+
# Search query template to find the top k items that are closest to the given embedding
307+
# SELECT * FROM items ORDER BY embedding <-> '[3,1,2]' LIMIT 5;
308+
cur.execute(
309+
sql.SQL(
310+
"""
311+
SELECT
312+
entity_key,
313+
feature_name,
314+
value,
315+
value <-> %s as distance,
316+
event_ts FROM {table_name}
317+
WHERE feature_name = {feature_name}
318+
ORDER BY distance
319+
LIMIT {top_k};
320+
"""
321+
).format(
322+
table_name=sql.Identifier(table_name),
323+
feature_name=sql.Literal(requested_feature),
324+
top_k=sql.Literal(top_k),
325+
),
326+
(query_embedding_str,),
327+
)
328+
rows = cur.fetchall()
329+
330+
for entity_key, feature_name, value, distance, event_ts in rows:
331+
# TODO Deserialize entity_key to return the entity in response
332+
# entity_key_proto = EntityKeyProto()
333+
# entity_key_proto_bin = bytes(entity_key)
334+
335+
# TODO Convert to List[float] for value type proto
336+
feature_value_proto = ValueProto(string_val=value)
337+
338+
distance_value_proto = ValueProto(float_val=distance)
339+
result.append((event_ts, feature_value_proto, distance_value_proto))
340+
341+
return result
342+
254343

255344
def _table_id(project: str, table: FeatureView) -> str:
256345
return f"{project}_{table.name}"
Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
1-
from feast.infra.offline_stores.contrib.postgres_offline_store.tests.data_source import (
2-
PostgreSQLDataSourceCreator,
3-
)
41
from tests.integration.feature_repos.integration_test_repo_config import (
52
IntegrationTestRepoConfig,
63
)
4+
from tests.integration.feature_repos.universal.online_store.postgres import (
5+
PGVectorOnlineStoreCreator,
6+
PostgresOnlineStoreCreator,
7+
)
78

89
FULL_REPO_CONFIGS = [
9-
IntegrationTestRepoConfig(online_store_creator=PostgreSQLDataSourceCreator),
10+
IntegrationTestRepoConfig(
11+
online_store="postgres", online_store_creator=PostgresOnlineStoreCreator
12+
),
13+
IntegrationTestRepoConfig(
14+
online_store="pgvector", online_store_creator=PGVectorOnlineStoreCreator
15+
),
1016
]
17+
18+
AVAILABLE_ONLINE_STORES = {"pgvector": PGVectorOnlineStoreCreator}

sdk/python/feast/infra/online_stores/online_store.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,30 @@ def teardown(
134134
entities: Entities whose corresponding infrastructure should be deleted.
135135
"""
136136
pass
137+
138+
def retrieve_online_documents(
139+
self,
140+
config: RepoConfig,
141+
table: FeatureView,
142+
requested_feature: str,
143+
embedding: List[float],
144+
top_k: int,
145+
) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]]:
146+
"""
147+
Retrieves online feature values for the specified embeddings.
148+
149+
Args:
150+
config: The config for the current feature store.
151+
table: The feature view whose feature values should be read.
152+
requested_feature: The name of the feature whose embeddings should be used for retrieval.
153+
embedding: The embeddings to use for retrieval.
154+
top_k: The number of nearest neighbors to retrieve.
155+
156+
Returns:
157+
object: A list of top k closest documents to the specified embedding. Each item in the list is a tuple
158+
where the first item is the event timestamp for the row, and the second item is a dict of feature
159+
name to embeddings.
160+
"""
161+
raise NotImplementedError(
162+
f"Online store {self.__class__.__name__} does not support online retrieval"
163+
)

sdk/python/feast/infra/passthrough_provider.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,23 @@ def online_read(
190190
)
191191
return result
192192

193+
@log_exceptions_and_usage(sampler=RatioSampler(ratio=0.001))
194+
def retrieve_online_documents(
195+
self,
196+
config: RepoConfig,
197+
table: FeatureView,
198+
requested_feature: str,
199+
query: List[float],
200+
top_k: int,
201+
) -> List:
202+
set_usage_attribute("provider", self.__class__.__name__)
203+
result = []
204+
if self.online_store:
205+
result = self.online_store.retrieve_online_documents(
206+
config, table, requested_feature, query, top_k
207+
)
208+
return result
209+
193210
def ingest_df(
194211
self,
195212
feature_view: FeatureView,

0 commit comments

Comments
 (0)