Skip to content

Commit 3503e76

Browse files
authored
Prefix cache functions that are non threadsafe with async_ (#724)
1 parent 88aa610 commit 3503e76

8 files changed

Lines changed: 55 additions & 51 deletions

File tree

tests/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,4 @@ def has_working_ipv6():
6464

6565

6666
def _clear_cache(zc):
67-
for name in zc.cache.names():
68-
for record in zc.cache.entries_with_name(name):
69-
zc.cache.remove(record)
67+
zc.cache.cache.clear()

tests/test_cache.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_order(self):
3131
record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a')
3232
record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b')
3333
cache = r.DNSCache()
34-
cache.add_records([record1, record2])
34+
cache.async_add_records([record1, record2])
3535
entry = r.DNSEntry('a', const._TYPE_SOA, const._CLASS_IN)
3636
cached_record = cache.get(entry)
3737
assert cached_record == record2
@@ -45,7 +45,7 @@ def test_adding_same_record_to_cache_different_ttls(self):
4545
record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a')
4646
record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 10, b'a')
4747
cache = r.DNSCache()
48-
cache.add_records([record1, record2])
48+
cache.async_add_records([record1, record2])
4949
entry = r.DNSEntry(record2)
5050
cached_record = cache.get(entry)
5151
assert cached_record == record2
@@ -61,26 +61,26 @@ def test_adding_same_record_to_cache_different_ttls(self):
6161
record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a')
6262
record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 10, b'a')
6363
cache = r.DNSCache()
64-
cache.add_records([record1, record2])
64+
cache.async_add_records([record1, record2])
6565
cached_records = cache.get_all_by_details('a', const._TYPE_A, const._CLASS_IN)
6666
assert cached_records == [record2]
6767

6868
def test_cache_empty_does_not_leak_memory_by_leaving_empty_list(self):
6969
record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a')
7070
record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b')
7171
cache = r.DNSCache()
72-
cache.add_records([record1, record2])
72+
cache.async_add_records([record1, record2])
7373
assert 'a' in cache.cache
74-
cache.remove_records([record1, record2])
74+
cache.async_remove_records([record1, record2])
7575
assert 'a' not in cache.cache
7676

7777
def test_cache_empty_multiple_calls(self):
7878
record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a')
7979
record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b')
8080
cache = r.DNSCache()
81-
cache.add_records([record1, record2])
81+
cache.async_add_records([record1, record2])
8282
assert 'a' in cache.cache
83-
cache.remove_records([record1, record2])
83+
cache.async_remove_records([record1, record2])
8484
assert 'a' not in cache.cache
8585

8686

@@ -91,22 +91,22 @@ def test_get(self):
9191
record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a')
9292
record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'b')
9393
cache = r.DNSCache()
94-
cache.add_records([record1, record2])
94+
cache.async_add_records([record1, record2])
9595
assert cache.get(record1) == record1
9696
assert cache.get(record2) == record2
9797

9898
def test_get_by_details(self):
9999
record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a')
100100
record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'b')
101101
cache = r.DNSCache()
102-
cache.add_records([record1, record2])
102+
cache.async_add_records([record1, record2])
103103
assert cache.get_by_details('a', const._TYPE_A, const._CLASS_IN) == record2
104104

105105
def test_get_all_by_details(self):
106106
record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a')
107107
record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'b')
108108
cache = r.DNSCache()
109-
cache.add_records([record1, record2])
109+
cache.async_add_records([record1, record2])
110110
assert set(cache.get_all_by_details('a', const._TYPE_A, const._CLASS_IN)) == set([record1, record2])
111111

112112
def test_entries_with_server(self):
@@ -117,7 +117,7 @@ def test_entries_with_server(self):
117117
'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'ab'
118118
)
119119
cache = r.DNSCache()
120-
cache.add_records([record1, record2])
120+
cache.async_add_records([record1, record2])
121121
assert set(cache.entries_with_server('ab')) == set([record1, record2])
122122

