Skip to content

Commit 5fb3e20

Browse files
authored
Send unicast replies on the same socket the query was received (#952)
When replying to a QU question, we do not know if the sending host is reachable from all of the sending sockets. We now avoid this problem by replying via the receiving socket. This was the existing behavior when `InterfaceChoice.Default` is set. This change extends the unicast relay behavior to used with `InterfaceChoice.Default` to apply when `InterfaceChoice.All` or interfaces are explicitly passed when instantiating a `Zeroconf` instance. Fixes #951
1 parent ebc23ee commit 5fb3e20

3 files changed

Lines changed: 74 additions & 41 deletions

File tree

tests/test_asyncio.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -974,12 +974,14 @@ async def test_legacy_unicast_response(run_isolated):
974974
query = DNSOutgoing(const._FLAGS_QR_QUERY, multicast=False, id_=888)
975975
question = DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)
976976
query.add_question(question)
977+
protocol = aiozc.zeroconf.engine.protocols[0]
977978

978979
with patch.object(aiozc.zeroconf, "async_send") as send_mock:
979-
aiozc.zeroconf.engine.protocols[0].datagram_received(query.packets()[0], ('127.0.0.1', 6503))
980+
protocol.datagram_received(query.packets()[0], ('127.0.0.1', 6503))
980981

981982
calls = send_mock.mock_calls
982-
assert calls == [call(ANY, '127.0.0.1', 6503, ())]
983+
# Verify the response is sent back on the socket it was recieved from
984+
assert calls == [call(ANY, '127.0.0.1', 6503, (), protocol.transport)]
983985
outgoing = send_mock.call_args[0][0]
984986
assert isinstance(outgoing, DNSOutgoing)
985987
assert outgoing.questions == [question]

tests/test_core.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -480,28 +480,28 @@ def test_tc_bit_defers():
480480

481481
next_packet = r.DNSIncoming(packets.pop(0))
482482
expected_deferred.append(next_packet)
483-
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT)
483+
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
484484
assert protocol._deferred[source_ip] == expected_deferred
485485
assert source_ip in protocol._timers
486486

487487
next_packet = r.DNSIncoming(packets.pop(0))
488488
expected_deferred.append(next_packet)
489-
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT)
489+
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
490490
assert protocol._deferred[source_ip] == expected_deferred
491491
assert source_ip in protocol._timers
492-
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT)
492+
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
493493
assert protocol._deferred[source_ip] == expected_deferred
494494
assert source_ip in protocol._timers
495495

496496
next_packet = r.DNSIncoming(packets.pop(0))
497497
expected_deferred.append(next_packet)
498-
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT)
498+
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
499499
assert protocol._deferred[source_ip] == expected_deferred
500500
assert source_ip in protocol._timers
501501

502502
next_packet = r.DNSIncoming(packets.pop(0))
503503
expected_deferred.append(next_packet)
504-
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT)
504+
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
505505
assert source_ip not in protocol._deferred
506506
assert source_ip not in protocol._timers
507507

@@ -559,21 +559,21 @@ def test_tc_bit_defers_last_response_missing():
559559

560560
next_packet = r.DNSIncoming(packets.pop(0))
561561
expected_deferred.append(next_packet)
562-
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT)
562+
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
563563
assert protocol._deferred[source_ip] == expected_deferred
564564
timer1 = protocol._timers[source_ip]
565565

566566
next_packet = r.DNSIncoming(packets.pop(0))
567567
expected_deferred.append(next_packet)
568-
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT)
568+
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
569569
assert protocol._deferred[source_ip] == expected_deferred
570570
timer2 = protocol._timers[source_ip]
571571
if sys.version_info >= (3, 7):
572572
assert timer1.cancelled()
573573
assert timer2 != timer1
574574

575575
# Send the same packet again to similar multi interfaces
576-
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT)
576+
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
577577
assert protocol._deferred[source_ip] == expected_deferred
578578
assert source_ip in protocol._timers
579579
timer3 = protocol._timers[source_ip]
@@ -583,7 +583,7 @@ def test_tc_bit_defers_last_response_missing():
583583

584584
next_packet = r.DNSIncoming(packets.pop(0))
585585
expected_deferred.append(next_packet)
586-
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT)
586+
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
587587
assert protocol._deferred[source_ip] == expected_deferred
588588
assert source_ip in protocol._timers
589589
timer4 = protocol._timers[source_ip]

