Skip to content

Commit 841c597

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add metadata_filter usage to all retrieval and generator methods in rag_retrieval.py
PiperOrigin-RevId: 889497482
1 parent 33cc6e2 commit 841c597

2 files changed

Lines changed: 33 additions & 0 deletions

File tree

tests/unit/vertex_rag/test_rag_retrieval_preview.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,27 @@ def test_retrieval_query_rag_corpora_config_rank_service_success(self):
223223
)
224224
retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE)
225225

226+
@pytest.mark.usefixtures("retrieve_contexts_mock")
227+
def test_retrieval_query_with_metadata_filter(self, retrieve_contexts_mock):
228+
metadata_filter = 'doc.metadata.genre == "fiction"'
229+
rag_retrieval_config = rag.RagRetrievalConfig(
230+
top_k=10,
231+
filter=rag.Filter(
232+
vector_distance_threshold=0.5, metadata_filter=metadata_filter
233+
),
234+
)
235+
rag.retrieval_query(
236+
rag_resources=[tc.TEST_RAG_RESOURCE],
237+
text=tc.TEST_QUERY_TEXT,
238+
rag_retrieval_config=rag_retrieval_config,
239+
)
240+
retrieve_contexts_mock.assert_called_once()
241+
args, kwargs = retrieve_contexts_mock.call_args
242+
request = kwargs["request"]
243+
assert (
244+
request.query.rag_retrieval_config.filter.metadata_filter == metadata_filter
245+
)
246+
226247
@pytest.mark.usefixtures("retrieve_contexts_mock")
227248
def test_retrieval_query_rag_corpora_config_llm_ranker_success(self):
228249
response = rag.retrieval_query(

vertexai/preview/rag/rag_retrieval.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,10 @@ def retrieval_query(
246246
api_retrival_config.filter.vector_similarity_threshold = (
247247
rag_retrieval_config.filter.vector_similarity_threshold
248248
)
249+
if rag_retrieval_config.filter and rag_retrieval_config.filter.metadata_filter:
250+
api_retrival_config.filter.metadata_filter = (
251+
rag_retrieval_config.filter.metadata_filter
252+
)
249253

250254
if (
251255
rag_retrieval_config.ranking
@@ -495,6 +499,10 @@ async def async_retrieve_contexts(
495499
api_retrival_config.ranking.llm_ranker.model_name = (
496500
rag_retrieval_config.ranking.llm_ranker.model_name
497501
)
502+
if rag_retrieval_config.filter and rag_retrieval_config.filter.metadata_filter:
503+
api_retrival_config.filter.metadata_filter = (
504+
rag_retrieval_config.filter.metadata_filter
505+
)
498506

499507
query = aiplatform_v1beta1.RagQuery(
500508
text=text,
@@ -742,6 +750,10 @@ def ask_contexts(
742750
api_retrival_config.ranking.llm_ranker.model_name = (
743751
rag_retrieval_config.ranking.llm_ranker.model_name
744752
)
753+
if rag_retrieval_config.filter and rag_retrieval_config.filter.metadata_filter:
754+
api_retrival_config.filter.metadata_filter = (
755+
rag_retrieval_config.filter.metadata_filter
756+
)
745757

746758
query = aiplatform_v1beta1.RagQuery(
747759
text=text,

0 commit comments

Comments
 (0)