123123
def test_entries_with_name(self):
@@ -128,7 +128,7 @@ def test_entries_with_name(self):
128128
'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'ab'
129129
)
130130
cache = r.DNSCache()
131-
cache.add_records([record1, record2])
131+
cache.async_add_records([record1, record2])
132132
assert set(cache.entries_with_name('irrelevant')) == set([record1, record2])
133133

134134
def test_current_entry_with_name_and_alias(self):
@@ -139,7 +139,7 @@ def test_current_entry_with_name_and_alias(self):
139139
'irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, 'y.irrelevant'
140140
)
141141
cache = r.DNSCache()
142-
cache.add_records([record1, record2])
142+
cache.async_add_records([record1, record2])
143143
assert cache.current_entry_with_name_and_alias('irrelevant', 'x.irrelevant') == record1
144144

145145
def test_entries_with_name(self):
@@ -150,5 +150,5 @@ def test_entries_with_name(self):
150150
'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'ab'
151151
)
152152
cache = r.DNSCache()
153-
cache.add_records([record1, record2])
153+
cache.async_add_records([record1, record2])
154154
assert cache.names() == ['irrelevant']

tests/test_core.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import zeroconf as r
2020
from zeroconf import _core, const, ServiceBrowser, Zeroconf, current_time_millis
21+
from zeroconf.aio import AsyncZeroconf
2122

2223
from . import has_working_ipv6, _clear_cache, _inject_response
2324

@@ -36,22 +37,23 @@ def teardown_module():
3637
log.setLevel(original_logging_level)
3738

3839

39-
class TestReaper(unittest.TestCase):
40-
@unittest.mock.patch.object(_core, "_CACHE_CLEANUP_INTERVAL", 10)
41-
def test_reaper(self):
42-
zeroconf = _core.Zeroconf(interfaces=['127.0.0.1'])
40+
# This test uses asyncio because it needs to access the cache directly
41+
# which is not threadsafe
42+
@pytest.mark.asyncio
43+
async def test_reaper():
44+
with unittest.mock.patch.object(_core, "_CACHE_CLEANUP_INTERVAL", 10):
45+
assert _core._CACHE_CLEANUP_INTERVAL == 10
46+
aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
47+
zeroconf = aiozc.zeroconf
4348
cache = zeroconf.cache
4449
original_entries = list(itertools.chain(*[cache.entries_with_name(name) for name in cache.names()]))
4550
record_with_10s_ttl = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 10, b'a')
4651
record_with_1s_ttl = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b')
47-
zeroconf.cache.add(record_with_10s_ttl)
48-
zeroconf.cache.add(record_with_1s_ttl)
52+
zeroconf.cache.async_add_records([record_with_10s_ttl, record_with_1s_ttl])
4953
entries_with_cache = list(itertools.chain(*[cache.entries_with_name(name) for name in cache.names()]))
50-
time.sleep(1)
51-
zeroconf.notify_all()
52-
time.sleep(0.1)
54+
await asyncio.sleep(1.2)
5355
entries = list(itertools.chain(*[cache.entries_with_name(name) for name in cache.names()]))
54-
zeroconf.close()
56+
await aiozc.async_close()
5557
assert entries != original_entries
5658
assert entries_with_cache != original_entries
5759
assert record_with_10s_ttl in entries

tests/test_handlers.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import zeroconf as r
1515
from zeroconf import ServiceInfo, Zeroconf, current_time_millis
1616
from zeroconf import const
17+
from zeroconf.aio import AsyncZeroconf
1718

1819
from . import _clear_cache, _inject_response
1920

@@ -703,10 +704,14 @@ def test_known_answer_supression_service_type_enumeration_query():
703704
zc.close()
704705

705706

