From 410f3f10f83baf91cba8d2b952abc0206686735c Mon Sep 17 00:00:00 2001 From: Dmitry Tantsur Date: Mon, 21 Oct 2019 13:24:31 +0200 Subject: [PATCH] Rework exposing IPv6 addresses on ServiceInfo * Return backward compatibility for ServiceInfo.addresses by making it return V4 addresses only * Add ServiceInfo.parsed_addresses for convenient access to addresses * Raise TypeError if addresses are not provided as bytes (otherwise an ugly assertion error is raised when sending) * Add more IPv6 unit tests --- examples/self_test.py | 3 +++ test_zeroconf.py | 36 +++++++++++++++++++++++--- zeroconf.py | 59 +++++++++++++++++++++++++++++++++---------- 3 files changed, 82 insertions(+), 16 deletions(-) diff --git a/examples/self_test.py b/examples/self_test.py index 0b231ac34..35007db13 100755 --- a/examples/self_test.py +++ b/examples/self_test.py @@ -19,8 +19,10 @@ print("1. Testing registration of a service...") desc = {'version': '0.10', 'a': 'test value', 'b': 'another value'} addresses = [socket.inet_aton("127.0.0.1")] + expected = {'127.0.0.1'} if socket.has_ipv6: addresses.append(socket.inet_pton(socket.AF_INET6, '::1')) + expected.add('::1') info = ServiceInfo( "_http._tcp.local.", "My Service Name._http._tcp.local.", @@ -37,6 +39,7 @@ print("3. Testing query of own service...") queried_info = r.get_service_info("_http._tcp.local.", "My Service Name._http._tcp.local.") assert queried_info + assert set(queried_info.parsed_addresses()) == expected print(" Getting self: %s" % (queried_info,)) print(" Query done.") print("4. Testing unregister of service information...") diff --git a/test_zeroconf.py b/test_zeroconf.py index a8c5a81ff..c4d5497e7 100644 --- a/test_zeroconf.py +++ b/test_zeroconf.py @@ -536,6 +536,23 @@ def test_good_service_names(self): r.service_type_name('_one_two._tcp.local.', allow_underscores=True) + def test_invalid_addresses(self): + type_ = "_test-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + + bad = ('127.0.0.1', '::1', 42) + for addr in bad: + self.assertRaisesRegex( + TypeError, + 'Addresses must be bytes', + ServiceInfo, + type_, + registration_name, + port=80, + addresses=[addr], + ) + class TestDnsIncoming(unittest.TestCase): def test_incoming_exception_handling(self): @@ -844,6 +861,9 @@ def update_service(self, zeroconf, type, name): assert info.properties[b'prop_blank'] == properties['prop_blank'] assert info.properties[b'prop_true'] is True assert info.properties[b'prop_false'] is False + assert info.addresses == addresses[:1] # no V6 by default + all_addresses = info.addresses_by_version(r.IPVersion.All) + assert all_addresses == addresses, all_addresses info = zeroconf_browser.get_service_info(subtype, registration_name) assert info is not None @@ -1039,7 +1059,8 @@ def test_multiple_addresses(): type_ = "_http._tcp.local." registration_name = "xxxyyy.%s" % type_ desc = {'path': '/~paulsm/'} - address = socket.inet_aton("10.0.1.2") + address_parsed = "10.0.1.2" + address = socket.inet_aton(address_parsed) # Old way info = ServiceInfo(type_, registration_name, address, 80, 0, 0, desc, "ash-2.local.") @@ -1059,6 +1080,11 @@ def test_multiple_addresses(): assert info.address is None assert info.addresses == [] + info.addresses = [address2] + + assert info.address == address2 + assert info.addresses == [address2] + # Compatibility way info = ServiceInfo(type_, registration_name, [address, address], 80, 0, 0, desc, "ash-2.local.") @@ -1072,12 +1098,16 @@ def test_multiple_addresses(): assert info.addresses == [address, address] if socket.has_ipv6: - address_v6 = socket.inet_pton(socket.AF_INET6, "2001:db8::1") + address_v6_parsed = "2001:db8::1" + address_v6 = socket.inet_pton(socket.AF_INET6, address_v6_parsed) info = ServiceInfo(type_, registration_name, [address, address_v6], 80, 0, 0, desc, "ash-2.local.") - assert info.addresses == [address, address_v6] + assert info.addresses == [address] assert info.addresses_by_version(r.IPVersion.All) == [address, address_v6] assert info.addresses_by_version(r.IPVersion.V4Only) == [address] assert info.addresses_by_version(r.IPVersion.V6Only) == [address_v6] + assert info.parsed_addresses() == [address_parsed, address_v6_parsed] + assert info.parsed_addresses(r.IPVersion.V4Only) == [address_parsed] + assert info.parsed_addresses(r.IPVersion.V6Only) == [address_v6_parsed] def test_ptr_optimization(): diff --git a/zeroconf.py b/zeroconf.py index 0a0c88da8..302b3e2f6 100644 --- a/zeroconf.py +++ b/zeroconf.py @@ -1498,15 +1498,21 @@ def __init__( self.type = type_ self.name = name if addresses is not None: - self.addresses = addresses + self._addresses = addresses elif address is not None: warnings.warn("address is deprecated, use addresses instead", DeprecationWarning) if isinstance(address, list): - self.addresses = address + self._addresses = address else: - self.addresses = [address] + self._addresses = [address] else: - self.addresses = [] + self._addresses = [] + # This results in an ugly error when registering, better check now + invalid = [a for a in self._addresses + if not isinstance(a, bytes) or len(a) not in (4, 16)] + if invalid: + raise TypeError('Addresses must be bytes, got %s. Hint: convert string addresses ' + 'with socket.inet_pton' % invalid) self.port = port self.weight = weight self.priority = priority @@ -1524,6 +1530,7 @@ def __init__( def address(self): warnings.warn("ServiceInfo.address is deprecated, use addresses instead", DeprecationWarning) try: + # Return the first V4 address for compatibility return self.addresses[0] except IndexError: return None @@ -1532,9 +1539,27 @@ def address(self): def address(self, value): warnings.warn("ServiceInfo.address is deprecated, use addresses instead", DeprecationWarning) if value is None: - self.addresses = [] + self._addresses = [] else: - self.addresses = [value] + self._addresses = [value] + + @property + def addresses(self): + """IPv4 addresses of this service. + + Only IPv4 addresses are returned for backward compatibility. + Use :meth:`addresses_by_version` or :meth:`parsed_addresses` to + include IPv6 addresses as well. + """ + return self.addresses_by_version(IPVersion.V4Only) + + @addresses.setter + def addresses(self, value): + """Replace the addresses list. + + This replaces all currently stored addresses, both IPv4 and IPv6. + """ + self._addresses = value @property def properties(self) -> ServicePropertiesType: @@ -1543,11 +1568,19 @@ def properties(self) -> ServicePropertiesType: def addresses_by_version(self, version: IPVersion) -> List[bytes]: """List addresses matching IP version.""" if version == IPVersion.V4Only: - return [addr for addr in self.addresses if not _is_v6_address(addr)] + return [addr for addr in self._addresses if not _is_v6_address(addr)] elif version == IPVersion.V6Only: - return list(filter(_is_v6_address, self.addresses)) + return list(filter(_is_v6_address, self._addresses)) else: - return self.addresses + return self._addresses + + def parsed_addresses(self, version: IPVersion = IPVersion.All) -> List[str]: + """List addresses in their parsed string form.""" + result = self.addresses_by_version(version) + return [ + socket.inet_ntop(socket.AF_INET6 if _is_v6_address(addr) else socket.AF_INET, addr) + for addr in result + ] def _set_properties(self, properties: Union[bytes, ServicePropertiesType]): """Sets properties and text of this info from a dictionary""" @@ -1625,8 +1658,8 @@ def update_record(self, zc: 'Zeroconf', now: float, record: DNSRecord) -> None: assert isinstance(record, DNSAddress) # if record.name == self.name: if record.name == self.server: - if record.address not in self.addresses: - self.addresses.append(record.address) + if record.address not in self._addresses: + self._addresses.append(record.address) elif record.type == _TYPE_SRV: assert isinstance(record, DNSService) if record.name == self.name: @@ -1660,12 +1693,12 @@ def request(self, zc: 'Zeroconf', timeout: float) -> bool: if cached: self.update_record(zc, now, cached) - if self.server is not None and self.text is not None and self.addresses: + if self.server is not None and self.text is not None and self._addresses: return True try: zc.add_listener(self, DNSQuestion(self.name, _TYPE_ANY, _CLASS_IN)) - while self.server is None or self.text is None or not self.addresses: + while self.server is None or self.text is None or not self._addresses: if last <= now: return False if next_ <= now: