File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -223,6 +223,27 @@ def test_retrieval_query_rag_corpora_config_rank_service_success(self):
223223 )
224224 retrieve_contexts_eq (response , tc .TEST_RETRIEVAL_RESPONSE )
225225
226+ @pytest .mark .usefixtures ("retrieve_contexts_mock" )
227+ def test_retrieval_query_with_metadata_filter (self , retrieve_contexts_mock ):
228+ metadata_filter = 'doc.metadata.genre == "fiction"'
229+ rag_retrieval_config = rag .RagRetrievalConfig (
230+ top_k = 10 ,
231+ filter = rag .Filter (
232+ vector_distance_threshold = 0.5 , metadata_filter = metadata_filter
233+ ),
234+ )
235+ rag .retrieval_query (
236+ rag_resources = [tc .TEST_RAG_RESOURCE ],
237+ text = tc .TEST_QUERY_TEXT ,
238+ rag_retrieval_config = rag_retrieval_config ,
239+ )
240+ retrieve_contexts_mock .assert_called_once ()
241+ args , kwargs = retrieve_contexts_mock .call_args
242+ request = kwargs ["request" ]
243+ assert (
244+ request .query .rag_retrieval_config .filter .metadata_filter == metadata_filter
245+ )
246+
226247 @pytest .mark .usefixtures ("retrieve_contexts_mock" )
227248 def test_retrieval_query_rag_corpora_config_llm_ranker_success (self ):
228249 response = rag .retrieval_query (
Original file line number Diff line number Diff line change @@ -246,6 +246,10 @@ def retrieval_query(
246246 api_retrival_config .filter .vector_similarity_threshold = (
247247 rag_retrieval_config .filter .vector_similarity_threshold
248248 )
249+ if rag_retrieval_config .filter and rag_retrieval_config .filter .metadata_filter :
250+ api_retrival_config .filter .metadata_filter = (
251+ rag_retrieval_config .filter .metadata_filter
252+ )
249253
250254 if (
251255 rag_retrieval_config .ranking
@@ -495,6 +499,10 @@ async def async_retrieve_contexts(
495499 api_retrival_config .ranking .llm_ranker .model_name = (
496500 rag_retrieval_config .ranking .llm_ranker .model_name
497501 )
502+ if rag_retrieval_config .filter and rag_retrieval_config .filter .metadata_filter :
503+ api_retrival_config .filter .metadata_filter = (
504+ rag_retrieval_config .filter .metadata_filter
505+ )
498506
499507 query = aiplatform_v1beta1 .RagQuery (
500508 text = text ,
@@ -742,6 +750,10 @@ def ask_contexts(
742750 api_retrival_config .ranking .llm_ranker .model_name = (
743751 rag_retrieval_config .ranking .llm_ranker .model_name
744752 )
753+ if rag_retrieval_config .filter and rag_retrieval_config .filter .metadata_filter :
754+ api_retrival_config .filter .metadata_filter = (
755+ rag_retrieval_config .filter .metadata_filter
756+ )
745757
746758 query = aiplatform_v1beta1 .RagQuery (
747759 text = text ,
You can’t perform that action at this time.
0 commit comments