706-
def test_qu_response_only_sends_additionals_if_sends_answer():
707+
# This test uses asyncio because it needs to access the cache directly
708+
# which is not threadsafe
709+
@pytest.mark.asyncio
710+
async def test_qu_response_only_sends_additionals_if_sends_answer():
707711
"""Test that a QU response does not send additionals unless it sends the answer as well."""
708712
# instantiate a zeroconf instance
709-
zc = Zeroconf(interfaces=['127.0.0.1'])
713+
aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
714+
zc = aiozc.zeroconf
710715

711716
type_ = "_addtest1._tcp.local."
712717
name = "knownname"
@@ -731,13 +736,13 @@ def test_qu_response_only_sends_additionals_if_sends_answer():
731736
ptr_record = info.dns_pointer()
732737

733738
# Add the PTR record to the cache
734-
zc.cache.add(ptr_record)
739+
zc.cache.async_add_records([ptr_record])
735740

736741
# Add the A record to the cache with 50% ttl remaining
737742
a_record = info.dns_addresses()[0]
738743
a_record.set_created_ttl(current_time_millis() - (a_record.ttl * 1000 / 2), a_record.ttl)
739744
assert not a_record.is_recent(current_time_millis())
740-
zc.cache.add(a_record)
745+
zc.cache.async_add_records([a_record])
741746

742747
# With QU should respond to only unicast when the answer has been recently multicast
743748
# even if the additional has not been recently multicast
@@ -755,10 +760,10 @@ def test_qu_response_only_sends_additionals_if_sends_answer():
755760
assert unicast_out.answers[0][0] == ptr_record
756761

757762
# Remove the 50% A record and add a 100% A record
758-
zc.cache.remove(a_record)
763+
zc.cache.async_remove_records([a_record])
759764
a_record = info.dns_addresses()[0]
760765
assert a_record.is_recent(current_time_millis())
761-
zc.cache.add(a_record)
766+
zc.cache.async_add_records([a_record])
762767
# With QU should respond to only unicast when the answer has been recently multicast
763768
# even if the additional has not been recently multicast
764769
query = r.DNSOutgoing(const._FLAGS_QR_QUERY)
@@ -775,10 +780,10 @@ def test_qu_response_only_sends_additionals_if_sends_answer():
775780
assert unicast_out.answers[0][0] == ptr_record
776781

777782
# Remove the 100% PTR record and add a 50% PTR record
778-
zc.cache.remove(ptr_record)
783+
zc.cache.async_remove_records([ptr_record])
779784
ptr_record.set_created_ttl(current_time_millis() - (ptr_record.ttl * 1000 / 2), ptr_record.ttl)
780785
assert not ptr_record.is_recent(current_time_millis())
781-
zc.cache.add(ptr_record)
786+
zc.cache.async_add_records([ptr_record])
782787
# With QU should respond to only multicast since the has less
783788
# than 75% of its ttl remaining
784789
query = r.DNSOutgoing(const._FLAGS_QR_QUERY)
@@ -811,7 +816,7 @@ def test_qu_response_only_sends_additionals_if_sends_answer():
811816
question.unicast = True # Set the QU bit
812817
assert question.unicast is True
813818
query.add_question(question)
814-
zc.cache.add(info2.dns_pointer()) # Add 100% TTL for info2 to the cache
819+
zc.cache.async_add_records([info2.dns_pointer()]) # Add 100% TTL for info2 to the cache
815820

