Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions tests/services/test_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,27 @@ def test_multiple_addresses():
# This test uses asyncio because it needs to access the cache directly
# which is not threadsafe
@pytest.mark.asyncio
async def test_multiple_a_addresses():
async def test_multiple_a_addresses_newest_address_first():
"""Test that info.addresses returns the newest seen address first."""
type_ = "_http._tcp.local."
registration_name = "multiarec.%s" % type_
desc = {'path': '/~paulsm/'}
aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
cache = aiozc.zeroconf.cache
host = "multahost.local."
record1 = r.DNSAddress(host, const._TYPE_A, const._CLASS_IN, 1000, b'\x7f\x00\x00\x01')
record2 = r.DNSAddress(host, const._TYPE_A, const._CLASS_IN, 1000, b'\x7f\x00\x00\x02')
cache.async_add_records([record1, record2])

# New kwarg way
info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, host)
info.load_from_cache(aiozc.zeroconf)
assert info.addresses == [b'\x7f\x00\x00\x02', b'\x7f\x00\x00\x01']
await aiozc.async_close()


@pytest.mark.asyncio
async def test_invalid_a_addresses(caplog):
type_ = "_http._tcp.local."
registration_name = "multiarec.%s" % type_
desc = {'path': '/~paulsm/'}
Expand All @@ -574,7 +594,9 @@ async def test_multiple_a_addresses():
# New kwarg way
info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, host)
info.load_from_cache(aiozc.zeroconf)
assert set(info.addresses) == set([b'a', b'b'])
assert not info.addresses
assert "Encountered invalid address while processing record" in caplog.text

await aiozc.async_close()


Expand Down
4 changes: 2 additions & 2 deletions tests/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,11 @@ def test_invalid_addresses(self):
name = "xxxyyy"
registration_name = f"{name}.{type_}"

bad = ('127.0.0.1', '::1', 42)
bad = (b'127.0.0.1', b'::1')
for addr in bad:
self.assertRaisesRegex(
TypeError,
'Addresses must be bytes',
'Addresses must either ',
ServiceInfo,
type_,
registration_name,
Expand Down
66 changes: 44 additions & 22 deletions zeroconf/_services/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from .._dns import DNSAddress, DNSPointer, DNSQuestionType, DNSRecord, DNSService, DNSText
from .._exceptions import BadTypeInNameException
from .._logger import log
from .._protocol.outgoing import DNSOutgoing
from .._updates import RecordUpdate, RecordUpdateListener
from .._utils.asyncio import get_running_loop, run_coro_with_timeout
Expand Down Expand Up @@ -124,19 +125,12 @@ def __init__(
self.type = type_
self._name = name
self.key = name.lower()
self._ipv4_addresses: List[ipaddress.IPv4Address] = []
self._ipv6_addresses: List[ipaddress.IPv6Address] = []
if addresses is not None:
self._addresses = addresses
self.addresses = addresses
elif parsed_addresses is not None:
self._addresses = [_encode_address(a) for a in parsed_addresses]
else:
self._addresses = []
# This results in an ugly error when registering, better check now
invalid = [a for a in self._addresses if not isinstance(a, bytes) or len(a) not in (4, 16)]
if invalid:
raise TypeError(
'Addresses must be bytes, got %s. Hint: convert string addresses '
'with socket.inet_pton' % invalid
)
self.addresses = [_encode_address(a) for a in parsed_addresses]
self.port = port
self.weight = weight
self.priority = priority
Expand Down Expand Up @@ -178,7 +172,21 @@ def addresses(self, value: List[bytes]) -> None:

This replaces all currently stored addresses, both IPv4 and IPv6.
"""
self._addresses = value
self._ipv4_addresses.clear()
self._ipv6_addresses.clear()

for address in value:
try:
addr = ipaddress.ip_address(address)
except ValueError:
raise TypeError(
"Addresses must either be IPv4 or IPv6 strings, bytes, or integers;"
f" got {address}. Hint: convert string addresses with socket.inet_pton" # type: ignore
)
if addr.version == 4:
self._ipv4_addresses.append(addr)
else:
self._ipv6_addresses.append(addr)

@property
def properties(self) -> Dict:
Expand All @@ -194,10 +202,13 @@ def properties(self) -> Dict:
def addresses_by_version(self, version: IPVersion) -> List[bytes]:
"""List addresses matching IP version."""
if version == IPVersion.V4Only:
return [addr for addr in self._addresses if not _is_v6_address(addr)]
return [addr.packed for addr in self._ipv4_addresses]
if version == IPVersion.V6Only:
return list(filter(_is_v6_address, self._addresses))
return self._addresses
return [addr.packed for addr in self._ipv6_addresses]
return [
*(addr.packed for addr in self._ipv4_addresses),
*(addr.packed for addr in self._ipv6_addresses),
]

def parsed_addresses(self, version: IPVersion = IPVersion.All) -> List[str]:
"""List addresses in their parsed string form."""
Expand All @@ -220,7 +231,7 @@ def is_link_local(addr_str: str) -> Any:

ll_addrs = list(filter(is_link_local, self.parsed_addresses(version)))
other_addrs = list(filter(lambda addr: not is_link_local(addr), self.parsed_addresses(version)))
return ["{}%{}".format(addr, self.interface_index) for addr in ll_addrs] + other_addrs
return [f"{addr}%{self.interface_index}" for addr in ll_addrs] + other_addrs

def _set_properties(self, properties: Dict) -> None:
"""Sets properties and text of this info from a dictionary"""
Expand Down Expand Up @@ -315,9 +326,20 @@ def _process_record_threadsafe(self, record: DNSRecord, now: float) -> None:
return

if isinstance(record, DNSAddress):
if record.key == self.server_key and record.address not in self._addresses:
self._addresses.append(record.address)
if record.type is _TYPE_AAAA and ipaddress.IPv6Address(record.address).is_link_local:
if record.key != self.server_key:
return
try:
ip_addr = ipaddress.ip_address(record.address)
except ValueError as ex:
log.warning("Encountered invalid address while processing %s: %s", record, ex)
return
if ip_addr.version == 4:
if ip_addr not in self._ipv4_addresses:
self._ipv4_addresses.insert(0, ip_addr)
return
if ip_addr not in self._ipv6_addresses:
self._ipv6_addresses.insert(0, ip_addr)
if ip_addr.is_link_local:
self.interface_index = record.scope_id
return

Expand Down Expand Up @@ -422,7 +444,7 @@ def load_from_cache(self, zc: 'Zeroconf') -> bool:
@property
def _is_complete(self) -> bool:
"""The ServiceInfo has all expected properties."""
return not (self.text is None or not self._addresses)
return bool(self.text is not None and (self._ipv4_addresses or self._ipv6_addresses))

def request(
self, zc: 'Zeroconf', timeout: float, question_type: Optional[DNSQuestionType] = None
Expand Down Expand Up @@ -494,10 +516,10 @@ def __eq__(self, other: object) -> bool:

def __repr__(self) -> str:
"""String representation"""
return '%s(%s)' % (
return '{}({})'.format(
type(self).__name__,
', '.join(
'%s=%r' % (name, getattr(self, name))
'{}={!r}'.format(name, getattr(self, name))
for name in (
'type',
'name',
Expand Down