|
4 | 4 | from starlette.concurrency import run_in_threadpool |
5 | 5 | from starlette.exceptions import HTTPException |
6 | 6 | from starlette.requests import Request |
7 | | -from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send |
| 7 | +from starlette.types import ( |
| 8 | + ASGIApp, |
| 9 | + ExceptionHandler, |
| 10 | + HTTPExceptionHandler, |
| 11 | + Message, |
| 12 | + Receive, |
| 13 | + Scope, |
| 14 | + Send, |
| 15 | + WebSocketExceptionHandler, |
| 16 | +) |
8 | 17 | from starlette.websockets import WebSocket |
9 | 18 |
|
10 | 19 | ExceptionHandlers = typing.Dict[typing.Any, ExceptionHandler] |
@@ -59,12 +68,17 @@ async def sender(message: Message) -> None: |
59 | 68 | raise RuntimeError(msg) from exc |
60 | 69 |
|
61 | 70 | if scope["type"] == "http": |
| 71 | + nonlocal conn |
| 72 | + handler = typing.cast(HTTPExceptionHandler, handler) |
| 73 | + conn = typing.cast(Request, conn) |
62 | 74 | if is_async_callable(handler): |
63 | 75 | response = await handler(conn, exc) |
64 | 76 | else: |
65 | 77 | response = await run_in_threadpool(handler, conn, exc) |
66 | 78 | await response(scope, receive, sender) |
67 | 79 | elif scope["type"] == "websocket": |
| 80 | + handler = typing.cast(WebSocketExceptionHandler, handler) |
| 81 | + conn = typing.cast(WebSocket, conn) |
68 | 82 | if is_async_callable(handler): |
69 | 83 | await handler(conn, exc) |
70 | 84 | else: |
|
0 commit comments