diff --git a/examples/self_test.py b/examples/self_test.py index 62e325732..0b231ac34 100755 --- a/examples/self_test.py +++ b/examples/self_test.py @@ -18,10 +18,13 @@ r = Zeroconf() print("1. Testing registration of a service...") desc = {'version': '0.10', 'a': 'test value', 'b': 'another value'} + addresses = [socket.inet_aton("127.0.0.1")] + if socket.has_ipv6: + addresses.append(socket.inet_pton(socket.AF_INET6, '::1')) info = ServiceInfo( "_http._tcp.local.", "My Service Name._http._tcp.local.", - addresses=[socket.inet_aton("127.0.0.1")], + addresses=addresses, port=1234, properties=desc, ) diff --git a/test_zeroconf.py b/test_zeroconf.py index fbef8e631..d5d887027 100644 --- a/test_zeroconf.py +++ b/test_zeroconf.py @@ -556,9 +556,16 @@ def test_incoming_unknown_type(self): assert parsed.is_query() != parsed.is_response() def test_incoming_ipv6(self): - # ::TODO:: could use a test here if we add IPV6 record handling - # ie: _TYPE_AAAA - pass + addr = "2606:2800:220:1:248:1893:25c8:1946" # example.com + packed = socket.inet_pton(socket.AF_INET6, addr) + generated = r.DNSOutgoing(0) + answer = r.DNSAddress('domain', r._TYPE_AAAA, r._CLASS_IN, 1, packed) + generated.add_additional_answer(answer) + packet = generated.packet() + parsed = r.DNSIncoming(packet) + record = parsed.answers[0] + assert isinstance(record, r.DNSAddress) + assert record.address == packed class TestRegistrar(unittest.TestCase): @@ -689,6 +696,30 @@ def test_integration_with_listener(self): finally: zeroconf_registrar.close() + @unittest.skipIf(not socket.has_ipv6, 'Requires IPv6') + def test_integration_with_listener_v6_records(self): + + type_ = "_test-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + addr = "2606:2800:220:1:248:1893:25c8:1946" # example.com + + zeroconf_registrar = Zeroconf(interfaces=['127.0.0.1']) + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, socket.inet_pton(socket.AF_INET6, addr), 80, 0, 0, desc, "ash-2.local." + ) + zeroconf_registrar.register_service(info) + + try: + service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=0.5) + assert type_ in service_types + service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=0.5) + assert type_ in service_types + + finally: + zeroconf_registrar.close() + @unittest.skipIf(not socket.has_ipv6, 'Requires IPv6') @attr('IPv6') def test_integration_with_listener_ipv6(self): @@ -1037,6 +1068,14 @@ def test_multiple_addresses(): assert info.addresses == [address, address] + if socket.has_ipv6: + address_v6 = socket.inet_pton(socket.AF_INET6, "2001:db8::1") + info = ServiceInfo(type_, registration_name, [address, address_v6], 80, 0, 0, desc, "ash-2.local.") + assert info.addresses == [address, address_v6] + assert info.addresses_by_version(r.IPVersion.All) == [address, address_v6] + assert info.addresses_by_version(r.IPVersion.V4Only) == [address] + assert info.addresses_by_version(r.IPVersion.V6Only) == [address_v6] + def test_ptr_optimization(): diff --git a/zeroconf.py b/zeroconf.py index 6c567ee25..8f6654e28 100644 --- a/zeroconf.py +++ b/zeroconf.py @@ -210,6 +210,10 @@ def current_time_millis() -> float: return time.time() * 1000 +def _is_v6_address(addr): + return len(addr) == 16 + + def service_type_name(type_, *, allow_underscores: bool = False): """ Validate a fully qualified service name, instance or subtype. [rfc6763] @@ -1535,6 +1539,15 @@ def address(self, value): def properties(self) -> ServicePropertiesType: return self._properties + def addresses_by_version(self, version: IPVersion) -> List[bytes]: + """List addresses matching IP version.""" + if version == IPVersion.V4Only: + return [addr for addr in self.addresses if not _is_v6_address(addr)] + elif version == IPVersion.V6Only: + return list(filter(_is_v6_address, self.addresses)) + else: + return self.addresses + def _set_properties(self, properties: Union[bytes, ServicePropertiesType]): """Sets properties and text of this info from a dictionary""" if isinstance(properties, dict): @@ -1607,7 +1620,7 @@ def get_name(self): def update_record(self, zc: 'Zeroconf', now: float, record: DNSRecord) -> None: """Updates service information from a DNS record""" if record is not None and not record.is_expired(now): - if record.type == _TYPE_A: + if record.type in [_TYPE_A, _TYPE_AAAA]: assert isinstance(record, DNSAddress) # if record.name == self.name: if record.name == self.server: @@ -1622,6 +1635,7 @@ def update_record(self, zc: 'Zeroconf', now: float, record: DNSRecord) -> None: self.priority = record.priority # self.address = None self.update_record(zc, now, zc.cache.get_by_details(self.server, _TYPE_A, _CLASS_IN)) + self.update_record(zc, now, zc.cache.get_by_details(self.server, _TYPE_AAAA, _CLASS_IN)) elif record.type == _TYPE_TXT: assert isinstance(record, DNSText) if record.name == self.name: @@ -1639,6 +1653,7 @@ def request(self, zc: 'Zeroconf', timeout: float) -> bool: record_types_for_check_cache = [(_TYPE_SRV, _CLASS_IN), (_TYPE_TXT, _CLASS_IN)] if self.server is not None: record_types_for_check_cache.append((_TYPE_A, _CLASS_IN)) + record_types_for_check_cache.append((_TYPE_AAAA, _CLASS_IN)) for record_type in record_types_for_check_cache: cached = zc.cache.get_by_details(self.name, *record_type) if cached: @@ -1663,6 +1678,10 @@ def request(self, zc: 'Zeroconf', timeout: float) -> bool: if self.server is not None: out.add_question(DNSQuestion(self.server, _TYPE_A, _CLASS_IN)) out.add_answer_at_time(zc.cache.get_by_details(self.server, _TYPE_A, _CLASS_IN), now) + out.add_question(DNSQuestion(self.server, _TYPE_AAAA, _CLASS_IN)) + out.add_answer_at_time( + zc.cache.get_by_details(self.server, _TYPE_AAAA, _CLASS_IN), now + ) zc.send(out) next_ = now + delay delay *= 2 @@ -2142,7 +2161,8 @@ def _broadcast_service(self, info: ServiceInfo) -> None: out.add_answer_at_time(DNSText(info.name, _TYPE_TXT, _CLASS_IN, info.other_ttl, info.text), 0) for address in info.addresses: - out.add_answer_at_time(DNSAddress(info.server, _TYPE_A, _CLASS_IN, info.host_ttl, address), 0) + type_ = _TYPE_AAAA if _is_v6_address(address) else _TYPE_A + out.add_answer_at_time(DNSAddress(info.server, type_, _CLASS_IN, info.host_ttl, address), 0) self.send(out) i += 1 next_time += _REGISTER_TIME @@ -2176,7 +2196,8 @@ def unregister_service(self, info: ServiceInfo) -> None: out.add_answer_at_time(DNSText(info.name, _TYPE_TXT, _CLASS_IN, 0, info.text), 0) for address in info.addresses: - out.add_answer_at_time(DNSAddress(info.server, _TYPE_A, _CLASS_IN, 0, address), 0) + type_ = _TYPE_AAAA if _is_v6_address(address) else _TYPE_A + out.add_answer_at_time(DNSAddress(info.server, type_, _CLASS_IN, 0, address), 0) self.send(out) i += 1 next_time += _UNREGISTER_TIME @@ -2210,7 +2231,8 @@ def unregister_all_services(self) -> None: ) out.add_answer_at_time(DNSText(info.name, _TYPE_TXT, _CLASS_IN, 0, info.text), 0) for address in info.addresses: - out.add_answer_at_time(DNSAddress(info.server, _TYPE_A, _CLASS_IN, 0, address), 0) + type_ = _TYPE_AAAA if _is_v6_address(address) else _TYPE_A + out.add_answer_at_time(DNSAddress(info.server, type_, _CLASS_IN, 0, address), 0) self.send(out) i += 1 next_time += _UNREGISTER_TIME