Skip to content

Commit 3338594

Browse files
authored
Fix ServiceInfo with multiple A records (#725)
1 parent e2d4d98 commit 3338594

2 files changed

Lines changed: 28 additions & 7 deletions

File tree

tests/test_services.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
ServiceInfo,
2525
ServiceStateChange,
2626
)
27+
from zeroconf.aio import AsyncZeroconf
2728

2829
from . import has_working_ipv6, _clear_cache, _inject_response
2930

@@ -947,6 +948,28 @@ def test_multiple_addresses():
947948
assert info.parsed_addresses(r.IPVersion.V6Only) == [address_v6_parsed]
948949

949950

951+
# This test uses asyncio because it needs to access the cache directly
952+
# which is not threadsafe
953+
@pytest.mark.asyncio
954+
async def test_multiple_a_addresses():
955+
type_ = "_http._tcp.local."
956+
registration_name = "multiarec.%s" % type_
957+
desc = {'path': '/~paulsm/'}
958+
aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
959+
cache = aiozc.zeroconf.cache
960+
host = "multahost.local."
961+
record1 = r.DNSAddress(host, const._TYPE_A, const._CLASS_IN, 1000, b'a')
962+
record2 = r.DNSAddress(host, const._TYPE_A, const._CLASS_IN, 1000, b'b')
963+
cache.add(record1)
964+
cache.add(record2)
965+
966+
# New kwarg way
967+
info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, host)
968+
info.load_from_cache(aiozc.zeroconf)
969+
assert set(info.addresses) == set([b'a', b'b'])
970+
await aiozc.async_close()
971+
972+
950973
def test_backoff():
951974
got_query = Event()
952975

zeroconf/_services/__init__.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -777,12 +777,10 @@ def dns_text(self, override_ttl: Optional[int] = None, created: Optional[float]
777777

778778
def _get_address_records_from_cache(self, zc: 'Zeroconf') -> List[DNSRecord]:
779779
"""Get the address records from the cache."""
780-
address_records = []
781-
cached_a_record = zc.cache.get_by_details(self.server, _TYPE_A, _CLASS_IN)
782-
if cached_a_record:
783-
address_records.append(cached_a_record)
784-
address_records.extend(zc.cache.get_all_by_details(self.server, _TYPE_AAAA, _CLASS_IN))
785-
return address_records
780+
return [
781+
*zc.cache.get_all_by_details(self.server, _TYPE_A, _CLASS_IN),
782+
*zc.cache.get_all_by_details(self.server, _TYPE_AAAA, _CLASS_IN),
783+
]
786784

787785
def load_from_cache(self, zc: 'Zeroconf') -> bool:
788786
"""Populate the service info from the cache."""
@@ -844,7 +842,7 @@ def generate_request_query(self, zc: 'Zeroconf', now: float) -> DNSOutgoing:
844842
out = DNSOutgoing(_FLAGS_QR_QUERY)
845843
out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_SRV, _CLASS_IN)
846844
out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_TXT, _CLASS_IN)
847-
out.add_question_or_one_cache(zc.cache, now, self.server, _TYPE_A, _CLASS_IN)
845+
out.add_question_or_all_cache(zc.cache, now, self.server, _TYPE_A, _CLASS_IN)
848846
out.add_question_or_all_cache(zc.cache, now, self.server, _TYPE_AAAA, _CLASS_IN)
849847
return out
850848

0 commit comments

Comments
 (0)