From 15daf6b88aaf2bae25af5a022830c12207a05a87 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 11 Nov 2023 23:15:13 -0600 Subject: [PATCH 1/3] feat: speed up instances only used to lookup answers If there is nothing in the ServiceRegistry we would still process incoming questions. If we have no answers to give, we can skip the handling queries completely as it significantly reduces cpu time. --- src/zeroconf/_listener.pxd | 11 +++++++++++ src/zeroconf/_listener.py | 12 +++++++++--- src/zeroconf/_services/registry.pxd | 5 +++++ src/zeroconf/_services/registry.py | 10 +++++++--- 4 files changed, 32 insertions(+), 6 deletions(-) diff --git a/src/zeroconf/_listener.pxd b/src/zeroconf/_listener.pxd index 3b1d62313..ec877c78b 100644 --- a/src/zeroconf/_listener.pxd +++ b/src/zeroconf/_listener.pxd @@ -3,6 +3,7 @@ import cython from ._handlers.record_manager cimport RecordManager from ._protocol.incoming cimport DNSIncoming +from ._services.registry cimport ServiceRegistry from ._utils.time cimport current_time_millis, millis_to_seconds @@ -18,6 +19,7 @@ cdef cython.uint _DUPLICATE_PACKET_SUPPRESSION_INTERVAL cdef class AsyncListener: cdef public object zc + cdef ServiceRegistry _registry cdef RecordManager _record_manager cdef public cython.bytes data cdef public cython.float last_time @@ -34,3 +36,12 @@ cdef class AsyncListener: cpdef _process_datagram_at_time(self, bint debug, cython.uint data_len, cython.float now, bytes data, cython.tuple addrs) cdef _cancel_any_timers_for_addr(self, object addr) + + cpdef handle_query_or_defer( + self, + DNSIncoming msg, + object addr, + object port, + object transport, + tuple v6_flow_scope + ) diff --git a/src/zeroconf/_listener.py b/src/zeroconf/_listener.py index c27d1b610..07d059eb0 100644 --- a/src/zeroconf/_listener.py +++ b/src/zeroconf/_listener.py @@ -57,6 +57,7 @@ class AsyncListener: __slots__ = ( 'zc', + '_registry', '_record_manager', 'data', 'last_time', @@ -69,6 +70,7 @@ class AsyncListener: def __init__(self, zc: 'Zeroconf') -> None: self.zc = zc + self._registry = zc.registry self._record_manager = zc.record_manager self.data: Optional[bytes] = None self.last_time: float = 0 @@ -171,6 +173,10 @@ def _process_datagram_at_time( self._record_manager.async_updates_from_response(msg) return + if not self._registry.has_entries: + # If the registry is empty, we have no answers to give. + return + if TYPE_CHECKING: assert self.transport is not None self.handle_query_or_defer(msg, addr, port, self.transport, v6_flow_scope) @@ -178,10 +184,10 @@ def _process_datagram_at_time( def handle_query_or_defer( self, msg: DNSIncoming, - addr: str, - port: int, + addr: _str, + port: _int, transport: _WrappedTransport, - v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), + v6_flow_scope: Union[Tuple[()], Tuple[int, int]], ) -> None: """Deal with incoming query packets. Provides a response if possible.""" diff --git a/src/zeroconf/_services/registry.pxd b/src/zeroconf/_services/registry.pxd index 1d0562c3b..6f9017db7 100644 --- a/src/zeroconf/_services/registry.pxd +++ b/src/zeroconf/_services/registry.pxd @@ -9,6 +9,7 @@ cdef class ServiceRegistry: cdef cython.dict _services cdef public cython.dict types cdef public cython.dict servers + cdef public bint has_entries @cython.locals( record_list=cython.list, @@ -17,6 +18,10 @@ cdef class ServiceRegistry: cdef _add(self, ServiceInfo info) + @cython.locals( + info=ServiceInfo, + old_service_info=ServiceInfo + ) cdef _remove(self, cython.list infos) cpdef ServiceInfo async_get_info_name(self, str name) diff --git a/src/zeroconf/_services/registry.py b/src/zeroconf/_services/registry.py index e9dc4a62b..261e8e9cd 100644 --- a/src/zeroconf/_services/registry.py +++ b/src/zeroconf/_services/registry.py @@ -35,7 +35,7 @@ class ServiceRegistry: the event loop as it is not thread safe. """ - __slots__ = ("_services", "types", "servers") + __slots__ = ("_services", "types", "servers", "has_entries") def __init__( self, @@ -44,6 +44,7 @@ def __init__( self._services: Dict[str, ServiceInfo] = {} self.types: Dict[str, List] = {} self.servers: Dict[str, List] = {} + self.has_entries: bool = False def async_add(self, info: ServiceInfo) -> None: """Add a new service to the registry.""" @@ -95,14 +96,17 @@ def _add(self, info: ServiceInfo) -> None: self._services[info.key] = info self.types.setdefault(info.type.lower(), []).append(info.key) self.servers.setdefault(info.server_key, []).append(info.key) + self.has_entries = True def _remove(self, infos: List[ServiceInfo]) -> None: """Remove a services under the lock.""" for info in infos: - if info.key not in self._services: + old_service_info = self._services.get(info.key) + if old_service_info is None: continue - old_service_info = self._services[info.key] assert old_service_info.server_key is not None self.types[old_service_info.type.lower()].remove(info.key) self.servers[old_service_info.server_key].remove(info.key) del self._services[info.key] + + self.has_entries = bool(self._services) From 7c0529333d138fcd170104e2f951064f17c4fb29 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 11 Nov 2023 23:25:50 -0600 Subject: [PATCH 2/3] fix: tests --- tests/test_core.py | 35 +++++++++++++++++++++++------------ tests/test_listener.py | 12 +++++++++++- 2 files changed, 34 insertions(+), 13 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index 4bce6db97..415833df9 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -12,7 +12,7 @@ import time import unittest import unittest.mock -from typing import cast +from typing import Tuple, Union, cast from unittest.mock import patch if sys.version_info[:3][1] < 8: @@ -26,6 +26,8 @@ import zeroconf as r from zeroconf import NotRunningException, Zeroconf, const, current_time_millis +from zeroconf._listener import AsyncListener, _WrappedTransport +from zeroconf._protocol.incoming import DNSIncoming from zeroconf.asyncio import AsyncZeroconf from . import _clear_cache, _inject_response, _wait_for_start, has_working_ipv6 @@ -45,10 +47,19 @@ def teardown_module(): log.setLevel(original_logging_level) -def threadsafe_query(zc, protocol, *args): +def threadsafe_query( + zc: 'Zeroconf', + protocol: 'AsyncListener', + msg: DNSIncoming, + addr: str, + port: int, + transport: _WrappedTransport, + v6_flow_scope: Union[Tuple[()], Tuple[int, int]], +) -> None: async def make_query(): - protocol.handle_query_or_defer(*args) + protocol.handle_query_or_defer(msg, addr, port, transport, v6_flow_scope) + assert zc.loop is not None asyncio.run_coroutine_threadsafe(make_query(), zc.loop).result() @@ -476,28 +487,28 @@ def test_tc_bit_defers(): next_packet = r.DNSIncoming(packets.pop(0)) expected_deferred.append(next_packet) - threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None) + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ()) assert protocol._deferred[source_ip] == expected_deferred assert source_ip in protocol._timers next_packet = r.DNSIncoming(packets.pop(0)) expected_deferred.append(next_packet) - threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None) + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ()) assert protocol._deferred[source_ip] == expected_deferred assert source_ip in protocol._timers - threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None) + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ()) assert protocol._deferred[source_ip] == expected_deferred assert source_ip in protocol._timers next_packet = r.DNSIncoming(packets.pop(0)) expected_deferred.append(next_packet) - threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None) + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ()) assert protocol._deferred[source_ip] == expected_deferred assert source_ip in protocol._timers next_packet = r.DNSIncoming(packets.pop(0)) expected_deferred.append(next_packet) - threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None) + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ()) assert source_ip not in protocol._deferred assert source_ip not in protocol._timers @@ -555,20 +566,20 @@ def test_tc_bit_defers_last_response_missing(): next_packet = r.DNSIncoming(packets.pop(0)) expected_deferred.append(next_packet) - threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None) + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ()) assert protocol._deferred[source_ip] == expected_deferred timer1 = protocol._timers[source_ip] next_packet = r.DNSIncoming(packets.pop(0)) expected_deferred.append(next_packet) - threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None) + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ()) assert protocol._deferred[source_ip] == expected_deferred timer2 = protocol._timers[source_ip] assert timer1.cancelled() assert timer2 != timer1 # Send the same packet again to similar multi interfaces - threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None) + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ()) assert protocol._deferred[source_ip] == expected_deferred assert source_ip in protocol._timers timer3 = protocol._timers[source_ip] @@ -577,7 +588,7 @@ def test_tc_bit_defers_last_response_missing(): next_packet = r.DNSIncoming(packets.pop(0)) expected_deferred.append(next_packet) - threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None) + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ()) assert protocol._deferred[source_ip] == expected_deferred assert source_ip in protocol._timers timer4 = protocol._timers[source_ip] diff --git a/tests/test_listener.py b/tests/test_listener.py index dff01d78b..bd8022736 100644 --- a/tests/test_listener.py +++ b/tests/test_listener.py @@ -10,7 +10,14 @@ from unittest.mock import MagicMock, patch import zeroconf as r -from zeroconf import Zeroconf, _engine, _listener, const, current_time_millis +from zeroconf import ( + ServiceInfo, + Zeroconf, + _engine, + _listener, + const, + current_time_millis, +) from zeroconf._protocol import outgoing from zeroconf._protocol.incoming import DNSIncoming @@ -125,6 +132,9 @@ def test_guard_against_duplicate_packets(): These packets can quickly overwhelm the system. """ zc = Zeroconf(interfaces=['127.0.0.1']) + zc.registry.async_add( + ServiceInfo("_http._tcp.local.", "Test._http._tcp.local.", server="Test._http._tcp.local.", port=4) + ) zc.question_history = QuestionHistoryWithoutSuppression() class SubListener(_listener.AsyncListener): From 36a83dce9fbad50e19de579ffdbe40b695df728b Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 11 Nov 2023 23:28:45 -0600 Subject: [PATCH 3/3] fix: tests --- tests/test_core.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index 415833df9..de4b2ef5b 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -13,11 +13,9 @@ import unittest import unittest.mock from typing import Tuple, Union, cast -from unittest.mock import patch +from unittest.mock import Mock, patch if sys.version_info[:3][1] < 8: - from unittest.mock import Mock - AsyncMock = Mock else: from unittest.mock import AsyncMock