Skip to content

Commit ceb92cf

Browse files
authored
feat: refactor notify implementation to reduce overhead of adding and removing listeners (#1224)
1 parent 0e96220 commit ceb92cf

3 files changed

Lines changed: 51 additions & 31 deletions

File tree

src/zeroconf/_core.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import sys
2626
import threading
2727
from types import TracebackType
28-
from typing import Awaitable, Dict, List, Optional, Tuple, Type, Union
28+
from typing import Awaitable, Dict, List, Optional, Set, Tuple, Type, Union
2929

3030
from ._cache import DNSCache
3131
from ._dns import DNSQuestion, DNSQuestionType
@@ -49,11 +49,13 @@
4949
from ._transport import _WrappedTransport
5050
from ._updates import RecordUpdateListener
5151
from ._utils.asyncio import (
52+
_resolve_all_futures_to_none,
5253
await_awaitable,
5354
get_running_loop,
5455
run_coro_with_timeout,
5556
shutdown_loop,
5657
wait_event_or_timeout,
58+
wait_for_future_set_or_timeout,
5759
)
5860
from ._utils.name import service_type_name
5961
from ._utils.net import (
@@ -188,7 +190,7 @@ def __init__(
188190
self.query_handler = QueryHandler(self.registry, self.cache, self.question_history)
189191
self.record_manager = RecordManager(self)
190192

191-
self.notify_event: Optional[asyncio.Event] = None
193+
self._notify_futures: Set[asyncio.Future] = set()
192194
self.loop: Optional[asyncio.AbstractEventLoop] = None
193195
self._loop_thread: Optional[threading.Thread] = None
194196

@@ -206,7 +208,6 @@ def start(self) -> None:
206208
"""Start Zeroconf."""
207209
self.loop = get_running_loop()
208210
if self.loop:
209-
self.notify_event = asyncio.Event()
210211
self.engine.setup(self.loop, None)
211212
return
212213
self._start_thread()
@@ -218,7 +219,6 @@ def _start_thread(self) -> None:
218219
def _run_loop() -> None:
219220
self.loop = asyncio.new_event_loop()
220221
asyncio.set_event_loop(self.loop)
221-
self.notify_event = asyncio.Event()
222222
self.engine.setup(self.loop, loop_thread_ready)
223223
self.loop.run_forever()
224224

@@ -245,8 +245,9 @@ def listeners(self) -> List[RecordUpdateListener]:
245245

246246
async def async_wait(self, timeout: float) -> None:
247247
"""Calling task waits for a given number of milliseconds or until notified."""
248-
assert self.notify_event is not None
249-
await wait_event_or_timeout(self.notify_event, timeout=millis_to_seconds(timeout))
248+
loop = self.loop
249+
assert loop is not None
250+
await wait_for_future_set_or_timeout(loop, self._notify_futures, timeout)
250251

251252
def notify_all(self) -> None:
252253
"""Notifies all waiting threads and notify listeners."""
@@ -255,9 +256,9 @@ def notify_all(self) -> None:
255256

256257
def async_notify_all(self) -> None:
257258
"""Schedule an async_notify_all."""
258-
assert self.notify_event is not None
259-
self.notify_event.set()
260-
self.notify_event.clear()
259+
notify_futures = self._notify_futures
260+
if notify_futures:
261+
_resolve_all_futures_to_none(notify_futures)
261262

262263
def get_service_info(
263264
self, type_: str, name: str, timeout: int = 3000, question_type: Optional[DNSQuestionType] = None

src/zeroconf/_services/info.py

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,15 @@
3939
from .._logger import log
4040
from .._protocol.outgoing import DNSOutgoing
4141
from .._updates import RecordUpdate, RecordUpdateListener
42-
from .._utils.asyncio import get_running_loop, run_coro_with_timeout
42+
from .._utils.asyncio import (
43+
_resolve_all_futures_to_none,
44+
get_running_loop,
45+
run_coro_with_timeout,
46+
wait_for_future_set_or_timeout,
47+
)
4348
from .._utils.name import service_type_name
4449
from .._utils.net import IPVersion, _encode_address
45-
from .._utils.time import current_time_millis, millis_to_seconds
50+
from .._utils.time import current_time_millis
4651
from ..const import (
4752
_ADDRESS_RECORD_TYPES,
4853
_CLASS_IN,
@@ -89,12 +94,6 @@ def instance_name_from_service_info(info: "ServiceInfo", strict: bool = True) ->
8994
_cached_ip_addresses = lru_cache(maxsize=256)(ip_address)
9095

9196

92-
def _set_future_none_if_not_done(fut: asyncio.Future) -> None:
93-
"""Set a future to None if it is not done."""
94-
if not fut.done(): # pragma: no branch
95-
fut.set_result(None)
96-
97-
9897
class ServiceInfo(RecordUpdateListener):
9998
"""Service information.
10099
@@ -180,7 +179,7 @@ def __init__(
180179
self.host_ttl = host_ttl
181180
self.other_ttl = other_ttl
182181
self.interface_index = interface_index
183-
self._new_records_futures: List[asyncio.Future] = []
182+
self._new_records_futures: Set[asyncio.Future] = set()
184183

185184
@property
186185
def name(self) -> str:
@@ -242,14 +241,9 @@ def properties(self) -> Dict[Union[str, bytes], Optional[Union[str, bytes]]]:
242241

243242
async def async_wait(self, timeout: float) -> None:
244243
"""Calling task waits for a given number of milliseconds or until notified."""
245-
loop = asyncio.get_running_loop()
246-
future = loop.create_future()
247-
self._new_records_futures.append(future)
248-
handle = loop.call_later(millis_to_seconds(timeout), _set_future_none_if_not_done, future)
249-
try:
250-
await future
251-
finally:
252-
handle.cancel()
244+
loop = get_running_loop()
245+
assert loop is not None
246+
await wait_for_future_set_or_timeout(loop, self._new_records_futures, timeout)
253247

254248
def addresses_by_version(self, version: IPVersion) -> List[bytes]:
255249
"""List addresses matching IP version.
@@ -441,11 +435,9 @@ def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordU
441435
442436
This method will be run in the event loop.
443437
"""
444-
if self._process_records_threadsafe(zc, now, records) and self._new_records_futures:
445-
for future in self._new_records_futures:
446-
if not future.done():
447-
future.set_result(None)
448-
self._new_records_futures.clear()
438+
new_records_futures = self._new_records_futures
439+
if self._process_records_threadsafe(zc, now, records) and new_records_futures:
440+
_resolve_all_futures_to_none(new_records_futures)
449441

450442
def _process_records_threadsafe(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> bool:
451443
"""Thread safe record updating.

src/zeroconf/_utils/asyncio.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,33 @@
4141
_WAIT_FOR_LOOP_TASKS_TIMEOUT = 3 # Must be larger than _TASK_AWAIT_TIMEOUT
4242

4343

44+
def _set_future_none_if_not_done(fut: asyncio.Future) -> None:
45+
"""Set a future to None if it is not done."""
46+
if not fut.done(): # pragma: no branch
47+
fut.set_result(None)
48+
49+
50+
def _resolve_all_futures_to_none(futures: Set[asyncio.Future]) -> None:
51+
"""Resolve all futures to None."""
52+
for fut in futures:
53+
_set_future_none_if_not_done(fut)
54+
futures.clear()
55+
56+
57+
async def wait_for_future_set_or_timeout(
58+
loop: asyncio.AbstractEventLoop, future_set: Set[asyncio.Future], timeout: float
59+
) -> None:
60+
"""Wait for a future or timeout (in milliseconds)."""
61+
future = loop.create_future()
62+
future_set.add(future)
63+
handle = loop.call_later(millis_to_seconds(timeout), _set_future_none_if_not_done, future)
64+
try:
65+
await future
66+
finally:
67+
handle.cancel()
68+
future_set.discard(future)
69+
70+
4471
async def wait_event_or_timeout(event: asyncio.Event, timeout: float) -> None:
4572
"""Wait for an event or timeout."""
4673
with contextlib.suppress(asyncio.TimeoutError):

0 commit comments

Comments
 (0)