Skip to content

Commit a4d619a

Browse files
apworks1bdraco
andauthored
Handle Service types that end with another service type (#1041)
Co-authored-by: J. Nick Koston <nick@koston.org>
1 parent 22ed08c commit a4d619a

4 files changed

Lines changed: 139 additions & 35 deletions

File tree

tests/services/test_browser.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,3 +1017,86 @@ async def test_query_scheduler():
10171017
assert set(query_scheduler.process_ready_types(now + delay * 20)) == set()
10181018

10191019
assert set(query_scheduler.process_ready_types(now + delay * 31)) == {"_http._tcp.local."}
1020+
1021+
1022+
def test_service_browser_matching():
1023+
"""Test that the ServiceBrowser matching does not match partial names."""
1024+
1025+
# instantiate a zeroconf instance
1026+
zc = Zeroconf(interfaces=['127.0.0.1'])
1027+
# start a browser
1028+
type_ = "_http._tcp.local."
1029+
registration_name = "xxxyyy.%s" % type_
1030+
not_match_type_ = "_asustor-looksgood_http._tcp.local."
1031+
not_match_registration_name = "xxxyyy.%s" % not_match_type_
1032+
callbacks = []
1033+
1034+
class MyServiceListener(r.ServiceListener):
1035+
def add_service(self, zc, type_, name) -> None:
1036+
nonlocal callbacks
1037+
if name == registration_name:
1038+
callbacks.append(("add", type_, name))
1039+
1040+
def remove_service(self, zc, type_, name) -> None:
1041+
nonlocal callbacks
1042+
if name == registration_name:
1043+
callbacks.append(("remove", type_, name))
1044+
1045+
def update_service(self, zc, type_, name) -> None:
1046+
nonlocal callbacks
1047+
if name == registration_name:
1048+
callbacks.append(("update", type_, name))
1049+
1050+
listener = MyServiceListener()
1051+
1052+
browser = r.ServiceBrowser(zc, type_, None, listener)
1053+
1054+
desc = {'path': '/~paulsm/'}
1055+
address_parsed = "10.0.1.2"
1056+
address = socket.inet_aton(address_parsed)
1057+
info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address])
1058+
should_not_match = ServiceInfo(
1059+
not_match_type_, not_match_registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address]
1060+
)
1061+
1062+
def mock_incoming_msg(records) -> r.DNSIncoming:
1063+
generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
1064+
for record in records:
1065+
generated.add_answer_at_time(record, 0)
1066+
return r.DNSIncoming(generated.packets()[0])
1067+
1068+
_inject_response(
1069+
zc,
1070+
mock_incoming_msg([info.dns_pointer(), info.dns_service(), info.dns_text(), *info.dns_addresses()]),
1071+
)
1072+
_inject_response(
1073+
zc,
1074+
mock_incoming_msg(
1075+
[
1076+
should_not_match.dns_pointer(),
1077+
should_not_match.dns_service(),
1078+
should_not_match.dns_text(),
1079+
*should_not_match.dns_addresses(),
1080+
]
1081+
),
1082+
)
1083+
time.sleep(0.2)
1084+
info.port = 400
1085+
_inject_response(
1086+
zc,
1087+
mock_incoming_msg([info.dns_service()]),
1088+
)
1089+
should_not_match.port = 400
1090+
_inject_response(
1091+
zc,
1092+
mock_incoming_msg([should_not_match.dns_service()]),
1093+
)
1094+
time.sleep(0.2)
1095+
1096+
assert callbacks == [
1097+
('add', type_, registration_name),
1098+
('update', type_, registration_name),
1099+
]
1100+
browser.cancel()
1101+
1102+
zc.close()

tests/test_services.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,11 @@ def teardown_module():
3838
class ListenerTest(unittest.TestCase):
3939
def test_integration_with_listener_class(self):
4040

41+
sub_service_added = Event()
4142
service_added = Event()
4243
service_removed = Event()
43-
service_updated = Event()
44-
service_updated2 = Event()
44+
sub_service_updated = Event()
45+
duplicate_service_added = Event()
4546

4647
subtype_name = "My special Subtype"
4748
type_ = "_http._tcp.local."
@@ -58,21 +59,32 @@ def remove_service(self, zeroconf, type, name):
5859
service_removed.set()
5960

6061
def update_service(self, zeroconf, type, name):
61-
service_updated2.set()
62+
pass
63+
64+
class DuplicateListener(r.ServiceListener):
65+
def add_service(self, zeroconf, type, name):
66+
duplicate_service_added.set()
67+
68+
def remove_service(self, zeroconf, type, name):
69+
pass
70+
71+
def update_service(self, zeroconf, type, name):
72+
pass
6273

6374
class MySubListener(r.ServiceListener):
6475
def add_service(self, zeroconf, type, name):
76+
sub_service_added.set()
6577
pass
6678

6779
def remove_service(self, zeroconf, type, name):
6880
pass
6981

7082
def update_service(self, zeroconf, type, name):
71-
service_updated.set()
83+
sub_service_updated.set()
7284

7385
listener = MyListener()
7486
zeroconf_browser = Zeroconf(interfaces=['127.0.0.1'])
75-
zeroconf_browser.add_service_listener(subtype, listener)
87+
zeroconf_browser.add_service_listener(type_, listener)
7688

