Skip to content

Commit a65a8ef

Browse files
sararobcopybara-github
authored andcommitted
chore: fix remaining strict mypy errors
PiperOrigin-RevId: 749786746
1 parent b9d3be1 commit a65a8ef

9 files changed

Lines changed: 51 additions & 44 deletions

File tree

.github/workflows/mypy.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,4 @@ jobs:
3232
pip install -r requirements.txt
3333
3434
- name: Run mypy ${{ matrix.python-version }}
35-
run: mypy google/genai/ --config-file=google/genai/mypy.ini
35+
run: mypy google/genai/ --strict --config-file=google/genai/mypy.ini

google/genai/_api_client.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
"""
2121

2222
import asyncio
23+
from collections.abc import Awaitable, Generator
2324
import copy
2425
from dataclasses import dataclass
2526
import datetime
@@ -201,7 +202,7 @@ def __aiter__(self) -> 'HttpResponse':
201202
self.segment_iterator = self.async_segments()
202203
return self
203204

204-
async def __anext__(self):
205+
async def __anext__(self) -> Any:
205206
try:
206207
return await self.segment_iterator.__anext__()
207208
except StopIteration:
@@ -213,7 +214,7 @@ def json(self) -> Any:
213214
return ''
214215
return json.loads(self.response_stream[0])
215216

216-
def segments(self):
217+
def segments(self) -> Generator[Any, None, None]:
217218
if isinstance(self.response_stream, list):
218219
# list of objects retrieved from replay or from non-streaming API.
219220
for chunk in self.response_stream:
@@ -222,7 +223,7 @@ def segments(self):
222223
yield from []
223224
else:
224225
# Iterator of objects retrieved from the API.
225-
for chunk in self.response_stream.iter_lines():
226+
for chunk in self.response_stream.iter_lines(): # type: ignore[union-attr]
226227
if chunk:
227228
# In streaming mode, the chunk of JSON is prefixed with "data:" which
228229
# we must strip before parsing.
@@ -256,7 +257,7 @@ async def async_segments(self) -> AsyncIterator[Any]:
256257
else:
257258
raise ValueError('Error parsing streaming response.')
258259

259-
def byte_segments(self):
260+
def byte_segments(self) -> Generator[Union[bytes, Any], None, None]:
260261
if isinstance(self.byte_stream, list):
261262
# list of objects retrieved from replay or from non-streaming API.
262263
yield from self.byte_stream
@@ -267,7 +268,7 @@ def byte_segments(self):
267268
'Byte segments are not supported for streaming responses.'
268269
)
269270

270-
def _copy_to_dict(self, response_payload: dict[str, object]):
271+
def _copy_to_dict(self, response_payload: dict[str, object]) -> None:
271272
# Cannot pickle 'generator' object.
272273
delattr(self, 'segment_iterator')
273274
for attribute in dir(self):
@@ -521,11 +522,11 @@ def _access_token(self) -> str:
521522
_refresh_auth(self._credentials)
522523
if not self._credentials.token:
523524
raise RuntimeError('Could not resolve API token from the environment')
524-
return self._credentials.token
525+
return self._credentials.token # type: ignore[no-any-return]
525526
else:
526527
raise RuntimeError('Could not resolve API token from the environment')
527528

528-
async def _async_access_token(self) -> str:
529+
async def _async_access_token(self) -> Union[str, Any]:
529530
"""Retrieves the access token for the credentials asynchronously."""
530531
if not self._credentials:
531532
async with self._auth_lock:
@@ -675,7 +676,7 @@ def _request(
675676

676677
async def _async_request(
677678
self, http_request: HttpRequest, stream: bool = False
678-
):
679+
) -> HttpResponse:
679680
data: Optional[Union[str, bytes]] = None
680681
if self.vertexai and not self.api_key:
681682
http_request.headers['Authorization'] = (
@@ -735,7 +736,7 @@ def request(
735736
path: str,
736737
request_dict: dict[str, object],
737738
http_options: Optional[HttpOptionsOrDict] = None,
738-
):
739+
) -> Union[BaseResponse, Any]:
739740
http_request = self._build_request(
740741
http_method, path, request_dict, http_options
741742
)
@@ -753,7 +754,7 @@ def request_streamed(
753754
path: str,
754755
request_dict: dict[str, object],
755756
http_options: Optional[HttpOptionsOrDict] = None,
756-
):
757+
) -> Generator[Any, None, None]:
757758
http_request = self._build_request(
758759
http_method, path, request_dict, http_options
759760
)
@@ -768,7 +769,7 @@ async def async_request(
768769
path: str,
769770
request_dict: dict[str, object],
770771
http_options: Optional[HttpOptionsOrDict] = None,
771-
) -> dict[str, object]:
772+
) -> Union[BaseResponse, Any]:
772773
http_request = self._build_request(
773774
http_method, path, request_dict, http_options
774775
)
@@ -785,18 +786,18 @@ async def async_request_streamed(
785786
path: str,
786787
request_dict: dict[str, object],
787788
http_options: Optional[HttpOptionsOrDict] = None,
788-
):
789+
) -> Any:
789790
http_request = self._build_request(
790791
http_method, path, request_dict, http_options
791792
)
792793

793794
response = await self._async_request(http_request=http_request, stream=True)
794795

795-
async def async_generator():
796+
async def async_generator(): # type: ignore[no-untyped-def]
796797
async for chunk in response:
797798
yield chunk
798799

799-
return async_generator()
800+
return async_generator() # type: ignore[no-untyped-call]
800801

801802
def upload_file(
802803
self,
@@ -977,7 +978,7 @@ async def async_upload_file(
977978

978979
async def _async_upload_fd(
979980
self,
980-
file: Union[io.IOBase, anyio.AsyncFile],
981+
file: Union[io.IOBase, anyio.AsyncFile[Any]],
981982
upload_url: str,
982983
upload_size: int,
983984
*,
@@ -1093,5 +1094,5 @@ async def async_download_file(
10931094
# This method does nothing in the real api client. It is used in the
10941095
# replay_api_client to verify the response from the SDK method matches the
10951096
# recorded response.
1096-
def _verify_response(self, response_model: _common.BaseModel):
1097+
def _verify_response(self, response_model: _common.BaseModel) -> None:
10971098
pass

google/genai/_extra_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def convert_argument_from_function(
218218

219219

220220
def invoke_function_from_dict_args(
221-
args: Dict[str, Any], function_to_invoke: Callable
221+
args: Dict[str, Any], function_to_invoke: Callable[..., Any]
222222
) -> Any:
223223
converted_args = convert_argument_from_function(args, function_to_invoke)
224224
try:
@@ -232,7 +232,7 @@ def invoke_function_from_dict_args(
232232

233233

234234
async def invoke_function_from_dict_args_async(
235-
args: Dict[str, Any], function_to_invoke: Callable
235+
args: Dict[str, Any], function_to_invoke: Callable[..., Any]
236236
) -> Any:
237237
converted_args = convert_argument_from_function(args, function_to_invoke)
238238
try:
@@ -280,7 +280,7 @@ def get_function_response_parts(
280280

281281
async def get_function_response_parts_async(
282282
response: types.GenerateContentResponse,
283-
function_map: dict[str, Callable],
283+
function_map: dict[str, Callable[..., Any]],
284284
) -> list[types.Part]:
285285
"""Returns the function response parts from the response."""
286286
func_response_parts = []

google/genai/_replay_api_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def initialize_replay_session(self, replay_id: str) -> None:
226226
self._replay_id = replay_id
227227
self._initialize_replay_session()
228228

229-
def _get_replay_file_path(self):
229+
def _get_replay_file_path(self) -> str:
230230
return self._generate_file_path_from_replay_id(
231231
self.replays_directory, self._replay_id
232232
)
@@ -575,7 +575,7 @@ def download_file(
575575

576576
async def async_download_file(
577577
self, path: str, *, http_options: Optional[HttpOptionsOrDict] = None
578-
):
578+
) -> Any:
579579
self._initialize_replay_session_if_not_loaded()
580580
request = self._build_request(
581581
'get', path=path, request_dict={}, http_options=http_options

google/genai/_transformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -948,7 +948,7 @@ def t_resolve_operation(api_client: _api_client.BaseApiClient, struct: dict[str,
948948
if total_seconds > LRO_POLLING_TIMEOUT_SECONDS:
949949
raise RuntimeError(f'Operation {name} timed out.\n{operation}')
950950
# TODO(b/374433890): Replace with LRO module once it's available.
951-
operation = api_client.request(
951+
operation = api_client.request( # type: ignore[assignment]
952952
http_method='GET', path=name, request_dict={}
953953
)
954954
time.sleep(delay_seconds)

google/genai/chats.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
#
1515

16-
from collections.abc import AsyncGenerator
16+
from collections.abc import Iterator
1717
import sys
1818
from typing import AsyncIterator, Awaitable, Optional, Union, get_args
1919

@@ -283,7 +283,7 @@ def send_message_stream(
283283
self,
284284
message: Union[list[PartUnionDict], PartUnionDict],
285285
config: Optional[GenerateContentConfigOrDict] = None,
286-
):
286+
) -> Iterator[GenerateContentResponse]:
287287
"""Sends the conversation history with the additional message and yields the model's response in chunks.
288288
289289
Args:
@@ -478,7 +478,7 @@ async def async_generator(): # type: ignore[no-untyped-def]
478478
chunk = None
479479
async for chunk in await self._modules.generate_content_stream( # type: ignore[attr-defined]
480480
model=self._model,
481-
contents=self._curated_history + [input_content],
481+
contents=self._curated_history + [input_content], # type: ignore[arg-type]
482482
config=config if config else self._config,
483483
):
484484
if not _validate_response(chunk):
@@ -489,13 +489,16 @@ async def async_generator(): # type: ignore[no-untyped-def]
489489
finish_reason = chunk.candidates[0].finish_reason
490490
yield chunk
491491

