Skip to content

Commit 3f38038

Browse files
TechNiickScirlat DanutKludex
authored
Add type hints to test_testclient.py (#2493)
* Add type hints to test_testclient.py * Fix check errors * Apply suggestions from code review * Use ASGIInstance instead --------- Co-authored-by: Scirlat Danut <scirlatdanut@scirlats-mini.lan> Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
1 parent eaee85b commit 3f38038

1 file changed

Lines changed: 45 additions & 41 deletions

File tree

tests/test_testclient.py

Lines changed: 45 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
from __future__ import annotations
2+
13
import itertools
24
import sys
3-
from asyncio import current_task as asyncio_current_task
5+
from asyncio import Task, current_task as asyncio_current_task
46
from contextlib import asynccontextmanager
5-
from typing import Callable
7+
from typing import Any, AsyncGenerator, Callable
68

79
import anyio
810
import anyio.lowlevel
@@ -15,19 +17,21 @@
1517
from starlette.requests import Request
1618
from starlette.responses import JSONResponse, RedirectResponse, Response
1719
from starlette.routing import Route
18-
from starlette.testclient import TestClient
20+
from starlette.testclient import ASGIInstance, TestClient
1921
from starlette.types import ASGIApp, Receive, Scope, Send
2022
from starlette.websockets import WebSocket, WebSocketDisconnect
2123

24+
TestClientFactory = Callable[..., TestClient]
25+
2226

23-
def mock_service_endpoint(request: Request):
27+
def mock_service_endpoint(request: Request) -> JSONResponse:
2428
return JSONResponse({"mock": "example"})
2529

2630

2731
mock_service = Starlette(routes=[Route("/", endpoint=mock_service_endpoint)])
2832

2933

30-
def current_task():
34+
def current_task() -> Task[Any] | trio.lowlevel.Task:
3135
# anyio's TaskInfo comparisons are invalid after their associated native
3236
# task object is GC'd https://github.com/agronholm/anyio/issues/324
3337
asynclib_name = sniffio.current_async_library()
@@ -42,19 +46,19 @@ def current_task():
4246
raise RuntimeError(f"unsupported asynclib={asynclib_name}") # pragma: no cover
4347

4448

45-
def startup():
49+
def startup() -> None:
4650
raise RuntimeError()
4751

4852

49-
def test_use_testclient_in_endpoint(test_client_factory: Callable[..., TestClient]):
53+
def test_use_testclient_in_endpoint(test_client_factory: TestClientFactory) -> None:
5054
"""
5155
We should be able to use the test client within applications.
5256
5357
This is useful if we need to mock out other services,
5458
during tests or in development.
5559
"""
5660

57-
def homepage(request: Request):
61+
def homepage(request: Request) -> JSONResponse:
5862
client = test_client_factory(mock_service)
5963
response = client.get("/")
6064
return JSONResponse(response.json())
@@ -66,7 +70,7 @@ def homepage(request: Request):
6670
assert response.json() == {"mock": "example"}
6771

6872

69-
def test_testclient_headers_behavior():
73+
def test_testclient_headers_behavior() -> None:
7074
"""
7175
We should be able to use the test client with user defined headers.
7276
@@ -86,16 +90,16 @@ def test_testclient_headers_behavior():
8690

8791

8892
def test_use_testclient_as_contextmanager(
89-
test_client_factory: Callable[..., TestClient], anyio_backend_name: str
90-
):
93+
test_client_factory: TestClientFactory, anyio_backend_name: str
94+
) -> None:
9195
"""
9296
This test asserts a number of properties that are important for an
9397
app level task_group
9498
"""
9599
counter = itertools.count()
96100
identity_runvar = anyio.lowlevel.RunVar[int]("identity_runvar")
97101

98-
def get_identity():
102+
def get_identity() -> int:
99103
try:
100104
return identity_runvar.get()
101105
except LookupError:
@@ -109,7 +113,7 @@ def get_identity():
109113
shutdown_loop = None
110114

111115
@asynccontextmanager
112-
async def lifespan_context(app: Starlette):
116+
async def lifespan_context(app: Starlette) -> AsyncGenerator[None, None]:
113117
nonlocal startup_task, startup_loop, shutdown_task, shutdown_loop
114118

115119
startup_task = current_task()
@@ -119,7 +123,7 @@ async def lifespan_context(app: Starlette):
119123
shutdown_task = current_task()
120124
shutdown_loop = get_identity()
121125

122-
async def loop_id(request: Request):
126+
async def loop_id(request: Request) -> JSONResponse:
123127
return JSONResponse(get_identity())
124128

125129
app = Starlette(
@@ -143,7 +147,7 @@ async def loop_id(request: Request):
143147
assert startup_task is shutdown_task
144148

145149
# outside the TestClient context, new requests continue to spawn in new
146-
# eventloops in new threads
150+
# event loops in new threads
147151
assert client.get("/loop_id").json() == 1
148152
assert client.get("/loop_id").json() == 2
149153

@@ -165,7 +169,7 @@ async def loop_id(request: Request):
165169
assert first_task is not startup_task
166170

167171

168-
def test_error_on_startup(test_client_factory: Callable[..., TestClient]):
172+
def test_error_on_startup(test_client_factory: TestClientFactory) -> None:
169173
with pytest.deprecated_call(
170174
match="The on_startup and on_shutdown parameters are deprecated"
171175
):
@@ -176,15 +180,15 @@ def test_error_on_startup(test_client_factory: Callable[..., TestClient]):
176180
pass # pragma: no cover
177181

178182

179-
def test_exception_in_middleware(test_client_factory: Callable[..., TestClient]):
183+
def test_exception_in_middleware(test_client_factory: TestClientFactory) -> None:
180184
class MiddlewareException(Exception):
181185
pass
182186

183187
class BrokenMiddleware:
184188
def __init__(self, app: ASGIApp):
185189
self.app = app
186190

187-
async def __call__(self, scope: Scope, receive: Receive, send: Send):
191+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
188192
raise MiddlewareException()
189193

190194
broken_middleware = Starlette(middleware=[Middleware(BrokenMiddleware)])
@@ -194,9 +198,9 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send):
194198
pass # pragma: no cover
195199

196200

197-
def test_testclient_asgi2(test_client_factory: Callable[..., TestClient]):
198-
def app(scope: Scope):
199-
async def inner(receive: Receive, send: Send):
201+
def test_testclient_asgi2(test_client_factory: TestClientFactory) -> None:
202+
def app(scope: Scope) -> ASGIInstance:
203+
async def inner(receive: Receive, send: Send) -> None:
200204
await send(
201205
{
202206
"type": "http.response.start",
@@ -213,8 +217,8 @@ async def inner(receive: Receive, send: Send):
213217
assert response.text == "Hello, world!"
214218

215219

216-
def test_testclient_asgi3(test_client_factory: Callable[..., TestClient]):
217-
async def app(scope: Scope, receive: Receive, send: Send):
220+
def test_testclient_asgi3(test_client_factory: TestClientFactory) -> None:
221+
async def app(scope: Scope, receive: Receive, send: Send) -> None:
218222
await send(
219223
{
220224
"type": "http.response.start",
@@ -229,12 +233,12 @@ async def app(scope: Scope, receive: Receive, send: Send):
229233
assert response.text == "Hello, world!"
230234

231235

232-
def test_websocket_blocking_receive(test_client_factory: Callable[..., TestClient]):
233-
def app(scope: Scope):
234-
async def respond(websocket: WebSocket):
236+
def test_websocket_blocking_receive(test_client_factory: TestClientFactory) -> None:
237+
def app(scope: Scope) -> ASGIInstance:
238+
async def respond(websocket: WebSocket) -> None:
235239
await websocket.send_json({"message": "test"})
236240

237-
async def asgi(receive: Receive, send: Send):
241+
async def asgi(receive: Receive, send: Send) -> None:
238242
websocket = WebSocket(scope, receive=receive, send=send)
239243
await websocket.accept()
240244
async with anyio.create_task_group() as task_group:
@@ -254,9 +258,9 @@ async def asgi(receive: Receive, send: Send):
254258
assert data == {"message": "test"}
255259

256260

257-
def test_websocket_not_block_on_close(test_client_factory: Callable[..., TestClient]):
258-
def app(scope: Scope):
259-
async def asgi(receive: Receive, send: Send):
261+
def test_websocket_not_block_on_close(test_client_factory: TestClientFactory) -> None:
262+
def app(scope: Scope) -> ASGIInstance:
263+
async def asgi(receive: Receive, send: Send) -> None:
260264
websocket = WebSocket(scope, receive=receive, send=send)
261265
await websocket.accept()
262266
while True:
@@ -271,8 +275,8 @@ async def asgi(receive: Receive, send: Send):
271275

272276

273277
@pytest.mark.parametrize("param", ("2020-07-14T00:00:00+00:00", "España", "voilà"))
274-
def test_query_params(test_client_factory: Callable[..., TestClient], param: str):
275-
def homepage(request: Request):
278+
def test_query_params(test_client_factory: TestClientFactory, param: str) -> None:
279+
def homepage(request: Request) -> Response:
276280
return Response(request.query_params["param"])
277281

278282
app = Starlette(routes=[Route("/", endpoint=homepage)])
@@ -301,8 +305,8 @@ def homepage(request: Request):
301305
],
302306
)
303307
def test_domain_restricted_cookies(
304-
test_client_factory: Callable[..., TestClient], domain: str, ok: bool
305-
):
308+
test_client_factory: TestClientFactory, domain: str, ok: bool
309+
) -> None:
306310
"""
307311
Test that test client discards domain restricted cookies which do not match the
308312
base_url of the testclient (`http://testserver` by default).
@@ -312,7 +316,7 @@ def test_domain_restricted_cookies(
312316
in accordance with RFC 2965.
313317
"""
314318

315-
async def app(scope: Scope, receive: Receive, send: Send):
319+
async def app(scope: Scope, receive: Receive, send: Send) -> None:
316320
response = Response("Hello, world!", media_type="text/plain")
317321
response.set_cookie(
318322
"mycookie",
@@ -328,8 +332,8 @@ async def app(scope: Scope, receive: Receive, send: Send):
328332
assert cookie_set == ok
329333

330334

331-
def test_forward_follow_redirects(test_client_factory: Callable[..., TestClient]):
332-
async def app(scope: Scope, receive: Receive, send: Send):
335+
def test_forward_follow_redirects(test_client_factory: TestClientFactory) -> None:
336+
async def app(scope: Scope, receive: Receive, send: Send) -> None:
333337
if "/ok" in scope["path"]:
334338
response = Response("ok")
335339
else:
@@ -341,8 +345,8 @@ async def app(scope: Scope, receive: Receive, send: Send):
341345
assert response.status_code == 200
342346

343347

344-
def test_forward_nofollow_redirects(test_client_factory: Callable[..., TestClient]):
345-
async def app(scope: Scope, receive: Receive, send: Send):
348+
def test_forward_nofollow_redirects(test_client_factory: TestClientFactory) -> None:
349+
async def app(scope: Scope, receive: Receive, send: Send) -> None:
346350
response = RedirectResponse("/ok")
347351
await response(scope, receive, send)
348352

@@ -351,7 +355,7 @@ async def app(scope: Scope, receive: Receive, send: Send):
351355
assert response.status_code == 307
352356

353357

354-
def test_with_duplicate_headers(test_client_factory: Callable[[Starlette], TestClient]):
358+
def test_with_duplicate_headers(test_client_factory: TestClientFactory) -> None:
355359
def homepage(request: Request) -> JSONResponse:
356360
return JSONResponse({"x-token": request.headers.getlist("x-token")})
357361

@@ -361,7 +365,7 @@ def homepage(request: Request) -> JSONResponse:
361365
assert response.json() == {"x-token": ["foo", "bar"]}
362366

363367

364-
def test_merge_url(test_client_factory: Callable[..., TestClient]):
368+
def test_merge_url(test_client_factory: TestClientFactory) -> None:
365369
def homepage(request: Request) -> Response:
366370
return Response(request.url.path)
367371

0 commit comments

Comments
 (0)