816821
unicast_out, multicast_out = zc.query_handler.async_response(
817822
[r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", const._MDNS_PORT
@@ -828,4 +833,4 @@ def test_qu_response_only_sends_additionals_if_sends_answer():
828833

829834
# unregister
830835
zc.registry.remove(info)
831-
zc.close()
836+
await aiozc.async_close()

tests/test_services.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -960,8 +960,7 @@ async def test_multiple_a_addresses():
960960
host = "multahost.local."
961961
record1 = r.DNSAddress(host, const._TYPE_A, const._CLASS_IN, 1000, b'a')
962962
record2 = r.DNSAddress(host, const._TYPE_A, const._CLASS_IN, 1000, b'b')
963-
cache.add(record1)
964-
cache.add(record2)
963+
cache.async_add_records([record1, record2])
965964

966965
# New kwarg way
967966
info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, host)

zeroconf/_cache.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@ def __init__(self) -> None:
4747
self.cache: _DNSRecordCacheType = {}
4848
self.service_cache: _DNSRecordCacheType = {}
4949

50-
# Functions prefixed with are NOT threadsafe and must
50+
# Functions prefixed with async_ are NOT threadsafe and must
5151
# be run in the event loop.
5252

53-
def add(self, entry: DNSRecord) -> None:
53+
def _async_add(self, entry: DNSRecord) -> None:
5454
"""Adds an entry.
5555
5656
This function must be run in from event loop.
@@ -65,15 +65,15 @@ def add(self, entry: DNSRecord) -> None:
6565
if isinstance(entry, DNSService):
6666
self.service_cache.setdefault(entry.server, {})[entry] = entry
6767

68-
def add_records(self, entries: Iterable[DNSRecord]) -> None:
68+
def async_add_records(self, entries: Iterable[DNSRecord]) -> None:
6969
"""Add multiple records.
7070
7171
This function must be run in from event loop.
7272
"""
7373
for entry in entries:
74-
self.add(entry)
74+
self._async_add(entry)
7575

76-
def remove(self, entry: DNSRecord) -> None:
76+
def _async_remove(self, entry: DNSRecord) -> None:
7777
"""Removes an entry.
7878
7979
This function must be run in from event loop.
@@ -82,23 +82,23 @@ def remove(self, entry: DNSRecord) -> None:
8282
_remove_key(self.service_cache, entry.server, entry)
8383
_remove_key(self.cache, entry.key, entry)
8484

85-
def remove_records(self, entries: Iterable[DNSRecord]) -> None:
85+
def async_remove_records(self, entries: Iterable[DNSRecord]) -> None:
8686
"""Remove multiple records.
8787
8888
This function must be run in from event loop.
8989
"""
9090
for entry in entries:
91-
self.remove(entry)
91+
self._async_remove(entry)
9292

93-
def expire(self, now: float) -> Iterable[DNSRecord]:
93+
def async_expire(self, now: float) -> Iterable[DNSRecord]:
9494
"""Purge expired entries from the cache.
9595
9696
This function must be run in from event loop.
9797
"""
9898
for name in self.names():
9999
for record in self.entries_with_name(name):
100100
if record.is_expired(now):
101-
self.remove(record)
101+
self._async_remove(record)
102102
yield record
103103

104104
# The below functions are threadsafe and do not need to be run in the

zeroconf/_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ async def _async_cache_cleanup(self) -> None:
146146
"""Periodic cache cleanup."""
147147
while not self.zc.done:
148148
now = current_time_millis()
149-
self.zc.record_manager.async_updates(now, list(self.zc.cache.expire(now)))
149+
self.zc.record_manager.async_updates(now, list(self.zc.cache.async_expire(now)))
150150
self.zc.record_manager.async_updates_complete()
151151
await asyncio.sleep(millis_to_seconds(_CACHE_CLEANUP_INTERVAL))
152152

zeroconf/_handlers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -362,11 +362,11 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None:
362362
# zc.get_service_info will see the cached value
363363
# but ONLY after all the record updates have been
364364
# processsed.
365-
self.cache.add_records(itertools.chain(address_adds, other_adds))
365+
self.cache.async_add_records(itertools.chain(address_adds, other_adds))
366366
# Removes are processed last since
367367
# ServiceInfo could generate an un-needed query
368368
# because the data was not yet populated.
369-
self.cache.remove_records(removes)
369+
self.cache.async_remove_records(removes)
370370
self.async_updates_complete()
371371

372372
def add_listener(

0 commit comments

Comments
 (0)