Skip to content

Commit 552a030

Browse files
mattsaxonjstasiak
authored andcommitted
Call UpdateService on SRV & A/AAAA updates as well as TXT (#239)
Fix #235 Contains: * Add lock around handlers list * Reverse DNSCache order to ensure newest records take precedence When there are multiple records in the cache, the behaviour was inconsistent. Whilst the DNSCache.get() method returned the newest, any function which iterated over the entire cache suffered from a last write winds issue. This change makes this behaviour consistent and allows the removal of an (incorrect) wait from one of the unit tests.
1 parent f8fe400 commit 552a030

2 files changed

Lines changed: 150 additions & 84 deletions

File tree

zeroconf/__init__.py

Lines changed: 85 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import threading
3636
import time
3737
import warnings
38+
from collections import OrderedDict
3839
from typing import Dict, List, Optional, Sequence, Union, cast
3940
from typing import Any, Callable, Set, Tuple # noqa # used in type hints
4041

@@ -1121,8 +1122,9 @@ def __init__(self) -> None:
11211122

11221123
def add(self, entry: DNSRecord) -> None:
11231124
"""Adds an entry"""
1124-
# Insert first in list so get returns newest entry
1125-
self.cache.setdefault(entry.key, []).insert(0, entry)
1125+
# Insert last in list, get will return newest entry
1126+
# iteration will result in last update winning
1127+
self.cache.setdefault(entry.key, []).append(entry)
11261128

11271129
def remove(self, entry: DNSRecord) -> None:
11281130
"""Removes an entry"""
@@ -1142,7 +1144,7 @@ def get(self, entry: DNSEntry) -> Optional[DNSRecord]:
11421144
matching entry."""
11431145
try:
11441146
list_ = self.cache[entry.key]
1145-
for cached_entry in list_:
1147+
for cached_entry in reversed(list_):
11461148
if entry.__eq__(cached_entry):
11471149
return cached_entry
11481150
return None
@@ -1164,7 +1166,7 @@ def entries_with_name(self, name: str) -> List[DNSRecord]:
11641166

11651167
def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[DNSRecord]:
11661168
now = current_time_millis()
1167-
for record in self.entries_with_name(name):
1169+
for record in reversed(self.entries_with_name(name)):
11681170
if (
11691171
record.type == _TYPE_PTR
11701172
and not record.is_expired(now)
@@ -1400,7 +1402,7 @@ def __init__(
14001402
self.services = {} # type: Dict[str, DNSRecord]
14011403
self.next_time = current_time_millis()
14021404
self.delay = delay
1403-
self._handlers_to_call = [] # type: List[Callable[[Zeroconf], None]]
1405+
self._handlers_to_call = OrderedDict() # type: OrderedDict[str, ServiceStateChange]
14041406

14051407
self._service_state_changed = Signal()
14061408

@@ -1445,14 +1447,30 @@ def service_state_changed(self) -> SignalRegistrationInterface:
14451447
def update_record(self, zc: 'Zeroconf', now: float, record: DNSRecord) -> None:
14461448
"""Callback invoked by Zeroconf when new information arrives.
14471449
1448-
Updates information required by browser in the Zeroconf cache."""
1450+
Updates information required by browser in the Zeroconf cache.
1451+
1452+
Ensures that there is are no unecessary duplicates in the list
1453+
1454+
"""
14491455

14501456
def enqueue_callback(state_change: ServiceStateChange, name: str) -> None:
1451-
self._handlers_to_call.append(
1452-
lambda zeroconf: self._service_state_changed.fire(
1453-
zeroconf=zeroconf, service_type=self.type, name=name, state_change=state_change
1457+
1458+
# Code to ensure we only do a single update message
1459+
# Precedence is; Added, Remove, Update
1460+
1461+
if (
1462+
state_change is ServiceStateChange.Added
1463+
or (
1464+
state_change is ServiceStateChange.Removed
1465+
and (
1466+
self._handlers_to_call.get(name) is ServiceStateChange.Updated
1467+
or self._handlers_to_call.get(name) is ServiceStateChange.Added
1468+
or self._handlers_to_call.get(name) is None
1469+
)
14541470
)
1455-
)
1471+
or (state_change is ServiceStateChange.Updated and name not in self._handlers_to_call)
1472+
):
1473+
self._handlers_to_call[name] = state_change
14561474

14571475
if record.type == _TYPE_PTR and record.name == self.type:
14581476
assert isinstance(record, DNSPointer)
@@ -1476,8 +1494,20 @@ def enqueue_callback(state_change: ServiceStateChange, name: str) -> None:
14761494
if expires < self.next_time:
14771495
self.next_time = expires
14781496

1479-
elif record.type == _TYPE_TXT and record.name.endswith(self.type):
1480-
assert isinstance(record, DNSText)
1497+
elif record.type == _TYPE_A or record.type == _TYPE_AAAA:
1498+
assert isinstance(record, DNSAddress)
1499+
1500+
# Iterate through the DNSCache and callback any services that use this address
1501+
for service in zc.cache.entries():
1502+
if (
1503+
isinstance(service, DNSService)
1504+
and service.name.endswith(self.type)
1505+
and service.server == record.name
1506+
and not record.is_expired(now)
1507+
):
1508+
enqueue_callback(ServiceStateChange.Updated, service.name)
1509+
1510+
elif record.name.endswith(self.type):
14811511
expired = record.is_expired(now)
14821512
if not expired:
14831513
enqueue_callback(ServiceStateChange.Updated, record.name)
@@ -1509,8 +1539,11 @@ def run(self) -> None:
15091539
self.delay = min(_BROWSER_BACKOFF_LIMIT * 1000, self.delay * 2)
15101540

15111541
if len(self._handlers_to_call) > 0 and not self.zc.done:
1512-
handler = self._handlers_to_call.pop(0)
1513-
handler(self.zc)
1542+
with self.zc._handlers_lock:
1543+
handler = self._handlers_to_call.popitem(False)
1544+
self._service_state_changed.fire(
1545+
zeroconf=self.zc, service_type=self.type, name=handler[0], state_change=handler[1]
1546+
)
15141547

15151548

15161549
class ServiceInfo(RecordUpdateListener):
@@ -2173,6 +2206,8 @@ def __init__(
21732206

21742207
self.debug = None # type: Optional[DNSOutgoing]
21752208

2209+
self._handlers_lock = threading.Lock() # ensure we process a full message in one go
2210+
21762211
@property
21772212
def done(self) -> bool:
21782213
return self._GLOBAL_DONE
@@ -2449,42 +2484,45 @@ def update_record(self, now: float, rec: DNSRecord) -> None:
24492484
def handle_response(self, msg: DNSIncoming) -> None:
24502485
"""Deal with incoming response packets. All answers
24512486
are held in the cache, and listeners are notified."""
2452-
now = current_time_millis()
2453-
for record in msg.answers:
2454-
2455-
updated = True
2456-
2457-
if record.unique: # https://tools.ietf.org/html/rfc6762#section-10.2
2458-
# Since the cache format is keyed on the lower case record name
2459-
# we can avoid iterating everything in the cache and
2460-
# only look though entries for the specific name.
2461-
# entries_with_name will take care of converting to lowercase
2462-
#
2463-
# We make a copy of the list that entries_with_name returns
2464-
# since we cannot iterate over something we might remove
2465-
for entry in self.cache.entries_with_name(record.name).copy():
24662487

2467-
if entry == record:
2468-
updated = False
2488+
with self._handlers_lock:
24692489

2470-
# Check the time first because it is far cheaper
2471-
# than the __eq__
2472-
if (record.created - entry.created > 1000) and DNSEntry.__eq__(entry, record):
2473-
self.cache.remove(entry)
2474-
2475-
expired = record.is_expired(now)
2476-
maybe_entry = self.cache.get(record)
2477-
if not expired:
2478-
if maybe_entry is not None:
2479-
maybe_entry.reset_ttl(record)
2490+
now = current_time_millis()
2491+
for record in msg.answers:
2492+
2493+
updated = True
2494+
2495+
if record.unique: # https://tools.ietf.org/html/rfc6762#section-10.2
2496+
# Since the cache format is keyed on the lower case record name
2497+
# we can avoid iterating everything in the cache and
2498+
# only look though entries for the specific name.
2499+
# entries_with_name will take care of converting to lowercase
2500+
#
2501+
# We make a copy of the list that entries_with_name returns
2502+
# since we cannot iterate over something we might remove
2503+
for entry in self.cache.entries_with_name(record.name).copy():
2504+
2505+
if entry == record:
2506+
updated = False
2507+
2508+
# Check the time first because it is far cheaper
2509+
# than the __eq__
2510+
if (record.created - entry.created > 1000) and DNSEntry.__eq__(entry, record):
2511+
self.cache.remove(entry)
2512+
2513+
expired = record.is_expired(now)
2514+
maybe_entry = self.cache.get(record)
2515+
if not expired:
2516+
if maybe_entry is not None:
2517+
maybe_entry.reset_ttl(record)
2518+
else:
2519+
self.cache.add(record)
2520+
if updated:
2521+
self.update_record(now, record)
24802522
else:
2481-
self.cache.add(record)
2482-
if updated:
2483-
self.update_record(now, record)
2484-
else:
2485-
if maybe_entry is not None:
2486-
self.update_record(now, record)
2487-
self.cache.remove(maybe_entry)
2523+
if maybe_entry is not None:
2524+
self.update_record(now, record)
2525+
self.cache.remove(maybe_entry)
24882526

24892527
def handle_query(self, msg: DNSIncoming, addr: Optional[str], port: int) -> None:
24902528
"""Deal with incoming query packets. Provides a response if

zeroconf/test.py

Lines changed: 65 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def test_numbers(self):
292292
def test_numbers_questions(self):
293293
generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE)
294294
question = r.DNSQuestion("testname.local.", r._TYPE_SRV, r._CLASS_IN)
295-
for i in range(10):
295+
for i in range(10): # pylint: disable=unused-variable
296296
generated.add_question(question)
297297
bytes = generated.packet()
298298
(num_questions, num_answers, num_authorities, num_additionals) = struct.unpack('!4H', bytes[4:12])
@@ -756,7 +756,7 @@ def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT):
756756
"""Sends an outgoing packet."""
757757
nonlocal nbr_answers, nbr_additionals, nbr_authorities
758758

759-
for answer, time_ in out.answers:
759+
for answer, time_ in out.answers: # pylint: disable=unused-variable
760760
nbr_answers += 1
761761
assert answer.ttl == get_ttl(answer.type)
762762
for answer in out.additionals:
@@ -1053,62 +1053,57 @@ def test_update_record(self):
10531053

10541054
service_name = 'name._type._tcp.local.'
10551055
service_type = '_type._tcp.local.'
1056-
service_server = 'ash-2.local.'
1057-
service_text = b'path=/~paulsm/'
1056+
service_server = 'ash-1.local.'
1057+
service_text = b'path=/~matt1/'
10581058
service_address = '10.0.1.2'
10591059

1060-
service_added = False
1061-
service_removed = False
1060+
service_added_count = 0
1061+
service_removed_count = 0
10621062
service_updated_count = 0
10631063
service_add_event = Event()
10641064
service_removed_event = Event()
10651065
service_updated_event = Event()
10661066

10671067
class MyServiceListener(r.ServiceListener):
10681068
def add_service(self, zc, type_, name) -> None:
1069-
nonlocal service_added
1070-
service_added = True
1069+
nonlocal service_added_count
1070+
service_added_count += 1
10711071
service_add_event.set()
10721072

10731073
def remove_service(self, zc, type_, name) -> None:
1074-
nonlocal service_added, service_removed
1075-
service_added = False
1076-
service_removed = True
1074+
nonlocal service_removed_count
1075+
service_removed_count += 1
10771076
service_removed_event.set()
10781077

10791078
def update_service(self, zc, type_, name) -> None:
10801079
nonlocal service_updated_count
10811080
service_updated_count += 1
1082-
10831081
service_info = zc.get_service_info(type_, name)
1082+
assert service_info.addresses[0] == socket.inet_aton(service_address)
10841083
assert service_info.text == service_text
1084+
assert service_info.server == service_server
10851085
service_updated_event.set()
10861086

10871087
def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncoming:
1088-
ttl = 120
1089-
generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE)
10901088

1091-
if service_state_change == r.ServiceStateChange.Updated:
1092-
generated.add_answer_at_time(
1093-
r.DNSText(service_name, r._TYPE_TXT, r._CLASS_IN | r._CLASS_UNIQUE, ttl, service_text), 0
1094-
)
1095-
return r.DNSIncoming(generated.packet())
1089+
generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE)
10961090

10971091
if service_state_change == r.ServiceStateChange.Removed:
10981092
ttl = 0
1093+
else:
1094+
ttl = 120
10991095

11001096
generated.add_answer_at_time(
1101-
r.DNSPointer(service_type, r._TYPE_PTR, r._CLASS_IN, ttl, service_name), 0
1097+
r.DNSText(service_name, r._TYPE_TXT, r._CLASS_IN | r._CLASS_UNIQUE, ttl, service_text), 0
11021098
)
1099+
11031100
generated.add_answer_at_time(
11041101
r.DNSService(
11051102
service_name, r._TYPE_SRV, r._CLASS_IN | r._CLASS_UNIQUE, ttl, 0, 0, 80, service_server
11061103
),
11071104
0,
11081105
)
1109-
generated.add_answer_at_time(
1110-
r.DNSText(service_name, r._TYPE_TXT, r._CLASS_IN | r._CLASS_UNIQUE, ttl, service_text), 0
1111-
)
1106+
11121107
generated.add_answer_at_time(
11131108
r.DNSAddress(
11141109
service_server,
@@ -1120,36 +1115,69 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi
11201115
0,
11211116
)
11221117

1118+
generated.add_answer_at_time(
1119+
r.DNSPointer(service_type, r._TYPE_PTR, r._CLASS_IN, ttl, service_name), 0
1120+
)
1121+
11231122
return r.DNSIncoming(generated.packet())
11241123

11251124
zeroconf = r.Zeroconf(interfaces=['127.0.0.1'])
11261125
service_browser = r.ServiceBrowser(zeroconf, service_type, listener=MyServiceListener())
11271126

11281127
try:
1128+
wait_time = 3
1129+
11291130
# service added
11301131
zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Added))
1131-
service_add_event.wait(1)
1132-
service_updated_event.wait(1)
1133-
assert service_added is True
1134-
assert service_updated_count == 1
1135-
assert service_removed is False
1132+
service_add_event.wait(wait_time)
1133+
assert service_added_count == 1
1134+
assert service_updated_count == 0
1135+
assert service_removed_count == 0
11361136

1137-
# service updated. currently only text record can be updated
1137+
# service SRV updated
11381138
service_updated_event.clear()
1139-
service_text = b'path=/~humingchun/'
1139+
service_server = 'ash-2.local.'
11401140
zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Updated))
1141+
service_updated_event.wait(wait_time)
1142+
assert service_added_count == 1
1143+
assert service_updated_count == 1
1144+
assert service_removed_count == 0
1145+
1146+
# service TXT updated
1147+
service_updated_event.clear()
1148+
service_text = b'path=/~matt2/'
11411149
zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Updated))
1142-
service_updated_event.wait(1)
1143-
assert service_added is True
1150+
service_updated_event.wait(wait_time)
1151+
assert service_added_count == 1
11441152
assert service_updated_count == 2
1145-
assert service_removed is False
1153+
assert service_removed_count == 0
1154+
1155+
# service A updated
1156+
service_updated_event.clear()
1157+
service_address = '10.0.1.3'
1158+
zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Updated))
1159+
service_updated_event.wait(wait_time)
1160+
assert service_added_count == 1
1161+
assert service_updated_count == 3
1162+
assert service_removed_count == 0
1163+
1164+
# service all updated
1165+
service_updated_event.clear()
1166+
service_server = 'ash-3.local.'
1167+
service_text = b'path=/~matt3/'
1168+
service_address = '10.0.1.3'
1169+
zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Updated))
1170+
service_updated_event.wait(wait_time)
1171+
assert service_added_count == 1
1172+
assert service_updated_count == 4
1173+
assert service_removed_count == 0
11461174

11471175
# service removed
11481176
zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Removed))
1149-
service_removed_event.wait(1)
1150-
assert service_added is False
1151-
assert service_updated_count == 2
1152-
assert service_removed is True
1177+
service_removed_event.wait(wait_time)
1178+
assert service_added_count == 1
1179+
assert service_updated_count == 4
1180+
assert service_removed_count == 1
11531181

11541182
finally:
11551183
service_browser.cancel()

0 commit comments

Comments
 (0)