Skip to content

Commit 9a213c1

Browse files
authored
Use ParamSpec for run_in_threadpool (#2375)
1 parent 6715eb4 commit 9a213c1

2 files changed

Lines changed: 23 additions & 4 deletions

File tree

starlette/_exception_handler.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,16 @@
44
from starlette.concurrency import run_in_threadpool
55
from starlette.exceptions import HTTPException
66
from starlette.requests import Request
7-
from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send
7+
from starlette.types import (
8+
ASGIApp,
9+
ExceptionHandler,
10+
HTTPExceptionHandler,
11+
Message,
12+
Receive,
13+
Scope,
14+
Send,
15+
WebSocketExceptionHandler,
16+
)
817
from starlette.websockets import WebSocket
918

1019
ExceptionHandlers = typing.Dict[typing.Any, ExceptionHandler]
@@ -59,12 +68,17 @@ async def sender(message: Message) -> None:
5968
raise RuntimeError(msg) from exc
6069

6170
if scope["type"] == "http":
71+
nonlocal conn
72+
handler = typing.cast(HTTPExceptionHandler, handler)
73+
conn = typing.cast(Request, conn)
6274
if is_async_callable(handler):
6375
response = await handler(conn, exc)
6476
else:
6577
response = await run_in_threadpool(handler, conn, exc)
6678
await response(scope, receive, sender)
6779
elif scope["type"] == "websocket":
80+
handler = typing.cast(WebSocketExceptionHandler, handler)
81+
conn = typing.cast(WebSocket, conn)
6882
if is_async_callable(handler):
6983
await handler(conn, exc)
7084
else:

starlette/concurrency.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
11
import functools
2+
import sys
23
import typing
34
import warnings
45

56
import anyio.to_thread
67

8+
if sys.version_info >= (3, 10): # pragma: no cover
9+
from typing import ParamSpec
10+
else: # pragma: no cover
11+
from typing_extensions import ParamSpec
12+
13+
P = ParamSpec("P")
714
T = typing.TypeVar("T")
815

916

@@ -24,10 +31,8 @@ async def run(func: typing.Callable[[], typing.Coroutine]) -> None: # type: ign
2431
task_group.start_soon(run, functools.partial(func, **kwargs))
2532

2633

27-
# TODO: We should use `ParamSpec` here, but mypy doesn't support it yet.
28-
# Check https://github.com/python/mypy/issues/12278 for more details.
2934
async def run_in_threadpool(
30-
func: typing.Callable[..., T], *args: typing.Any, **kwargs: typing.Any
35+
func: typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs
3136
) -> T:
3237
if kwargs: # pragma: no cover
3338
# run_sync doesn't accept 'kwargs', so bind them in here

0 commit comments

Comments
 (0)