Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions src/zeroconf/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import sys
import threading
from types import TracebackType
from typing import Awaitable, Dict, List, Optional, Tuple, Type, Union
from typing import Awaitable, Dict, List, Optional, Set, Tuple, Type, Union

from ._cache import DNSCache
from ._dns import DNSQuestion, DNSQuestionType
Expand All @@ -49,11 +49,13 @@
from ._transport import _WrappedTransport
from ._updates import RecordUpdateListener
from ._utils.asyncio import (
_resolve_all_futures_to_none,
await_awaitable,
get_running_loop,
run_coro_with_timeout,
shutdown_loop,
wait_event_or_timeout,
wait_for_future_set_or_timeout,
)
from ._utils.name import service_type_name
from ._utils.net import (
Expand Down Expand Up @@ -188,7 +190,7 @@ def __init__(
self.query_handler = QueryHandler(self.registry, self.cache, self.question_history)
self.record_manager = RecordManager(self)

self.notify_event: Optional[asyncio.Event] = None
self._notify_futures: Set[asyncio.Future] = set()
self.loop: Optional[asyncio.AbstractEventLoop] = None
self._loop_thread: Optional[threading.Thread] = None

Expand All @@ -206,7 +208,6 @@ def start(self) -> None:
"""Start Zeroconf."""
self.loop = get_running_loop()
if self.loop:
self.notify_event = asyncio.Event()
self.engine.setup(self.loop, None)
return
self._start_thread()
Expand All @@ -218,7 +219,6 @@ def _start_thread(self) -> None:
def _run_loop() -> None:
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.notify_event = asyncio.Event()
self.engine.setup(self.loop, loop_thread_ready)
self.loop.run_forever()

Expand All @@ -245,8 +245,9 @@ def listeners(self) -> List[RecordUpdateListener]:

async def async_wait(self, timeout: float) -> None:
"""Calling task waits for a given number of milliseconds or until notified."""
assert self.notify_event is not None
await wait_event_or_timeout(self.notify_event, timeout=millis_to_seconds(timeout))
loop = self.loop
assert loop is not None
await wait_for_future_set_or_timeout(loop, self._notify_futures, timeout)

def notify_all(self) -> None:
"""Notifies all waiting threads and notify listeners."""
Expand All @@ -255,9 +256,9 @@ def notify_all(self) -> None:

def async_notify_all(self) -> None:
"""Schedule an async_notify_all."""
assert self.notify_event is not None
self.notify_event.set()
self.notify_event.clear()
notify_futures = self._notify_futures
if notify_futures:
_resolve_all_futures_to_none(notify_futures)

def get_service_info(
self, type_: str, name: str, timeout: int = 3000, question_type: Optional[DNSQuestionType] = None
Expand Down
36 changes: 14 additions & 22 deletions src/zeroconf/_services/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,15 @@
from .._logger import log
from .._protocol.outgoing import DNSOutgoing
from .._updates import RecordUpdate, RecordUpdateListener
from .._utils.asyncio import get_running_loop, run_coro_with_timeout
from .._utils.asyncio import (
_resolve_all_futures_to_none,
get_running_loop,
run_coro_with_timeout,
wait_for_future_set_or_timeout,
)
from .._utils.name import service_type_name
from .._utils.net import IPVersion, _encode_address
from .._utils.time import current_time_millis, millis_to_seconds
from .._utils.time import current_time_millis
from ..const import (
_ADDRESS_RECORD_TYPES,
_CLASS_IN,
Expand Down Expand Up @@ -89,12 +94,6 @@ def instance_name_from_service_info(info: "ServiceInfo", strict: bool = True) ->
_cached_ip_addresses = lru_cache(maxsize=256)(ip_address)


def _set_future_none_if_not_done(fut: asyncio.Future) -> None:
"""Set a future to None if it is not done."""
if not fut.done(): # pragma: no branch
fut.set_result(None)


class ServiceInfo(RecordUpdateListener):
"""Service information.

Expand Down Expand Up @@ -180,7 +179,7 @@ def __init__(
self.host_ttl = host_ttl
self.other_ttl = other_ttl
self.interface_index = interface_index
self._new_records_futures: List[asyncio.Future] = []
self._new_records_futures: Set[asyncio.Future] = set()

@property
def name(self) -> str:
Expand Down Expand Up @@ -242,14 +241,9 @@ def properties(self) -> Dict[Union[str, bytes], Optional[Union[str, bytes]]]:

async def async_wait(self, timeout: float) -> None:
"""Calling task waits for a given number of milliseconds or until notified."""
loop = asyncio.get_running_loop()
future = loop.create_future()
self._new_records_futures.append(future)
handle = loop.call_later(millis_to_seconds(timeout), _set_future_none_if_not_done, future)
try:
await future
finally:
handle.cancel()
loop = get_running_loop()
assert loop is not None
await wait_for_future_set_or_timeout(loop, self._new_records_futures, timeout)

def addresses_by_version(self, version: IPVersion) -> List[bytes]:
"""List addresses matching IP version.
Expand Down Expand Up @@ -441,11 +435,9 @@ def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordU

This method will be run in the event loop.
"""
if self._process_records_threadsafe(zc, now, records) and self._new_records_futures:
for future in self._new_records_futures:
if not future.done():
future.set_result(None)
self._new_records_futures.clear()
new_records_futures = self._new_records_futures
if self._process_records_threadsafe(zc, now, records) and new_records_futures:
_resolve_all_futures_to_none(new_records_futures)

def _process_records_threadsafe(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> bool:
"""Thread safe record updating.
Expand Down
27 changes: 27 additions & 0 deletions src/zeroconf/_utils/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,33 @@
_WAIT_FOR_LOOP_TASKS_TIMEOUT = 3 # Must be larger than _TASK_AWAIT_TIMEOUT


def _set_future_none_if_not_done(fut: asyncio.Future) -> None:
"""Set a future to None if it is not done."""
if not fut.done(): # pragma: no branch
fut.set_result(None)


def _resolve_all_futures_to_none(futures: Set[asyncio.Future]) -> None:
"""Resolve all futures to None."""
for fut in futures:
_set_future_none_if_not_done(fut)
futures.clear()


async def wait_for_future_set_or_timeout(
loop: asyncio.AbstractEventLoop, future_set: Set[asyncio.Future], timeout: float
) -> None:
"""Wait for a future or timeout (in milliseconds)."""
future = loop.create_future()
future_set.add(future)
handle = loop.call_later(millis_to_seconds(timeout), _set_future_none_if_not_done, future)
try:
await future
finally:
handle.cancel()
future_set.discard(future)


async def wait_event_or_timeout(event: asyncio.Event, timeout: float) -> None:
"""Wait for an event or timeout."""
with contextlib.suppress(asyncio.TimeoutError):
Expand Down