From fe5d69e04026e82314dfb93f47e184e01c993a22 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Sun, 3 May 2026 22:51:32 -0700 Subject: [PATCH] feat: [Python] Multimodal file search Add embeddingModel for create file searech store Add mediaID to GroundingChunkRetrievedContext Add file_search_stores.downloadMedia PiperOrigin-RevId: 909810388 --- google/genai/file_search_stores.py | 125 ++++++++++++++- .../test_multimodal_flow.py | 146 ++++++++++++++++++ google/genai/types.py | 36 +++++ 3 files changed, 304 insertions(+), 3 deletions(-) create mode 100644 google/genai/tests/file_search_stores/test_multimodal_flow.py diff --git a/google/genai/file_search_stores.py b/google/genai/file_search_stores.py index 51eb85088..821a206c6 100644 --- a/google/genai/file_search_stores.py +++ b/google/genai/file_search_stores.py @@ -26,7 +26,9 @@ from . import _api_module from . import _common from . import _extra_utils +from . import _transformers as t from . import types +from ._api_client import BaseApiClient from ._common import get_value_by_path as getv from ._common import set_value_by_path as setv from ._operations_converters import _UploadToFileSearchStoreOperation_from_mldev @@ -37,6 +39,7 @@ def _CreateFileSearchStoreConfig_to_mldev( + api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: @@ -45,17 +48,25 @@ def _CreateFileSearchStoreConfig_to_mldev( if getv(from_object, ['display_name']) is not None: setv(parent_object, ['displayName'], getv(from_object, ['display_name'])) + if getv(from_object, ['embedding_model']) is not None: + setv( + parent_object, + ['_query', 'embeddingModel'], + t.t_model(api_client, getv(from_object, ['embedding_model'])), + ) + return to_object def _CreateFileSearchStoreParameters_to_mldev( + api_client: BaseApiClient, from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: to_object: dict[str, Any] = {} if getv(from_object, ['config']) is not None: _CreateFileSearchStoreConfig_to_mldev( - getv(from_object, ['config']), to_object + api_client, getv(from_object, ['config']), to_object ) return to_object @@ -340,7 +351,9 @@ def create( 'This method is only supported in the Gemini Developer client.' ) else: - request_dict = _CreateFileSearchStoreParameters_to_mldev(parameter_model) + request_dict = _CreateFileSearchStoreParameters_to_mldev( + self._api_client, parameter_model + ) request_url_dict = request_dict.get('_url') if request_url_dict: path = 'fileSearchStores'.format_map(request_url_dict) @@ -843,6 +856,58 @@ def upload_to_file_search_store( response=response_dict, kwargs={} ) + def download_media( + self, + *, + media_id: str, + config: Optional[types.DownloadMediaConfigOrDict] = None, + ) -> bytes: + """Downloads media using a Media ID. + + The media_id has the format: + fileSearchStores//media/ + + This is mapped to the DownloadMedia RPC which expects: + GET /{name=fileSearchStores/*/media/*} + + Args: + media_id: The Media ID from grounding metadata. + config: Optional configuration for the download. + + Returns: + bytes: The media data. + """ + if self._api_client.vertexai: + raise ValueError( + 'This method is only supported in the Gemini Developer client.' + ) + + clean_id = media_id.lstrip('/') + if '/media/' not in clean_id: + raise ValueError( + f'Invalid media_id format: {media_id!r}. ' + 'Expected format: fileSearchStores//media/' + ) + + path = f'{clean_id}?alt=media' + + config_model = None + if config: + if isinstance(config, dict): + config_model = types.DownloadMediaConfig(**config) + else: + config_model = config + + http_options = None + if config_model and getv(config_model, ['http_options']) is not None: + http_options = getv(config_model, ['http_options']) + + data = self._api_client.download_file( + path, + http_options=http_options, + ) + return data + def list( self, *, config: Optional[types.ListFileSearchStoresConfigOrDict] = None ) -> Pager[types.FileSearchStore]: @@ -903,7 +968,9 @@ async def create( 'This method is only supported in the Gemini Developer client.' ) else: - request_dict = _CreateFileSearchStoreParameters_to_mldev(parameter_model) + request_dict = _CreateFileSearchStoreParameters_to_mldev( + self._api_client, parameter_model + ) request_url_dict = request_dict.get('_url') if request_url_dict: path = 'fileSearchStores'.format_map(request_url_dict) @@ -1412,6 +1479,58 @@ async def upload_to_file_search_store( response=response_dict, kwargs={} ) + async def download_media( + self, + *, + media_id: str, + config: Optional[types.DownloadMediaConfigOrDict] = None, + ) -> bytes: + """Downloads media using a Media ID. + + The media_id has the format: + fileSearchStores//media/ + + This is mapped to the DownloadMedia RPC which expects: + GET /{name=fileSearchStores/*/media/*} + + Args: + media_id: The Media ID from grounding metadata. + config: Optional configuration for the download. + + Returns: + bytes: The media data. + """ + if self._api_client.vertexai: + raise ValueError( + 'This method is only supported in the Gemini Developer client.' + ) + + clean_id = media_id.lstrip('/') + if '/media/' not in clean_id: + raise ValueError( + f'Invalid media_id format: {media_id!r}. ' + 'Expected format: fileSearchStores//media/' + ) + + path = f'{clean_id}?alt=media' + + config_model = None + if config: + if isinstance(config, dict): + config_model = types.DownloadMediaConfig(**config) + else: + config_model = config + + http_options = None + if config_model and getv(config_model, ['http_options']) is not None: + http_options = getv(config_model, ['http_options']) + + data = await self._api_client.async_download_file( + path, + http_options=http_options, + ) + return data + async def list( self, *, config: Optional[types.ListFileSearchStoresConfigOrDict] = None ) -> AsyncPager[types.FileSearchStore]: diff --git a/google/genai/tests/file_search_stores/test_multimodal_flow.py b/google/genai/tests/file_search_stores/test_multimodal_flow.py new file mode 100644 index 000000000..18a259ee3 --- /dev/null +++ b/google/genai/tests/file_search_stores/test_multimodal_flow.py @@ -0,0 +1,146 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import os +import time +import pydantic +from ... import types +from .. import pytest_helper + + +class MultimodalFlowParams(pydantic.BaseModel): + display_name: str + query: str + text_content: str + image_relative_path: str + + +test_table: list[pytest_helper.TestTableItem] = [ + pytest_helper.TestTableItem( + name='test_multimodal_search_flow', + parameters=MultimodalFlowParams( + display_name='test-multimodal-store', + query=( + 'Find the photo of the dog in the park, what is the dog doing?' + ), + text_content='This is a test text file content for file search.', + image_relative_path='../data/dog.jpg', + ), + exception_if_vertex='supported', + ) +] + +pytestmark = pytest_helper.setup( + file=__file__, + globals_for_file=globals(), + test_method='multimodal_search_flow', + test_table=test_table, + http_options={ + 'api_version': 'v1beta', + 'base_url': ( + 'https://autopush-generativelanguage.sandbox.googleapis.com' + ), + }, +) + + +def multimodal_search_flow(client, parameters: MultimodalFlowParams): + # 1. Create Store + store = None + try: + store = client.file_search_stores.create( + config=types.CreateFileSearchStoreConfig( + display_name=parameters.display_name, + embedding_model='models/gemini-embedding-2-preview', + ) + ) + + # 2. Upload Text + text_file = io.BytesIO(parameters.text_content.encode('utf-8')) + op_text = client.file_search_stores.upload_to_file_search_store( + file_search_store_name=store.name, + file=text_file, + config=types.UploadToFileSearchStoreConfig(mime_type='text/plain'), + ) + + original_cwd = os.getcwd() + try: + # cd is necessary because the recorder records the file path, so we need to use a relative path here. + os.chdir(os.path.dirname(__file__)) + op_image = client.file_search_stores.upload_to_file_search_store( + file_search_store_name=store.name, + file=parameters.image_relative_path, + config=types.UploadToFileSearchStoreConfig(mime_type='image/png'), + ) + finally: + os.chdir(original_cwd) + + # 4. Wait for operations + # In replay mode, these might be fast or pre-recorded. + # In live mode, we need to poll. + while not op_text.done: + time.sleep(1) + op_text = client.operations.get(op_text) + + if op_image: + while not op_image.done: + time.sleep(1) + op_image = client.operations.get(op_image) + + # 5. Search + response = client.models.generate_content( + model='gemini-2.5-flash', + contents=parameters.query, + config=types.GenerateContentConfig( + tools=[ + types.Tool( + file_search=types.FileSearch( + file_search_store_names=[store.name] + ) + ) + ] + ), + ) + + # Verify response has grounding metadata + assert response.candidates[0].grounding_metadata is not None + + # 6. Download Media + # Extract Media ID from grounding chunks if available + blob_media_id = None + if response.candidates[0].grounding_metadata.grounding_chunks: + for chunk in response.candidates[0].grounding_metadata.grounding_chunks: + if chunk.retrieved_context and chunk.retrieved_context.media_id: + blob_media_id = chunk.retrieved_context.media_id + break + + # If we are on MLDev, we expect a Media ID and should be able to download it. + if not client.vertexai: + if not blob_media_id: + raise ValueError('No media_id found in grounding metadata to test download.') + content = client.file_search_stores.download_media( + media_id=blob_media_id + ) + assert content is not None + else: + # On Vertex, we expect download_media to fail if we call it. + with pytest_helper.exception_if_vertex(client, ValueError): + if blob_media_id: + client.file_search_stores.download_media(media_id=blob_media_id) + finally: + if store: + client.file_search_stores.delete( + name=store.name, config=types.DeleteFileSearchStoreConfig(force=True) + ) diff --git a/google/genai/types.py b/google/genai/types.py index 37a965025..df3a898c0 100644 --- a/google/genai/types.py +++ b/google/genai/types.py @@ -6955,6 +6955,10 @@ class GroundingChunkRetrievedContext(_common.BaseModel): default=None, description="""Optional. Page number of the retrieved context. This field is not supported in Vertex AI.""", ) + media_id: Optional[str] = Field( + default=None, + description="""Optional. Media ID. This field is not supported in Vertex AI.""", + ) class GroundingChunkRetrievedContextDict(TypedDict, total=False): @@ -6988,6 +6992,9 @@ class GroundingChunkRetrievedContextDict(TypedDict, total=False): page_number: Optional[int] """Optional. Page number of the retrieved context. This field is not supported in Vertex AI.""" + media_id: Optional[str] + """Optional. Media ID. This field is not supported in Vertex AI.""" + GroundingChunkRetrievedContextOrDict = Union[ GroundingChunkRetrievedContext, GroundingChunkRetrievedContextDict @@ -15201,6 +15208,12 @@ class CreateFileSearchStoreConfig(_common.BaseModel): description="""The human-readable display name for the file search store. """, ) + embedding_model: Optional[str] = Field( + default=None, + description="""The embedding model to use for the FileSearchStore. + Format: `models/{model}`. If not specified, the default embedding model will be used. + """, + ) class CreateFileSearchStoreConfigDict(TypedDict, total=False): @@ -15213,6 +15226,11 @@ class CreateFileSearchStoreConfigDict(TypedDict, total=False): """The human-readable display name for the file search store. """ + embedding_model: Optional[str] + """The embedding model to use for the FileSearchStore. + Format: `models/{model}`. If not specified, the default embedding model will be used. + """ + CreateFileSearchStoreConfigOrDict = Union[ CreateFileSearchStoreConfig, CreateFileSearchStoreConfigDict @@ -21205,3 +21223,21 @@ def from_api_response( response_dict = _UploadToFileSearchStoreOperation_from_mldev(api_response) return cls._from_response(response=response_dict, kwargs={}) + + +class DownloadMediaConfig(_common.BaseModel): + """Used to override the default configuration.""" + + http_options: Optional[HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class DownloadMediaConfigDict(TypedDict, total=False): + """Used to override the default configuration.""" + + http_options: Optional[HttpOptionsDict] + """Used to override HTTP request options.""" + + +DownloadMediaConfigOrDict = Union[DownloadMediaConfig, DownloadMediaConfigDict]