diff --git a/tests/services/test_info.py b/tests/services/test_info.py index a72d82f9..9c9bfa02 100644 --- a/tests/services/test_info.py +++ b/tests/services/test_info.py @@ -560,7 +560,27 @@ def test_multiple_addresses(): # This test uses asyncio because it needs to access the cache directly # which is not threadsafe @pytest.mark.asyncio -async def test_multiple_a_addresses(): +async def test_multiple_a_addresses_newest_address_first(): + """Test that info.addresses returns the newest seen address first.""" + type_ = "_http._tcp.local." + registration_name = "multiarec.%s" % type_ + desc = {'path': '/~paulsm/'} + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + cache = aiozc.zeroconf.cache + host = "multahost.local." + record1 = r.DNSAddress(host, const._TYPE_A, const._CLASS_IN, 1000, b'\x7f\x00\x00\x01') + record2 = r.DNSAddress(host, const._TYPE_A, const._CLASS_IN, 1000, b'\x7f\x00\x00\x02') + cache.async_add_records([record1, record2]) + + # New kwarg way + info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, host) + info.load_from_cache(aiozc.zeroconf) + assert info.addresses == [b'\x7f\x00\x00\x02', b'\x7f\x00\x00\x01'] + await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_invalid_a_addresses(caplog): type_ = "_http._tcp.local." registration_name = "multiarec.%s" % type_ desc = {'path': '/~paulsm/'} @@ -574,7 +594,9 @@ async def test_multiple_a_addresses(): # New kwarg way info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, host) info.load_from_cache(aiozc.zeroconf) - assert set(info.addresses) == set([b'a', b'b']) + assert not info.addresses + assert "Encountered invalid address while processing record" in caplog.text + await aiozc.async_close() diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 47e68b75..6b682093 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -141,11 +141,11 @@ def test_invalid_addresses(self): name = "xxxyyy" registration_name = f"{name}.{type_}" - bad = ('127.0.0.1', '::1', 42) + bad = (b'127.0.0.1', b'::1') for addr in bad: self.assertRaisesRegex( TypeError, - 'Addresses must be bytes', + 'Addresses must either ', ServiceInfo, type_, registration_name, diff --git a/zeroconf/_services/info.py b/zeroconf/_services/info.py index beaf0678..2558f726 100644 --- a/zeroconf/_services/info.py +++ b/zeroconf/_services/info.py @@ -27,6 +27,7 @@ from .._dns import DNSAddress, DNSPointer, DNSQuestionType, DNSRecord, DNSService, DNSText from .._exceptions import BadTypeInNameException +from .._logger import log from .._protocol.outgoing import DNSOutgoing from .._updates import RecordUpdate, RecordUpdateListener from .._utils.asyncio import get_running_loop, run_coro_with_timeout @@ -124,19 +125,12 @@ def __init__( self.type = type_ self._name = name self.key = name.lower() + self._ipv4_addresses: List[ipaddress.IPv4Address] = [] + self._ipv6_addresses: List[ipaddress.IPv6Address] = [] if addresses is not None: - self._addresses = addresses + self.addresses = addresses elif parsed_addresses is not None: - self._addresses = [_encode_address(a) for a in parsed_addresses] - else: - self._addresses = [] - # This results in an ugly error when registering, better check now - invalid = [a for a in self._addresses if not isinstance(a, bytes) or len(a) not in (4, 16)] - if invalid: - raise TypeError( - 'Addresses must be bytes, got %s. Hint: convert string addresses ' - 'with socket.inet_pton' % invalid - ) + self.addresses = [_encode_address(a) for a in parsed_addresses] self.port = port self.weight = weight self.priority = priority @@ -178,7 +172,21 @@ def addresses(self, value: List[bytes]) -> None: This replaces all currently stored addresses, both IPv4 and IPv6. """ - self._addresses = value + self._ipv4_addresses.clear() + self._ipv6_addresses.clear() + + for address in value: + try: + addr = ipaddress.ip_address(address) + except ValueError: + raise TypeError( + "Addresses must either be IPv4 or IPv6 strings, bytes, or integers;" + f" got {address}. Hint: convert string addresses with socket.inet_pton" # type: ignore + ) + if addr.version == 4: + self._ipv4_addresses.append(addr) + else: + self._ipv6_addresses.append(addr) @property def properties(self) -> Dict: @@ -194,10 +202,13 @@ def properties(self) -> Dict: def addresses_by_version(self, version: IPVersion) -> List[bytes]: """List addresses matching IP version.""" if version == IPVersion.V4Only: - return [addr for addr in self._addresses if not _is_v6_address(addr)] + return [addr.packed for addr in self._ipv4_addresses] if version == IPVersion.V6Only: - return list(filter(_is_v6_address, self._addresses)) - return self._addresses + return [addr.packed for addr in self._ipv6_addresses] + return [ + *(addr.packed for addr in self._ipv4_addresses), + *(addr.packed for addr in self._ipv6_addresses), + ] def parsed_addresses(self, version: IPVersion = IPVersion.All) -> List[str]: """List addresses in their parsed string form.""" @@ -220,7 +231,7 @@ def is_link_local(addr_str: str) -> Any: ll_addrs = list(filter(is_link_local, self.parsed_addresses(version))) other_addrs = list(filter(lambda addr: not is_link_local(addr), self.parsed_addresses(version))) - return ["{}%{}".format(addr, self.interface_index) for addr in ll_addrs] + other_addrs + return [f"{addr}%{self.interface_index}" for addr in ll_addrs] + other_addrs def _set_properties(self, properties: Dict) -> None: """Sets properties and text of this info from a dictionary""" @@ -315,9 +326,20 @@ def _process_record_threadsafe(self, record: DNSRecord, now: float) -> None: return if isinstance(record, DNSAddress): - if record.key == self.server_key and record.address not in self._addresses: - self._addresses.append(record.address) - if record.type is _TYPE_AAAA and ipaddress.IPv6Address(record.address).is_link_local: + if record.key != self.server_key: + return + try: + ip_addr = ipaddress.ip_address(record.address) + except ValueError as ex: + log.warning("Encountered invalid address while processing %s: %s", record, ex) + return + if ip_addr.version == 4: + if ip_addr not in self._ipv4_addresses: + self._ipv4_addresses.insert(0, ip_addr) + return + if ip_addr not in self._ipv6_addresses: + self._ipv6_addresses.insert(0, ip_addr) + if ip_addr.is_link_local: self.interface_index = record.scope_id return @@ -422,7 +444,7 @@ def load_from_cache(self, zc: 'Zeroconf') -> bool: @property def _is_complete(self) -> bool: """The ServiceInfo has all expected properties.""" - return not (self.text is None or not self._addresses) + return bool(self.text is not None and (self._ipv4_addresses or self._ipv6_addresses)) def request( self, zc: 'Zeroconf', timeout: float, question_type: Optional[DNSQuestionType] = None @@ -494,10 +516,10 @@ def __eq__(self, other: object) -> bool: def __repr__(self) -> str: """String representation""" - return '%s(%s)' % ( + return '{}({})'.format( type(self).__name__, ', '.join( - '%s=%r' % (name, getattr(self, name)) + '{}={!r}'.format(name, getattr(self, name)) for name in ( 'type', 'name',