Skip to content

Commit c9f3c91

Browse files
authored
Ensure all listeners are cleaned up on ServiceBrowser cancelation (python-zeroconf#290)
When creating listeners for a ServiceBrowser with multiple types they would not all be removed on cancelation. This led to a build up of stale listeners when ServiceBrowsers were frequently added and removed.
1 parent 19e33a6 commit c9f3c91

2 files changed

Lines changed: 61 additions & 32 deletions

File tree

zeroconf/__init__.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,10 @@
174174
_HAS_ONLY_A_TO_Z_NUM_HYPHEN_UNDERSCORE = re.compile(r'^[A-Za-z0-9\-\_]+$')
175175
_HAS_ASCII_CONTROL_CHARS = re.compile(r'[\x00-\x1f\x7f]')
176176

177+
_EXPIRE_FULL_TIME_PERCENT = 100
178+
_EXPIRE_STALE_TIME_PERCENT = 50
179+
_EXPIRE_REFRESH_TIME_PERCENT = 75
180+
177181
try:
178182
_IPPROTO_IPV6 = socket.IPPROTO_IPV6
179183
except AttributeError:
@@ -459,8 +463,8 @@ def __init__(self, name: str, type_: int, class_: int, ttl: Union[float, int]) -
459463
DNSEntry.__init__(self, name, type_, class_)
460464
self.ttl = ttl
461465
self.created = current_time_millis()
462-
self._expiration_time = self.get_expiration_time(100)
463-
self._stale_time = self.get_expiration_time(50)
466+
self._expiration_time = self.get_expiration_time(_EXPIRE_FULL_TIME_PERCENT)
467+
self._stale_time = self.get_expiration_time(_EXPIRE_STALE_TIME_PERCENT)
464468

465469
def __eq__(self, other: Any) -> bool:
466470
"""Abstract method"""
@@ -506,8 +510,8 @@ def reset_ttl(self, other: 'DNSRecord') -> None:
506510
another record."""
507511
self.created = other.created
508512
self.ttl = other.ttl
509-
self._expiration_time = self.get_expiration_time(100)
510-
self._stale_time = self.get_expiration_time(50)
513+
self._expiration_time = self.get_expiration_time(_EXPIRE_FULL_TIME_PERCENT)
514+
self._stale_time = self.get_expiration_time(_EXPIRE_STALE_TIME_PERCENT)
511515

512516
def write(self, out: 'DNSOutgoing') -> None:
513517
"""Abstract method"""
@@ -1609,7 +1613,7 @@ def enqueue_callback(state_change: ServiceStateChange, type_: str, name: str) ->
16091613
enqueue_callback(ServiceStateChange.Removed, record.name, record.alias)
16101614
return
16111615

1612-
expires = record.get_expiration_time(75)
1616+
expires = record.get_expiration_time(_EXPIRE_REFRESH_TIME_PERCENT)
16131617
if expires < self._next_time[record.name]:
16141618
self._next_time[record.name] = expires
16151619

@@ -1649,8 +1653,8 @@ def cancel(self) -> None:
16491653
self.join()
16501654

16511655
def run(self) -> None:
1652-
for type_ in self.types:
1653-
self.zc.add_listener(self, DNSQuestion(type_, _TYPE_PTR, _CLASS_IN))
1656+
questions = [DNSQuestion(type_, _TYPE_PTR, _CLASS_IN) for type_ in self.types]
1657+
self.zc.add_listener(self, questions)
16541658

16551659
while True:
16561660
now = current_time_millis()
@@ -2595,16 +2599,20 @@ def check_service(
25952599
i += 1
25962600
next_time += _CHECK_TIME
25972601

2598-
def add_listener(self, listener: RecordUpdateListener, question: Optional[DNSQuestion]) -> None:
2602+
def add_listener(
2603+
self, listener: RecordUpdateListener, question: Optional[Union[DNSQuestion, List[DNSQuestion]]]
2604+
) -> None:
25992605
"""Adds a listener for a given question. The listener will have
26002606
its update_record method called when information is available to
2601-
answer the question."""
2607+
answer the question(s)."""
26022608
now = current_time_millis()
26032609
self.listeners.append(listener)
26042610
if question is not None:
2605-
for record in self.cache.entries_with_name(question.name):
2606-
if question.answered_by(record) and not record.is_expired(now):
2607-
listener.update_record(self, now, record)
2611+
questions = [question] if isinstance(question, DNSQuestion) else question
2612+
for single_question in questions:
2613+
for record in self.cache.entries_with_name(single_question.name):
2614+
if single_question.answered_by(record) and not record.is_expired(now):
2615+
listener.update_record(self, now, record)
26082616
self.notify_all()
26092617

26102618
def remove_listener(self, listener: RecordUpdateListener) -> None:

zeroconf/test.py

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
ServiceStateChange,
2525
Zeroconf,
2626
ZeroconfServiceTypes,
27+
_EXPIRE_REFRESH_TIME_PERCENT,
2728
)
2829

2930
log = logging.getLogger('zeroconf')
@@ -1237,16 +1238,18 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi
12371238
assert service_removed_count == 1
12381239

