diff --git a/src/zeroconf/_listener.pxd b/src/zeroconf/_listener.pxd index 4cbc5d00..f07f6b12 100644 --- a/src/zeroconf/_listener.pxd +++ b/src/zeroconf/_listener.pxd @@ -29,6 +29,7 @@ cdef class AsyncListener: cdef public object sock_description cdef public cython.dict _deferred cdef public cython.dict _timers + cdef public cython.dict _deferred_deadlines @cython.locals(now=double, debug=cython.bint) cpdef datagram_received(self, cython.bytes bytes, cython.tuple addrs) @@ -38,7 +39,7 @@ cdef class AsyncListener: cdef _cancel_any_timers_for_addr(self, object addr) - @cython.locals(incoming=DNSIncoming, deferred=list) + @cython.locals(incoming=DNSIncoming, deferred=list, now=double, delay=double, deadline=object, fire_at=double) cpdef handle_query_or_defer( self, DNSIncoming msg, diff --git a/src/zeroconf/_listener.py b/src/zeroconf/_listener.py index ed503169..a3d01368 100644 --- a/src/zeroconf/_listener.py +++ b/src/zeroconf/_listener.py @@ -58,6 +58,7 @@ class AsyncListener: __slots__ = ( "_deferred", + "_deferred_deadlines", "_query_handler", "_record_manager", "_registry", @@ -82,6 +83,7 @@ def __init__(self, zc: Zeroconf) -> None: self.sock_description: str | None = None self._deferred: dict[str, list[DNSIncoming]] = {} self._timers: dict[str, asyncio.TimerHandle] = {} + self._deferred_deadlines: dict[str, float] = {} super().__init__() def datagram_received(self, data: _bytes, addrs: tuple[str, int] | tuple[str, int, int, int]) -> None: @@ -203,12 +205,25 @@ def handle_query_or_defer( if incoming.data == msg.data: return deferred.append(msg) - delay = millis_to_seconds(random.randint(*_TC_DELAY_RANDOM_INTERVAL)) # noqa: S311 loop = self.zc.loop assert loop is not None + now = loop.time() + delay = millis_to_seconds(random.randint(*_TC_DELAY_RANDOM_INTERVAL)) # noqa: S311 + # Bound the assembly window to first_arrival + max delay so a peer + # streaming TC packets cannot keep deferring the flush indefinitely. + deadline = self._deferred_deadlines.get(addr) + if deadline is None: + deadline = now + millis_to_seconds(_TC_DELAY_RANDOM_INTERVAL[1]) + self._deferred_deadlines[addr] = deadline + fire_at = now + delay + if fire_at >= deadline: + # Existing timer (if any) already fires at or before the deadline. + if addr in self._timers: + return + fire_at = deadline self._cancel_any_timers_for_addr(addr) self._timers[addr] = loop.call_at( - loop.time() + delay, + fire_at, self._respond_query, None, addr, @@ -232,6 +247,7 @@ def _respond_query( ) -> None: """Respond to a query and reassemble any truncated deferred packets.""" self._cancel_any_timers_for_addr(addr) + self._deferred_deadlines.pop(addr, None) packets = self._deferred.pop(addr, []) if msg: packets.append(msg) diff --git a/tests/test_core.py b/tests/test_core.py index 16f765d4..bef765f0 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -743,6 +743,51 @@ def test_tc_bit_defers_last_response_missing(): zc.close() +def test_tc_bit_defer_window_is_bounded(): + """TC-deferral assembly window must not slide past first_arrival + max delay.""" + zc = Zeroconf(interfaces=["127.0.0.1"]) + _wait_for_start(zc) + type_ = "_boundeddefer._tcp.local." + registration_name = f"knownname.{type_}" + + info = r.ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + {"path": "/~paulsm/"}, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + zc.registry.async_add(info) + + protocol = zc.engine.protocols[0] + now_ms = r.current_time_millis() + _clear_cache(zc) + source_ip = "203.0.113.99" + + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + generated.add_question(r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN)) + for _ in range(300): + generated.add_answer_at_time(info.dns_pointer(), now_ms) + packets = generated.packets() + assert len(packets) >= 3 + + # Pin the per-packet delay at its maximum so any subsequent reset would + # land past the deadline established by the first packet. + with patch("zeroconf._listener.random.randint", return_value=500): + threadsafe_query(zc, protocol, r.DNSIncoming(packets[0]), source_ip, const._MDNS_PORT, Mock(), ()) + first_when = protocol._timers[source_ip].when() + + for raw in packets[1:-1]: + threadsafe_query(zc, protocol, r.DNSIncoming(raw), source_ip, const._MDNS_PORT, Mock(), ()) + assert protocol._timers[source_ip].when() <= first_when + + zc.registry.async_remove(info) + zc.close() + + @pytest.mark.asyncio async def test_open_close_twice_from_async() -> None: """Test we can close twice from a coroutine when using Zeroconf.