Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions tests/test_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,12 +974,14 @@ async def test_legacy_unicast_response(run_isolated):
query = DNSOutgoing(const._FLAGS_QR_QUERY, multicast=False, id_=888)
question = DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)
query.add_question(question)
protocol = aiozc.zeroconf.engine.protocols[0]

with patch.object(aiozc.zeroconf, "async_send") as send_mock:
aiozc.zeroconf.engine.protocols[0].datagram_received(query.packets()[0], ('127.0.0.1', 6503))
protocol.datagram_received(query.packets()[0], ('127.0.0.1', 6503))

calls = send_mock.mock_calls
assert calls == [call(ANY, '127.0.0.1', 6503, ())]
# Verify the response is sent back on the socket it was recieved from
assert calls == [call(ANY, '127.0.0.1', 6503, (), protocol.transport)]
outgoing = send_mock.call_args[0][0]
assert isinstance(outgoing, DNSOutgoing)
assert outgoing.questions == [question]
Expand Down
18 changes: 9 additions & 9 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,28 +480,28 @@ def test_tc_bit_defers():

next_packet = r.DNSIncoming(packets.pop(0))
expected_deferred.append(next_packet)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
assert protocol._deferred[source_ip] == expected_deferred
assert source_ip in protocol._timers

next_packet = r.DNSIncoming(packets.pop(0))
expected_deferred.append(next_packet)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
assert protocol._deferred[source_ip] == expected_deferred
assert source_ip in protocol._timers
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
assert protocol._deferred[source_ip] == expected_deferred
assert source_ip in protocol._timers

next_packet = r.DNSIncoming(packets.pop(0))
expected_deferred.append(next_packet)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
assert protocol._deferred[source_ip] == expected_deferred
assert source_ip in protocol._timers

next_packet = r.DNSIncoming(packets.pop(0))
expected_deferred.append(next_packet)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
assert source_ip not in protocol._deferred
assert source_ip not in protocol._timers

Expand Down Expand Up @@ -559,21 +559,21 @@ def test_tc_bit_defers_last_response_missing():

next_packet = r.DNSIncoming(packets.pop(0))
expected_deferred.append(next_packet)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
assert protocol._deferred[source_ip] == expected_deferred
timer1 = protocol._timers[source_ip]

next_packet = r.DNSIncoming(packets.pop(0))
expected_deferred.append(next_packet)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
assert protocol._deferred[source_ip] == expected_deferred
timer2 = protocol._timers[source_ip]
if sys.version_info >= (3, 7):
assert timer1.cancelled()
assert timer2 != timer1

# Send the same packet again to similar multi interfaces
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
assert protocol._deferred[source_ip] == expected_deferred
assert source_ip in protocol._timers
timer3 = protocol._timers[source_ip]
Expand All @@ -583,7 +583,7 @@ def test_tc_bit_defers_last_response_missing():

next_packet = r.DNSIncoming(packets.pop(0))
expected_deferred.append(next_packet)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
assert protocol._deferred[source_ip] == expected_deferred
assert source_ip in protocol._timers
timer4 = protocol._timers[source_ip]
Expand Down
91 changes: 61 additions & 30 deletions zeroconf/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,8 @@ def __init__(self, zc: 'Zeroconf') -> None:
self.data: Optional[bytes] = None
self.last_time: float = 0
self.transport: Optional[asyncio.DatagramTransport] = None

self.sock_name: Optional[str] = None
self.sock_fileno: Optional[int] = None
self._deferred: Dict[str, List[DNSIncoming]] = {}
self._timers: Dict[str, asyncio.TimerHandle] = {}

