Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions tests/test_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,11 @@ def _mock_get_expiration_time(self, percent):
assert service_added_count == 3
assert service_removed_count == 0

_inject_response(
zeroconf,
mock_incoming_msg(r.ServiceStateChange.Updated, service_types[0], service_names[0], 0),
)

# all three services removed
_inject_response(
zeroconf,
Expand Down Expand Up @@ -1265,6 +1270,124 @@ def mock_incoming_msg(records) -> r.DNSIncoming:
zc.close()


def test_service_browser_listeners_update_service():
"""Test that the ServiceBrowser ServiceListener that implements update_service."""

# instantiate a zeroconf instance
zc = Zeroconf(interfaces=['127.0.0.1'])
# start a browser
type_ = "_hap._tcp.local."
registration_name = "xxxyyy.%s" % type_
callbacks = []

class MyServiceListener(r.ServiceListener):
def add_service(self, zc, type_, name) -> None:
nonlocal callbacks
if name == registration_name:
callbacks.append(("add", type_, name))

def remove_service(self, zc, type_, name) -> None:
nonlocal callbacks
if name == registration_name:
callbacks.append(("remove", type_, name))

def update_service(self, zc, type_, name) -> None:
nonlocal callbacks
if name == registration_name:
callbacks.append(("update", type_, name))

listener = MyServiceListener()

browser = r.ServiceBrowser(zc, type_, None, listener)

desc = {'path': '/~paulsm/'}
address_parsed = "10.0.1.2"
address = socket.inet_aton(address_parsed)
info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address])

def mock_incoming_msg(records) -> r.DNSIncoming:
generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
for record in records:
generated.add_answer_at_time(record, 0)
return r.DNSIncoming(generated.packets()[0])

_inject_response(
zc,
mock_incoming_msg([info.dns_pointer(), info.dns_service(), info.dns_text(), *info.dns_addresses()]),
)
time.sleep(0.2)
info.port = 400
_inject_response(
zc,
mock_incoming_msg([info.dns_service()]),
)
time.sleep(0.2)

assert callbacks == [
('add', type_, registration_name),
('update', type_, registration_name),
]
browser.cancel()

zc.close()


def test_service_browser_listeners_no_update_service():
"""Test that the ServiceBrowser ServiceListener that does not implement update_service."""

# instantiate a zeroconf instance
zc = Zeroconf(interfaces=['127.0.0.1'])
# start a browser
type_ = "_hap._tcp.local."
registration_name = "xxxyyy.%s" % type_
callbacks = []

class MyServiceListener:
def add_service(self, zc, type_, name) -> None:
nonlocal callbacks
if name == registration_name:
callbacks.append(("add", type_, name))

def remove_service(self, zc, type_, name) -> None:
nonlocal callbacks
if name == registration_name:
callbacks.append(("remove", type_, name))

listener = MyServiceListener()

browser = r.ServiceBrowser(zc, type_, None, listener)

desc = {'path': '/~paulsm/'}
address_parsed = "10.0.1.2"
address = socket.inet_aton(address_parsed)
info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address])

def mock_incoming_msg(records) -> r.DNSIncoming:
generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
for record in records:
generated.add_answer_at_time(record, 0)
return r.DNSIncoming(generated.packets()[0])

_inject_response(
zc,
mock_incoming_msg([info.dns_pointer(), info.dns_service(), info.dns_text(), *info.dns_addresses()]),
)
time.sleep(0.2)
info.port = 400
_inject_response(
zc,
mock_incoming_msg([info.dns_service()]),
)
time.sleep(0.2)

assert callbacks == [
('add', type_, registration_name),
]
browser.cancel()

zc.close()


def test_changing_name_updates_serviceinfo_key():
"""Verify a name change will adjust the underlying key value."""
type_ = "_homeassistant._tcp.local."
Expand Down
49 changes: 26 additions & 23 deletions zeroconf/_services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,31 @@ def _group_ptr_queries_with_known_answers(
return [query_bucket.out for query_bucket in query_buckets]


def _service_state_changed_from_listener(listener: ServiceListener) -> Callable[..., None]:
"""Generate a service_state_changed handlers from a listener."""

def on_change(
zeroconf: 'Zeroconf', service_type: str, name: str, state_change: ServiceStateChange
) -> None:
assert listener is not None
args = (zeroconf, service_type, name)
if state_change is ServiceStateChange.Added:
listener.add_service(*args)
elif state_change is ServiceStateChange.Removed:
listener.remove_service(*args)
elif state_change is ServiceStateChange.Updated:
if hasattr(listener, 'update_service'):
listener.update_service(*args)
else:
warnings.warn(
"%r has no update_service method. Provide one (it can be empty if you "
"don't care about the updates), it'll become mandatory." % (listener,),
FutureWarning,
)

return on_change


class _ServiceBrowserBase(RecordUpdateListener):
"""Base class for ServiceBrowser."""

Expand Down Expand Up @@ -262,29 +287,7 @@ def __init__(
handlers = cast(List[Callable[..., None]], handlers or [])

if listener:

def on_change(
zeroconf: 'Zeroconf', service_type: str, name: str, state_change: ServiceStateChange
) -> None:
assert listener is not None
args = (zeroconf, service_type, name)
if state_change is ServiceStateChange.Added:
listener.add_service(*args)
elif state_change is ServiceStateChange.Removed:
listener.remove_service(*args)
elif state_change is ServiceStateChange.Updated:
if hasattr(listener, 'update_service'):
listener.update_service(*args)
else:
warnings.warn(
"%r has no update_service method. Provide one (it can be empty if you "
"don't care about the updates), it'll become mandatory." % (listener,),
FutureWarning,
)
else:
raise NotImplementedError(state_change)

handlers.append(on_change)
handlers.append(_service_state_changed_from_listener(listener))

for h in handlers:
self.service_state_changed.register_handler(h)
Expand Down