Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions src/zeroconf/_services/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
USA
"""

import asyncio
import ipaddress
import random
from functools import lru_cache
Expand All @@ -37,10 +38,14 @@
from .._logger import log
from .._protocol.outgoing import DNSOutgoing
from .._updates import RecordUpdate, RecordUpdateListener
from .._utils.asyncio import get_running_loop, run_coro_with_timeout
from .._utils.asyncio import (
get_running_loop,
run_coro_with_timeout,
wait_event_or_timeout,
)
from .._utils.name import service_type_name
from .._utils.net import IPVersion, _encode_address
from .._utils.time import current_time_millis
from .._utils.time import current_time_millis, millis_to_seconds
from ..const import (
_CLASS_IN,
_CLASS_UNIQUE,
Expand Down Expand Up @@ -166,6 +171,7 @@ def __init__(
self.host_ttl = host_ttl
self.other_ttl = other_ttl
self.interface_index = interface_index
self._notify_event: Optional[asyncio.Event] = None

@property
def name(self) -> str:
Expand Down Expand Up @@ -221,6 +227,12 @@ def properties(self) -> Dict:
"""
return self._properties

async def async_wait(self, timeout: float) -> None:
"""Calling task waits for a given number of milliseconds or until notified."""
if self._notify_event is None:
self._notify_event = asyncio.Event()
await wait_event_or_timeout(self._notify_event, timeout=millis_to_seconds(timeout))

def addresses_by_version(self, version: IPVersion) -> List[bytes]:
"""List addresses matching IP version.

Expand Down Expand Up @@ -384,7 +396,9 @@ def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordU

This method will be run in the event loop.
"""
self._process_records_threadsafe(zc, now, records)
if self._process_records_threadsafe(zc, now, records) and self._notify_event:
self._notify_event.set()
self._notify_event.clear()

def _process_records_threadsafe(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> bool:
"""Thread safe record updating.
Expand Down Expand Up @@ -605,7 +619,7 @@ async def async_request(
delay *= 2
next_ += random.randint(*_AVOID_SYNC_DELAY_RANDOM_INTERVAL)

await zc.async_wait(min(next_, last) - now)
await self.async_wait(min(next_, last) - now)
now = current_time_millis()
finally:
zc.async_remove_listener(self)
Expand Down