diff --git a/src/zeroconf/_listener.pxd b/src/zeroconf/_listener.pxd index 3b1d62313..ec877c78b 100644 --- a/src/zeroconf/_listener.pxd +++ b/src/zeroconf/_listener.pxd @@ -3,6 +3,7 @@ import cython from ._handlers.record_manager cimport RecordManager from ._protocol.incoming cimport DNSIncoming +from ._services.registry cimport ServiceRegistry from ._utils.time cimport current_time_millis, millis_to_seconds @@ -18,6 +19,7 @@ cdef cython.uint _DUPLICATE_PACKET_SUPPRESSION_INTERVAL cdef class AsyncListener: cdef public object zc + cdef ServiceRegistry _registry cdef RecordManager _record_manager cdef public cython.bytes data cdef public cython.float last_time @@ -34,3 +36,12 @@ cdef class AsyncListener: cpdef _process_datagram_at_time(self, bint debug, cython.uint data_len, cython.float now, bytes data, cython.tuple addrs) cdef _cancel_any_timers_for_addr(self, object addr) + + cpdef handle_query_or_defer( + self, + DNSIncoming msg, + object addr, + object port, + object transport, + tuple v6_flow_scope + ) diff --git a/src/zeroconf/_listener.py b/src/zeroconf/_listener.py index c27d1b610..07d059eb0 100644 --- a/src/zeroconf/_listener.py +++ b/src/zeroconf/_listener.py @@ -57,6 +57,7 @@ class AsyncListener: __slots__ = ( 'zc', + '_registry', '_record_manager', 'data', 'last_time', @@ -69,6 +70,7 @@ class AsyncListener: def __init__(self, zc: 'Zeroconf') -> None: self.zc = zc + self._registry = zc.registry self._record_manager = zc.record_manager self.data: Optional[bytes] = None self.last_time: float = 0 @@ -171,6 +173,10 @@ def _process_datagram_at_time( self._record_manager.async_updates_from_response(msg) return + if not self._registry.has_entries: + # If the registry is empty, we have no answers to give. + return + if TYPE_CHECKING: assert self.transport is not None self.handle_query_or_defer(msg, addr, port, self.transport, v6_flow_scope) @@ -178,10 +184,10 @@ def _process_datagram_at_time( def handle_query_or_defer( self, msg: DNSIncoming, - addr: str, - port: int, + addr: _str, + port: _int, transport: _WrappedTransport, - v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), + v6_flow_scope: Union[Tuple[()], Tuple[int, int]], ) -> None: """Deal with incoming query packets. Provides a response if possible.""" diff --git a/src/zeroconf/_services/registry.pxd b/src/zeroconf/_services/registry.pxd index 1d0562c3b..6f9017db7 100644 --- a/src/zeroconf/_services/registry.pxd +++ b/src/zeroconf/_services/registry.pxd @@ -9,6 +9,7 @@ cdef class ServiceRegistry: cdef cython.dict _services cdef public cython.dict types cdef public cython.dict servers + cdef public bint has_entries @cython.locals( record_list=cython.list, @@ -17,6 +18,10 @@ cdef class ServiceRegistry: cdef _add(self, ServiceInfo info) + @cython.locals( + info=ServiceInfo, + old_service_info=ServiceInfo + ) cdef _remove(self, cython.list infos) cpdef ServiceInfo async_get_info_name(self, str name) diff --git a/src/zeroconf/_services/registry.py b/src/zeroconf/_services/registry.py index e9dc4a62b..261e8e9cd 100644 --- a/src/zeroconf/_services/registry.py +++ b/src/zeroconf/_services/registry.py @@ -35,7 +35,7 @@ class ServiceRegistry: the event loop as it is not thread safe. """ - __slots__ = ("_services", "types", "servers") + __slots__ = ("_services", "types", "servers", "has_entries") def __init__( self, @@ -44,6 +44,7 @@ def __init__( self._services: Dict[str, ServiceInfo] = {} self.types: Dict[str, List] = {} self.servers: Dict[str, List] = {} + self.has_entries: bool = False def async_add(self, info: ServiceInfo) -> None: """Add a new service to the registry.""" @@ -95,14 +96,17 @@ def _add(self, info: ServiceInfo) -> None: self._services[info.key] = info self.types.setdefault(info.type.lower(), []).append(info.key) self.servers.setdefault(info.server_key, []).append(info.key) + self.has_entries = True def _remove(self, infos: List[ServiceInfo]) -> None: """Remove a services under the lock.""" for info in infos: - if info.key not in self._services: + old_service_info = self._services.get(info.key) + if old_service_info is None: continue - old_service_info = self._services[info.key] assert old_service_info.server_key is not None self.types[old_service_info.type.lower()].remove(info.key) self.servers[old_service_info.server_key].remove(info.key) del self._services[info.key] + + self.has_entries = bool(self._services) diff --git a/tests/test_core.py b/tests/test_core.py index 4bce6db97..de4b2ef5b 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -12,12 +12,10 @@ import time import unittest import unittest.mock -from typing import cast -from unittest.mock import patch +from typing import Tuple, Union, cast +from unittest.mock import Mock, patch if sys.version_info[:3][1] < 8: - from unittest.mock import Mock - AsyncMock = Mock else: from unittest.mock import AsyncMock @@ -26,6 +24,8 @@ import zeroconf as r from zeroconf import NotRunningException, Zeroconf, const, current_time_millis +from zeroconf._listener import AsyncListener, _WrappedTransport +from zeroconf._protocol.incoming import DNSIncoming from zeroconf.asyncio import AsyncZeroconf from . import _clear_cache, _inject_response, _wait_for_start, has_working_ipv6 @@ -45,10 +45,19 @@ def teardown_module(): log.setLevel(original_logging_level) -def threadsafe_query(zc, protocol, *args): +def threadsafe_query( + zc: 'Zeroconf', + protocol: 'AsyncListener', + msg: DNSIncoming, + addr: str, + port: int, + transport: _WrappedTransport, + v6_flow_scope: Union[Tuple[()], Tuple[int, int]], +) -> None: async def make_query(): - protocol.handle_query_or_defer(*args) + protocol.handle_query_or_defer(msg, addr, port, transport, v6_flow_scope) + assert zc.loop is not None asyncio.run_coroutine_threadsafe(make_query(), zc.loop).result() @@ -476,28 +485,28 @@ def test_tc_bit_defers(): next_packet = r.DNSIncoming(packets.pop(0)) expected_deferred.append(next_packet) - threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None) + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ()) assert protocol._deferred[source_ip] == expected_deferred assert source_ip in protocol._timers next_packet = r.DNSIncoming(packets.pop(0)) expected_deferred.append(next_packet) - threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None) + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ()) assert protocol._deferred[source_ip] == expected_deferred assert source_ip in protocol._timers - threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None) + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ()) assert protocol._deferred[source_ip] == expected_deferred assert source_ip in protocol._timers next_packet = r.DNSIncoming(packets.pop(0)) expected_deferred.append(next_packet) - threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None) + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ()) assert protocol._deferred[source_ip] == expected_deferred assert source_ip in protocol._timers next_packet = r.DNSIncoming(packets.pop(0)) expected_deferred.append(next_packet) - threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None) + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ()) assert source_ip not in protocol._deferred assert source_ip not in protocol._timers @@ -555,20 +564,20 @@ def test_tc_bit_defers_last_response_missing(): next_packet = r.DNSIncoming(packets.pop(0)) expected_deferred.append(next_packet) - threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None) + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ()) assert protocol._deferred[source_ip] == expected_deferred timer1 = protocol._timers[source_ip] next_packet = r.DNSIncoming(packets.pop(0)) expected_deferred.append(next_packet) - threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None) + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ()) assert protocol._deferred[source_ip] == expected_deferred timer2 = protocol._timers[source_ip] assert timer1.cancelled() assert timer2 != timer1 # Send the same packet again to similar multi interfaces - threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None) + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ()) assert protocol._deferred[source_ip] == expected_deferred assert source_ip in protocol._timers timer3 = protocol._timers[source_ip] @@ -577,7 +586,7 @@ def test_tc_bit_defers_last_response_missing(): next_packet = r.DNSIncoming(packets.pop(0)) expected_deferred.append(next_packet) - threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None) + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ()) assert protocol._deferred[source_ip] == expected_deferred assert source_ip in protocol._timers timer4 = protocol._timers[source_ip] diff --git a/tests/test_listener.py b/tests/test_listener.py index dff01d78b..bd8022736 100644 --- a/tests/test_listener.py +++ b/tests/test_listener.py @@ -10,7 +10,14 @@ from unittest.mock import MagicMock, patch import zeroconf as r -from zeroconf import Zeroconf, _engine, _listener, const, current_time_millis +from zeroconf import ( + ServiceInfo, + Zeroconf, + _engine, + _listener, + const, + current_time_millis, +) from zeroconf._protocol import outgoing from zeroconf._protocol.incoming import DNSIncoming @@ -125,6 +132,9 @@ def test_guard_against_duplicate_packets(): These packets can quickly overwhelm the system. """ zc = Zeroconf(interfaces=['127.0.0.1']) + zc.registry.async_add( + ServiceInfo("_http._tcp.local.", "Test._http._tcp.local.", server="Test._http._tcp.local.", port=4) + ) zc.question_history = QuestionHistoryWithoutSuppression() class SubListener(_listener.AsyncListener):