3939from .._logger import log
4040from .._protocol .outgoing import DNSOutgoing
4141from .._updates import RecordUpdate , RecordUpdateListener
42- from .._utils .asyncio import (
43- get_running_loop ,
44- run_coro_with_timeout ,
45- wait_event_or_timeout ,
46- )
42+ from .._utils .asyncio import get_running_loop , run_coro_with_timeout
4743from .._utils .name import service_type_name
4844from .._utils .net import IPVersion , _encode_address
4945from .._utils .time import current_time_millis , millis_to_seconds
@@ -131,6 +127,7 @@ class ServiceInfo(RecordUpdateListener):
131127 "host_ttl" ,
132128 "other_ttl" ,
133129 "interface_index" ,
130+ "_new_records_futures" ,
134131 )
135132
136133 def __init__ (
@@ -177,7 +174,7 @@ def __init__(
177174 self .host_ttl = host_ttl
178175 self .other_ttl = other_ttl
179176 self .interface_index = interface_index
180- self ._notify_event : Optional [asyncio .Event ] = None
177+ self ._new_records_futures : List [asyncio .Future ] = []
181178
182179 @property
183180 def name (self ) -> str :
@@ -235,9 +232,14 @@ def properties(self) -> Dict:
235232
236233 async def async_wait (self , timeout : float ) -> None :
237234 """Calling task waits for a given number of milliseconds or until notified."""
238- if self ._notify_event is None :
239- self ._notify_event = asyncio .Event ()
240- await wait_event_or_timeout (self ._notify_event , timeout = millis_to_seconds (timeout ))
235+ loop = asyncio .get_running_loop ()
236+ future = loop .create_future ()
237+ self ._new_records_futures .append (future )
238+ handle = loop .call_later (millis_to_seconds (timeout ), future .set_result , None )
239+ try :
240+ await future
241+ finally :
242+ handle .cancel ()
241243
242244 def addresses_by_version (self , version : IPVersion ) -> List [bytes ]:
243245 """List addresses matching IP version.
@@ -409,9 +411,11 @@ def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordU
409411
410412 This method will be run in the event loop.
411413 """
412- if self ._process_records_threadsafe (zc , now , records ) and self ._notify_event :
413- self ._notify_event .set ()
414- self ._notify_event .clear ()
414+ if self ._process_records_threadsafe (zc , now , records ) and self ._new_records_futures :
415+ for future in self ._new_records_futures :
416+ if not future .done ():
417+ future .set_result (None )
418+ self ._new_records_futures .clear ()
415419
416420 def _process_records_threadsafe (self , zc : 'Zeroconf' , now : float , records : List [RecordUpdate ]) -> bool :
417421 """Thread safe record updating.
@@ -591,12 +595,13 @@ def set_server_if_missing(self) -> None:
591595 self .server = self .name
592596 self .server_key = self .server .lower ()
593597
594- def load_from_cache (self , zc : 'Zeroconf' ) -> bool :
598+ def load_from_cache (self , zc : 'Zeroconf' , now : Optional [ float ] = None ) -> bool :
595599 """Populate the service info from the cache.
596600
597601 This method is designed to be threadsafe.
598602 """
599- now = current_time_millis ()
603+ if not now :
604+ now = current_time_millis ()
600605 original_server_key = self .server_key
601606 cached_srv_record = zc .cache .get_by_details (self .name , _TYPE_SRV , _CLASS_IN )
602607 if cached_srv_record :
@@ -664,11 +669,13 @@ async def async_request(
664669 """
665670 if not zc .started :
666671 await zc .async_wait_for_start ()
667- if self .load_from_cache (zc ):
672+
673+ now = current_time_millis ()
674+
675+ if self .load_from_cache (zc , now ):
668676 return True
669677
670678 first_request = True
671- now = current_time_millis ()
672679 delay = _LISTENER_TIME
673680 next_ = now
674681 last = now + timeout
@@ -683,7 +690,7 @@ async def async_request(
683690 )
684691 first_request = False
685692 if not out .questions :
686- return self .load_from_cache (zc )
693+ return self .load_from_cache (zc , now )
687694 zc .async_send (out , addr , port )
688695 next_ = now + delay
689696 delay *= 2
0 commit comments