Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 64 additions & 24 deletions src/zeroconf/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import sys
import threading
from types import TracebackType # noqa # used in type hints
from typing import Awaitable, Dict, List, Optional, Tuple, Type, Union, cast
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Type, Union, cast

from ._cache import DNSCache
from ._dns import DNSQuestion, DNSQuestionType
Expand Down Expand Up @@ -105,6 +105,48 @@
_REGISTER_BROADCASTS = 3


class _WrappedTransport:
"""A wrapper for transports."""

__slots__ = (
'transport',
'is_ipv6',
'sock',
'fileno',
'sock_name',
)

def __init__(
self,
transport: asyncio.DatagramTransport,
is_ipv6: bool,
sock: socket.socket,
fileno: int,
sock_name: Any,
) -> None:
"""Initialize the wrapped transport.

These attributes are used when sending packets.
"""
self.transport = transport
self.is_ipv6 = is_ipv6
self.sock = sock
self.fileno = fileno
self.sock_name = sock_name


def _make_wrapped_transport(transport: asyncio.DatagramTransport) -> _WrappedTransport:
"""Make a wrapped transport."""
sock: socket.socket = transport.get_extra_info('socket')
return _WrappedTransport(
transport=transport,
is_ipv6=sock.family == socket.AF_INET6,
sock=sock,
fileno=sock.fileno(),
sock_name=sock.getsockname(),
)


class AsyncEngine:
"""An engine wraps sockets in the event loop."""

