Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/zeroconf/_listener.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ cdef class AsyncListener:
cdef public object sock_description
cdef public cython.dict _deferred
cdef public cython.dict _timers
cdef public cython.dict _deferred_deadlines

@cython.locals(now=double, debug=cython.bint)
cpdef datagram_received(self, cython.bytes bytes, cython.tuple addrs)
Expand All @@ -38,7 +39,7 @@ cdef class AsyncListener:

cdef _cancel_any_timers_for_addr(self, object addr)

@cython.locals(incoming=DNSIncoming, deferred=list)
@cython.locals(incoming=DNSIncoming, deferred=list, now=double, delay=double, deadline=object, fire_at=double)
cpdef handle_query_or_defer(
self,
DNSIncoming msg,
Expand Down
20 changes: 18 additions & 2 deletions src/zeroconf/_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class AsyncListener:

__slots__ = (
"_deferred",
"_deferred_deadlines",
"_query_handler",
"_record_manager",
"_registry",
Expand All @@ -82,6 +83,7 @@ def __init__(self, zc: Zeroconf) -> None:
self.sock_description: str | None = None
self._deferred: dict[str, list[DNSIncoming]] = {}
self._timers: dict[str, asyncio.TimerHandle] = {}
self._deferred_deadlines: dict[str, float] = {}
super().__init__()

def datagram_received(self, data: _bytes, addrs: tuple[str, int] | tuple[str, int, int, int]) -> None:
Expand Down Expand Up @@ -203,12 +205,25 @@ def handle_query_or_defer(
if incoming.data == msg.data:
return
deferred.append(msg)
delay = millis_to_seconds(random.randint(*_TC_DELAY_RANDOM_INTERVAL)) # noqa: S311
loop = self.zc.loop
assert loop is not None
now = loop.time()
delay = millis_to_seconds(random.randint(*_TC_DELAY_RANDOM_INTERVAL)) # noqa: S311
# Bound the assembly window to first_arrival + max delay so a peer
# streaming TC packets cannot keep deferring the flush indefinitely.
deadline = self._deferred_deadlines.get(addr)
if deadline is None:
deadline = now + millis_to_seconds(_TC_DELAY_RANDOM_INTERVAL[1])
self._deferred_deadlines[addr] = deadline
fire_at = now + delay
if fire_at >= deadline:
# Existing timer (if any) already fires at or before the deadline.
if addr in self._timers:
return
fire_at = deadline
self._cancel_any_timers_for_addr(addr)
self._timers[addr] = loop.call_at(
loop.time() + delay,
fire_at,
self._respond_query,
None,
addr,
Expand All @@ -232,6 +247,7 @@ def _respond_query(
) -> None:
"""Respond to a query and reassemble any truncated deferred packets."""
self._cancel_any_timers_for_addr(addr)
self._deferred_deadlines.pop(addr, None)
packets = self._deferred.pop(addr, [])
if msg:
packets.append(msg)
Expand Down
45 changes: 45 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,51 @@ def test_tc_bit_defers_last_response_missing():
zc.close()


def test_tc_bit_defer_window_is_bounded():
"""TC-deferral assembly window must not slide past first_arrival + max delay."""
zc = Zeroconf(interfaces=["127.0.0.1"])
_wait_for_start(zc)
type_ = "_boundeddefer._tcp.local."
registration_name = f"knownname.{type_}"

info = r.ServiceInfo(
type_,
registration_name,
80,
0,
0,
{"path": "/~paulsm/"},
"ash-2.local.",
addresses=[socket.inet_aton("10.0.1.2")],
)
zc.registry.async_add(info)

protocol = zc.engine.protocols[0]
now_ms = r.current_time_millis()
_clear_cache(zc)
source_ip = "203.0.113.99"

generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
generated.add_question(r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN))
for _ in range(300):
generated.add_answer_at_time(info.dns_pointer(), now_ms)
packets = generated.packets()
assert len(packets) >= 3

# Pin the per-packet delay at its maximum so any subsequent reset would
# land past the deadline established by the first packet.
with patch("zeroconf._listener.random.randint", return_value=500):
threadsafe_query(zc, protocol, r.DNSIncoming(packets[0]), source_ip, const._MDNS_PORT, Mock(), ())
first_when = protocol._timers[source_ip].when()

for raw in packets[1:-1]:
threadsafe_query(zc, protocol, r.DNSIncoming(raw), source_ip, const._MDNS_PORT, Mock(), ())
assert protocol._timers[source_ip].when() <= first_when

zc.registry.async_remove(info)
zc.close()


@pytest.mark.asyncio
async def test_open_close_twice_from_async() -> None:
"""Test we can close twice from a coroutine when using Zeroconf.
Expand Down
Loading