1616#
1717
1818from dataclasses import dataclass , field
19- from typing import Dict , List , Optional , Sequence , Tuple
19+ from typing import Dict , List , Optional , Sequence , Tuple , Union
2020
2121from google .auth import credentials as auth_credentials
2222from google .cloud .aiplatform import base
@@ -148,6 +148,37 @@ def __post_init__(self):
148148 )
149149
150150
151+ @dataclass
152+ class HybridQuery :
153+ """
154+ Hyrbid query. Could be used for dense-only or sparse-only or hybrid queries.
155+
156+ dense_embedding (List[float]):
157+ Optional. The dense part of the hybrid queries.
158+ sparse_embedding_values (List[float]):
159+ Optional. The sparse values of the sparse part of the queries.
160+
161+ sparse_embedding_dimensions (List[int]):
162+ Optional. The corresponding dimensions of the sparse values.
163+ For example, values [1,2,3] with dimensions [4,5,6] means value 1 is of the
164+ 4th dimension, value 2 is of the 4th dimension, and value 3 is of the 6th
165+ dimension.
166+
167+ rrf_ranking_alpha (float):
168+ Optional. This should not be specified for dense-only or sparse-only queries.
169+ A value between 0 and 1 for ranking algorithm RRF, representing
170+ the ratio for sparse v.s. dense embeddings returned in the query result.
171+ If the alpha is 0, only sparse embeddings are being returned, and no dense
172+ embedding is being returned. When alhpa is 1, only dense embeddings are being
173+ returned, and no sparse embedding is being returned.
174+ """
175+
176+ dense_embedding : List [float ] = None
177+ sparse_embedding_values : List [float ] = None
178+ sparse_embedding_dimensions : List [int ] = None
179+ rrf_ranking_alpha : float = None
180+
181+
151182@dataclass
152183class MatchNeighbor :
153184 """The id and distance of a nearest neighbor match for a given query embedding.
@@ -157,7 +188,7 @@ class MatchNeighbor:
157188 Required. The id of the neighbor.
158189 distance (float):
159190 Required. The distance to the query embedding.
160- feature_vector (List( float) ):
191+ feature_vector (List[ float] ):
161192 Optional. The feature vector of the matching datapoint.
162193 crowding_tag (Optional[str]):
163194 Optional. Crowding tag of the datapoint, the
@@ -167,6 +198,14 @@ class MatchNeighbor:
167198 Optional. The restricts of the matching datapoint.
168199 numeric_restricts:
169200 Optional. The numeric restricts of the matching datapoint.
201+ sparse_embedding_values (List[float]):
202+ Optional. The sparse values of the sparse part of the matching
203+ datapoint.
204+ sparse_embedding_dimensions (List[int]):
205+ Optional. The corresponding dimensions of the sparse values.
206+ For example, values [1,2,3] with dimensions [4,5,6] means value 1 is
207+ of the 4th dimension, value 2 is of the 4th dimension, and value 3 is
208+ of the 6th dimension.
170209
171210 """
172211
@@ -176,6 +215,8 @@ class MatchNeighbor:
176215 crowding_tag : Optional [str ] = None
177216 restricts : Optional [List [Namespace ]] = None
178217 numeric_restricts : Optional [List [NumericNamespace ]] = None
218+ sparse_embedding_values : Optional [List [float ]] = None
219+ sparse_embedding_dimensions : Optional [List [int ]] = None
179220
180221 def from_index_datapoint (
181222 self , index_datapoint : gca_index_v1beta1 .IndexDatapoint
@@ -207,22 +248,31 @@ def from_index_datapoint(
207248 ]
208249 if index_datapoint .numeric_restricts is not None :
209250 self .numeric_restricts = []
210- for restrict in index_datapoint .numeric_restricts :
211- numeric_namespace = None
212- restrict_value_type = restrict ._pb .WhichOneof ("Value" )
213- if restrict_value_type == "value_int" :
214- numeric_namespace = NumericNamespace (
215- name = restrict .namespace , value_int = restrict .value_int
216- )
217- elif restrict_value_type == "value_float" :
218- numeric_namespace = NumericNamespace (
219- name = restrict .namespace , value_float = restrict .value_float
220- )
221- elif restrict_value_type == "value_double" :
222- numeric_namespace = NumericNamespace (
223- name = restrict .namespace , value_double = restrict .value_double
224- )
225- self .numeric_restricts .append (numeric_namespace )
251+ for restrict in index_datapoint .numeric_restricts :
252+ numeric_namespace = None
253+ restrict_value_type = restrict ._pb .WhichOneof ("Value" )
254+ if restrict_value_type == "value_int" :
255+ numeric_namespace = NumericNamespace (
256+ name = restrict .namespace , value_int = restrict .value_int
257+ )
258+ elif restrict_value_type == "value_float" :
259+ numeric_namespace = NumericNamespace (
260+ name = restrict .namespace , value_float = restrict .value_float
261+ )
262+ elif restrict_value_type == "value_double" :
263+ numeric_namespace = NumericNamespace (
264+ name = restrict .namespace , value_double = restrict .value_double
265+ )
266+ self .numeric_restricts .append (numeric_namespace )
267+ # sparse embeddings
268+ if (
269+ index_datapoint .sparse_embedding is not None
270+ and index_datapoint .sparse_embedding .values is not None
271+ ):
272+ self .sparse_embedding_values = index_datapoint .sparse_embedding .values
273+ self .sparse_embedding_dimensions = (
274+ index_datapoint .sparse_embedding .dimensions
275+ )
226276 return self
227277
228278 def from_embedding (self , embedding : match_service_pb2 .Embedding ) -> "MatchNeighbor" :
@@ -250,22 +300,22 @@ def from_embedding(self, embedding: match_service_pb2.Embedding) -> "MatchNeighb
250300 ]
251301 if embedding .numeric_restricts :
252302 self .numeric_restricts = []
253- for restrict in embedding .numeric_restricts :
254- numeric_namespace = None
255- restrict_value_type = restrict .WhichOneof ("Value" )
256- if restrict_value_type == "value_int" :
257- numeric_namespace = NumericNamespace (
258- name = restrict .name , value_int = restrict .value_int
259- )
260- elif restrict_value_type == "value_float" :
261- numeric_namespace = NumericNamespace (
262- name = restrict .name , value_float = restrict .value_float
263- )
264- elif restrict_value_type == "value_double" :
265- numeric_namespace = NumericNamespace (
266- name = restrict .name , value_double = restrict .value_double
267- )
268- self .numeric_restricts .append (numeric_namespace )
303+ for restrict in embedding .numeric_restricts :
304+ numeric_namespace = None
305+ restrict_value_type = restrict .WhichOneof ("Value" )
306+ if restrict_value_type == "value_int" :
307+ numeric_namespace = NumericNamespace (
308+ name = restrict .name , value_int = restrict .value_int
309+ )
310+ elif restrict_value_type == "value_float" :
311+ numeric_namespace = NumericNamespace (
312+ name = restrict .name , value_float = restrict .value_float
313+ )
314+ elif restrict_value_type == "value_double" :
315+ numeric_namespace = NumericNamespace (
316+ name = restrict .name , value_double = restrict .value_double
317+ )
318+ self .numeric_restricts .append (numeric_namespace )
269319 return self
270320
271321
@@ -1322,7 +1372,7 @@ def find_neighbors(
13221372 self ,
13231373 * ,
13241374 deployed_index_id : str ,
1325- queries : Optional [List [List [float ]]] = None ,
1375+ queries : Optional [Union [ List [List [float ]], List [ HybridQuery ]]] = None ,
13261376 num_neighbors : int = 10 ,
13271377 filter : Optional [List [Namespace ]] = None ,
13281378 per_crowding_attribute_neighbor_count : Optional [int ] = None ,
@@ -1346,8 +1396,15 @@ def find_neighbors(
13461396 Args:
13471397 deployed_index_id (str):
13481398 Required. The ID of the DeployedIndex to match the queries against.
1349- queries (List[List[float]]):
1350- Required. A list of queries. Each query is a list of floats, representing a single embedding.
1399+ queries (Union[List[List[float]], List[HybridQuery]]):
1400+ Optional. A list of queries.
1401+
1402+ For regular dense-only queries, each query is a list of floats,
1403+ representing a single embedding.
1404+
1405+ For hybrid queries, each query is a hybrid query of type
1406+ aiplatform.matching_engine.matching_engine_index_endpoint.HybridQuery.
1407+
13511408 num_neighbors (int):
13521409 Required. The number of nearest neighbors to be retrieved from database for
13531410 each query.
@@ -1381,7 +1438,7 @@ def find_neighbors(
13811438 Note that returning full datapoint will significantly increase the
13821439 latency and cost of the query.
13831440
1384- numeric_filter (list [NumericNamespace]):
1441+ numeric_filter (List [NumericNamespace]):
13851442 Optional. A list of NumericNamespaces for filtering the matching
13861443 results. For example:
13871444 [NumericNamespace(name="cost", value_int=5, op="GREATER")]
@@ -1437,30 +1494,54 @@ def find_neighbors(
14371494 numeric_restrict .value_double = numeric_namespace .value_double
14381495 numeric_restricts .append (numeric_restrict )
14391496 # Queries
1440- query_by_id = False if queries else True
1441- queries = queries if queries else embedding_ids
1442- if queries :
1443- for query in queries :
1444- find_neighbors_query = gca_match_service_v1beta1 .FindNeighborsRequest .Query (
1445- neighbor_count = num_neighbors ,
1446- per_crowding_attribute_neighbor_count = per_crowding_attribute_neighbor_count ,
1447- approximate_neighbor_count = approx_num_neighbors ,
1448- fraction_leaf_nodes_to_search_override = fraction_leaf_nodes_to_search_override ,
1449- )
1450- datapoint = gca_index_v1beta1 .IndexDatapoint (
1451- datapoint_id = query if query_by_id else None ,
1452- feature_vector = None if query_by_id else query ,
1453- )
1454- datapoint .restricts .extend (restricts )
1455- datapoint .numeric_restricts .extend (numeric_restricts )
1456- find_neighbors_query .datapoint = datapoint
1457- find_neighbors_request .queries .append (find_neighbors_query )
1497+ query_by_id = False
1498+ query_is_hybrid = False
1499+ if embedding_ids :
1500+ query_by_id = True
1501+ query_iterators : list [str ] = embedding_ids
1502+ elif queries :
1503+ query_is_hybrid = isinstance (queries [0 ], HybridQuery )
1504+ query_iterators = queries
14581505 else :
14591506 raise ValueError (
14601507 "To find neighbors using matching engine,"
1461- "please specify `queries` or `embedding_ids`"
1508+ "please specify `queries` or `embedding_ids` or `hybrid_queries` "
14621509 )
14631510
1511+ for query in query_iterators :
1512+ find_neighbors_query = gca_match_service_v1beta1 .FindNeighborsRequest .Query (
1513+ neighbor_count = num_neighbors ,
1514+ per_crowding_attribute_neighbor_count = per_crowding_attribute_neighbor_count ,
1515+ approximate_neighbor_count = approx_num_neighbors ,
1516+ fraction_leaf_nodes_to_search_override = fraction_leaf_nodes_to_search_override ,
1517+ )
1518+ if query_by_id :
1519+ datapoint = gca_index_v1beta1 .IndexDatapoint (
1520+ datapoint_id = query ,
1521+ )
1522+ elif query_is_hybrid :
1523+ datapoint = gca_index_v1beta1 .IndexDatapoint (
1524+ feature_vector = query .dense_embedding ,
1525+ sparse_embedding = gca_index_v1beta1 .IndexDatapoint .SparseEmbedding (
1526+ values = query .sparse_embedding_values ,
1527+ dimensions = query .sparse_embedding_dimensions ,
1528+ ),
1529+ )
1530+ if query .rrf_ranking_alpha :
1531+ find_neighbors_query .rrf = (
1532+ gca_match_service_v1beta1 .FindNeighborsRequest .Query .RRF (
1533+ alpha = query .rrf_ranking_alpha ,
1534+ )
1535+ )
1536+ else :
1537+ datapoint = gca_index_v1beta1 .IndexDatapoint (
1538+ feature_vector = query ,
1539+ )
1540+ datapoint .restricts .extend (restricts )
1541+ datapoint .numeric_restricts .extend (numeric_restricts )
1542+ find_neighbors_query .datapoint = datapoint
1543+ find_neighbors_request .queries .append (find_neighbors_query )
1544+
14641545 response = self ._public_match_client .find_neighbors (find_neighbors_request )
14651546
14661547 # Wrap the results in MatchNeighbor objects and return
@@ -1543,7 +1624,6 @@ def read_index_datapoints(
15431624 read_index_datapoints_request
15441625 )
15451626
1546- # Wrap the results and return
15471627 return response .datapoints
15481628
15491629 def _batch_get_embeddings (
0 commit comments