1+ from __future__ import annotations
2+
13import itertools
24import sys
3- from asyncio import current_task as asyncio_current_task
5+ from asyncio import Task , current_task as asyncio_current_task
46from contextlib import asynccontextmanager
5- from typing import Callable
7+ from typing import Any , AsyncGenerator , Callable
68
79import anyio
810import anyio .lowlevel
1517from starlette .requests import Request
1618from starlette .responses import JSONResponse , RedirectResponse , Response
1719from starlette .routing import Route
18- from starlette .testclient import TestClient
20+ from starlette .testclient import ASGIInstance , TestClient
1921from starlette .types import ASGIApp , Receive , Scope , Send
2022from starlette .websockets import WebSocket , WebSocketDisconnect
2123
24+ TestClientFactory = Callable [..., TestClient ]
25+
2226
23- def mock_service_endpoint (request : Request ):
27+ def mock_service_endpoint (request : Request ) -> JSONResponse :
2428 return JSONResponse ({"mock" : "example" })
2529
2630
2731mock_service = Starlette (routes = [Route ("/" , endpoint = mock_service_endpoint )])
2832
2933
30- def current_task ():
34+ def current_task () -> Task [ Any ] | trio . lowlevel . Task :
3135 # anyio's TaskInfo comparisons are invalid after their associated native
3236 # task object is GC'd https://github.com/agronholm/anyio/issues/324
3337 asynclib_name = sniffio .current_async_library ()
@@ -42,19 +46,19 @@ def current_task():
4246 raise RuntimeError (f"unsupported asynclib={ asynclib_name } " ) # pragma: no cover
4347
4448
45- def startup ():
49+ def startup () -> None :
4650 raise RuntimeError ()
4751
4852
49- def test_use_testclient_in_endpoint (test_client_factory : Callable [..., TestClient ]) :
53+ def test_use_testclient_in_endpoint (test_client_factory : TestClientFactory ) -> None :
5054 """
5155 We should be able to use the test client within applications.
5256
5357 This is useful if we need to mock out other services,
5458 during tests or in development.
5559 """
5660
57- def homepage (request : Request ):
61+ def homepage (request : Request ) -> JSONResponse :
5862 client = test_client_factory (mock_service )
5963 response = client .get ("/" )
6064 return JSONResponse (response .json ())
@@ -66,7 +70,7 @@ def homepage(request: Request):
6670 assert response .json () == {"mock" : "example" }
6771
6872
69- def test_testclient_headers_behavior ():
73+ def test_testclient_headers_behavior () -> None :
7074 """
7175 We should be able to use the test client with user defined headers.
7276
@@ -86,16 +90,16 @@ def test_testclient_headers_behavior():
8690
8791
8892def test_use_testclient_as_contextmanager (
89- test_client_factory : Callable [..., TestClient ] , anyio_backend_name : str
90- ):
93+ test_client_factory : TestClientFactory , anyio_backend_name : str
94+ ) -> None :
9195 """
9296 This test asserts a number of properties that are important for an
9397 app level task_group
9498 """
9599 counter = itertools .count ()
96100 identity_runvar = anyio .lowlevel .RunVar [int ]("identity_runvar" )
97101
98- def get_identity ():
102+ def get_identity () -> int :
99103 try :
100104 return identity_runvar .get ()
101105 except LookupError :
@@ -109,7 +113,7 @@ def get_identity():
109113 shutdown_loop = None
110114
111115 @asynccontextmanager
112- async def lifespan_context (app : Starlette ):
116+ async def lifespan_context (app : Starlette ) -> AsyncGenerator [ None , None ] :
113117 nonlocal startup_task , startup_loop , shutdown_task , shutdown_loop
114118
115119 startup_task = current_task ()
@@ -119,7 +123,7 @@ async def lifespan_context(app: Starlette):
119123 shutdown_task = current_task ()
120124 shutdown_loop = get_identity ()
121125
122- async def loop_id (request : Request ):
126+ async def loop_id (request : Request ) -> JSONResponse :
123127 return JSONResponse (get_identity ())
124128
125129 app = Starlette (
@@ -143,7 +147,7 @@ async def loop_id(request: Request):
143147 assert startup_task is shutdown_task
144148
145149 # outside the TestClient context, new requests continue to spawn in new
146- # eventloops in new threads
150+ # event loops in new threads
147151 assert client .get ("/loop_id" ).json () == 1
148152 assert client .get ("/loop_id" ).json () == 2
149153
@@ -165,7 +169,7 @@ async def loop_id(request: Request):
165169 assert first_task is not startup_task
166170
167171
168- def test_error_on_startup (test_client_factory : Callable [..., TestClient ]) :
172+ def test_error_on_startup (test_client_factory : TestClientFactory ) -> None :
169173 with pytest .deprecated_call (
170174 match = "The on_startup and on_shutdown parameters are deprecated"
171175 ):
@@ -176,15 +180,15 @@ def test_error_on_startup(test_client_factory: Callable[..., TestClient]):
176180 pass # pragma: no cover
177181
178182
179- def test_exception_in_middleware (test_client_factory : Callable [..., TestClient ]) :
183+ def test_exception_in_middleware (test_client_factory : TestClientFactory ) -> None :
180184 class MiddlewareException (Exception ):
181185 pass
182186
183187 class BrokenMiddleware :
184188 def __init__ (self , app : ASGIApp ):
185189 self .app = app
186190
187- async def __call__ (self , scope : Scope , receive : Receive , send : Send ):
191+ async def __call__ (self , scope : Scope , receive : Receive , send : Send ) -> None :
188192 raise MiddlewareException ()
189193
190194 broken_middleware = Starlette (middleware = [Middleware (BrokenMiddleware )])
@@ -194,9 +198,9 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send):
194198 pass # pragma: no cover
195199
196200
197- def test_testclient_asgi2 (test_client_factory : Callable [..., TestClient ]) :
198- def app (scope : Scope ):
199- async def inner (receive : Receive , send : Send ):
201+ def test_testclient_asgi2 (test_client_factory : TestClientFactory ) -> None :
202+ def app (scope : Scope ) -> ASGIInstance :
203+ async def inner (receive : Receive , send : Send ) -> None :
200204 await send (
201205 {
202206 "type" : "http.response.start" ,
@@ -213,8 +217,8 @@ async def inner(receive: Receive, send: Send):
213217 assert response .text == "Hello, world!"
214218
215219
216- def test_testclient_asgi3 (test_client_factory : Callable [..., TestClient ]) :
217- async def app (scope : Scope , receive : Receive , send : Send ):
220+ def test_testclient_asgi3 (test_client_factory : TestClientFactory ) -> None :
221+ async def app (scope : Scope , receive : Receive , send : Send ) -> None :
218222 await send (
219223 {
220224 "type" : "http.response.start" ,
@@ -229,12 +233,12 @@ async def app(scope: Scope, receive: Receive, send: Send):
229233 assert response .text == "Hello, world!"
230234
231235
232- def test_websocket_blocking_receive (test_client_factory : Callable [..., TestClient ]) :
233- def app (scope : Scope ):
234- async def respond (websocket : WebSocket ):
236+ def test_websocket_blocking_receive (test_client_factory : TestClientFactory ) -> None :
237+ def app (scope : Scope ) -> ASGIInstance :
238+ async def respond (websocket : WebSocket ) -> None :
235239 await websocket .send_json ({"message" : "test" })
236240
237- async def asgi (receive : Receive , send : Send ):
241+ async def asgi (receive : Receive , send : Send ) -> None :
238242 websocket = WebSocket (scope , receive = receive , send = send )
239243 await websocket .accept ()
240244 async with anyio .create_task_group () as task_group :
@@ -254,9 +258,9 @@ async def asgi(receive: Receive, send: Send):
254258 assert data == {"message" : "test" }
255259
256260
257- def test_websocket_not_block_on_close (test_client_factory : Callable [..., TestClient ]) :
258- def app (scope : Scope ):
259- async def asgi (receive : Receive , send : Send ):
261+ def test_websocket_not_block_on_close (test_client_factory : TestClientFactory ) -> None :
262+ def app (scope : Scope ) -> ASGIInstance :
263+ async def asgi (receive : Receive , send : Send ) -> None :
260264 websocket = WebSocket (scope , receive = receive , send = send )
261265 await websocket .accept ()
262266 while True :
@@ -271,8 +275,8 @@ async def asgi(receive: Receive, send: Send):
271275
272276
273277@pytest .mark .parametrize ("param" , ("2020-07-14T00:00:00+00:00" , "España" , "voilà" ))
274- def test_query_params (test_client_factory : Callable [..., TestClient ], param : str ):
275- def homepage (request : Request ):
278+ def test_query_params (test_client_factory : TestClientFactory , param : str ) -> None :
279+ def homepage (request : Request ) -> Response :
276280 return Response (request .query_params ["param" ])
277281
278282 app = Starlette (routes = [Route ("/" , endpoint = homepage )])
@@ -301,8 +305,8 @@ def homepage(request: Request):
301305 ],
302306)
303307def test_domain_restricted_cookies (
304- test_client_factory : Callable [..., TestClient ] , domain : str , ok : bool
305- ):
308+ test_client_factory : TestClientFactory , domain : str , ok : bool
309+ ) -> None :
306310 """
307311 Test that test client discards domain restricted cookies which do not match the
308312 base_url of the testclient (`http://testserver` by default).
@@ -312,7 +316,7 @@ def test_domain_restricted_cookies(
312316 in accordance with RFC 2965.
313317 """
314318
315- async def app (scope : Scope , receive : Receive , send : Send ):
319+ async def app (scope : Scope , receive : Receive , send : Send ) -> None :
316320 response = Response ("Hello, world!" , media_type = "text/plain" )
317321 response .set_cookie (
318322 "mycookie" ,
@@ -328,8 +332,8 @@ async def app(scope: Scope, receive: Receive, send: Send):
328332 assert cookie_set == ok
329333
330334
331- def test_forward_follow_redirects (test_client_factory : Callable [..., TestClient ]) :
332- async def app (scope : Scope , receive : Receive , send : Send ):
335+ def test_forward_follow_redirects (test_client_factory : TestClientFactory ) -> None :
336+ async def app (scope : Scope , receive : Receive , send : Send ) -> None :
333337 if "/ok" in scope ["path" ]:
334338 response = Response ("ok" )
335339 else :
@@ -341,8 +345,8 @@ async def app(scope: Scope, receive: Receive, send: Send):
341345 assert response .status_code == 200
342346
343347
344- def test_forward_nofollow_redirects (test_client_factory : Callable [..., TestClient ]) :
345- async def app (scope : Scope , receive : Receive , send : Send ):
348+ def test_forward_nofollow_redirects (test_client_factory : TestClientFactory ) -> None :
349+ async def app (scope : Scope , receive : Receive , send : Send ) -> None :
346350 response = RedirectResponse ("/ok" )
347351 await response (scope , receive , send )
348352
@@ -351,7 +355,7 @@ async def app(scope: Scope, receive: Receive, send: Send):
351355 assert response .status_code == 307
352356
353357
354- def test_with_duplicate_headers (test_client_factory : Callable [[ Starlette ], TestClient ]) :
358+ def test_with_duplicate_headers (test_client_factory : TestClientFactory ) -> None :
355359 def homepage (request : Request ) -> JSONResponse :
356360 return JSONResponse ({"x-token" : request .headers .getlist ("x-token" )})
357361
@@ -361,7 +365,7 @@ def homepage(request: Request) -> JSONResponse:
361365 assert response .json () == {"x-token" : ["foo" , "bar" ]}
362366
363367
364- def test_merge_url (test_client_factory : Callable [..., TestClient ]) :
368+ def test_merge_url (test_client_factory : TestClientFactory ) -> None :
365369 def homepage (request : Request ) -> Response :
366370 return Response (request .url .path )
367371
0 commit comments