Skip to content

Commit 6ce08d3

Browse files
feat: Adding support to return additional features from vector retrieval for Milvus db (#4971)
* checking in progress but this Pr still is not ready yet Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> * feat: Adding new method to FeatureStore to allow more flexible retrieval of features from vector similarity search Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> * Adding requested_features back into online_store Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> * linter Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> * removed type adjustment Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> --------- Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
1 parent 6a1c102 commit 6ce08d3

File tree

9 files changed

+573
-63
lines changed

9 files changed

+573
-63
lines changed

sdk/python/feast/feature_store.py

Lines changed: 143 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
from feast.feast_object import FeastObject
6363
from feast.feature_service import FeatureService
6464
from feast.feature_view import DUMMY_ENTITY, DUMMY_ENTITY_NAME, FeatureView
65+
from feast.field import Field
6566
from feast.inference import (
6667
update_data_sources_with_inferred_event_timestamp_col,
6768
update_feature_views_with_inferred_features_and_entities,
@@ -1833,7 +1834,6 @@ def retrieve_online_documents(
18331834
top_k,
18341835
distance_metric,
18351836
)
1836-
18371837
# TODO currently not return the vector value since it is same as feature value, if embedding is supported,
18381838
# the feature value can be raw text before embedded
18391839
entity_key_vals = [feature[1] for feature in document_features]
@@ -1861,6 +1861,66 @@ def retrieve_online_documents(
18611861
)
18621862
return OnlineResponse(online_features_response)
18631863

1864+
def retrieve_online_documents_v2(
1865+
self,
1866+
query: Union[str, List[float]],
1867+
top_k: int,
1868+
features: List[str],
1869+
distance_metric: Optional[str] = "L2",
1870+
) -> OnlineResponse:
1871+
"""
1872+
Retrieves the top k closest document features. Note, embeddings are a subset of features.
1873+
1874+
Args:
1875+
features: The list of features that should be retrieved from the online document store. These features can be
1876+
specified either as a list of string document feature references or as a feature service. String feature
1877+
references must have format "feature_view:feature", e.g, "document_fv:document_embeddings".
1878+
query: The query to retrieve the closest document features for.
1879+
top_k: The number of closest document features to retrieve.
1880+
distance_metric: The distance metric to use for retrieval.
1881+
"""
1882+
if isinstance(query, str):
1883+
raise ValueError(
1884+
"Using embedding functionality is not supported for document retrieval. Please embed the query before calling retrieve_online_documents."
1885+
)
1886+
1887+
(
1888+
available_feature_views,
1889+
_,
1890+
) = utils._get_feature_views_to_use(
1891+
registry=self._registry,
1892+
project=self.project,
1893+
features=features,
1894+
allow_cache=True,
1895+
hide_dummy_entity=False,
1896+
)
1897+
feature_view_set = set()
1898+
for feature in features:
1899+
feature_view_name = feature.split(":")[0]
1900+
feature_view = self.get_feature_view(feature_view_name)
1901+
feature_view_set.add(feature_view.name)
1902+
if len(feature_view_set) > 1:
1903+
raise ValueError("Document retrieval only supports a single feature view.")
1904+
requested_features = [
1905+
f.split(":")[1] for f in features if isinstance(f, str) and ":" in f
1906+
]
1907+
1908+
requested_feature_view = available_feature_views[0]
1909+
if not requested_feature_view:
1910+
raise ValueError(
1911+
f"Feature view {requested_feature_view} not found in the registry."
1912+
)
1913+
1914+
provider = self._get_provider()
1915+
return self._retrieve_from_online_store_v2(
1916+
provider,
1917+
requested_feature_view,
1918+
requested_features,
1919+
query,
1920+
top_k,
1921+
distance_metric,
1922+
)
1923+
18641924
def _retrieve_from_online_store(
18651925
self,
18661926
provider: Provider,
@@ -1878,6 +1938,10 @@ def _retrieve_from_online_store(
18781938
"""
18791939
Search and return document features from the online document store.
18801940
"""
1941+
vector_field_metadata = _get_feature_view_vector_field_metadata(table)
1942+
if vector_field_metadata:
1943+
distance_metric = vector_field_metadata.vector_search_metric
1944+
18811945
documents = provider.retrieve_online_documents(
18821946
config=self.config,
18831947
table=table,
@@ -1891,7 +1955,7 @@ def _retrieve_from_online_store(
18911955
read_row_protos = []
18921956
row_ts_proto = Timestamp()
18931957

1894-
for row_ts, entity_key, feature_val, vector_value, distance_val in documents:
1958+
for row_ts, entity_key, feature_val, vector_value, distance_val in documents: # type: ignore[misc]
18951959
# Reset timestamp to default or update if row_ts is not None
18961960
if row_ts is not None:
18971961
row_ts_proto.FromDatetime(row_ts)
@@ -1916,6 +1980,70 @@ def _retrieve_from_online_store(
19161980
)
19171981
return read_row_protos
19181982

1983+
def _retrieve_from_online_store_v2(
1984+
self,
1985+
provider: Provider,
1986+
table: FeatureView,
1987+
requested_features: List[str],
1988+
query: List[float],
1989+
top_k: int,
1990+
distance_metric: Optional[str],
1991+
) -> OnlineResponse:
1992+
"""
1993+
Search and return document features from the online document store.
1994+
"""
1995+
vector_field_metadata = _get_feature_view_vector_field_metadata(table)
1996+
if vector_field_metadata:
1997+
distance_metric = vector_field_metadata.vector_search_metric
1998+
1999+
documents = provider.retrieve_online_documents_v2(
2000+
config=self.config,
2001+
table=table,
2002+
requested_features=requested_features,
2003+
query=query,
2004+
top_k=top_k,
2005+
distance_metric=distance_metric,
2006+
)
2007+
2008+
entity_key_dict: Dict[str, List[ValueProto]] = {}
2009+
datevals, entityvals, list_of_feature_dicts = [], [], []
2010+
for row_ts, entity_key, feature_dict in documents: # type: ignore[misc]
2011+
datevals.append(row_ts)
2012+
entityvals.append(entity_key)
2013+
list_of_feature_dicts.append(feature_dict)
2014+
if entity_key:
2015+
for key, value in zip(entity_key.join_keys, entity_key.entity_values):
2016+
python_value = value
2017+
if key not in entity_key_dict:
2018+
entity_key_dict[key] = []
2019+
entity_key_dict[key].append(python_value)
2020+
2021+
table_entity_values, idxs = utils._get_unique_entities_from_values(
2022+
entity_key_dict,
2023+
)
2024+
2025+
features_to_request: List[str] = []
2026+
if requested_features:
2027+
features_to_request = requested_features + ["distance"]
2028+
else:
2029+
features_to_request = ["distance"]
2030+
feature_data = utils._convert_rows_to_protobuf(
2031+
requested_features=features_to_request,
2032+
read_rows=list(zip(datevals, list_of_feature_dicts)),
2033+
)
2034+
2035+
online_features_response = GetOnlineFeaturesResponse(results=[])
2036+
utils._populate_response_from_feature_data(
2037+
feature_data=feature_data,
2038+
indexes=idxs,
2039+
online_features_response=online_features_response,
2040+
full_feature_names=False,
2041+
requested_features=features_to_request,
2042+
table=table,
2043+
)
2044+
2045+
return OnlineResponse(online_features_response)
2046+
19192047
def serve(
19202048
self,
19212049
host: str,
@@ -2265,3 +2393,16 @@ def _validate_data_sources(data_sources: List[DataSource]):
22652393
raise DataSourceRepeatNamesException(case_insensitive_ds_name)
22662394
else:
22672395
ds_names.add(case_insensitive_ds_name)
2396+
2397+
2398+
def _get_feature_view_vector_field_metadata(
2399+
feature_view: FeatureView,
2400+
) -> Optional[Field]:
2401+
vector_fields = [field for field in feature_view.schema if field.vector_index]
2402+
if len(vector_fields) > 1:
2403+
raise ValueError(
2404+
f"Feature view {feature_view.name} has multiple vector fields. Only one vector field per feature view is supported."
2405+
)
2406+
if not vector_fields:
2407+
return None
2408+
return vector_fields[0]

0 commit comments

Comments
 (0)