492+
if not output_contents or finish_reason is None:
493+
is_valid = False
494+
492495
self.record_history(
493496
user_input=input_content,
494497
model_output=output_contents,
495498
automatic_function_calling_history=chunk.automatic_function_calling_history,
496-
is_valid=is_valid and output_contents and finish_reason,
499+
is_valid=is_valid,
497500
)
498-
return async_generator()
501+
return async_generator() # type: ignore[no-untyped-call, no-any-return]
499502

500503

501504
class AsyncChats:

google/genai/live.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@
4141

4242

4343
try:
44-
from websockets.asyncio.client import ClientConnection # type: ignore
45-
from websockets.asyncio.client import connect # type: ignore
44+
from websockets.asyncio.client import ClientConnection
45+
from websockets.asyncio.client import connect
4646
except ModuleNotFoundError:
4747
# This try/except is for TAP, mypy complains about it which is why we have the type: ignore
4848
from websockets.client import ClientConnection # type: ignore
@@ -127,7 +127,7 @@ async def send_client_content(
127127
]
128128
] = None,
129129
turn_complete: bool = True,
130-
):
130+
) -> None:
131131
"""Send non-realtime, turn based content to the model.
132132
133133
There are two ways to send messages to the live API:
@@ -203,7 +203,7 @@ async def send_client_content(
203203

204204
await self._ws.send(json.dumps({'client_content': client_content_dict}))
205205

206-
async def send_realtime_input(self, *, media: t.BlobUnion):
206+
async def send_realtime_input(self, *, media: t.BlobUnion) -> None:
207207
"""Send realtime media chunks to the model.
208208
209209
Use `send_realtime_input` for realtime audio chunks and video
@@ -267,7 +267,7 @@ async def send_tool_response(
267267
types.FunctionResponseOrDict,
268268
Sequence[types.FunctionResponseOrDict],
269269
],
270-
):
270+
) -> None:
271271
"""Send a tool response to the session.
272272
273273
Use `send_tool_response` to reply to `LiveServerToolCall` messages

google/genai/models.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5565,7 +5565,7 @@ async def _generate_content_stream(
55655565
'post', path, request_dict, http_options
55665566
)
55675567

5568-
async def async_generator():
5568+
async def async_generator(): # type: ignore[no-untyped-def]
55695569
async for response_dict in response_stream:
55705570

55715571
if self._api_client.vertexai:
@@ -5584,7 +5584,7 @@ async def async_generator():
55845584
self._api_client._verify_response(return_value)
55855585
yield return_value
55865586

5587-
return async_generator()
5587+
return async_generator() # type: ignore[no-untyped-call, no-any-return]
55885588

55895589
async def embed_content(
55905590
self,
@@ -6602,13 +6602,13 @@ async def generate_content_stream(
66026602
model=model, contents=contents, config=config
66036603
)
66046604

6605-
async def base_async_generator(model, contents, config):
6605+
async def base_async_generator(model, contents, config): # type: ignore[no-untyped-def]
66066606
async for chunk in response: # type: ignore[attr-defined]
66076607
yield chunk
66086608

6609-
return base_async_generator(model, contents, config)
6609+
return base_async_generator(model, contents, config) # type: ignore[no-untyped-call, no-any-return]
66106610

6611-
async def async_generator(model, contents, config):
6611+
async def async_generator(model, contents, config): # type: ignore[no-untyped-def]
66126612
remaining_remote_calls_afc = _extra_utils.get_max_remote_calls_afc(config)
66136613
logger.info(
66146614
f'AFC is enabled with max remote calls: {remaining_remote_calls_afc}.'
@@ -6687,15 +6687,15 @@ async def async_generator(model, contents, config):
66876687
)
66886688
contents = t.t_contents(self._api_client, contents)
66896689
if not automatic_function_calling_history:
6690-
automatic_function_calling_history.extend(contents) # type: ignore[arg-type]
6690+
automatic_function_calling_history.extend(contents)
66916691
if isinstance(contents, list) and func_call_content is not None:
6692-
contents.append(func_call_content) # type: ignore[arg-type]
6693-
contents.append(func_response_content) # type: ignore[arg-type]
6692+
contents.append(func_call_content)
6693+
contents.append(func_response_content)
66946694
if func_call_content is not None:
66956695
automatic_function_calling_history.append(func_call_content)
66966696
automatic_function_calling_history.append(func_response_content)
66976697

6698-
return async_generator(model, contents, config)
6698+
return async_generator(model, contents, config) # type: ignore[no-untyped-call, no-any-return]
66996699

67006700
async def edit_image(
67016701
self,

google/genai/mypy.ini

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
[mypy]
2-
exclude = tests/.*
2+
exclude = (tests/|_test_api_client\.py)
33
plugins = pydantic.mypy
4-
disable_error_code = import-not-found, import-untyped
4+
; we are ignoring 'unused-ignore' because we run mypy on Python 3.9 - 3.13 and
5+
; some errors in _automatic_function_calling_util.py only apply in 3.10+
6+
; 'import-not-found' and 'import-untyped' are environment specific
7+
disable_error_code = import-not-found, import-untyped, unused-ignore

0 commit comments

Comments
 (0)