@@ -45,6 +45,72 @@ def _expected_pb(parent, vector_field, vector, distance_type, limit):
4545 return expected_pb
4646
4747
48+ @pytest .mark .parametrize (
49+ "distance_measure, expected_distance" ,
50+ [
51+ (
52+ DistanceMeasure .EUCLIDEAN ,
53+ StructuredQuery .FindNearest .DistanceMeasure .EUCLIDEAN ,
54+ ),
55+ (DistanceMeasure .COSINE , StructuredQuery .FindNearest .DistanceMeasure .COSINE ),
56+ (
57+ DistanceMeasure .DOT_PRODUCT ,
58+ StructuredQuery .FindNearest .DistanceMeasure .DOT_PRODUCT ,
59+ ),
60+ ],
61+ )
62+ @pytest .mark .asyncio
63+ async def test_async_vector_query (distance_measure , expected_distance ):
64+ # Create a minimal fake GAPIC.
65+ firestore_api = AsyncMock (spec = ["run_query" ])
66+ client = make_async_client ()
67+ client ._firestore_api_internal = firestore_api
68+
69+ # Make a **real** collection reference as parent.
70+ parent = client .collection ("dee" )
71+ parent_path , expected_prefix = parent ._parent_info ()
72+
73+ data = {"snooze" : 10 , "embedding" : Vector ([1.0 , 2.0 , 3.0 ])}
74+ response_pb1 = _make_query_response (
75+ name = "{}/test_doc" .format (expected_prefix ), data = data
76+ )
77+
78+ kwargs = make_retry_timeout_kwargs (retry = None , timeout = None )
79+
80+ # Execute the vector query and check the response.
81+ firestore_api .run_query .return_value = AsyncIter ([response_pb1 ])
82+
83+ vector_async_query = parent .find_nearest (
84+ vector_field = "embedding" ,
85+ query_vector = Vector ([1.0 , 2.0 , 3.0 ]),
86+ distance_measure = distance_measure ,
87+ limit = 5 ,
88+ )
89+
90+ returned = await vector_async_query .get (transaction = _transaction (client ), ** kwargs )
91+ assert isinstance (returned , list )
92+ assert len (returned ) == 1
93+ assert returned [0 ].to_dict () == data
94+
95+ expected_pb = _expected_pb (
96+ parent = parent ,
97+ vector_field = "embedding" ,
98+ vector = Vector ([1.0 , 2.0 , 3.0 ]),
99+ distance_type = expected_distance ,
100+ limit = 5 ,
101+ )
102+
103+ firestore_api .run_query .assert_called_once_with (
104+ request = {
105+ "parent" : parent_path ,
106+ "structured_query" : expected_pb ,
107+ "transaction" : _TXN_ID ,
108+ },
109+ metadata = client ._rpc_metadata ,
110+ ** kwargs ,
111+ )
112+
113+
48114@pytest .mark .parametrize (
49115 "distance_measure, expected_distance" ,
50116 [
@@ -84,14 +150,14 @@ async def test_async_vector_query_with_filter(distance_measure, expected_distanc
84150 # Execute the vector query and check the response.
85151 firestore_api .run_query .return_value = AsyncIter ([response_pb1 , response_pb2 ])
86152
87- vector_async__query = query .where ("snooze" , "==" , 10 ).find_nearest (
153+ vector_async_query = query .where ("snooze" , "==" , 10 ).find_nearest (
88154 vector_field = "embedding" ,
89155 query_vector = Vector ([1.0 , 2.0 , 3.0 ]),
90156 distance_measure = distance_measure ,
91157 limit = 5 ,
92158 )
93159
94- returned = await vector_async__query .get (transaction = _transaction (client ), ** kwargs )
160+ returned = await vector_async_query .get (transaction = _transaction (client ), ** kwargs )
95161 assert isinstance (returned , list )
96162 assert len (returned ) == 2
97163 assert returned [0 ].to_dict () == data
0 commit comments