Expand Down Expand Up @@ -294,15 +295,20 @@ def datagram_received(
self.zc.handle_response(msg)
return

self.handle_query_or_defer(msg, addr, port, v6_flow_scope)
self.handle_query_or_defer(msg, addr, port, self.transport, v6_flow_scope)

def handle_query_or_defer(
self, msg: DNSIncoming, addr: str, port: int, v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = ()
self,
msg: DNSIncoming,
addr: str,
port: int,
transport: asyncio.DatagramTransport,
v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (),
) -> None:
"""Deal with incoming query packets. Provides a response if
possible."""
if not msg.truncated:
self._respond_query(msg, addr, port, v6_flow_scope)
self._respond_query(msg, addr, port, transport, v6_flow_scope)
return

deferred = self._deferred.setdefault(addr, [])
Expand All @@ -315,7 +321,7 @@ def handle_query_or_defer(
assert self.zc.loop is not None
self._cancel_any_timers_for_addr(addr)
self._timers[addr] = self.zc.loop.call_later(
delay, self._respond_query, None, addr, port, v6_flow_scope
delay, self._respond_query, None, addr, port, transport, v6_flow_scope
)

def _cancel_any_timers_for_addr(self, addr: str) -> None:
Expand All @@ -328,6 +334,7 @@ def _respond_query(
msg: Optional[DNSIncoming],
addr: str,
port: int,
transport: asyncio.DatagramTransport,
v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (),
) -> None:
"""Respond to a query and reassemble any truncated deferred packets."""
Expand All @@ -336,15 +343,12 @@ def _respond_query(
if msg:
packets.append(msg)

self.zc.handle_assembled_query(packets, addr, port, v6_flow_scope)
self.zc.handle_assembled_query(packets, addr, port, transport, v6_flow_scope)

@property
def _socket_description(self) -> str:
"""A human readable description of the socket."""
assert self.transport is not None
fileno = self.transport.get_extra_info('socket').fileno()
sockname = self.transport.get_extra_info('sockname')
return f"{fileno} ({sockname})"
return f"{self.sock_fileno} ({self.sock_name})"

def error_received(self, exc: Exception) -> None:
"""Likely socket closed or IPv6."""
Expand All @@ -357,6 +361,8 @@ def error_received(self, exc: Exception) -> None:

def connection_made(self, transport: asyncio.BaseTransport) -> None:
self.transport = cast(asyncio.DatagramTransport, transport)
self.sock_name = self.transport.get_extra_info('sockname')
self.sock_fileno = self.transport.get_extra_info('socket').fileno()

def connection_lost(self, exc: Optional[Exception]) -> None:
"""Handle connection lost."""
Expand Down Expand Up @@ -400,6 +406,7 @@ def __init__(
if apple_p2p and sys.platform != 'darwin':
raise RuntimeError('Option `apple_p2p` is not supported on non-Apple platforms.')

self.unicast = unicast
listen_socket, respond_sockets = create_sockets(interfaces, unicast, ip_version, apple_p2p=apple_p2p)
log.debug('Listen socket %s, respond sockets %s', listen_socket, respond_sockets)

Expand Down Expand Up @@ -732,6 +739,7 @@ def handle_assembled_query(
packets: List[DNSIncoming],
addr: str,
port: int,
transport: asyncio.DatagramTransport,
v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (),
) -> None:
"""Respond to a (re)assembled query.
Expand All @@ -749,7 +757,10 @@ def handle_assembled_query(
questions = packets[0].questions
id_ = packets[0].id
out = construct_outgoing_unicast_answers(question_answers.ucast, ucast_source, questions, id_)
self.async_send(out, addr, port, v6_flow_scope)
# When sending unicast, only send back the reply
# via the same socket that it was recieved from
# as we know its reachable from that socket
self.async_send(out, addr, port, v6_flow_scope, transport)
if question_answers.mcast_now:
self.async_send(construct_outgoing_multicast_answers(question_answers.mcast_now))
if question_answers.mcast_aggregate:
Expand All @@ -766,44 +777,64 @@ def send(
addr: Optional[str] = None,
port: int = _MDNS_PORT,
v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (),
transport: Optional[asyncio.DatagramTransport] = None,
) -> None:
"""Sends an outgoing packet threadsafe."""
assert self.loop is not None
self.loop.call_soon_threadsafe(self.async_send, out, addr, port, v6_flow_scope)
self.loop.call_soon_threadsafe(self.async_send, out, addr, port, v6_flow_scope, transport)

def async_send(
self,
out: DNSOutgoing,
addr: Optional[str] = None,
port: int = _MDNS_PORT,
v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (),
transport: Optional[asyncio.DatagramTransport] = None,
) -> None:
"""Sends an outgoing packet."""
if self._GLOBAL_DONE:
return

# If no transport is specified, we send to all the ones
# with the same address family
transports = [transport] if transport else self.engine.senders

for packet_num, packet in enumerate(out.packets()):
if len(packet) > _MAX_MSG_ABSOLUTE:
self.log_warning_once("Dropping %r over-sized packet (%d bytes) %r", out, len(packet), packet)
return
log.debug(
'Sending to (%s, %d) (%d bytes #%d) %r as %r...',
addr,
port,
len(packet),
packet_num + 1,
out,
packet,
)
for transport in self.engine.senders:
s = transport.get_extra_info('socket')
if addr is None:
real_addr = _MDNS_ADDR6 if s.family == socket.AF_INET6 else _MDNS_ADDR
elif not can_send_to(s, addr):
continue
else:
real_addr = addr
transport.sendto(packet, (real_addr, port or _MDNS_PORT, *v6_flow_scope))
for send_transport in transports:
self._async_send_transport(send_transport, packet, packet_num, out, addr, port, v6_flow_scope)

def _async_send_transport(
self,
transport: asyncio.DatagramTransport,
packet: bytes,
packet_num: int,
out: DNSOutgoing,
addr: Optional[str],
port: int,
v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (),
) -> None:
s = transport.get_extra_info('socket')
if addr is None:
real_addr = _MDNS_ADDR6 if s.family == socket.AF_INET6 else _MDNS_ADDR
else:
real_addr = addr
if not can_send_to(s, real_addr):
return
log.debug(
'Sending to (%s, %d) via [socket %s (%s)] (%d bytes #%d) %r as %r...',
real_addr,
port or _MDNS_PORT,
s.fileno(),
transport.get_extra_info('sockname'),
len(packet),
packet_num + 1,
out,
packet,
)
transport.sendto(packet, (real_addr, port or _MDNS_PORT, *v6_flow_scope))

def _close(self) -> None:
"""Set global done and remove all service listeners."""
Expand Down