zeroconf/_core.py

Lines changed: 61 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,8 @@ def __init__(self, zc: 'Zeroconf') -> None:
215215
self.data: Optional[bytes] = None
216216
self.last_time: float = 0
217217
self.transport: Optional[asyncio.DatagramTransport] = None
218-
218+
self.sock_name: Optional[str] = None
219+
self.sock_fileno: Optional[int] = None
219220
self._deferred: Dict[str, List[DNSIncoming]] = {}
220221
self._timers: Dict[str, asyncio.TimerHandle] = {}
221222

@@ -294,15 +295,20 @@ def datagram_received(
294295
self.zc.handle_response(msg)
295296
return
296297

297-
self.handle_query_or_defer(msg, addr, port, v6_flow_scope)
298+
self.handle_query_or_defer(msg, addr, port, self.transport, v6_flow_scope)
298299

299300
def handle_query_or_defer(
300-
self, msg: DNSIncoming, addr: str, port: int, v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = ()
301+
self,
302+
msg: DNSIncoming,
303+
addr: str,
304+
port: int,
305+
transport: asyncio.DatagramTransport,
306+
v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (),
301307
) -> None:
302308
"""Deal with incoming query packets. Provides a response if
303309
possible."""
304310
if not msg.truncated:
305-
self._respond_query(msg, addr, port, v6_flow_scope)
311+
self._respond_query(msg, addr, port, transport, v6_flow_scope)
306312
return
307313

308314
deferred = self._deferred.setdefault(addr, [])
@@ -315,7 +321,7 @@ def handle_query_or_defer(
315321
assert self.zc.loop is not None
316322
self._cancel_any_timers_for_addr(addr)
317323
self._timers[addr] = self.zc.loop.call_later(
318-
delay, self._respond_query, None, addr, port, v6_flow_scope
324+
delay, self._respond_query, None, addr, port, transport, v6_flow_scope
319325
)
320326

321327
def _cancel_any_timers_for_addr(self, addr: str) -> None:
@@ -328,6 +334,7 @@ def _respond_query(
328334
msg: Optional[DNSIncoming],
329335
addr: str,
330336
port: int,
337+
transport: asyncio.DatagramTransport,
331338
v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (),
332339
) -> None:
333340
"""Respond to a query and reassemble any truncated deferred packets."""
@@ -336,15 +343,12 @@ def _respond_query(
336343
if msg:
337344
packets.append(msg)
338345

339-
self.zc.handle_assembled_query(packets, addr, port, v6_flow_scope)
346+
self.zc.handle_assembled_query(packets, addr, port, transport, v6_flow_scope)
340347

341348
@property
342349
def _socket_description(self) -> str:
343350
"""A human readable description of the socket."""
344-
assert self.transport is not None
345-
fileno = self.transport.get_extra_info('socket').fileno()
346-
sockname = self.transport.get_extra_info('sockname')
347-
return f"{fileno} ({sockname})"
351+
return f"{self.sock_fileno} ({self.sock_name})"
348352

349353
def error_received(self, exc: Exception) -> None:
350354
"""Likely socket closed or IPv6."""
@@ -357,6 +361,8 @@ def error_received(self, exc: Exception) -> None:
357361

358362
def connection_made(self, transport: asyncio.BaseTransport) -> None:
359363
self.transport = cast(asyncio.DatagramTransport, transport)
364+
self.sock_name = self.transport.get_extra_info('sockname')
365+
self.sock_fileno = self.transport.get_extra_info('socket').fileno()
360366

361367
def connection_lost(self, exc: Optional[Exception]) -> None:
362368
"""Handle connection lost."""
@@ -400,6 +406,7 @@ def __init__(
400406
if apple_p2p and sys.platform != 'darwin':
401407
raise RuntimeError('Option `apple_p2p` is not supported on non-Apple platforms.')
402408

409+
self.unicast = unicast
403410
listen_socket, respond_sockets = create_sockets(interfaces, unicast, ip_version, apple_p2p=apple_p2p)
404411
log.debug('Listen socket %s, respond sockets %s', listen_socket, respond_sockets)
405412

@@ -732,6 +739,7 @@ def handle_assembled_query(
732739
packets: List[DNSIncoming],
733740
addr: str,
734741
port: int,
742+
transport: asyncio.DatagramTransport,
735743
v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (),
736744
) -> None:
737745
"""Respond to a (re)assembled query.
@@ -749,7 +757,10 @@ def handle_assembled_query(
749757
questions = packets[0].questions
750758
id_ = packets[0].id
751759
out = construct_outgoing_unicast_answers(question_answers.ucast, ucast_source, questions, id_)
752-
self.async_send(out, addr, port, v6_flow_scope)
760+
# When sending unicast, only send back the reply
761+
# via the same socket that it was recieved from
762+
# as we know its reachable from that socket
763+
self.async_send(out, addr, port, v6_flow_scope, transport)
753764
if question_answers.mcast_now:
754765
self.async_send(construct_outgoing_multicast_answers(question_answers.mcast_now))
755766
if question_answers.mcast_aggregate:
@@ -766,44 +777,64 @@ def send(
766777
addr: Optional[str] = None,
767778
port: int = _MDNS_PORT,
768779
v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (),
780+
transport: Optional[asyncio.DatagramTransport] = None,
769781
) -> None:
770782
"""Sends an outgoing packet threadsafe."""
771783
assert self.loop is not None
772-
self.loop.call_soon_threadsafe(self.async_send, out, addr, port, v6_flow_scope)
784+
self.loop.call_soon_threadsafe(self.async_send, out, addr, port, v6_flow_scope, transport)
773785

774786
def async_send(
775787
self,
776788
out: DNSOutgoing,
777789
addr: Optional[str] = None,
778790
port: int = _MDNS_PORT,
779791
v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (),
792+
transport: Optional[asyncio.DatagramTransport] = None,
780793
) -> None:
781794
"""Sends an outgoing packet."""
782795
if self._GLOBAL_DONE:
783796
return
784797

798+
# If no transport is specified, we send to all the ones
799+
# with the same address family
800+
transports = [transport] if transport else self.engine.senders
801+
785802
for packet_num, packet in enumerate(out.packets()):
786803
if len(packet) > _MAX_MSG_ABSOLUTE:
787804
self.log_warning_once("Dropping %r over-sized packet (%d bytes) %r", out, len(packet), packet)
788805
return
789-
log.debug(
790-
'Sending to (%s, %d) (%d bytes #%d) %r as %r...',
791-
addr,
792-
port,
793-
len(packet),
794-
packet_num + 1,
795-
out,
796-
packet,
797-
)
798-
for transport in self.engine.senders:
799-
s = transport.get_extra_info('socket')
800-
if addr is None:
801-
real_addr = _MDNS_ADDR6 if s.family == socket.AF_INET6 else _MDNS_ADDR
802-
elif not can_send_to(s, addr):
803-
continue
804-
else:
805-
real_addr = addr
806-
transport.sendto(packet, (real_addr, port or _MDNS_PORT, *v6_flow_scope))
806+
for send_transport in transports:
807+
self._async_send_transport(send_transport, packet, packet_num, out, addr, port, v6_flow_scope)
808+
809+
def _async_send_transport(
810+
self,
811+
transport: asyncio.DatagramTransport,
812+
packet: bytes,
813+
packet_num: int,
814+
out: DNSOutgoing,
815+
addr: Optional[str],
816+
port: int,
817+
v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (),
818+
) -> None:
819+
s = transport.get_extra_info('socket')
820+
if addr is None:
821+
real_addr = _MDNS_ADDR6 if s.family == socket.AF_INET6 else _MDNS_ADDR
822+
else:
823+
real_addr = addr
824+
if not can_send_to(s, real_addr):
825+
return
826+
log.debug(
827+
'Sending to (%s, %d) via [socket %s (%s)] (%d bytes #%d) %r as %r...',
828+
real_addr,
829+
port or _MDNS_PORT,
830+
s.fileno(),
831+
transport.get_extra_info('sockname'),
832+
len(packet),
833+
packet_num + 1,
834+
out,
835+
packet,
836+
)
837+
transport.sendto(packet, (real_addr, port or _MDNS_PORT, *v6_flow_scope))
807838

808839
def _close(self) -> None:
809840
"""Set global done and remove all service listeners."""

0 commit comments

Comments
 (0)