diff --git a/tests/test_services.py b/tests/test_services.py index f1499eae..2730811e 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -44,7 +44,7 @@ def test_integration_with_listener_class(self): sub_service_updated = Event() duplicate_service_added = Event() - subtype_name = "My special Subtype" + subtype_name = "_printer" type_ = "_http._tcp.local." subtype = subtype_name + "._sub." + type_ name = "UPPERxxxyyyæøå" diff --git a/tests/utils/test_name.py b/tests/utils/test_name.py index e9b781ad..07feccb7 100644 --- a/tests/utils/test_name.py +++ b/tests/utils/test_name.py @@ -23,3 +23,25 @@ def test_service_type_name_overlong_full_name(): nameutils.service_type_name(f"{long_name}._tivo-videostream._tcp.local.") with pytest.raises(BadTypeInNameException): nameutils.service_type_name(f"{long_name}._tivo-videostream._tcp.local.", strict=False) + + +def test_possible_types(): + """Test possible types from name.""" + assert nameutils.possible_types('.') == set() + assert nameutils.possible_types('local.') == set() + assert nameutils.possible_types('_tcp.local.') == set() + assert nameutils.possible_types('_test-srvc-type._tcp.local.') == {'_test-srvc-type._tcp.local.'} + assert nameutils.possible_types('_any._tcp.local.') == {'_any._tcp.local.'} + assert nameutils.possible_types('.._x._tcp.local.') == {'_x._tcp.local.'} + assert nameutils.possible_types('x.y._http._tcp.local.') == {'_http._tcp.local.'} + assert nameutils.possible_types('1.2.3._mqtt._tcp.local.') == {'_mqtt._tcp.local.'} + assert nameutils.possible_types('x.sub._http._tcp.local.') == {'_http._tcp.local.'} + assert nameutils.possible_types('6d86f882b90facee9170ad3439d72a4d6ee9f511._zget._http._tcp.local.') == { + '_http._tcp.local.', + '_zget._http._tcp.local.', + } + assert nameutils.possible_types('my._printer._sub._http._tcp.local.') == { + '_http._tcp.local.', + '_sub._http._tcp.local.', + '_printer._sub._http._tcp.local.', + } diff --git a/zeroconf/_services/browser.py b/zeroconf/_services/browser.py index 12c19a1d..bbe5a056 100644 --- a/zeroconf/_services/browser.py +++ b/zeroconf/_services/browser.py @@ -38,7 +38,7 @@ SignalRegistrationInterface, ) from .._updates import RecordUpdate, RecordUpdateListener -from .._utils.name import service_type_name +from .._utils.name import possible_types, service_type_name from .._utils.time import current_time_millis, millis_to_seconds from ..const import ( _BROWSER_BACKOFF_LIMIT, @@ -326,7 +326,7 @@ def service_state_changed(self) -> SignalRegistrationInterface: def _names_matching_types(self, names: Iterable[str]) -> List[Tuple[str, str]]: """Return the type and name for records matching the types we are browsing.""" - return [(type_, name) for type_ in self.types for name in names if name.endswith(f".{type_}")] + return [(type_, name) for name in names for type_ in self.types.intersection(possible_types(name))] def _enqueue_callback( self, @@ -352,16 +352,11 @@ def _async_process_record_update( ) -> None: """Process a single record update from a batch of updates.""" if isinstance(record, DNSPointer): - name = record.name - alias = record.alias - matches = self._names_matching_types((alias,)) - if name in self.types: - matches.append((name, alias)) - for type_, name in matches: + for type_ in self.types.intersection(possible_types(record.name)): if old_record is None: - self._enqueue_callback(ServiceStateChange.Added, type_, name) + self._enqueue_callback(ServiceStateChange.Added, type_, record.alias) elif record.is_expired(now): - self._enqueue_callback(ServiceStateChange.Removed, type_, name) + self._enqueue_callback(ServiceStateChange.Removed, type_, record.alias) else: self.reschedule_type(type_, now, record.get_expiration_time(_EXPIRE_REFRESH_TIME_PERCENT)) return diff --git a/zeroconf/_utils/name.py b/zeroconf/_utils/name.py index f0c34e5d..367dfb18 100644 --- a/zeroconf/_utils/name.py +++ b/zeroconf/_utils/name.py @@ -20,6 +20,8 @@ USA """ +from typing import Set + from .._exceptions import BadTypeInNameException from ..const import ( _HAS_ASCII_CONTROL_CHARS, @@ -155,3 +157,16 @@ def service_type_name(type_: str, *, strict: bool = True) -> str: # pylint: dis ) return service_name + trailer + + +def possible_types(name: str) -> Set[str]: + """Build a set of all possible types from a fully qualified name.""" + labels = name.split('.') + label_count = len(labels) + types = set() + for count in range(label_count): + parts = labels[label_count - count - 4 :] + if not parts[0].startswith('_'): + break + types.add('.'.join(parts)) + return types