Skip to content

Commit d438fb1

Browse files
authored
Support ws_ping_interval and ws_ping_timeout in wsproto implementation (#2916)
1 parent 3e6b964 commit d438fb1

3 files changed

Lines changed: 111 additions & 10 deletions

File tree

docs/settings.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ Using Uvicorn with watchfiles will enable the following options (which are other
9696
* `--ws <str>` - Set the WebSockets protocol implementation. Either of the `websockets` and `wsproto` packages are supported. There are two versions of `websockets` supported: `websockets` and `websockets-sansio`. Use `'none'` to ignore all websocket requests. **Options:** *'auto', 'none', 'websockets', 'websockets-sansio', 'wsproto'.* **Default:** *'auto'*.
9797
* `--ws-max-size <int>` - Set the WebSockets max message size, in bytes. **Default:** *16777216* (16 MB).
9898
* `--ws-max-queue <int>` - Set the maximum length of the WebSocket incoming message queue. Only available with the `websockets` protocol. **Default:** *32*.
99-
* `--ws-ping-interval <float>` - Set the WebSockets ping interval, in seconds. Available with the `websockets` and `websockets-sansio` protocols. **Default:** *20.0*.
100-
* `--ws-ping-timeout <float>` - Set the WebSockets ping timeout, in seconds. Available with the `websockets` and `websockets-sansio` protocols. **Default:** *20.0*.
99+
* `--ws-ping-interval <float>` - Set the WebSockets ping interval, in seconds. **Default:** *20.0*.
100+
* `--ws-ping-timeout <float>` - Set the WebSockets ping timeout, in seconds. **Default:** *20.0*.
101101
* `--ws-per-message-deflate <bool>` - Enable/disable WebSocket per-message-deflate compression. Only available with the `websockets` protocol. **Default:** *True*.
102102
* `--lifespan <str>` - Set the Lifespan protocol implementation. **Options:** *'auto', 'on', 'off'.* **Default:** *'auto'*.
103103
* `--h11-max-incomplete-event-size <int>` - Set the maximum number of bytes to buffer of an incomplete event. Only available for `h11` HTTP protocol implementation. **Default:** *16384* (16 KB).

tests/protocols/test_websocket.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444

4545
HTTPProtocol: TypeAlias = "type[H11Protocol | HttpToolsProtocol]"
4646
WSProtocol: TypeAlias = "type[_WSProtocol | WebSocketProtocol]"
47+
KeepaliveWSProtocol: TypeAlias = "type[_WSProtocol | WebSocketsSansIOProtocol]"
4748

4849
pytestmark = pytest.mark.anyio
4950

@@ -1230,7 +1231,27 @@ async def app_wrapper(scope: Scope, receive: ASGIReceiveCallable, send: ASGISend
12301231
assert expected_states == actual_states
12311232

12321233

1233-
async def test_server_keepalive_ping_pong(http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
1234+
@pytest.fixture(
1235+
params=[
1236+
pytest.param(
1237+
"uvicorn.protocols.websockets.wsproto_impl:WSProtocol",
1238+
marks=skip_if_no_wsproto,
1239+
id="wsproto",
1240+
),
1241+
pytest.param(
1242+
"uvicorn.protocols.websockets.websockets_sansio_impl:WebSocketsSansIOProtocol", id="websockets-sansio"
1243+
),
1244+
]
1245+
)
1246+
def keepalive_ws_protocol_cls(request: pytest.FixtureRequest):
1247+
from uvicorn.importer import import_from_string
1248+
1249+
return import_from_string(request.param)
1250+
1251+
1252+
async def test_server_keepalive_ping_pong(
1253+
keepalive_ws_protocol_cls: KeepaliveWSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
1254+
):
12341255
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
12351256
while True:
12361257
message = await receive()
@@ -1241,7 +1262,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
12411262

12421263
config = Config(
12431264
app=app,
1244-
ws=WebSocketsSansIOProtocol,
1265+
ws=keepalive_ws_protocol_cls,
12451266
http=http_protocol_cls,
12461267
lifespan="off",
12471268
ws_ping_interval=0.1,
@@ -1252,7 +1273,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
12521273
# The websockets client auto-responds to ping frames, keeping the connection alive.
12531274
async with websockets.connect(f"ws://127.0.0.1:{unused_tcp_port}", ping_interval=None):
12541275
protocol = list(server.server_state.connections)[0]
1255-
assert isinstance(protocol, WebSocketsSansIOProtocol)
1276+
assert isinstance(protocol, (_WSProtocol, WebSocketsSansIOProtocol))
12561277

12571278
# Wait until the server sends at least one keepalive ping, then
12581279
# sleep past the timeout window and ensure the connection stays open.
@@ -1267,7 +1288,9 @@ async def ping_sent() -> None:
12671288
assert not protocol.transport.is_closing()
12681289

12691290

1270-
async def test_server_keepalive_ping_timeout(http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
1291+
async def test_server_keepalive_ping_timeout(
1292+
keepalive_ws_protocol_cls: KeepaliveWSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
1293+
):
12711294
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
12721295
while True:
12731296
message = await receive()
@@ -1278,7 +1301,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
12781301

12791302
config = Config(
12801303
app=app,
1281-
ws=WebSocketsSansIOProtocol,
1304+
ws=keepalive_ws_protocol_cls,
12821305
http=http_protocol_cls,
12831306
lifespan="off",
12841307
ws_ping_interval=0.1,
@@ -1297,7 +1320,9 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
12971320
assert exc_info.value.rcvd.reason == "keepalive ping timeout"
12981321

12991322

1300-
async def test_server_keepalive_disabled(http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
1323+
async def test_server_keepalive_disabled(
1324+
keepalive_ws_protocol_cls: KeepaliveWSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int
1325+
):
13011326
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
13021327
while True:
13031328
message = await receive()
@@ -1308,7 +1333,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
13081333

13091334
config = Config(
13101335
app=app,
1311-
ws=WebSocketsSansIOProtocol,
1336+
ws=keepalive_ws_protocol_cls,
13121337
http=http_protocol_cls,
13131338
lifespan="off",
13141339
ws_ping_interval=None,
@@ -1317,5 +1342,5 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
13171342
async with run_server(config) as server:
13181343
async with websockets.connect(f"ws://127.0.0.1:{unused_tcp_port}", ping_interval=None):
13191344
protocol = list(server.server_state.connections)[0]
1320-
assert isinstance(protocol, WebSocketsSansIOProtocol)
1345+
assert isinstance(protocol, (_WSProtocol, WebSocketsSansIOProtocol))
13211346
assert protocol.ping_timer is None

uvicorn/protocols/websockets/wsproto_impl.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
import asyncio
44
import logging
5+
import random
6+
import struct
7+
from asyncio import TimerHandle
58
from io import BytesIO, StringIO
69
from typing import Any, Literal, cast
710
from urllib.parse import unquote
@@ -99,6 +102,15 @@ def __init__(
99102
self.writable = asyncio.Event()
100103
self.writable.set()
101104

105+
# Keepalive state
106+
self.ping_interval = config.ws_ping_interval
107+
self.ping_timeout = config.ws_ping_timeout
108+
self.ping_timer: TimerHandle | None = None
109+
self.pong_timer: TimerHandle | None = None
110+
self.pending_ping_payload: bytes | None = None
111+
self.ping_sent_at: float = 0.0
112+
self.last_ping_rtt: float = 0.0
113+
102114
# Buffer
103115
self.buffer = WebsocketBuffer(self.config.ws_max_size)
104116

@@ -116,6 +128,7 @@ def connection_made(self, transport: asyncio.Transport) -> None: # type: ignore
116128
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix)
117129

118130
def connection_lost(self, exc: Exception | None) -> None:
131+
self.stop_keepalive()
119132
code = 1005 if self.handshake_complete else 1006
120133
self.queue.put_nowait({"type": "websocket.disconnect", "code": code})
121134
self.connections.remove(self)
@@ -153,6 +166,8 @@ def handle_events(self) -> None:
153166
self.handle_close(event)
154167
elif isinstance(event, events.Ping):
155168
self.handle_ping(event)
169+
elif isinstance(event, events.Pong):
170+
self.handle_pong(event)
156171

157172
def pause_writing(self) -> None:
158173
"""
@@ -167,6 +182,7 @@ def resume_writing(self) -> None:
167182
self.writable.set() # pragma: full coverage
168183

169184
def shutdown(self) -> None:
185+
self.stop_keepalive()
170186
if self.handshake_complete:
171187
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012})
172188
output = self.conn.send(wsproto.events.CloseConnection(code=1012))
@@ -235,6 +251,65 @@ def handle_close(self, event: events.CloseConnection) -> None:
235251
def handle_ping(self, event: events.Ping) -> None:
236252
self.transport.write(self.conn.send(event.response()))
237253

254+
def handle_pong(self, event: events.Pong) -> None:
255+
# Ignore unsolicited pongs and stale pongs whose payload doesn't match the ping currently in flight.
256+
if self.pending_ping_payload is None or bytes(event.payload) != self.pending_ping_payload:
257+
return # pragma: no cover
258+
259+
self.last_ping_rtt = self.loop.time() - self.ping_sent_at
260+
self.pending_ping_payload = None
261+
# The peer answered in time; cancel the pong deadline and chain the next ping. This `schedule_ping()` call is
262+
# what keeps the keepalive loop running when ping_timeout is set. When ping_timeout is None the next ping is
263+
# already scheduled by `send_keepalive_ping`, so we must not schedule a duplicate here.
264+
if self.pong_timer is not None:
265+
self.pong_timer.cancel()
266+
self.pong_timer = None
267+
self.schedule_ping()
268+
269+
def start_keepalive(self) -> None:
270+
if self.ping_interval is not None and self.ping_interval > 0:
271+
self.schedule_ping()
272+
273+
def stop_keepalive(self) -> None:
274+
if self.ping_timer is not None:
275+
self.ping_timer.cancel()
276+
self.ping_timer = None
277+
if self.pong_timer is not None: # pragma: no cover
278+
self.pong_timer.cancel()
279+
self.pong_timer = None
280+
self.pending_ping_payload = None
281+
282+
def schedule_ping(self) -> None:
283+
assert self.ping_interval is not None
284+
delay = max(0.0, self.ping_interval - self.last_ping_rtt)
285+
self.ping_timer = self.loop.call_later(delay, self.send_keepalive_ping)
286+
287+
def send_keepalive_ping(self) -> None:
288+
self.ping_timer = None
289+
if self.close_sent or self.transport.is_closing(): # pragma: no cover
290+
return
291+
# Random 4-byte payload identifies this ping; `handle_pong` uses it to ignore stale or unsolicited pongs.
292+
self.pending_ping_payload = struct.pack("!I", random.getrandbits(32))
293+
self.ping_sent_at = self.loop.time()
294+
self.transport.write(self.conn.send(wsproto.events.Ping(payload=self.pending_ping_payload)))
295+
if self.ping_timeout is not None:
296+
self.pong_timer = self.loop.call_later(self.ping_timeout, self.keepalive_timeout)
297+
else: # pragma: no cover
298+
self.schedule_ping()
299+
300+
def keepalive_timeout(self) -> None:
301+
self.pong_timer = None
302+
self.pending_ping_payload = None
303+
if self.close_sent or self.transport.is_closing(): # pragma: no cover
304+
return
305+
if self.logger.level <= TRACE_LOG_LEVEL:
306+
prefix = "%s:%d - " % self.client if self.client else ""
307+
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket keepalive ping timeout", prefix)
308+
reason = "keepalive ping timeout"
309+
self.transport.write(self.conn.send(wsproto.events.CloseConnection(code=1011, reason=reason)))
310+
self.close_sent = True
311+
self.transport.close()
312+
238313
def send_500_response(self) -> None:
239314
if self.response_started or self.handshake_complete:
240315
return # we cannot send responses anymore
@@ -288,6 +363,7 @@ async def send(self, message: ASGISendEvent) -> None:
288363
)
289364
)
290365
self.transport.write(output)
366+
self.start_keepalive()
291367

292368
elif message["type"] == "websocket.close":
293369
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})

0 commit comments

Comments
 (0)