Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/reference/online-stores/elasticsearch.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ top_k = 5
# Retrieve the top k closest features to the query vector

feature_values = feature_store.retrieve_online_documents(
feature="my_feature",
features=["my_feature"],
query=query_vector,
top_k=top_k
)
Expand Down
2 changes: 1 addition & 1 deletion docs/reference/online-stores/qdrant.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ top_k = 5
# the vector to use can be specified in the repo config.
# Reference: https://qdrant.tech/documentation/concepts/vectors/#named-vectors
feature_values = feature_store.retrieve_online_documents(
feature="my_feature",
features=["my_feature"],
query=query_vector,
top_k=top_k
)
Expand Down
88 changes: 38 additions & 50 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1831,19 +1831,15 @@ async def get_online_features_async(

def retrieve_online_documents(
self,
feature: Optional[str],
query: Union[str, List[float]],
top_k: int,
features: Optional[List[str]] = None,
features: List[str],
distance_metric: Optional[str] = "L2",
) -> OnlineResponse:
"""
Retrieves the top k closest document features. Note, embeddings are a subset of features.

Args:
feature: The list of document features that should be retrieved from the online document store. These features can be
specified either as a list of string document feature references or as a feature service. String feature
references must have format "feature_view:feature", e.g, "document_fv:document_embeddings".
features: The list of features that should be retrieved from the online store.
query: The query to retrieve the closest document features for.
top_k: The number of closest document features to retrieve.
Expand All @@ -1853,68 +1849,55 @@ def retrieve_online_documents(
raise ValueError(
"Using embedding functionality is not supported for document retrieval. Please embed the query before calling retrieve_online_documents."
)
feature_list: List[str] = (
features
if features is not None
else ([feature] if feature is not None else [])
)

(
available_feature_views,
_,
) = utils._get_feature_views_to_use(
registry=self._registry,
project=self.project,
features=feature_list,
features=features,
allow_cache=True,
hide_dummy_entity=False,
)
if features:
Comment thread
jyejare marked this conversation as resolved.
feature_view_set = set()
for feature in features:
feature_view_name = feature.split(":")[0]
feature_view = self.get_feature_view(feature_view_name)
feature_view_set.add(feature_view.name)
if len(feature_view_set) > 1:
raise ValueError(
"Document retrieval only supports a single feature view."
)
requested_feature = None
requested_features = [
f.split(":")[1] for f in features if isinstance(f, str) and ":" in f
]
else:
requested_feature = (
feature.split(":")[1] if isinstance(feature, str) else feature
)
requested_features = [requested_feature] if requested_feature else []

requested_feature_view_name = (
feature.split(":")[0] if feature else list(feature_view_set)[0]
)
feature_view_set = set()
for _feature in features:
feature_view_name = _feature.split(":")[0]
feature_view = self.get_feature_view(feature_view_name)
feature_view_set.add(feature_view.name)
if len(feature_view_set) > 1:
raise ValueError("Document retrieval only supports a single feature view.")
requested_features = [
f.split(":")[1] for f in features if isinstance(f, str) and ":" in f
]
requested_feature_view_name = list(feature_view_set)[0]
for feature_view in available_feature_views:
if feature_view.name == requested_feature_view_name:
requested_feature_view = feature_view
if not requested_feature_view:
break
else:
raise ValueError(
f"Feature view {requested_feature_view} not found in the registry."
)

requested_feature_view = available_feature_views[0]

provider = self._get_provider()
document_features = self._retrieve_from_online_store(
provider,
requested_feature_view,
requested_feature,
requested_features,
query,
top_k,
distance_metric,
)

# TODO currently not return the vector value since it is same as feature value, if embedding is supported,
# the feature value can be raw text before embedded
entity_key_vals = [feature[1] for feature in document_features]
def _doc_feature(x):
return [feature[x] for feature in document_features]

entity_key_vals, document_feature_vals, document_feature_distance_vals = map(
_doc_feature, (1, 4, 5)
)
join_key_values: Dict[str, List[ValueProto]] = {}
for entity_key_val in entity_key_vals:
if entity_key_val is not None:
Expand All @@ -1924,18 +1907,25 @@ def retrieve_online_documents(
if join_key not in join_key_values:
join_key_values[join_key] = []
join_key_values[join_key].append(entity_value)

document_feature_vals = [feature[4] for feature in document_features]
document_feature_distance_vals = [feature[5] for feature in document_features]
online_features_response = GetOnlineFeaturesResponse(results=[])
requested_feature = requested_feature or requested_features[0]
if vector_field_metadata := _get_feature_view_vector_field_metadata(
requested_feature_view
):
vector_field_name = vector_field_metadata.name
data = {
**join_key_values,
vector_field_name: document_feature_vals,
"distance": document_feature_distance_vals,
}
_requested_features = [_feature.split(":")[-1] for _feature in features]
requested_features_data = {
_feature: data[_feature]
for _feature in _requested_features
if _feature in data
}
utils._populate_result_rows_from_columnar(
online_features_response=online_features_response,
data={
**join_key_values,
requested_feature: document_feature_vals,
"distance": document_feature_distance_vals,
Comment thread
franciscojavierarceo marked this conversation as resolved.
},
data=requested_features_data,
)
return OnlineResponse(online_features_response)

Expand Down Expand Up @@ -2012,7 +2002,6 @@ def _retrieve_from_online_store(
self,
provider: Provider,
table: FeatureView,
requested_feature: Optional[str],
requested_features: Optional[List[str]],
query: List[float],
top_k: int,
Expand All @@ -2032,7 +2021,6 @@ def _retrieve_from_online_store(
documents = provider.retrieve_online_documents(
config=self.config,
table=table,
requested_feature=requested_feature,
requested_features=requested_features,
query=query,
top_k=top_k,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,7 @@ def retrieve_online_documents(
self,
config: RepoConfig,
table: FeatureView,
requested_feature: Optional[str],
requested_features: Optional[List[str]],
requested_features: List[str],
embedding: List[float],
top_k: int,
*args,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,7 @@ def retrieve_online_documents(
self,
config: RepoConfig,
table: FeatureView,
requested_feature: Optional[str],
requested_featres: Optional[List[str]],
requested_featres: List[str],
embedding: List[float],
top_k: int,
distance_metric: Optional[str] = None,
Expand Down
10 changes: 3 additions & 7 deletions sdk/python/feast/infra/online_stores/online_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,7 @@ def retrieve_online_documents(
self,
config: RepoConfig,
table: FeatureView,
requested_feature: Optional[str],
requested_features: Optional[List[str]],
requested_features: List[str],
embedding: List[float],
top_k: int,
distance_metric: Optional[str] = None,
Expand All @@ -413,7 +412,6 @@ def retrieve_online_documents(
distance_metric: distance metric to use for retrieval.
config: The config for the current feature store.
table: The feature view whose feature values should be read.
requested_feature: The name of the feature whose embeddings should be used for retrieval.
requested_features: The list of features whose embeddings should be used for retrieval.
embedding: The embeddings to use for retrieval.
top_k: The number of documents to retrieve.
Expand All @@ -423,10 +421,8 @@ def retrieve_online_documents(
where the first item is the event timestamp for the row, and the second item is a dict of feature
name to embeddings.
"""
if not requested_feature and not requested_features:
raise ValueError(
"Either requested_feature or requested_features must be specified"
)
if not requested_features:
raise ValueError("Requested_features must be specified")
raise NotImplementedError(
f"Online store {self.__class__.__name__} does not support online retrieval"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,6 @@ def retrieve_online_documents(
self,
config: RepoConfig,
table: FeatureView,
requested_feature: Optional[str],
requested_features: Optional[List[str]],
embedding: List[float],
top_k: int,
Expand All @@ -373,7 +372,6 @@ def retrieve_online_documents(
Args:
config: Feast configuration object
table: FeatureView object as the table to search
requested_feature: The requested feature as the column to search
requested_features: The list of features whose embeddings should be used for retrieval.
embedding: The query embedding to search for
top_k: The number of items to return
Expand All @@ -394,6 +392,11 @@ def retrieve_online_documents(
f"Distance metric {distance_metric} is not supported. Supported distance metrics are {SUPPORTED_DISTANCE_METRICS_DICT.keys()}"
)

if requested_features:
required_feature_names = ", ".join(
[feature for feature in requested_features]
)

distance_metric_sql = SUPPORTED_DISTANCE_METRICS_DICT[distance_metric]

result: List[
Expand All @@ -415,19 +418,18 @@ def retrieve_online_documents(
"""
SELECT
entity_key,
feature_name,
{feature_names},
value,
vector_value,
vector_value {distance_metric_sql} %s::vector as distance,
event_ts FROM {table_name}
WHERE feature_name = {feature_name}
ORDER BY distance
LIMIT {top_k};
"""
).format(
distance_metric_sql=sql.SQL(distance_metric_sql),
table_name=sql.Identifier(table_name),
feature_name=sql.Literal(requested_feature),
feature_names=required_feature_names,
top_k=sql.Literal(top_k),
),
(embedding,),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,7 @@ def retrieve_online_documents(
self,
config: RepoConfig,
table: FeatureView,
requested_feature: Optional[str],
requested_features: Optional[List[str]],
requested_features: List[str],
embedding: List[float],
top_k: int,
distance_metric: Optional[str] = "cosine",
Expand Down
5 changes: 2 additions & 3 deletions sdk/python/feast/infra/online_stores/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,7 @@ def retrieve_online_documents(
self,
config: RepoConfig,
table: FeatureView,
requested_feature: Optional[str],
requested_featuers: Optional[List[str]],
requested_featuers: List[str],
Comment thread
jyejare marked this conversation as resolved.
Outdated
embedding: List[float],
top_k: int,
distance_metric: Optional[str] = None,
Expand All @@ -341,7 +340,7 @@ def retrieve_online_documents(
Args:
config: Feast configuration object
table: FeatureView object as the table to search
requested_feature: The requested feature as the column to search
requested_features: The list of requested features to retrieve
embedding: The query embedding to search for
top_k: The number of items to return
Returns:
Expand Down
2 changes: 0 additions & 2 deletions sdk/python/feast/infra/passthrough_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,6 @@ def retrieve_online_documents(
self,
config: RepoConfig,
table: FeatureView,
requested_feature: Optional[str],
requested_features: Optional[List[str]],
query: List[float],
top_k: int,
Expand All @@ -305,7 +304,6 @@ def retrieve_online_documents(
result = self.online_store.retrieve_online_documents(
config,
table,
requested_feature,
requested_features,
query,
top_k,
Expand Down
2 changes: 0 additions & 2 deletions sdk/python/feast/infra/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,6 @@ def retrieve_online_documents(
self,
config: RepoConfig,
table: FeatureView,
requested_feature: Optional[str],
requested_features: Optional[List[str]],
query: List[float],
top_k: int,
Expand All @@ -440,7 +439,6 @@ def retrieve_online_documents(
distance_metric: distance metric to use for the search.
config: The config for the current feature store.
table: The feature view whose embeddings should be searched.
requested_feature: the requested document feature name.
requested_features: the requested document feature names.
query: The query embedding to search for.
top_k: The number of documents to return.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,7 @@ def test_retrieve_online_documents(environment, fake_document_data):
fs.write_to_online_store("item_embeddings", df)

documents = fs.retrieve_online_documents(
feature="item_embeddings:embedding_float",
features=["item_embeddings:embedding_float"],
query=[1.0, 2.0],
top_k=2,
distance_metric="L2",
Expand All @@ -881,7 +881,7 @@ def test_retrieve_online_documents(environment, fake_document_data):
assert len(documents["item_id"]) == 2

documents = fs.retrieve_online_documents(
feature="item_embeddings:embedding_float",
features=["item_embeddings:embedding_float"],
query=[1.0, 2.0],
top_k=2,
distance_metric="L1",
Expand All @@ -890,7 +890,7 @@ def test_retrieve_online_documents(environment, fake_document_data):

with pytest.raises(ValueError):
fs.retrieve_online_documents(
feature="item_embeddings:embedding_float",
features=["item_embeddings:embedding_float"],
query=[1.0, 2.0],
top_k=2,
distance_metric="wrong",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,9 @@ def test_sqlite_get_online_documents() -> None:
vector_length,
)
result = store.retrieve_online_documents(
feature="document_embeddings:Embeddings", query=query_embedding, top_k=3
query=query_embedding,
top_k=3,
features=["document_embeddings:Embeddings", "document_embeddings:distance"],
).to_dict()

assert "Embeddings" in result
Expand Down
Loading