|
20 | 20 | USA |
21 | 21 | """ |
22 | 22 |
|
| 23 | +import asyncio |
23 | 24 | import ipaddress |
24 | 25 | import random |
25 | 26 | from functools import lru_cache |
|
37 | 38 | from .._logger import log |
38 | 39 | from .._protocol.outgoing import DNSOutgoing |
39 | 40 | from .._updates import RecordUpdate, RecordUpdateListener |
40 | | -from .._utils.asyncio import get_running_loop, run_coro_with_timeout |
| 41 | +from .._utils.asyncio import ( |
| 42 | + get_running_loop, |
| 43 | + run_coro_with_timeout, |
| 44 | + wait_event_or_timeout, |
| 45 | +) |
41 | 46 | from .._utils.name import service_type_name |
42 | 47 | from .._utils.net import IPVersion, _encode_address |
43 | | -from .._utils.time import current_time_millis |
| 48 | +from .._utils.time import current_time_millis, millis_to_seconds |
44 | 49 | from ..const import ( |
45 | 50 | _CLASS_IN, |
46 | 51 | _CLASS_UNIQUE, |
@@ -166,6 +171,7 @@ def __init__( |
166 | 171 | self.host_ttl = host_ttl |
167 | 172 | self.other_ttl = other_ttl |
168 | 173 | self.interface_index = interface_index |
| 174 | + self._notify_event: Optional[asyncio.Event] = None |
169 | 175 |
|
170 | 176 | @property |
171 | 177 | def name(self) -> str: |
@@ -221,6 +227,12 @@ def properties(self) -> Dict: |
221 | 227 | """ |
222 | 228 | return self._properties |
223 | 229 |
|
| 230 | + async def async_wait(self, timeout: float) -> None: |
| 231 | + """Calling task waits for a given number of milliseconds or until notified.""" |
| 232 | + if self._notify_event is None: |
| 233 | + self._notify_event = asyncio.Event() |
| 234 | + await wait_event_or_timeout(self._notify_event, timeout=millis_to_seconds(timeout)) |
| 235 | + |
224 | 236 | def addresses_by_version(self, version: IPVersion) -> List[bytes]: |
225 | 237 | """List addresses matching IP version. |
226 | 238 |
|
@@ -384,7 +396,9 @@ def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordU |
384 | 396 |
|
385 | 397 | This method will be run in the event loop. |
386 | 398 | """ |
387 | | - self._process_records_threadsafe(zc, now, records) |
| 399 | + if self._process_records_threadsafe(zc, now, records) and self._notify_event: |
| 400 | + self._notify_event.set() |
| 401 | + self._notify_event.clear() |
388 | 402 |
|
389 | 403 | def _process_records_threadsafe(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> bool: |
390 | 404 | """Thread safe record updating. |
@@ -605,7 +619,7 @@ async def async_request( |
605 | 619 | delay *= 2 |
606 | 620 | next_ += random.randint(*_AVOID_SYNC_DELAY_RANDOM_INTERVAL) |
607 | 621 |
|
608 | | - await zc.async_wait(min(next_, last) - now) |
| 622 | + await self.async_wait(min(next_, last) - now) |
609 | 623 | now = current_time_millis() |
610 | 624 | finally: |
611 | 625 | zc.async_remove_listener(self) |
|
0 commit comments