diff --git a/src/zeroconf/_core.py b/src/zeroconf/_core.py index 958b34688..aa50ddae1 100644 --- a/src/zeroconf/_core.py +++ b/src/zeroconf/_core.py @@ -28,7 +28,7 @@ import sys import threading from types import TracebackType # noqa # used in type hints -from typing import Awaitable, Dict, List, Optional, Tuple, Type, Union, cast +from typing import Any, Awaitable, Dict, List, Optional, Tuple, Type, Union, cast from ._cache import DNSCache from ._dns import DNSQuestion, DNSQuestionType @@ -105,6 +105,48 @@ _REGISTER_BROADCASTS = 3 +class _WrappedTransport: + """A wrapper for transports.""" + + __slots__ = ( + 'transport', + 'is_ipv6', + 'sock', + 'fileno', + 'sock_name', + ) + + def __init__( + self, + transport: asyncio.DatagramTransport, + is_ipv6: bool, + sock: socket.socket, + fileno: int, + sock_name: Any, + ) -> None: + """Initialize the wrapped transport. + + These attributes are used when sending packets. + """ + self.transport = transport + self.is_ipv6 = is_ipv6 + self.sock = sock + self.fileno = fileno + self.sock_name = sock_name + + +def _make_wrapped_transport(transport: asyncio.DatagramTransport) -> _WrappedTransport: + """Make a wrapped transport.""" + sock: socket.socket = transport.get_extra_info('socket') + return _WrappedTransport( + transport=transport, + is_ipv6=sock.family == socket.AF_INET6, + sock=sock, + fileno=sock.fileno(), + sock_name=sock.getsockname(), + ) + + class AsyncEngine: """An engine wraps sockets in the event loop.""" @@ -117,8 +159,8 @@ def __init__( self.loop: Optional[asyncio.AbstractEventLoop] = None self.zc = zeroconf self.protocols: List[AsyncListener] = [] - self.readers: List[asyncio.DatagramTransport] = [] - self.senders: List[asyncio.DatagramTransport] = [] + self.readers: List[_WrappedTransport] = [] + self.senders: List[_WrappedTransport] = [] self.running_event: Optional[asyncio.Event] = None self._listen_socket = listen_socket self._respond_sockets = respond_sockets @@ -158,9 +200,9 @@ async def _async_create_endpoints(self) -> None: for s in reader_sockets: transport, protocol = await loop.create_datagram_endpoint(lambda: AsyncListener(self.zc), sock=s) self.protocols.append(cast(AsyncListener, protocol)) - self.readers.append(cast(asyncio.DatagramTransport, transport)) + self.readers.append(_make_wrapped_transport(cast(asyncio.DatagramTransport, transport))) if s in sender_sockets: - self.senders.append(cast(asyncio.DatagramTransport, transport)) + self.senders.append(_make_wrapped_transport(cast(asyncio.DatagramTransport, transport))) def _async_cache_cleanup(self) -> None: """Periodic cache cleanup.""" @@ -186,8 +228,8 @@ def _async_shutdown(self) -> None: """Shutdown transports and sockets.""" assert self.running_event is not None self.running_event.clear() - for transport in itertools.chain(self.senders, self.readers): - transport.close() + for wrapped_transport in itertools.chain(self.senders, self.readers): + wrapped_transport.transport.close() def close(self) -> None: """Close from sync context. @@ -221,7 +263,7 @@ def __init__(self, zc: 'Zeroconf') -> None: self.zc = zc self.data: Optional[bytes] = None self.last_time: float = 0 - self.transport: Optional[asyncio.DatagramTransport] = None + self.transport: Optional[_WrappedTransport] = None self.sock_description: Optional[str] = None self._deferred: Dict[str, List[DNSIncoming]] = {} self._timers: Dict[str, asyncio.TimerHandle] = {} @@ -309,7 +351,7 @@ def handle_query_or_defer( msg: DNSIncoming, addr: str, port: int, - transport: asyncio.DatagramTransport, + transport: _WrappedTransport, v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), ) -> None: """Deal with incoming query packets. Provides a response if @@ -341,7 +383,7 @@ def _respond_query( msg: Optional[DNSIncoming], addr: str, port: int, - transport: asyncio.DatagramTransport, + transport: _WrappedTransport, v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), ) -> None: """Respond to a query and reassemble any truncated deferred packets.""" @@ -362,10 +404,9 @@ def error_received(self, exc: Exception) -> None: self.log_exception_once(exc, msg_str, exc) def connection_made(self, transport: asyncio.BaseTransport) -> None: - self.transport = cast(asyncio.DatagramTransport, transport) - sock_name = self.transport.get_extra_info('sockname') - sock_fileno = self.transport.get_extra_info('socket').fileno() - self.sock_description = f"{sock_fileno} ({sock_name})" + wrapped_transport = _make_wrapped_transport(cast(asyncio.DatagramTransport, transport)) + self.transport = wrapped_transport + self.sock_description = f"{wrapped_transport.fileno} ({wrapped_transport.sock_name})" def connection_lost(self, exc: Optional[Exception]) -> None: """Handle connection lost.""" @@ -373,7 +414,7 @@ def connection_lost(self, exc: Optional[Exception]) -> None: def async_send_with_transport( log_debug: bool, - transport: asyncio.DatagramTransport, + transport: _WrappedTransport, packet: bytes, packet_num: int, out: DNSOutgoing, @@ -381,8 +422,7 @@ def async_send_with_transport( port: int, v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), ) -> None: - s = transport.get_extra_info('socket') - ipv6_socket = s.family == socket.AF_INET6 + ipv6_socket = transport.is_ipv6 if addr is None: real_addr = _MDNS_ADDR6 if ipv6_socket else _MDNS_ADDR else: @@ -394,8 +434,8 @@ def async_send_with_transport( 'Sending to (%s, %d) via [socket %s (%s)] (%d bytes #%d) %r as %r...', real_addr, port or _MDNS_PORT, - s.fileno(), - transport.get_extra_info('sockname'), + transport.fileno, + transport.sock_name, len(packet), packet_num + 1, out, @@ -404,9 +444,9 @@ def async_send_with_transport( # Get flowinfo and scopeid for the IPV6 socket to create a complete IPv6 # address tuple: https://docs.python.org/3.6/library/socket.html#socket-families if ipv6_socket and not v6_flow_scope: - _, _, sock_flowinfo, sock_scopeid = s.getsockname() + _, _, sock_flowinfo, sock_scopeid = transport.sock_name v6_flow_scope = (sock_flowinfo, sock_scopeid) - transport.sendto(packet, (real_addr, port or _MDNS_PORT, *v6_flow_scope)) + transport.transport.sendto(packet, (real_addr, port or _MDNS_PORT, *v6_flow_scope)) class Zeroconf(QuietLogger): @@ -832,7 +872,7 @@ def handle_assembled_query( packets: List[DNSIncoming], addr: str, port: int, - transport: asyncio.DatagramTransport, + transport: _WrappedTransport, v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), ) -> None: """Respond to a (re)assembled query. @@ -870,7 +910,7 @@ def send( addr: Optional[str] = None, port: int = _MDNS_PORT, v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), - transport: Optional[asyncio.DatagramTransport] = None, + transport: Optional[_WrappedTransport] = None, ) -> None: """Sends an outgoing packet threadsafe.""" assert self.loop is not None @@ -882,7 +922,7 @@ def async_send( addr: Optional[str] = None, port: int = _MDNS_PORT, v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), - transport: Optional[asyncio.DatagramTransport] = None, + transport: Optional[_WrappedTransport] = None, ) -> None: """Sends an outgoing packet.""" if self.done: