Skip to content

Commit 0701b8a

Browse files
authored
feat: speed up instances only used to lookup answers (#1307)
1 parent 9ca9a57 commit 0701b8a

6 files changed

Lines changed: 67 additions & 22 deletions

File tree

src/zeroconf/_listener.pxd

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import cython
33

44
from ._handlers.record_manager cimport RecordManager
55
from ._protocol.incoming cimport DNSIncoming
6+
from ._services.registry cimport ServiceRegistry
67
from ._utils.time cimport current_time_millis, millis_to_seconds
78

89

@@ -18,6 +19,7 @@ cdef cython.uint _DUPLICATE_PACKET_SUPPRESSION_INTERVAL
1819
cdef class AsyncListener:
1920

2021
cdef public object zc
22+
cdef ServiceRegistry _registry
2123
cdef RecordManager _record_manager
2224
cdef public cython.bytes data
2325
cdef public cython.float last_time
@@ -34,3 +36,12 @@ cdef class AsyncListener:
3436
cpdef _process_datagram_at_time(self, bint debug, cython.uint data_len, cython.float now, bytes data, cython.tuple addrs)
3537

3638
cdef _cancel_any_timers_for_addr(self, object addr)
39+
40+
cpdef handle_query_or_defer(
41+
self,
42+
DNSIncoming msg,
43+
object addr,
44+
object port,
45+
object transport,
46+
tuple v6_flow_scope
47+
)

src/zeroconf/_listener.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class AsyncListener:
5757

5858
__slots__ = (
5959
'zc',
60+
'_registry',
6061
'_record_manager',
6162
'data',
6263
'last_time',
@@ -69,6 +70,7 @@ class AsyncListener:
6970

7071
def __init__(self, zc: 'Zeroconf') -> None:
7172
self.zc = zc
73+
self._registry = zc.registry
7274
self._record_manager = zc.record_manager
7375
self.data: Optional[bytes] = None
7476
self.last_time: float = 0
@@ -171,17 +173,21 @@ def _process_datagram_at_time(
171173
self._record_manager.async_updates_from_response(msg)
172174
return
173175

176+
if not self._registry.has_entries:
177+
# If the registry is empty, we have no answers to give.
178+
return
179+
174180
if TYPE_CHECKING:
175181
assert self.transport is not None
176182
self.handle_query_or_defer(msg, addr, port, self.transport, v6_flow_scope)
177183

178184
def handle_query_or_defer(
179185
self,
180186
msg: DNSIncoming,
181-
addr: str,
182-
port: int,
187+
addr: _str,
188+
port: _int,
183189
transport: _WrappedTransport,
184-
v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (),
190+
v6_flow_scope: Union[Tuple[()], Tuple[int, int]],
185191
) -> None:
186192
"""Deal with incoming query packets. Provides a response if
187193
possible."""

src/zeroconf/_services/registry.pxd

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ cdef class ServiceRegistry:
99
cdef cython.dict _services
1010
cdef public cython.dict types
1111
cdef public cython.dict servers
12+
cdef public bint has_entries
1213

1314
@cython.locals(
1415
record_list=cython.list,
@@ -17,6 +18,10 @@ cdef class ServiceRegistry:
1718

1819
cdef _add(self, ServiceInfo info)
1920

21+
@cython.locals(
22+
info=ServiceInfo,
23+
old_service_info=ServiceInfo
24+
)
2025
cdef _remove(self, cython.list infos)
2126

2227
cpdef ServiceInfo async_get_info_name(self, str name)

src/zeroconf/_services/registry.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class ServiceRegistry:
3535
the event loop as it is not thread safe.
3636
"""
3737

38-
__slots__ = ("_services", "types", "servers")
38+
__slots__ = ("_services", "types", "servers", "has_entries")
3939

4040
def __init__(
4141
self,
@@ -44,6 +44,7 @@ def __init__(
4444
self._services: Dict[str, ServiceInfo] = {}
4545
self.types: Dict[str, List] = {}
4646
self.servers: Dict[str, List] = {}
47+
self.has_entries: bool = False
4748

4849
def async_add(self, info: ServiceInfo) -> None:
4950
"""Add a new service to the registry."""
@@ -95,14 +96,17 @@ def _add(self, info: ServiceInfo) -> None:
9596
self._services[info.key] = info
9697
self.types.setdefault(info.type.lower(), []).append(info.key)
9798
self.servers.setdefault(info.server_key, []).append(info.key)
99+
self.has_entries = True
98100

99101
def _remove(self, infos: List[ServiceInfo]) -> None:
100102
"""Remove a services under the lock."""
101103
for info in infos:
102-
if info.key not in self._services:
104+
old_service_info = self._services.get(info.key)
105+
if old_service_info is None:
103106
continue
104-
old_service_info = self._services[info.key]
105107
assert old_service_info.server_key is not None
106108
self.types[old_service_info.type.lower()].remove(info.key)
107109
self.servers[old_service_info.server_key].remove(info.key)
108110
del self._services[info.key]
111+
112+
self.has_entries = bool(self._services)

tests/test_core.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,10 @@
1212
import time
1313
import unittest
1414
import unittest.mock
15-
from typing import cast
16-
from unittest.mock import patch
15+
from typing import Tuple, Union, cast
16+
from unittest.mock import Mock, patch
1717

1818
if sys.version_info[:3][1] < 8:
19-
from unittest.mock import Mock
20-
2119
AsyncMock = Mock
2220
else:
2321
from unittest.mock import AsyncMock
@@ -26,6 +24,8 @@
2624

2725
import zeroconf as r
2826
from zeroconf import NotRunningException, Zeroconf, const, current_time_millis
27+
from zeroconf._listener import AsyncListener, _WrappedTransport
28+
from zeroconf._protocol.incoming import DNSIncoming
2929
from zeroconf.asyncio import AsyncZeroconf
3030

3131
from . import _clear_cache, _inject_response, _wait_for_start, has_working_ipv6
@@ -45,10 +45,19 @@ def teardown_module():
4545
log.setLevel(original_logging_level)
4646

4747

48-
def threadsafe_query(zc, protocol, *args):
48+
def threadsafe_query(
49+
zc: 'Zeroconf',
50+
protocol: 'AsyncListener',
51+
msg: DNSIncoming,
52+
addr: str,
53+
port: int,
54+
transport: _WrappedTransport,
55+
v6_flow_scope: Union[Tuple[()], Tuple[int, int]],
56+
) -> None:
4957
async def make_query():
50-
protocol.handle_query_or_defer(*args)
58+
protocol.handle_query_or_defer(msg, addr, port, transport, v6_flow_scope)
5159

60+
assert zc.loop is not None
5261
asyncio.run_coroutine_threadsafe(make_query(), zc.loop).result()
5362

5463

@@ -476,28 +485,28 @@ def test_tc_bit_defers():
476485

477486
next_packet = r.DNSIncoming(packets.pop(0))
478487
expected_deferred.append(next_packet)
479-
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
488+
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ())
480489
assert protocol._deferred[source_ip] == expected_deferred
481490
assert source_ip in protocol._timers
482491

483492
next_packet = r.DNSIncoming(packets.pop(0))
484493
expected_deferred.append(next_packet)
485-
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
494+
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ())
486495
assert protocol._deferred[source_ip] == expected_deferred
487496
assert source_ip in protocol._timers
488-
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
497+
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ())
489498
assert protocol._deferred[source_ip] == expected_deferred
490499
assert source_ip in protocol._timers
491500

492501
next_packet = r.DNSIncoming(packets.pop(0))
493502
expected_deferred.append(next_packet)
494-
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
503+
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ())
495504
assert protocol._deferred[source_ip] == expected_deferred
496505
assert source_ip in protocol._timers
497506

498507
next_packet = r.DNSIncoming(packets.pop(0))
499508
expected_deferred.append(next_packet)
500-
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
509+
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ())
501510
assert source_ip not in protocol._deferred
502511
assert source_ip not in protocol._timers
503512

@@ -555,20 +564,20 @@ def test_tc_bit_defers_last_response_missing():
555564

556565
next_packet = r.DNSIncoming(packets.pop(0))
557566
expected_deferred.append(next_packet)
558-
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
567+
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ())
559568
assert protocol._deferred[source_ip] == expected_deferred
560569
timer1 = protocol._timers[source_ip]
561570

562571
next_packet = r.DNSIncoming(packets.pop(0))
563572
expected_deferred.append(next_packet)
564-
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
573+
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ())
565574
assert protocol._deferred[source_ip] == expected_deferred
566575
timer2 = protocol._timers[source_ip]
567576
assert timer1.cancelled()
568577
assert timer2 != timer1
569578

570579
# Send the same packet again to similar multi interfaces
571-
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
580+
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ())
572581
assert protocol._deferred[source_ip] == expected_deferred
573582
assert source_ip in protocol._timers
574583
timer3 = protocol._timers[source_ip]
@@ -577,7 +586,7 @@ def test_tc_bit_defers_last_response_missing():
577586

578587
next_packet = r.DNSIncoming(packets.pop(0))
579588
expected_deferred.append(next_packet)
580-
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
589+
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ())
581590
assert protocol._deferred[source_ip] == expected_deferred
582591
assert source_ip in protocol._timers
583592
timer4 = protocol._timers[source_ip]

tests/test_listener.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,14 @@
1010
from unittest.mock import MagicMock, patch
1111

1212
import zeroconf as r
13-
from zeroconf import Zeroconf, _engine, _listener, const, current_time_millis
13+
from zeroconf import (
14+
ServiceInfo,
15+
Zeroconf,
16+
_engine,
17+
_listener,
18+
const,
19+
current_time_millis,
20+
)
1421
from zeroconf._protocol import outgoing
1522
from zeroconf._protocol.incoming import DNSIncoming
1623

@@ -125,6 +132,9 @@ def test_guard_against_duplicate_packets():
125132
These packets can quickly overwhelm the system.
126133
"""
127134
zc = Zeroconf(interfaces=['127.0.0.1'])
135+
zc.registry.async_add(
136+
ServiceInfo("_http._tcp.local.", "Test._http._tcp.local.", server="Test._http._tcp.local.", port=4)
137+
)
128138
zc.question_history = QuestionHistoryWithoutSuppression()
129139

130140
class SubListener(_listener.AsyncListener):

0 commit comments

Comments
 (0)