Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 24 additions & 17 deletions src/zeroconf/_services/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,7 @@
from .._logger import log
from .._protocol.outgoing import DNSOutgoing
from .._updates import RecordUpdate, RecordUpdateListener
from .._utils.asyncio import (
get_running_loop,
run_coro_with_timeout,
wait_event_or_timeout,
)
from .._utils.asyncio import get_running_loop, run_coro_with_timeout
from .._utils.name import service_type_name
from .._utils.net import IPVersion, _encode_address
from .._utils.time import current_time_millis, millis_to_seconds
Expand Down Expand Up @@ -131,6 +127,7 @@ class ServiceInfo(RecordUpdateListener):
"host_ttl",
"other_ttl",
"interface_index",
"_new_records_futures",
)

def __init__(
Expand Down Expand Up @@ -177,7 +174,7 @@ def __init__(
self.host_ttl = host_ttl
self.other_ttl = other_ttl
self.interface_index = interface_index
self._notify_event: Optional[asyncio.Event] = None
self._new_records_futures: List[asyncio.Future] = []

@property
def name(self) -> str:
Expand Down Expand Up @@ -235,9 +232,14 @@ def properties(self) -> Dict:

async def async_wait(self, timeout: float) -> None:
"""Calling task waits for a given number of milliseconds or until notified."""
if self._notify_event is None:
self._notify_event = asyncio.Event()
await wait_event_or_timeout(self._notify_event, timeout=millis_to_seconds(timeout))
loop = asyncio.get_running_loop()
future = loop.create_future()
self._new_records_futures.append(future)
handle = loop.call_later(millis_to_seconds(timeout), future.set_result, None)
try:
await future
finally:
handle.cancel()

def addresses_by_version(self, version: IPVersion) -> List[bytes]:
"""List addresses matching IP version.
Expand Down Expand Up @@ -409,9 +411,11 @@ def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordU

This method will be run in the event loop.
"""
if self._process_records_threadsafe(zc, now, records) and self._notify_event:
self._notify_event.set()
self._notify_event.clear()
if self._process_records_threadsafe(zc, now, records) and self._new_records_futures:
for future in self._new_records_futures:
if not future.done():
future.set_result(None)
self._new_records_futures.clear()

def _process_records_threadsafe(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> bool:
"""Thread safe record updating.
Expand Down Expand Up @@ -591,12 +595,13 @@ def set_server_if_missing(self) -> None:
self.server = self.name
self.server_key = self.server.lower()

def load_from_cache(self, zc: 'Zeroconf') -> bool:
def load_from_cache(self, zc: 'Zeroconf', now: Optional[float] = None) -> bool:
"""Populate the service info from the cache.

This method is designed to be threadsafe.
"""
now = current_time_millis()
if not now:
now = current_time_millis()
original_server_key = self.server_key
cached_srv_record = zc.cache.get_by_details(self.name, _TYPE_SRV, _CLASS_IN)
if cached_srv_record:
Expand Down Expand Up @@ -664,11 +669,13 @@ async def async_request(
"""
if not zc.started:
await zc.async_wait_for_start()
if self.load_from_cache(zc):

now = current_time_millis()

if self.load_from_cache(zc, now):
return True

first_request = True
now = current_time_millis()
delay = _LISTENER_TIME
next_ = now
last = now + timeout
Expand All @@ -683,7 +690,7 @@ async def async_request(
)
first_request = False
if not out.questions:
return self.load_from_cache(zc)
return self.load_from_cache(zc, now)
zc.async_send(out, addr, port)
next_ = now + delay
delay *= 2
Expand Down