2626import threading
2727import warnings
2828from collections import OrderedDict
29- from typing import Callable , Dict , List , Optional , Set , TYPE_CHECKING , Tuple , Union , cast
29+ from typing import Callable , Dict , Iterable , List , Optional , Set , TYPE_CHECKING , Tuple , Union , cast
3030
3131from .._dns import DNSAddress , DNSPointer , DNSQuestion , DNSQuestionType , DNSRecord
3232from .._logger import log
@@ -324,9 +324,9 @@ def _async_start(self) -> None:
324324 def service_state_changed (self ) -> SignalRegistrationInterface :
325325 return self ._service_state_changed .registration_interface
326326
327- def _record_matching_type (self , record : DNSRecord ) -> Optional [ str ]:
328- """Return the type if the record matches one of the types we are browsing."""
329- return next (( type_ for type_ in self .types if record . name .endswith (type_ )), None )
327+ def _names_matching_types (self , names : Iterable [ str ] ) -> List [ Tuple [ str , str ] ]:
328+ """Return the type and name for records matching the types we are browsing."""
329+ return [( type_ , name ) for type_ in self .types for name in names if name .endswith (f". { type_ } " )]
330330
331331 def _enqueue_callback (
332332 self ,
@@ -352,14 +352,18 @@ def _async_process_record_update(
352352 ) -> None :
353353 """Process a single record update from a batch of updates."""
354354 if isinstance (record , DNSPointer ):
355- if record .name not in self .types :
356- return
357- if old_record is None :
358- self ._enqueue_callback (ServiceStateChange .Added , record .name , record .alias )
359- elif record .is_expired (now ):
360- self ._enqueue_callback (ServiceStateChange .Removed , record .name , record .alias )
361- else :
362- self .reschedule_type (record .name , record .get_expiration_time (_EXPIRE_REFRESH_TIME_PERCENT ))
355+ name = record .name
356+ alias = record .alias
357+ matches = self ._names_matching_types ((alias ,))
358+ if name in self .types :
359+ matches .append ((name , alias ))
360+ for type_ , name in matches :
361+ if old_record is None :
362+ self ._enqueue_callback (ServiceStateChange .Added , type_ , name )
363+ elif record .is_expired (now ):
364+ self ._enqueue_callback (ServiceStateChange .Removed , type_ , name )
365+ else :
366+ self .reschedule_type (type_ , record .get_expiration_time (_EXPIRE_REFRESH_TIME_PERCENT ))
363367 return
364368
365369 # If its expired or already exists in the cache it cannot be updated.
@@ -368,17 +372,14 @@ def _async_process_record_update(
368372
369373 if isinstance (record , DNSAddress ):
370374 # Iterate through the DNSCache and callback any services that use this address
371- for service in self .zc .cache .async_entries_with_server (record .name ):
372- type_ = self ._record_matching_type (service )
373- if type_ :
374- self ._enqueue_callback (ServiceStateChange .Updated , type_ , service .name )
375- break
376-
375+ for type_ , name in self ._names_matching_types (
376+ {service .name for service in self .zc .cache .async_entries_with_server (record .name )}
377+ ):
378+ self ._enqueue_callback (ServiceStateChange .Updated , type_ , name )
377379 return
378380
379- type_ = self ._record_matching_type (record )
380- if type_ :
381- self ._enqueue_callback (ServiceStateChange .Updated , type_ , record .name )
381+ for type_ , name in self ._names_matching_types ((record .name ,)):
382+ self ._enqueue_callback (ServiceStateChange .Updated , type_ , name )
382383
383384 def async_update_records (self , zc : 'Zeroconf' , now : float , records : List [RecordUpdate ]) -> None :
384385 """Callback invoked by Zeroconf when new information arrives.
0 commit comments