Expand All @@ -117,8 +159,8 @@ def __init__(
self.loop: Optional[asyncio.AbstractEventLoop] = None
self.zc = zeroconf
self.protocols: List[AsyncListener] = []
self.readers: List[asyncio.DatagramTransport] = []
self.senders: List[asyncio.DatagramTransport] = []
self.readers: List[_WrappedTransport] = []
self.senders: List[_WrappedTransport] = []
self.running_event: Optional[asyncio.Event] = None
self._listen_socket = listen_socket
self._respond_sockets = respond_sockets
Expand Down Expand Up @@ -158,9 +200,9 @@ async def _async_create_endpoints(self) -> None:
for s in reader_sockets:
transport, protocol = await loop.create_datagram_endpoint(lambda: AsyncListener(self.zc), sock=s)
self.protocols.append(cast(AsyncListener, protocol))
self.readers.append(cast(asyncio.DatagramTransport, transport))
self.readers.append(_make_wrapped_transport(cast(asyncio.DatagramTransport, transport)))
if s in sender_sockets:
self.senders.append(cast(asyncio.DatagramTransport, transport))
self.senders.append(_make_wrapped_transport(cast(asyncio.DatagramTransport, transport)))

def _async_cache_cleanup(self) -> None:
"""Periodic cache cleanup."""
Expand All @@ -186,8 +228,8 @@ def _async_shutdown(self) -> None:
"""Shutdown transports and sockets."""
assert self.running_event is not None
self.running_event.clear()
for transport in itertools.chain(self.senders, self.readers):
transport.close()
for wrapped_transport in itertools.chain(self.senders, self.readers):
wrapped_transport.transport.close()

def close(self) -> None:
"""Close from sync context.
Expand Down Expand Up @@ -221,7 +263,7 @@ def __init__(self, zc: 'Zeroconf') -> None:
self.zc = zc
self.data: Optional[bytes] = None
self.last_time: float = 0
self.transport: Optional[asyncio.DatagramTransport] = None
self.transport: Optional[_WrappedTransport] = None
self.sock_description: Optional[str] = None
self._deferred: Dict[str, List[DNSIncoming]] = {}
self._timers: Dict[str, asyncio.TimerHandle] = {}
Expand Down Expand Up @@ -309,7 +351,7 @@ def handle_query_or_defer(
msg: DNSIncoming,
addr: str,
port: int,
transport: asyncio.DatagramTransport,
transport: _WrappedTransport,
v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (),
) -> None:
"""Deal with incoming query packets. Provides a response if
Expand Down Expand Up @@ -341,7 +383,7 @@ def _respond_query(
msg: Optional[DNSIncoming],
addr: str,
port: int,
transport: asyncio.DatagramTransport,
transport: _WrappedTransport,
v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (),
) -> None:
"""Respond to a query and reassemble any truncated deferred packets."""
Expand All @@ -362,27 +404,25 @@ def error_received(self, exc: Exception) -> None:
self.log_exception_once(exc, msg_str, exc)

def connection_made(self, transport: asyncio.BaseTransport) -> None:
self.transport = cast(asyncio.DatagramTransport, transport)
sock_name = self.transport.get_extra_info('sockname')
sock_fileno = self.transport.get_extra_info('socket').fileno()
self.sock_description = f"{sock_fileno} ({sock_name})"
wrapped_transport = _make_wrapped_transport(cast(asyncio.DatagramTransport, transport))
self.transport = wrapped_transport
self.sock_description = f"{wrapped_transport.fileno} ({wrapped_transport.sock_name})"

def connection_lost(self, exc: Optional[Exception]) -> None:
"""Handle connection lost."""


def async_send_with_transport(
log_debug: bool,
transport: asyncio.DatagramTransport,
transport: _WrappedTransport,
packet: bytes,
packet_num: int,
out: DNSOutgoing,
addr: Optional[str],
port: int,
v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (),
) -> None:
s = transport.get_extra_info('socket')
ipv6_socket = s.family == socket.AF_INET6
ipv6_socket = transport.is_ipv6
if addr is None:
real_addr = _MDNS_ADDR6 if ipv6_socket else _MDNS_ADDR
else:
Expand All @@ -394,8 +434,8 @@ def async_send_with_transport(
'Sending to (%s, %d) via [socket %s (%s)] (%d bytes #%d) %r as %r...',
real_addr,
port or _MDNS_PORT,
s.fileno(),
transport.get_extra_info('sockname'),
transport.fileno,
transport.sock_name,
len(packet),
packet_num + 1,
out,
Expand All @@ -404,9 +444,9 @@ def async_send_with_transport(
# Get flowinfo and scopeid for the IPV6 socket to create a complete IPv6
# address tuple: https://docs.python.org/3.6/library/socket.html#socket-families
if ipv6_socket and not v6_flow_scope:
_, _, sock_flowinfo, sock_scopeid = s.getsockname()
_, _, sock_flowinfo, sock_scopeid = transport.sock_name
v6_flow_scope = (sock_flowinfo, sock_scopeid)
transport.sendto(packet, (real_addr, port or _MDNS_PORT, *v6_flow_scope))
transport.transport.sendto(packet, (real_addr, port or _MDNS_PORT, *v6_flow_scope))


class Zeroconf(QuietLogger):
Expand Down Expand Up @@ -832,7 +872,7 @@ def handle_assembled_query(
packets: List[DNSIncoming],
addr: str,
port: int,
transport: asyncio.DatagramTransport,
transport: _WrappedTransport,
v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (),
) -> None:
"""Respond to a (re)assembled query.
Expand Down Expand Up @@ -870,7 +910,7 @@ def send(
addr: Optional[str] = None,
port: int = _MDNS_PORT,
v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (),
transport: Optional[asyncio.DatagramTransport] = None,
transport: Optional[_WrappedTransport] = None,
) -> None:
"""Sends an outgoing packet threadsafe."""
assert self.loop is not None
Expand All @@ -882,7 +922,7 @@ def async_send(
addr: Optional[str] = None,
port: int = _MDNS_PORT,
v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (),
transport: Optional[asyncio.DatagramTransport] = None,
transport: Optional[_WrappedTransport] = None,
) -> None:
"""Sends an outgoing packet."""
if self.done:
Expand Down