Skip to content

Commit d05d601

Browse files
almost have retrieval working, having to make a lot of changes to online retrieval. long term this can all go in the FeatureView class and in get_online_features
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
1 parent 2b18594 commit d05d601

File tree

10 files changed

+94
-67
lines changed

10 files changed

+94
-67
lines changed

sdk/python/feast/feature_store.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1750,9 +1750,10 @@ async def get_online_features_async(
17501750

17511751
def retrieve_online_documents(
17521752
self,
1753-
feature: str,
1753+
feature: Optional[str],
17541754
query: Union[str, List[float]],
17551755
top_k: int,
1756+
features: Optional[List[str]] = None,
17561757
distance_metric: Optional[str] = None,
17571758
) -> OnlineResponse:
17581759
"""
@@ -1762,6 +1763,7 @@ def retrieve_online_documents(
17621763
feature: The list of document features that should be retrieved from the online document store. These features can be
17631764
specified either as a list of string document feature references or as a feature service. String feature
17641765
references must have format "feature_view:feature", e.g, "document_fv:document_embeddings".
1766+
features: The list of features that should be retrieved from the online store.
17651767
query: The query to retrieve the closest document features for.
17661768
top_k: The number of closest document features to retrieve.
17671769
distance_metric: The distance metric to use for retrieval.
@@ -1770,18 +1772,39 @@ def retrieve_online_documents(
17701772
raise ValueError(
17711773
"Using embedding functionality is not supported for document retrieval. Please embed the query before calling retrieve_online_documents."
17721774
)
1775+
feature_list = features or [feature]
17731776
(
17741777
available_feature_views,
17751778
_,
17761779
) = utils._get_feature_views_to_use(
17771780
registry=self._registry,
17781781
project=self.project,
1779-
features=[feature],
1782+
features=feature_list,
17801783
allow_cache=True,
17811784
hide_dummy_entity=False,
17821785
)
1786+
if features:
1787+
feature_view_set = set()
1788+
for feature in features:
1789+
feature_view_name = feature.split(":")[0]
1790+
feature_view = self.get_feature_view(feature_view_name)
1791+
feature_view_set.add(feature_view.name)
1792+
if len(feature_view_set) > 1:
1793+
raise ValueError(
1794+
"Document retrieval only supports a single feature view."
1795+
)
1796+
requested_feature = None
1797+
requested_features = [
1798+
f.split(":")[1] for f in features if isinstance(f, str) and ":" in f
1799+
]
1800+
else:
1801+
requested_feature = (
1802+
feature.split(":")[1] if isinstance(feature, str) else feature
1803+
)
1804+
requested_features = [requested_feature]
1805+
17831806
requested_feature_view_name = (
1784-
feature.split(":")[0] if isinstance(feature, str) else feature
1807+
feature.split(":")[0] if feature else list(feature_view_set)[0]
17851808
)
17861809
for feature_view in available_feature_views:
17871810
if feature_view.name == requested_feature_view_name:
@@ -1790,14 +1813,15 @@ def retrieve_online_documents(
17901813
raise ValueError(
17911814
f"Feature view {requested_feature_view} not found in the registry."
17921815
)
1793-
requested_feature = (
1794-
feature.split(":")[1] if isinstance(feature, str) else feature
1795-
)
1816+
1817+
requested_feature_view = available_feature_views[0]
1818+
17961819
provider = self._get_provider()
17971820
document_features = self._retrieve_from_online_store(
17981821
provider,
17991822
requested_feature_view,
18001823
requested_feature,
1824+
requested_features,
18011825
query,
18021826
top_k,
18031827
distance_metric,
@@ -1833,7 +1857,8 @@ def _retrieve_from_online_store(
18331857
self,
18341858
provider: Provider,
18351859
table: FeatureView,
1836-
requested_feature: str,
1860+
requested_feature: Optional[str],
1861+
requested_features: Optional[List[str]],
18371862
query: List[float],
18381863
top_k: int,
18391864
distance_metric: Optional[str],
@@ -1849,6 +1874,7 @@ def _retrieve_from_online_store(
18491874
config=self.config,
18501875
table=table,
18511876
requested_feature=requested_feature,
1877+
requested_features=requested_features,
18521878
query=query,
18531879
top_k=top_k,
18541880
distance_metric=distance_metric,

sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py

Lines changed: 32 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@
3838
)
3939

4040
PROTO_TO_MILVUS_TYPE_MAPPING: Dict[ValueType, DataType] = {
41-
PROTO_VALUE_TO_VALUE_TYPE_MAP["bytes_val"]: DataType.STRING,
41+
PROTO_VALUE_TO_VALUE_TYPE_MAP["bytes_val"]: DataType.VARCHAR,
4242
PROTO_VALUE_TO_VALUE_TYPE_MAP["bool_val"]: DataType.BOOL,
43-
PROTO_VALUE_TO_VALUE_TYPE_MAP["string_val"]: DataType.STRING,
43+
PROTO_VALUE_TO_VALUE_TYPE_MAP["string_val"]: DataType.VARCHAR,
4444
PROTO_VALUE_TO_VALUE_TYPE_MAP["float_val"]: DataType.FLOAT,
4545
PROTO_VALUE_TO_VALUE_TYPE_MAP["double_val"]: DataType.DOUBLE,
4646
PROTO_VALUE_TO_VALUE_TYPE_MAP["int32_val"]: DataType.INT32,
@@ -71,6 +71,8 @@
7171
ValueType.DOUBLE,
7272
]:
7373
FEAST_PRIMITIVE_TO_MILVUS_TYPE_MAPPING[feast_type] = DataType.FLOAT_VECTOR
74+
elif base_value_type == ValueType.STRING:
75+
FEAST_PRIMITIVE_TO_MILVUS_TYPE_MAPPING[feast_type] = DataType.VARCHAR
7476
elif base_value_type == ValueType.BOOL:
7577
FEAST_PRIMITIVE_TO_MILVUS_TYPE_MAPPING[feast_type] = DataType.BINARY_VECTOR
7678

@@ -149,7 +151,14 @@ def _get_collection(self, config: RepoConfig, table: FeatureView) -> Collection:
149151
dim=config.online_store.embedding_dim,
150152
)
151153
)
152-
154+
elif dtype == DataType.VARCHAR:
155+
fields.append(
156+
FieldSchema(
157+
name=field.name,
158+
dtype=dtype,
159+
max_length=512,
160+
)
161+
)
153162
else:
154163
fields.append(FieldSchema(name=field.name, dtype=dtype))
155164

@@ -210,17 +219,14 @@ def online_write_batch(
210219
int(to_naive_utc(created_ts).timestamp() * 1e6) if created_ts else 0
211220
)
212221
for feature_name in values_dict:
213-
for vector_list_type_name in numeric_vector_list_types:
214-
vector_list = getattr(
215-
values_dict[feature_name], vector_list_type_name, None
216-
)
217-
if vector_list:
218-
vector_values = getattr(
219-
values_dict[feature_name], vector_list_type_name
220-
).val
221-
if vector_values != []:
222-
# Note here we are over-writing the feature and collapsing the list into a single value
223-
values_dict[feature_name] = vector_values
222+
feature_values = values_dict[feature_name]
223+
for proto_val_type in PROTO_VALUE_TO_VALUE_TYPE_MAP:
224+
if feature_values.HasField(proto_val_type):
225+
if proto_val_type in numeric_vector_list_types:
226+
vector_values = getattr(feature_values, proto_val_type).val
227+
else:
228+
vector_values = getattr(feature_values, proto_val_type)
229+
values_dict[feature_name] = vector_values
224230

225231
single_entity_record = {
226232
composite_key_name: entity_key_str,
@@ -243,40 +249,7 @@ def online_read(
243249
entity_keys: List[EntityKeyProto],
244250
requested_features: Optional[List[str]] = None,
245251
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
246-
collection = self._get_collection(config, table)
247-
results = []
248-
249-
for entity_key in entity_keys:
250-
entity_key_str = serialize_entity_key(
251-
entity_key,
252-
entity_key_serialization_version=config.entity_key_serialization_version,
253-
).hex()
254-
expr = f"entity_key == '{entity_key_str}'"
255-
if requested_features:
256-
features_str = ", ".join([f"'{f}'" for f in requested_features])
257-
expr += f" && feature_name in [{features_str}]"
258-
259-
res = collection.query(
260-
expr,
261-
output_fields=["feature_name", "value", "event_ts"],
262-
consistency_level="Strong",
263-
)
264-
265-
res_dict = {}
266-
res_ts = None
267-
for r in res:
268-
feature_name = r["feature_name"]
269-
val_bin = r["value"]
270-
val = ValueProto()
271-
val.ParseFromString(val_bin)
272-
res_dict[feature_name] = val
273-
res_ts = datetime.fromtimestamp(r["event_ts"] / 1e6)
274-
if not res_dict:
275-
results.append((None, None))
276-
else:
277-
results.append((res_ts, res_dict))
278-
279-
return results
252+
raise NotImplementedError
280253

281254
def update(
282255
self,
@@ -320,6 +293,7 @@ def retrieve_online_documents(
320293
config: RepoConfig,
321294
table: FeatureView,
322295
requested_feature: str,
296+
requested_features: List[str],
323297
embedding: List[float],
324298
top_k: int,
325299
distance_metric: Optional[str] = None,
@@ -342,13 +316,22 @@ def retrieve_online_documents(
342316
}
343317
expr = f"feature_name == '{requested_feature}'"
344318

319+
composite_key_name = (
320+
"_".join([str(value) for value in table.entity_columns]) + "_pk"
321+
)
322+
if requested_features:
323+
features_str = ", ".join([f"'{f}'" for f in requested_features])
324+
expr += f" && feature_name in [{features_str}]"
325+
345326
results = collection.search(
346327
data=[embedding],
347328
anns_field="vector_value",
348329
param=search_params,
349330
limit=top_k,
350331
expr=expr,
351-
output_fields=["entity_key", "value", "event_ts"],
332+
output_fields=[composite_key_name]
333+
+ requested_features
334+
+ ["created_ts", "event_ts"],
352335
consistency_level="Strong",
353336
)
354337

sdk/python/feast/infra/passthrough_provider.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,7 @@ def retrieve_online_documents(
295295
config: RepoConfig,
296296
table: FeatureView,
297297
requested_feature: str,
298+
requested_features: Optional[List[str]],
298299
query: List[float],
299300
top_k: int,
300301
distance_metric: Optional[str] = None,
@@ -305,6 +306,7 @@ def retrieve_online_documents(
305306
config,
306307
table,
307308
requested_feature,
309+
requested_features,
308310
query,
309311
top_k,
310312
distance_metric,

sdk/python/feast/infra/provider.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,7 @@ def retrieve_online_documents(
420420
config: RepoConfig,
421421
table: FeatureView,
422422
requested_feature: str,
423+
requested_features: Optional[List[str]],
423424
query: List[float],
424425
top_k: int,
425426
distance_metric: Optional[str] = None,
@@ -440,6 +441,7 @@ def retrieve_online_documents(
440441
config: The config for the current feature store.
441442
table: The feature view whose embeddings should be searched.
442443
requested_feature: the requested document feature name.
444+
requested_features: the requested document feature names.
443445
query: The query embedding to search for.
444446
top_k: The number of documents to return.
445447

sdk/python/tests/conftest.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,6 @@
5656
driver,
5757
location,
5858
)
59-
from tests.integration.feature_repos.universal.online_store.milvus import (
60-
MilvusOnlineStoreCreator,
61-
)
6259
from tests.utils.auth_permissions_util import default_store
6360
from tests.utils.generate_self_signed_certifcate_util import generate_self_signed_cert
6461
from tests.utils.http_server import check_port_open, free_port # noqa: E402
@@ -204,7 +201,6 @@ def environment(request, worker_id):
204201
e.teardown()
205202

206203

207-
208204
@pytest.fixture
209205
def vectordb_environment(request, worker_id):
210206
db_config = IntegrationTestRepoConfig(
@@ -231,6 +227,7 @@ def vectordb_environment(request, worker_id):
231227

232228
e.teardown()
233229

230+
234231
_config_cache: Any = {}
235232

236233

sdk/python/tests/data/data_creator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ def get_feature_values_for_dtype(
8484
def create_document_dataset() -> pd.DataFrame:
8585
data = {
8686
"item_id": [1, 2, 3],
87+
"string_feature": ["a", "b", "c"],
88+
"float_feature": [1.0, 2.0, 3.0],
8789
"embedding_float": [[4.0, 5.0], [1.0, 2.0], [3.0, 4.0]],
8890
"embedding_double": [[4.0, 5.0], [1.0, 2.0], [3.0, 4.0]],
8991
"ts": [

sdk/python/tests/foo_provider.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def retrieve_online_documents(
150150
config: RepoConfig,
151151
table: FeatureView,
152152
requested_feature: str,
153+
requested_features: Optional[List[str]],
153154
query: List[float],
154155
top_k: int,
155156
distance_metric: Optional[str] = None,

sdk/python/tests/integration/feature_repos/universal/feature_views.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from feast.data_source import DataSource, RequestSource
1818
from feast.feature_view_projection import FeatureViewProjection
1919
from feast.on_demand_feature_view import PandasTransformation, SubstraitTransformation
20-
from feast.types import Array, FeastType, Float32, Float64, Int32, Int64
20+
from feast.types import Array, FeastType, Float32, Float64, Int32, Int64, String
2121
from tests.integration.feature_repos.universal.entities import (
2222
customer,
2323
driver,
@@ -160,8 +160,20 @@ def create_item_embeddings_feature_view(source, infer_features: bool = False):
160160
schema=None
161161
if infer_features
162162
else [
163-
Field(name="embedding_double", dtype=Array(Float64)),
164-
Field(name="embedding_float", dtype=Array(Float32)),
163+
Field(
164+
name="embedding_double",
165+
dtype=Array(Float64),
166+
vector_index=True,
167+
vector_search_metric="L2",
168+
),
169+
Field(
170+
name="embedding_float",
171+
dtype=Array(Float32),
172+
vector_index=True,
173+
vector_search_metric="L2",
174+
),
175+
Field(name="string_feature", dtype=String),
176+
Field(name="float_feature", dtype=Float32),
165177
],
166178
source=source,
167179
ttl=timedelta(hours=2),

sdk/python/tests/integration/online_store/test_universal_online.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -903,7 +903,12 @@ def test_retrieve_online_documents2(environment, fake_document_data):
903903
fs.apply([item_embeddings_feature_view, item()])
904904
fs.write_to_online_store("item_embeddings", df)
905905
documents = fs.retrieve_online_documents(
906-
feature="item_embeddings:embedding_float",
906+
feature=None,
907+
features=[
908+
"item_embeddings:embedding_float",
909+
"item_embeddings:item_id",
910+
"item_embeddings:string_feature",
911+
],
907912
query=[1.0, 2.0],
908913
top_k=2,
909914
distance_metric="L2",

sdk/python/tests/unit/online_store/test_online_retrieval.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212

1313
from feast import FeatureStore, RepoConfig
1414
from feast.errors import FeatureViewNotFoundException
15-
from feast.infra.online_stores.milvus_online_store.milvus import MilvusOnlineStoreConfig
16-
from feast.infra.provider import Provider
1715
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
1816
from feast.protos.feast.types.Value_pb2 import FloatList as FloatListProto
1917
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
@@ -563,4 +561,3 @@ def test_sqlite_vec_import() -> None:
563561
""").fetchall()
564562
result = [(rowid, round(distance, 2)) for rowid, distance in result]
565563
assert result == [(2, 2.39), (1, 2.39)]
566-

0 commit comments

Comments
 (0)