12391240
finally:
1241+
assert len(zeroconf.listeners) == 1
12401242
service_browser.cancel()
1243+
assert len(zeroconf.listeners) == 0
12411244
zeroconf.remove_all_service_listeners()
12421245
zeroconf.close()
12431246

12441247

12451248
class TestServiceBrowserMultipleTypes(unittest.TestCase):
12461249
def test_update_record(self):
12471250

1248-
service_names = ['name._type._tcp.local.', 'name._type._udp.local']
1249-
service_types = ['_type._tcp.local.', '_type._udp.local.']
1251+
service_names = ['name2._type2._tcp.local.', 'name._type._tcp.local.', 'name._type._udp.local']
1252+
service_types = ['_type2._tcp.local.', '_type._tcp.local.', '_type._udp.local.']
12501253

12511254
service_added_count = 0
12521255
service_removed_count = 0
@@ -1257,25 +1260,19 @@ class MyServiceListener(r.ServiceListener):
12571260
def add_service(self, zc, type_, name) -> None:
12581261
nonlocal service_added_count
12591262
service_added_count += 1
1260-
if service_added_count == 2:
1263+
if service_added_count == 3:
12611264
service_add_event.set()
12621265

12631266
def remove_service(self, zc, type_, name) -> None:
12641267
nonlocal service_removed_count
12651268
service_removed_count += 1
1266-
if service_removed_count == 2:
1269+
if service_removed_count == 3:
12671270
service_removed_event.set()
12681271

12691272
def mock_incoming_msg(
1270-
service_state_change: r.ServiceStateChange, service_type: str, service_name: str
1273+
service_state_change: r.ServiceStateChange, service_type: str, service_name: str, ttl: int
12711274
) -> r.DNSIncoming:
12721275
generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE)
1273-
1274-
if service_state_change == r.ServiceStateChange.Removed:
1275-
ttl = 0
1276-
else:
1277-
ttl = 120
1278-
12791276
generated.add_answer_at_time(
12801277
r.DNSPointer(service_type, r._TYPE_PTR, r._CLASS_IN, ttl, service_name), 0
12811278
)
@@ -1287,30 +1284,54 @@ def mock_incoming_msg(
12871284
try:
12881285
wait_time = 3
12891286

1290-
# both services added
1287+
# all three services added
1288+
zeroconf.handle_response(
1289+
mock_incoming_msg(r.ServiceStateChange.Added, service_types[0], service_names[0], 120)
1290+
)
12911291
zeroconf.handle_response(
1292-
mock_incoming_msg(r.ServiceStateChange.Added, service_types[0], service_names[0])
1292+
mock_incoming_msg(r.ServiceStateChange.Added, service_types[1], service_names[1], 120)
12931293
)
12941294
zeroconf.handle_response(
1295-
mock_incoming_msg(r.ServiceStateChange.Added, service_types[1], service_names[1])
1295+
mock_incoming_msg(r.ServiceStateChange.Added, service_types[2], service_names[2], 120)
12961296
)
1297+
1298+
called_with_refresh_time_check = False
1299+
1300+
def _mock_get_expiration_time(self, percent):
1301+
nonlocal called_with_refresh_time_check
1302+
if percent == _EXPIRE_REFRESH_TIME_PERCENT:
1303+
called_with_refresh_time_check = True
1304+
return 0
1305+
return self.created + (percent * self.ttl * 10)
1306+
1307+
# Set an expire time that will force a refresh
1308+
with unittest.mock.patch("zeroconf.DNSRecord.get_expiration_time", new=_mock_get_expiration_time):
1309+
zeroconf.handle_response(
1310+
mock_incoming_msg(r.ServiceStateChange.Added, service_types[2], service_names[2], 120)
1311+
)
12971312
service_add_event.wait(wait_time)
1298-
assert service_added_count == 2
1313+
assert called_with_refresh_time_check is True
1314+
assert service_added_count == 3
12991315
assert service_removed_count == 0
13001316

1301-
# both services removed
1317+
# all three services removed
1318+
zeroconf.handle_response(
1319+
mock_incoming_msg(r.ServiceStateChange.Removed, service_types[0], service_names[0], 0)
1320+
)
13021321
zeroconf.handle_response(
1303-
mock_incoming_msg(r.ServiceStateChange.Removed, service_types[0], service_names[0])
1322+
mock_incoming_msg(r.ServiceStateChange.Removed, service_types[1], service_names[1], 0)
13041323
)
13051324
zeroconf.handle_response(
1306-
mock_incoming_msg(r.ServiceStateChange.Removed, service_types[1], service_names[1])
1325+
mock_incoming_msg(r.ServiceStateChange.Removed, service_types[2], service_names[2], 0)
13071326
)
13081327
service_removed_event.wait(wait_time)
1309-
assert service_added_count == 2
1310-
assert service_removed_count == 2
1328+
assert service_added_count == 3
1329+
assert service_removed_count == 3
13111330

13121331
finally:
1332+
assert len(zeroconf.listeners) == 1
13131333
service_browser.cancel()
1334+
assert len(zeroconf.listeners) == 0
13141335
zeroconf.remove_all_service_listeners()
13151336
zeroconf.close()
13161337

0 commit comments

Comments
 (0)