diff --git a/examples/browser.py b/examples/browser.py index c851be676..8b4fd5ee2 100755 --- a/examples/browser.py +++ b/examples/browser.py @@ -19,7 +19,8 @@ def on_service_state_change( if state_change is ServiceStateChange.Added: info = zeroconf.get_service_info(service_type, name) if info: - print(" Address: %s:%d" % (socket.inet_ntoa(cast(bytes, info.address)), cast(int, info.port))) + addresses = ["%s:%d" % (socket.inet_ntoa(addr), cast(int, info.port)) for addr in info.addresses] + print(" Addresses: %s" % ", ".join(addresses)) print(" Weight: %d, priority: %d" % (info.weight, info.priority)) print(" Server: %s" % (info.server,)) if info.properties: diff --git a/examples/registration.py b/examples/registration.py index 7829acc92..bda55b831 100755 --- a/examples/registration.py +++ b/examples/registration.py @@ -20,12 +20,10 @@ info = ServiceInfo( "_http._tcp.local.", "Paul's Test Web Site._http._tcp.local.", - socket.inet_aton("127.0.0.1"), - 80, - 0, - 0, - desc, - "ash-2.local.", + addresses=[socket.inet_aton("127.0.0.1")], + port=80, + properties=desc, + server="ash-2.local.", ) zeroconf = Zeroconf() diff --git a/examples/self_test.py b/examples/self_test.py index 6667d13ee..62e325732 100755 --- a/examples/self_test.py +++ b/examples/self_test.py @@ -21,11 +21,9 @@ info = ServiceInfo( "_http._tcp.local.", "My Service Name._http._tcp.local.", - socket.inet_aton("127.0.0.1"), - 1234, - 0, - 0, - desc, + addresses=[socket.inet_aton("127.0.0.1")], + port=1234, + properties=desc, ) print(" Registering service...") r.register_service(info) diff --git a/test_zeroconf.py b/test_zeroconf.py index 9737d38f4..59f03df6c 100644 --- a/test_zeroconf.py +++ b/test_zeroconf.py @@ -958,3 +958,40 @@ def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): assert service_removed.is_set() browser.cancel() zeroconf_browser.close() + + +def test_multiple_addresses(): + type_ = "_http._tcp.local." + registration_name = "xxxyyy.%s" % type_ + desc = {'path': '/~paulsm/'} + address = socket.inet_aton("10.0.1.2") + + # Old way + info = ServiceInfo(type_, registration_name, address, 80, 0, 0, desc, "ash-2.local.") + + assert info.address == address + assert info.addresses == [address] + + # Updating works + address2 = socket.inet_aton("10.0.1.3") + info.address = address2 + + assert info.address == address2 + assert info.addresses == [address2] + + info.address = None + + assert info.address is None + assert info.addresses == [] + + # Compatibility way + info = ServiceInfo(type_, registration_name, [address, address], 80, 0, 0, desc, "ash-2.local.") + + assert info.addresses == [address, address] + + # New kwarg way + info = ServiceInfo( + type_, registration_name, None, 80, 0, 0, desc, "ash-2.local.", addresses=[address, address] + ) + + assert info.addresses == [address, address] diff --git a/zeroconf.py b/zeroconf.py index b8c664730..c0f16797e 100644 --- a/zeroconf.py +++ b/zeroconf.py @@ -30,6 +30,7 @@ import sys import threading import time +import warnings from functools import reduce from typing import AnyStr, Dict, List, Optional, Union, cast from typing import Callable, Set, Tuple # noqa # used in type hints @@ -1429,11 +1430,14 @@ class ServiceInfo(RecordUpdateListener): """Service information""" + # FIXME(dtantsur): black 19.3b0 produces code that is not valid syntax on + # Python 3.5: https://github.com/python/black/issues/759 + # fmt: off def __init__( self, type_: str, name: str, - address: Optional[bytes] = None, + address: Optional[Union[bytes, List[bytes]]] = None, port: Optional[int] = None, weight: int = 0, priority: int = 0, @@ -1441,12 +1445,14 @@ def __init__( server: Optional[str] = None, host_ttl: int = _DNS_HOST_TTL, other_ttl: int = _DNS_OTHER_TTL, + *, + addresses: Optional[List[bytes]] = None ) -> None: """Create a service description. type_: fully qualified service type name name: fully qualified service name - address: IP address as unsigned short, network byte order + address: IP address as unsigned short, network byte order (deprecated, use addresses) port: port that the service runs on weight: weight of the service priority: priority of the service @@ -1454,13 +1460,29 @@ def __init__( bytes for the text field) server: fully qualified name for service host (defaults to name) host_ttl: ttl used for A/SRV records - other_ttl: ttl used for PTR/TXT records""" + other_ttl: ttl used for PTR/TXT records + addresses: List of IP addresses as unsigned short, network byte + order + """ + + # Accept both none, or one, but not both. + if address is not None and addresses is not None: + raise TypeError("address and addresses cannot be provided together") if not type_.endswith(service_type_name(name, allow_underscores=True)): raise BadTypeInNameException self.type = type_ self.name = name - self.address = address + if addresses is not None: + self.addresses = addresses + elif address is not None: + warnings.warn("address is deprecated, use addresses instead", DeprecationWarning) + if isinstance(address, list): + self.addresses = address + else: + self.addresses = [address] + else: + self.addresses = [] self.port = port self.weight = weight self.priority = priority @@ -1472,6 +1494,23 @@ def __init__( self._set_properties(properties) self.host_ttl = host_ttl self.other_ttl = other_ttl + # fmt: on + + @property + def address(self): + warnings.warn("ServiceInfo.address is deprecated, use addresses instead", DeprecationWarning) + try: + return self.addresses[0] + except IndexError: + return None + + @address.setter + def address(self, value): + warnings.warn("ServiceInfo.address is deprecated, use addresses instead", DeprecationWarning) + if value is None: + self.addresses = [] + else: + self.addresses = [value] @property def properties(self) -> ServicePropertiesType: @@ -1553,7 +1592,8 @@ def update_record(self, zc: 'Zeroconf', now: float, record: DNSRecord) -> None: assert isinstance(record, DNSAddress) # if record.name == self.name: if record.name == self.server: - self.address = record.address + if record.address not in self.addresses: + self.addresses.append(record.address) elif record.type == _TYPE_SRV: assert isinstance(record, DNSService) if record.name == self.name: @@ -1585,12 +1625,12 @@ def request(self, zc: 'Zeroconf', timeout: float) -> bool: if cached: self.update_record(zc, now, cached) - if None not in (self.server, self.address, self.text): + if self.server is not None and self.text is not None and self.addresses: return True try: zc.add_listener(self, DNSQuestion(self.name, _TYPE_ANY, _CLASS_IN)) - while None in (self.server, self.address, self.text): + while self.server is None or self.text is None or not self.addresses: if last <= now: return False if next_ <= now: @@ -1629,7 +1669,16 @@ def __repr__(self) -> str: type(self).__name__, ', '.join( '%s=%r' % (name, getattr(self, name)) - for name in ('type', 'name', 'address', 'port', 'weight', 'priority', 'server', 'properties') + for name in ( + 'type', + 'name', + 'addresses', + 'port', + 'weight', + 'priority', + 'server', + 'properties', + ) ), ) @@ -1916,10 +1965,8 @@ def _broadcast_service(self, info: ServiceInfo) -> None: ) out.add_answer_at_time(DNSText(info.name, _TYPE_TXT, _CLASS_IN, info.other_ttl, info.text), 0) - if info.address: - out.add_answer_at_time( - DNSAddress(info.server, _TYPE_A, _CLASS_IN, info.host_ttl, info.address), 0 - ) + for address in info.addresses: + out.add_answer_at_time(DNSAddress(info.server, _TYPE_A, _CLASS_IN, info.host_ttl, address), 0) self.send(out) i += 1 next_time += _REGISTER_TIME @@ -1952,8 +1999,8 @@ def unregister_service(self, info: ServiceInfo) -> None: ) out.add_answer_at_time(DNSText(info.name, _TYPE_TXT, _CLASS_IN, 0, info.text), 0) - if info.address: - out.add_answer_at_time(DNSAddress(info.server, _TYPE_A, _CLASS_IN, 0, info.address), 0) + for address in info.addresses: + out.add_answer_at_time(DNSAddress(info.server, _TYPE_A, _CLASS_IN, 0, address), 0) self.send(out) i += 1 next_time += _UNREGISTER_TIME @@ -1986,10 +2033,8 @@ def unregister_all_services(self) -> None: 0, ) out.add_answer_at_time(DNSText(info.name, _TYPE_TXT, _CLASS_IN, 0, info.text), 0) - if info.address: - out.add_answer_at_time( - DNSAddress(info.server, _TYPE_A, _CLASS_IN, 0, info.address), 0 - ) + for address in info.addresses: + out.add_answer_at_time(DNSAddress(info.server, _TYPE_A, _CLASS_IN, 0, address), 0) self.send(out) i += 1 next_time += _UNREGISTER_TIME @@ -2126,16 +2171,17 @@ def handle_query(self, msg: DNSIncoming, addr: str, port: int) -> None: if question.type in (_TYPE_A, _TYPE_ANY): for service in self.services.values(): if service.server == question.name.lower(): - out.add_answer( - msg, - DNSAddress( - question.name, - _TYPE_A, - _CLASS_IN | _CLASS_UNIQUE, - service.host_ttl, - service.address, - ), - ) + for address in service.addresses: + out.add_answer( + msg, + DNSAddress( + question.name, + _TYPE_A, + _CLASS_IN | _CLASS_UNIQUE, + service.host_ttl, + address, + ), + ) name_to_find = question.name.lower() if name_to_find not in self.services: @@ -2168,15 +2214,16 @@ def handle_query(self, msg: DNSIncoming, addr: str, port: int) -> None: ), ) if question.type == _TYPE_SRV: - out.add_additional_answer( - DNSAddress( - service.server, - _TYPE_A, - _CLASS_IN | _CLASS_UNIQUE, - service.host_ttl, - service.address, + for address in service.addresses: + out.add_additional_answer( + DNSAddress( + service.server, + _TYPE_A, + _CLASS_IN | _CLASS_UNIQUE, + service.host_ttl, + address, + ) ) - ) except Exception: # TODO stop catching all Exceptions self.log_exception_warning()