7789
properties = dict(
7890
prop_none=None,
@@ -107,6 +119,11 @@ def update_service(self, zeroconf, type, name):
107119
# short pause to allow multicast timers to expire
108120
time.sleep(3)
109121

122+
zeroconf_browser.add_service_listener(type_, DuplicateListener())
123+
duplicate_service_added.wait(
124+
1
125+
) # Ensure a listener for the same type calls back right away from cache
126+
110127
# clear the answer cache to force query
111128
_clear_cache(zeroconf_browser)
112129

@@ -160,7 +177,9 @@ def update_service(self, zeroconf, type, name):
160177

161178
# test TXT record update
162179
sublistener = MySubListener()
163-
zeroconf_browser.add_service_listener(registration_name, sublistener)
180+
181+
zeroconf_browser.add_service_listener(subtype, sublistener)
182+
164183
properties['prop_blank'] = b'an updated string'
165184
desc.update(properties)
166185
info_service = ServiceInfo(
@@ -174,8 +193,9 @@ def update_service(self, zeroconf, type, name):
174193
addresses=[socket.inet_aton("10.0.1.2")],
175194
)
176195
zeroconf_registrar.update_service(info_service)
177-
service_updated.wait(1)
178-
assert service_updated.is_set()
196+
197+
sub_service_added.wait(1) # we cleared the cache above
198+
assert sub_service_added.is_set()
179199

180200
info = zeroconf_browser.get_service_info(type_, registration_name)
181201
assert info is not None

zeroconf/_handlers.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -515,12 +515,12 @@ def _async_update_matching_records(
515515
This function must be run from the event loop.
516516
"""
517517
now = current_time_millis()
518-
records: List[RecordUpdate] = []
519-
for question in questions:
520-
for record in self.cache.async_entries_with_name(question.name):
521-
if not record.is_expired(now) and question.answered_by(record):
522-
records.append(RecordUpdate(record, None))
523-
518+
records: List[RecordUpdate] = [
519+
RecordUpdate(record, None)
520+
for question in questions
521+
for record in self.cache.async_entries_with_name(question.name)
522+
if not record.is_expired(now) and question.answered_by(record)
523+
]
524524
if not records:
525525
return
526526
listener.async_update_records(self.zc, now, records)

zeroconf/_services/browser.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import threading
2727
import warnings
2828
from collections import OrderedDict
29-
from typing import Callable, Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Union, cast
29+
from typing import Callable, Dict, Iterable, List, Optional, Set, TYPE_CHECKING, Tuple, Union, cast
3030

3131
from .._dns import DNSAddress, DNSPointer, DNSQuestion, DNSQuestionType, DNSRecord
3232
from .._logger import log
@@ -324,9 +324,9 @@ def _async_start(self) -> None:
324324
def service_state_changed(self) -> SignalRegistrationInterface:
325325
return self._service_state_changed.registration_interface
326326

327-
def _record_matching_type(self, record: DNSRecord) -> Optional[str]:
328-
"""Return the type if the record matches one of the types we are browsing."""
329-
return next((type_ for type_ in self.types if record.name.endswith(type_)), None)
327+
def _names_matching_types(self, names: Iterable[str]) -> List[Tuple[str, str]]:
328+
"""Return the type and name for records matching the types we are browsing."""
329+
return [(type_, name) for type_ in self.types for name in names if name.endswith(f".{type_}")]
330330

331331
def _enqueue_callback(
332332
self,
@@ -352,14 +352,18 @@ def _async_process_record_update(
352352
) -> None:
353353
"""Process a single record update from a batch of updates."""
354354
if isinstance(record, DNSPointer):
355-
if record.name not in self.types:
356-
return
357-
if old_record is None:
358-
self._enqueue_callback(ServiceStateChange.Added, record.name, record.alias)
359-
elif record.is_expired(now):
360-
self._enqueue_callback(ServiceStateChange.Removed, record.name, record.alias)
361-
else:
362-
self.reschedule_type(record.name, record.get_expiration_time(_EXPIRE_REFRESH_TIME_PERCENT))
355+
name = record.name
356+
alias = record.alias
357+
matches = self._names_matching_types((alias,))
358+
if name in self.types:
359+
matches.append((name, alias))
360+
for type_, name in matches:
361+
if old_record is None:
362+
self._enqueue_callback(ServiceStateChange.Added, type_, name)
363+
elif record.is_expired(now):
364+
self._enqueue_callback(ServiceStateChange.Removed, type_, name)
365+
else:
366+
self.reschedule_type(type_, record.get_expiration_time(_EXPIRE_REFRESH_TIME_PERCENT))
363367
return
364368

365369
# If its expired or already exists in the cache it cannot be updated.
@@ -368,17 +372,14 @@ def _async_process_record_update(
368372

369373
if isinstance(record, DNSAddress):
370374
# Iterate through the DNSCache and callback any services that use this address
371-
for service in self.zc.cache.async_entries_with_server(record.name):
372-
type_ = self._record_matching_type(service)
373-
if type_:
374-
self._enqueue_callback(ServiceStateChange.Updated, type_, service.name)
375-
break
376-
375+
for type_, name in self._names_matching_types(
376+
{service.name for service in self.zc.cache.async_entries_with_server(record.name)}
377+
):
378+
self._enqueue_callback(ServiceStateChange.Updated, type_, name)
377379
return
378380

379-
type_ = self._record_matching_type(record)
380-
if type_:
381-
self._enqueue_callback(ServiceStateChange.Updated, type_, record.name)
381+
for type_, name in self._names_matching_types((record.name,)):
382+
self._enqueue_callback(ServiceStateChange.Updated, type_, name)
382383

383384
def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> None:
384385
"""Callback invoked by Zeroconf when new information arrives.

0 commit comments

Comments
 (0)