|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | import sys |
4 | | -from collections.abc import Iterator |
5 | | -from typing import Any, Protocol |
| 4 | +from collections.abc import Awaitable, Iterator |
| 5 | +from typing import Any, Callable, Protocol |
6 | 6 |
|
7 | 7 | if sys.version_info >= (3, 10): # pragma: no cover |
8 | 8 | from typing import ParamSpec |
9 | 9 | else: # pragma: no cover |
10 | 10 | from typing_extensions import ParamSpec |
11 | 11 |
|
12 | | -from starlette.types import ASGIApp |
13 | 12 |
|
14 | 13 | P = ParamSpec("P") |
15 | 14 |
|
16 | 15 |
|
| 16 | +_Scope = Any |
| 17 | +_Receive = Callable[[], Awaitable[Any]] |
| 18 | +_Send = Callable[[Any], Awaitable[None]] |
| 19 | +# Since `starlette.types.ASGIApp` type differs from `ASGIApplication` from `asgiref` |
| 20 | +# we need to define a more permissive version of ASGIApp that doesn't cause type errors. |
| 21 | +_ASGIApp = Callable[[_Scope, _Receive, _Send], Awaitable[None]] |
| 22 | + |
| 23 | + |
17 | 24 | class _MiddlewareFactory(Protocol[P]): |
18 | | - def __call__(self, app: ASGIApp, /, *args: P.args, **kwargs: P.kwargs) -> ASGIApp: ... # pragma: no cover |
| 25 | + def __call__(self, app: _ASGIApp, /, *args: P.args, **kwargs: P.kwargs) -> _ASGIApp: ... # pragma: no cover |
19 | 26 |
|
20 | 27 |
|
21 | 28 | class Middleware: |
22 | | - def __init__( |
23 | | - self, |
24 | | - cls: _MiddlewareFactory[P], |
25 | | - *args: P.args, |
26 | | - **kwargs: P.kwargs, |
27 | | - ) -> None: |
| 29 | + def __init__(self, cls: _MiddlewareFactory[P], *args: P.args, **kwargs: P.kwargs) -> None: |
28 | 30 | self.cls = cls |
29 | 31 | self.args = args |
30 | 32 | self.kwargs = kwargs |
|
0 commit comments