-
Notifications
You must be signed in to change notification settings - Fork 1.3k
feat: Enable Vector database and retrieve_online_documents API #4061
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 29 commits
713768e
58d5d94
2cd73d1
d2e0a59
7079e7f
8c9ee97
513dd39
29d98cd
11eb97f
865baf2
47cd117
3f9f59f
7935071
ba39f93
cf53c71
92046af
d0acd2d
cc45f73
006b5c6
6e0ba03
a2302be
2e6fc55
3cbbf21
ec32764
e2d8008
523d20f
5cd085d
795699e
67b007f
33b46bd
82fe5f1
0618378
7de2016
92fed1d
d4f2639
396d7de
6c38b92
f763dc9
818c055
a51b555
2624b22
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1690,6 +1690,67 @@ def _get_online_features( | |
| ) | ||
| return OnlineResponse(online_features_response) | ||
|
|
||
| @log_exceptions_and_usage | ||
| def retrieve_online_documents( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's probably something to be said about having a configurable distance metric to let the user choose which way to get the
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, there are a bunch of different algorithms/configs for Postgresql to retrieve the documents. We can support it in the future after this PR |
||
| self, | ||
| feature: str, | ||
| query: Union[str, List[float]], | ||
| top_k: int, | ||
| ) -> OnlineResponse: | ||
| """ | ||
| Retrieves the top k closest document features. Note, embeddings are a subset of features. | ||
|
|
||
| Args: | ||
| feature: The list of document features that should be retrieved from the online document store. These features can be | ||
| specified either as a list of string document feature references or as a feature service. String feature | ||
| references must have format "feature_view:feature", e.g, "document_fv:document_embeddings". | ||
| query: The query to retrieve the closest document features for. | ||
| top_k: The number of closest document features to retrieve. | ||
| """ | ||
| return self._retrieve_online_documents( | ||
| feature=feature, | ||
| query=query, | ||
| top_k=top_k, | ||
| ) | ||
|
|
||
| def _retrieve_online_documents( | ||
| self, | ||
| feature: str, | ||
| query: Union[str, List[float]], | ||
| top_k: int, | ||
| ): | ||
| if isinstance(query, str): | ||
| raise ValueError( | ||
| "Using embedding functionality is not supported for document retrieval. Please embed the query before calling retrieve_online_documents." | ||
| ) | ||
| ( | ||
| requested_feature_views, | ||
| _, | ||
| ) = self._get_feature_views_to_use( | ||
| features=[feature], allow_cache=True, hide_dummy_entity=False | ||
| ) | ||
| requested_feature = ( | ||
| feature.split(":")[1] if isinstance(feature, str) else feature | ||
| ) | ||
| provider = self._get_provider() | ||
| document_features = self._retrieve_from_online_store( | ||
| provider, | ||
| requested_feature_views[0], | ||
| requested_feature, | ||
| query, | ||
| top_k, | ||
| ) | ||
| online_features_response = GetOnlineFeaturesResponse(results=[]) | ||
| self._populate_response_from_feature_data( | ||
| document_features, | ||
| [], | ||
| online_features_response, | ||
| False, | ||
| requested_feature, | ||
| requested_feature_views[0], | ||
| ) | ||
| return OnlineResponse(online_features_response) | ||
|
|
||
| @staticmethod | ||
| def _get_columnar_entity_values( | ||
| rowise: Optional[List[Dict[str, Any]]], columnar: Optional[Dict[str, List[Any]]] | ||
|
|
@@ -1906,6 +1967,45 @@ def _read_from_online_store( | |
| read_row_protos.append((event_timestamps, statuses, values)) | ||
| return read_row_protos | ||
|
|
||
| def _retrieve_from_online_store( | ||
| self, | ||
| provider: Provider, | ||
| table: FeatureView, | ||
| requested_feature: str, | ||
| query: List[float], | ||
| top_k: int, | ||
| ) -> List[Tuple[List[Timestamp], List["FieldStatus.ValueType"], List[Value]]]: | ||
| """ | ||
| Search and return document features from the online document store. | ||
| """ | ||
| documents = provider.retrieve_online_documents( | ||
| config=self.config, | ||
| table=table, | ||
| requested_feature=requested_feature, | ||
| query=query, | ||
| top_k=top_k, | ||
| ) | ||
| # Each row is a set of features for a given entity key. We only need to convert | ||
| # the data to Protobuf once. | ||
| null_value = Value() | ||
| read_row_protos = [] | ||
|
|
||
| for doc in documents: | ||
| row_ts_proto = Timestamp() | ||
| row_ts, feature_data = doc | ||
| # TODO (Ly): reuse whatever timestamp if row_ts is None? | ||
| if row_ts is not None: | ||
| row_ts_proto.FromDatetime(row_ts) | ||
| event_timestamps = [row_ts_proto] | ||
| if feature_data is None: | ||
| statuses = [FieldStatus.NOT_FOUND] | ||
| values = [null_value] | ||
| else: | ||
| statuses = [FieldStatus.PRESENT] | ||
| values = [feature_data] | ||
| read_row_protos.append((event_timestamps, statuses, values)) | ||
| return read_row_protos | ||
|
|
||
| @staticmethod | ||
| def _populate_response_from_feature_data( | ||
| feature_data: Iterable[ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,15 @@ | |
| from feast.repo_config import RepoConfig | ||
| from feast.usage import log_exceptions_and_usage | ||
|
|
||
| # Search query template to find the top k items that are closest to the given embedding | ||
| # SELECT * FROM items ORDER BY embedding <-> '[3,1,2]' LIMIT 5; | ||
| SEARCH_QUERY_TEMPLATE = """ | ||
| SELECT feature_name, value, event_ts FROM {table_name} | ||
| WHERE feature_name = '{feature_name}' | ||
| ORDER BY value <-> %s | ||
| LIMIT %s; | ||
| """ | ||
|
|
||
|
|
||
| class PostgreSQLOnlineStoreConfig(PostgreSQLConfig): | ||
| type: Literal["postgres"] = "postgres" | ||
|
|
@@ -251,6 +260,48 @@ def teardown( | |
| logging.exception("Teardown failed") | ||
| raise | ||
|
|
||
| def retrieve_online_documents( | ||
| self, | ||
| config: RepoConfig, | ||
| table: FeatureView, | ||
| requested_feature: str, | ||
| embedding: List[float], | ||
| top_k: int, | ||
| ) -> List[Tuple[Optional[datetime], Optional[ValueProto]]]: | ||
| """ | ||
|
|
||
| Args: | ||
| config: Feast configuration object | ||
| table: FeatureView object as the table to search | ||
| requested_feature: The requested feature as the column to search | ||
| embedding: The query embedding to search for | ||
| top_k: The number of items to return | ||
| Returns: | ||
| List of tuples containing the event timestamp and the document feature | ||
|
|
||
| """ | ||
|
|
||
| # Convert the embedding to a string to be used in postgres vector search | ||
| query_embedding_str = f"'[{','.join(str(el) for el in embedding)}]'" | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this the best serialization we can do? This feels pretty brittle but I get it. |
||
|
|
||
| result: List[Tuple[Optional[datetime], Optional[ValueProto]]] = [] | ||
| with self._get_conn(config) as conn, conn.cursor() as cur: | ||
| cur.execute( | ||
| SEARCH_QUERY_TEMPLATE.format( | ||
| table_name=table, feature_name=requested_feature | ||
| ), | ||
| (query_embedding_str, top_k), | ||
| ) | ||
| rows = cur.fetchall() | ||
|
|
||
| for feature_name, value, event_ts in rows: | ||
| val = ValueProto() | ||
| val.ParseFromString(value) | ||
|
|
||
| result.append((event_ts, val)) | ||
|
|
||
| return result | ||
|
|
||
|
|
||
| def _table_id(project: str, table: FeatureView) -> str: | ||
| return f"{project}_{table.name}" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| from typing import Dict | ||
|
|
||
| from testcontainers.postgres import PostgresContainer | ||
|
|
||
| from tests.integration.feature_repos.universal.online_store_creator import ( | ||
| OnlineStoreCreator, | ||
| ) | ||
|
|
||
|
|
||
| class PostgresOnlieStoreCreator(OnlineStoreCreator): | ||
| def __init__(self, project_name: str, **kwargs): | ||
| super().__init__(project_name) | ||
| self.container = ( | ||
| PostgresContainer("postgres:latest", platform="linux/amd64") | ||
| .with_exposed_ports(5432) | ||
| .with_env("POSTGRES_USER", "root") | ||
| .with_env("POSTGRES_PASSWORD", "test") | ||
| .with_env("POSTGRES_DB", "test") | ||
| ) | ||
|
|
||
| def create_online_store(self) -> Dict[str, str]: | ||
| self.container.start() | ||
| exposed_port = self.container.get_exposed_port(5432) | ||
| return { | ||
| "type": "postgres", | ||
| "user": "root", | ||
| "password": "test", | ||
| "database": "test", | ||
| "port": exposed_port, | ||
| } | ||
|
|
||
| def teardown(self): | ||
| self.container.stop() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -785,3 +785,26 @@ def assert_feature_service_entity_mapping_correctness( | |
| entity_rows=entity_rows, | ||
| full_feature_names=full_feature_names, | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.integration | ||
| @pytest.mark.universal_online_stores(only=["postgres"]) | ||
| def test_retrieve_online_documents( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will you be outputting the cosine similarity as well? That would be useful possibly for debugging. Would be good to be able to test that the engine computes it...maybe not doable though.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be possible. Somehow just the integration test doesn't startup the Postgres container. And I'm debugging it. |
||
| environment, universal_data_sources, fake_ingest_document_data | ||
| ): | ||
| fs = environment.feature_store | ||
| entities, datasets, data_sources = universal_data_sources | ||
| driver_hourly_stats = create_driver_hourly_stats_feature_view(data_sources.driver) | ||
| driver_entity = driver() | ||
|
|
||
| # Register Feature View and Entity | ||
| fs.apply([driver_hourly_stats, driver_entity]) | ||
|
|
||
| # directly ingest data into the Online Store | ||
| fs.write_to_online_store("document_fv", fake_ingest_document_data) | ||
|
|
||
| # retrieve the online documents | ||
| documents = fs.retrieve_online_documents( | ||
| feature="document_fv:doc", query="[1, 2]", top_k=5 | ||
| ) | ||
| assert len(documents) == 2 | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at this now, was this the right choice?