From 632e2023d76eece578a0de58f3deb8b7cf8ddf96 Mon Sep 17 00:00:00 2001 From: Steven B <51370195+sdb9696@users.noreply.github.com> Date: Fri, 4 Oct 2024 11:18:30 +0100 Subject: [PATCH] Enable keep alive for webrtc stream --- ring_doorbell/doorbot.py | 29 +++++-- .../{rtcstream.py => webrtcstream.py} | 80 ++++++++++++++----- 2 files changed, 79 insertions(+), 30 deletions(-) rename ring_doorbell/{rtcstream.py => webrtcstream.py} (83%) diff --git a/ring_doorbell/doorbot.py b/ring_doorbell/doorbot.py index e275bde..ce1d572 100644 --- a/ring_doorbell/doorbot.py +++ b/ring_doorbell/doorbot.py @@ -49,7 +49,7 @@ ) from ring_doorbell.exceptions import RingError from ring_doorbell.generic import RingGeneric -from ring_doorbell.rtcstream import RingWebRtcStream +from ring_doorbell.webrtcstream import RingWebRtcStream _LOGGER = logging.getLogger(__name__) @@ -64,7 +64,7 @@ def __init__(self, ring: Ring, device_api_id: int, *, shared: bool = False) -> N """Initialise the doorbell.""" super().__init__(ring, device_api_id) self.shared = shared - self._rtc_streams: dict[str, RingWebRtcStream] = {} + self._webrtc_streams: dict[str, RingWebRtcStream] = {} @property def family(self) -> str: @@ -452,22 +452,35 @@ async def async_set_motion_detection(self, state: bool) -> None: # noqa: FBT001 await self._ring.async_query(url, method="PATCH", json=payload) await self._ring.async_update_devices() - async def generate_rtc_stream(self, sdp_offer: str) -> str: + async def generate_webrtc_stream( + self, sdp_offer: str, keep_alive_timeout: int | None = 30 + ) -> str: """Generate the rtc stream.""" if session_id := RingWebRtcStream.get_sdp_session_id(sdp_offer): - stream = RingWebRtcStream(self._ring, self.device_api_id) + stream = RingWebRtcStream( + self._ring, + self.device_api_id, + keep_alive_timeout=keep_alive_timeout, + on_close_callback=self.close_webrtc_stream(session_id), + ) sdp_answer = await stream.generate(sdp_offer) - self._rtc_streams[session_id] = stream + self._webrtc_streams[session_id] = stream return sdp_answer - msg = "Unable to generate the stream" + msg = "Unable to generate the stream, could not extract session id from offer." raise RingError(msg) - async def close_rtc_stream(self, sdp_session_id: str) -> None: + async def close_webrtc_stream(self, sdp_session_id: str) -> None: """Close the rtc stream.""" - stream = self._rtc_streams.pop(sdp_session_id, None) + stream = self._webrtc_streams.pop(sdp_session_id, None) if stream: await stream.close() + async def keep_alive_webrtc_stream(self, sdp_session_id: str) -> None: + """Keep alive the rtc stream.""" + stream = self._webrtc_streams.get(sdp_session_id, None) + if stream: + await stream.keep_alive() + def get_ice_servers(self) -> list[str]: """Return the ICE servers.""" return ICE_SERVERS diff --git a/ring_doorbell/rtcstream.py b/ring_doorbell/webrtcstream.py similarity index 83% rename from ring_doorbell/rtcstream.py rename to ring_doorbell/webrtcstream.py index 7f94607..49c0f2a 100644 --- a/ring_doorbell/rtcstream.py +++ b/ring_doorbell/webrtcstream.py @@ -10,6 +10,7 @@ import contextlib import logging import ssl +import time import uuid from json import dumps as json_dumps from json import loads as json_loads @@ -26,6 +27,8 @@ from ring_doorbell.exceptions import RingError if TYPE_CHECKING: + from collections.abc import Coroutine + from websockets import WebSocketClientProtocol from .ring import Ring @@ -38,7 +41,16 @@ class RingWebRtcStream: """Class to handle a Web RTC Stream.""" - def __init__(self, ring: Ring, device_api_id: int) -> None: + PING_TIME_SECONDS = 5 + + def __init__( + self, + ring: Ring, + device_api_id: int, + *, + keep_alive_timeout: int | None = 30, + on_close_callback: Coroutine | None = None, + ) -> None: """Initialise the class.""" self._ring = ring self.device_api_id = device_api_id @@ -51,6 +63,9 @@ def __init__(self, ring: Ring, device_api_id: int) -> None: self.collect_ice_candidates = False self.ssl_context: ssl.SSLContext | None = None self._sdp_answer_event = asyncio.Event() + self._keep_alive_timeout = keep_alive_timeout + self._last_keep_alive: float | None = None + self._on_close_callback: Coroutine | None = on_close_callback @staticmethod def get_sdp_session_id(sdp_offer: str) -> str | None: @@ -136,6 +151,7 @@ async def generate(self, sdp_offer: str) -> str: await self.websocket.send(json_dumps(options_msg)) _LOGGER.debug("Starting ping and reader tasks") + self._last_keep_alive = time.time() self.ping_task = asyncio.create_task(self.pinger()) self.read_task = asyncio.create_task(self.reader()) @@ -160,21 +176,9 @@ async def generate(self, sdp_offer: str) -> str: _LOGGER.debug("Returning SDP answer: %s", self.sdp) return self.sdp - async def close(self) -> None: - """Close the rtc stream.""" - _LOGGER.debug("Closing the RTC Stream") - self.do_ping = False - if self.ping_task and not self.ping_task.done(): - self.ping_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await self.ping_task - self.ping_task = None - if self.websocket: - await self.websocket.close() - self.websocket = None - if self.read_task and not self.read_task.done(): - await self.read_task - self.read_task = None + async def keep_alive(self) -> None: + """Keep alive the rtc stream.""" + self._last_keep_alive = time.time() def get_session_message(self, method: str, body: dict[str, Any]) -> dict[str, Any]: """Get a message to send to the session.""" @@ -201,8 +205,13 @@ async def pinger(self) -> None: """Ping to keep the session alive.""" if TYPE_CHECKING: assert self.websocket - while self.do_ping: - await asyncio.sleep(3) + assert self._last_keep_alive + + while self.do_ping and ( + self._keep_alive_timeout is None + or (time.time() - self._last_keep_alive) <= self._keep_alive_timeout + ): + await asyncio.sleep(self.PING_TIME_SECONDS) ping = self.get_session_message("ping", {}) await self.websocket.send(json_dumps(ping)) @@ -237,6 +246,31 @@ def insert_ice_candidates(self) -> None: multi_text = f"a=mid:{line_index}" self.sdp = self.sdp.replace(multi_text, candidates_text + multi_text) + async def close(self) -> None: + """Close the rtc stream.""" + _LOGGER.debug("Closing the RTC Stream") + await self._close(closed_by_self=False) + + async def _close(self, *, closed_by_self: bool) -> None: + """Close the stream.""" + if closed_by_self and (close_cb := self._on_close_callback): + self._on_close_callback = None + await close_cb + self.do_ping = False + if ping_task := self.ping_task: + self.ping_task = None + if not ping_task.done(): + ping_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await ping_task + if websocket := self.websocket: + self.websocket = None + await websocket.close() + if read_task := self.read_task: + self.read_task = None + if not read_task.done(): + await read_task + async def handle_message(self, message_str: str) -> None: """Handle a message from the web socket.""" if TYPE_CHECKING: @@ -254,8 +288,8 @@ async def handle_message(self, message_str: str) -> None: self._sdp_answer_event.set() elif method == "notification": text = message["body"]["text"] - _LOGGER.debug("Notification received: %s", text) if text == "camera_connected": + _LOGGER.debug("Notification received: %s", text) camera_options = self.get_session_message( "camera_options", {"stealth_mode": False} ) @@ -269,8 +303,10 @@ async def handle_message(self, message_str: str) -> None: "Session created: %s___%s", self.session_id[:16], self.session_id[-16:] ) elif method == "close": - _LOGGER.debug("Close: %s", str(message["body"]["reason"])) + _LOGGER.debug("Close message received: %s", str(message["body"]["reason"])) self.do_ping = False - await self.websocket.close() + await self._close(closed_by_self=True) + elif method == "pong": + _LOGGER.debug("Pong message received") else: - _LOGGER.debug("Message received with method: %s", method) + _LOGGER.debug("Unknown message received with method: %s", method)