Skip to content

Commit e417fc0

Browse files
authored
Reduce duplicate code between zeroconf.asyncio and zeroconf._core (#904)
1 parent f8af0fb commit e417fc0

3 files changed

Lines changed: 87 additions & 77 deletions

File tree

zeroconf/_core.py

Lines changed: 70 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import sys
3030
import threading
3131
from types import TracebackType # noqa # used in type hints
32-
from typing import Dict, List, Optional, Tuple, Type, Union, cast
32+
from typing import Awaitable, Dict, List, Optional, Tuple, Type, Union, cast
3333

3434
from ._cache import DNSCache
3535
from ._dns import DNSQuestion, DNSQuestionType
@@ -43,7 +43,7 @@
4343
from ._services.info import ServiceInfo, instance_name_from_service_info
4444
from ._services.registry import ServiceRegistry
4545
from ._updates import RecordUpdate, RecordUpdateListener
46-
from ._utils.asyncio import get_running_loop, shutdown_loop, wait_event_or_timeout
46+
from ._utils.asyncio import await_awaitable, get_running_loop, shutdown_loop, wait_event_or_timeout
4747
from ._utils.name import service_type_name
4848
from ._utils.net import (
4949
IPVersion,
@@ -74,6 +74,7 @@
7474

7575
_TC_DELAY_RANDOM_INTERVAL = (400, 500)
7676
_CLOSE_TIMEOUT = 3
77+
_REGISTER_BROADCASTS = 3
7778

7879

7980
class AsyncEngine:
@@ -478,6 +479,27 @@ def register_service(
478479
allow_name_change: bool = False,
479480
cooperating_responders: bool = False,
480481
) -> None:
482+
"""Registers service information to the network with a default TTL.
483+
Zeroconf will then respond to requests for information for that
484+
service. The name of the service may be changed if needed to make
485+
it unique on the network. Additionally multiple cooperating responders
486+
can register the same service on the network for resilience
487+
(if you want this behavior set `cooperating_responders` to `True`)."""
488+
assert self.loop is not None
489+
asyncio.run_coroutine_threadsafe(
490+
await_awaitable(
491+
self.async_register_service(info, ttl, allow_name_change, cooperating_responders)
492+
),
493+
self.loop,
494+
).result(millis_to_seconds(_REGISTER_TIME * _REGISTER_BROADCASTS) + _LOADED_SYSTEM_TIMEOUT)
495+
496+
async def async_register_service(
497+
self,
498+
info: ServiceInfo,
499+
ttl: Optional[int] = None,
500+
allow_name_change: bool = False,
501+
cooperating_responders: bool = False,
502+
) -> Awaitable:
481503
"""Registers service information to the network with a default TTL.
482504
Zeroconf will then respond to requests for information for that
483505
service. The name of the service may be changed if needed to make
@@ -489,47 +511,41 @@ def register_service(
489511
# Setting TTLs via ServiceInfo is preferred
490512
info.host_ttl = ttl
491513
info.other_ttl = ttl
492-
self.check_service(info, allow_name_change, cooperating_responders)
514+
515+
await self.async_wait_for_start()
516+
await self.async_check_service(info, allow_name_change, cooperating_responders)
493517
self.registry.add(info)
494-
self._broadcast_service(info, _REGISTER_TIME, None)
518+
return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None))
495519

496520
def update_service(self, info: ServiceInfo) -> None:
497521
"""Registers service information to the network with a default TTL.
498522
Zeroconf will then respond to requests for information for that
499523
service."""
524+
assert self.loop is not None
525+
asyncio.run_coroutine_threadsafe(await_awaitable(self.async_update_service(info)), self.loop).result(
526+
millis_to_seconds(_REGISTER_TIME * _REGISTER_BROADCASTS) + _LOADED_SYSTEM_TIMEOUT
527+
)
500528

529+
async def async_update_service(self, info: ServiceInfo) -> Awaitable:
530+
"""Registers service information to the network with a default TTL.
531+
Zeroconf will then respond to requests for information for that
532+
service."""
501533
self.registry.update(info)
502-
self._broadcast_service(info, _REGISTER_TIME, None)
534+
return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None))
503535

504-
def _broadcast_service(self, info: ServiceInfo, interval: int, ttl: Optional[int]) -> None:
536+
async def _async_broadcast_service(self, info: ServiceInfo, interval: int, ttl: Optional[int]) -> None:
505537
"""Send a broadcasts to announce a service at intervals."""
506-
now = current_time_millis()
507-
next_time = now
508-
i = 0
509-
while i < 3:
510-
if now < next_time:
511-
self.wait(next_time - now)
512-
now = current_time_millis()
513-
continue
514-
515-
self.send_service_broadcast(info, ttl)
516-
i += 1
517-
next_time += interval
518-
519-
def send_service_broadcast(self, info: ServiceInfo, ttl: Optional[int]) -> None:
520-
"""Send a broadcast to announce a service."""
521-
self.send(self.generate_service_broadcast(info, ttl))
538+
for i in range(_REGISTER_BROADCASTS):
539+
if i != 0:
540+
await asyncio.sleep(millis_to_seconds(interval))
541+
self.async_send(self.generate_service_broadcast(info, ttl))
522542

523543
def generate_service_broadcast(self, info: ServiceInfo, ttl: Optional[int]) -> DNSOutgoing:
524544
"""Generate a broadcast to announce a service."""
525545
out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
526546
self._add_broadcast_answer(out, info, ttl)
527547
return out
528548

529-
def send_service_query(self, info: ServiceInfo) -> None:
530-
"""Send a query to lookup a service."""
531-
self.send(self.generate_service_query(info))
532-
533549
def generate_service_query(self, info: ServiceInfo) -> DNSOutgoing: # pylint: disable=no-self-use
534550
"""Generate a query to lookup a service."""
535551
out = DNSOutgoing(_FLAGS_QR_QUERY | _FLAGS_AA)
@@ -559,9 +575,16 @@ def _add_broadcast_answer( # pylint: disable=no-self-use
559575
out.add_answer_at_time(dns_address, 0)
560576

561577
def unregister_service(self, info: ServiceInfo) -> None:
578+
"""Unregister a service."""
579+
assert self.loop is not None
580+
asyncio.run_coroutine_threadsafe(
581+
await_awaitable(self.async_unregister_service(info)), self.loop
582+
).result(millis_to_seconds(_UNREGISTER_TIME * _REGISTER_BROADCASTS) + _LOADED_SYSTEM_TIMEOUT)
583+
584+
async def async_unregister_service(self, info: ServiceInfo) -> Awaitable:
562585
"""Unregister a service."""
563586
self.registry.remove(info)
564-
self._broadcast_service(info, _UNREGISTER_TIME, 0)
587+
return asyncio.ensure_future(self._async_broadcast_service(info, _UNREGISTER_TIME, 0))
565588

566589
def generate_unregister_all_services(self) -> Optional[DNSOutgoing]:
567590
"""Generate a DNSOutgoing goodbye for all services and remove them from the registry."""
@@ -574,6 +597,22 @@ def generate_unregister_all_services(self) -> Optional[DNSOutgoing]:
574597
self.registry.remove(service_infos)
575598
return out
576599

600+
async def async_unregister_all_services(self) -> None:
601+
"""Unregister all registered services.
602+
603+
Unlike async_register_service and async_unregister_service, this
604+
method does not return a future and is always expected to be
605+
awaited since its only called at shutdown.
606+
"""
607+
# Send Goodbye packets https://datatracker.ietf.org/doc/html/rfc6762#section-10.1
608+
out = self.generate_unregister_all_services()
609+
if not out:
610+
return
611+
for i in range(_REGISTER_BROADCASTS):
612+
if i != 0:
613+
await asyncio.sleep(millis_to_seconds(_UNREGISTER_TIME))
614+
self.async_send(out)
615+
577616
def unregister_all_services(self) -> None:
578617
"""Unregister all registered services."""
579618
# Send Goodbye packets https://datatracker.ietf.org/doc/html/rfc6762#section-10.1
@@ -592,7 +631,7 @@ def unregister_all_services(self) -> None:
592631
i += 1
593632
next_time += _UNREGISTER_TIME
594633

595-
def check_service(
634+
async def async_check_service(
596635
self, info: ServiceInfo, allow_name_change: bool, cooperating_responders: bool = False
597636
) -> None:
598637
"""Checks the network for a unique service name, modifying the
@@ -603,7 +642,7 @@ def check_service(
603642
next_instance_number = 2
604643
next_time = now = current_time_millis()
605644
i = 0
606-
while i < 3:
645+
while i < _REGISTER_BROADCASTS:
607646
# check for a name conflict
608647
while self.cache.current_entry_with_name_and_alias(info.type, info.name):
609648
if not allow_name_change:
@@ -617,11 +656,11 @@ def check_service(
617656
i = 0
618657

619658
if now < next_time:
620-
self.wait(next_time - now)
659+
await self.async_wait(next_time - now)
621660
now = current_time_millis()
622661
continue
623662

624-
self.send_service_query(info)
663+
self.async_send(self.generate_service_query(info))
625664
i += 1
626665
next_time += _CHECK_TIME
627666

zeroconf/_utils/asyncio.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@
2323
import asyncio
2424
import contextlib
2525
import queue
26-
from typing import Any, List, Optional, Set, cast
26+
from typing import Any, Awaitable, List, Optional, Set, cast
2727

28+
# The combined timeouts should be lower than _CLOSE_TIMEOUT + _WAIT_FOR_LOOP_TASKS_TIMEOUT
2829
_TASK_AWAIT_TIMEOUT = 1
2930
_GET_ALL_TASKS_TIMEOUT = 3
3031
_WAIT_FOR_LOOP_TASKS_TIMEOUT = 3 # Must be larger than _TASK_AWAIT_TIMEOUT
@@ -80,6 +81,12 @@ async def _wait_for_loop_tasks(wait_tasks: Set[asyncio.Task]) -> None:
8081
await asyncio.wait(wait_tasks, timeout=_TASK_AWAIT_TIMEOUT)
8182

8283

84+
async def await_awaitable(aw: Awaitable) -> None:
85+
"""Wait on an awaitable and the task it returns."""
86+
task = await aw
87+
await task
88+
89+
8390
def shutdown_loop(loop: asyncio.AbstractEventLoop) -> None:
8491
"""Wait for pending tasks and stop an event loop."""
8592
pending_tasks = set(

zeroconf/asyncio.py

Lines changed: 9 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,15 @@
2626

2727
from ._core import Zeroconf
2828
from ._dns import DNSQuestionType
29-
from ._exceptions import NonUniqueNameException
3029
from ._services import ServiceListener
3130
from ._services.browser import _ServiceBrowserBase
32-
from ._services.info import ServiceInfo, instance_name_from_service_info
31+
from ._services.info import ServiceInfo
3332
from ._services.types import ZeroconfServiceTypes
3433
from ._utils.net import IPVersion, InterfaceChoice, InterfacesType
35-
from ._utils.time import millis_to_seconds
3634
from .const import (
3735
_BROWSER_TIME,
38-
_CHECK_TIME,
3936
_MDNS_PORT,
40-
_REGISTER_TIME,
4137
_SERVICE_TYPE_ENUMERATION_NAME,
42-
_UNREGISTER_TIME,
4338
)
4439

4540

@@ -172,16 +167,11 @@ def __init__(
172167
)
173168
self.async_browsers: Dict[ServiceListener, AsyncServiceBrowser] = {}
174169

175-
async def _async_broadcast_service(self, info: ServiceInfo, interval: int, ttl: Optional[int]) -> None:
176-
"""Send a broadcasts to announce a service at intervals."""
177-
for i in range(3):
178-
if i != 0:
179-
await asyncio.sleep(millis_to_seconds(interval))
180-
self.zeroconf.async_send(self.zeroconf.generate_service_broadcast(info, ttl))
181-
182170
async def async_register_service(
183171
self,
184172
info: ServiceInfo,
173+
ttl: Optional[int] = None,
174+
allow_name_change: bool = False,
185175
cooperating_responders: bool = False,
186176
) -> Awaitable:
187177
"""Registers service information to the network with a default TTL.
@@ -194,10 +184,9 @@ async def async_register_service(
194184
The service will be broadcast in a task. This task is returned
195185
and therefore can be awaited if necessary.
196186
"""
197-
await self.zeroconf.async_wait_for_start()
198-
await self.async_check_service(info, cooperating_responders)
199-
self.zeroconf.registry.add(info)
200-
return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None))
187+
return await self.zeroconf.async_register_service(
188+
info, ttl, allow_name_change, cooperating_responders
189+
)
201190

202191
async def async_unregister_all_services(self) -> None:
203192
"""Unregister all registered services.
@@ -206,39 +195,15 @@ async def async_unregister_all_services(self) -> None:
206195
method does not return a future and is always expected to be
207196
awaited since its only called at shutdown.
208197
"""
209-
out = self.zeroconf.generate_unregister_all_services()
210-
if not out:
211-
return
212-
for i in range(3):
213-
if i != 0:
214-
await asyncio.sleep(millis_to_seconds(_UNREGISTER_TIME))
215-
self.zeroconf.async_send(out)
216-
217-
async def async_check_service(self, info: ServiceInfo, cooperating_responders: bool = False) -> None:
218-
"""Checks the network for a unique service name."""
219-
instance_name_from_service_info(info)
220-
if cooperating_responders:
221-
return
222-
self._raise_on_name_conflict(info)
223-
for i in range(3):
224-
if i != 0:
225-
await asyncio.sleep(millis_to_seconds(_CHECK_TIME))
226-
self.zeroconf.async_send(self.zeroconf.generate_service_query(info))
227-
self._raise_on_name_conflict(info)
228-
229-
def _raise_on_name_conflict(self, info: ServiceInfo) -> None:
230-
"""Raise NonUniqueNameException if the ServiceInfo has a conflict."""
231-
if self.zeroconf.cache.current_entry_with_name_and_alias(info.type, info.name):
232-
raise NonUniqueNameException
198+
await self.zeroconf.async_unregister_all_services()
233199

234200
async def async_unregister_service(self, info: ServiceInfo) -> Awaitable:
235201
"""Unregister a service.
236202
237203
The service will be broadcast in a task. This task is returned
238204
and therefore can be awaited if necessary.
239205
"""
240-
self.zeroconf.registry.remove(info)
241-
return asyncio.ensure_future(self._async_broadcast_service(info, _UNREGISTER_TIME, 0))
206+
return await self.zeroconf.async_unregister_service(info)
242207

243208
async def async_update_service(self, info: ServiceInfo) -> Awaitable:
244209
"""Registers service information to the network with a default TTL.
@@ -248,8 +213,7 @@ async def async_update_service(self, info: ServiceInfo) -> Awaitable:
248213
The service will be broadcast in a task. This task is returned
249214
and therefore can be awaited if necessary.
250215
"""
251-
self.zeroconf.registry.update(info)
252-
return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None))
216+
return await self.zeroconf.async_update_service(info)
253217

254218
async def async_close(self) -> None:
255219
"""Ends the background threads, and prevent this instance from

0 commit comments

Comments
 (0)