2929import sys
3030import threading
3131from types import TracebackType # noqa # used in type hints
32- from typing import Dict , List , Optional , Tuple , Type , Union , cast
32+ from typing import Awaitable , Dict , List , Optional , Tuple , Type , Union , cast
3333
3434from ._cache import DNSCache
3535from ._dns import DNSQuestion , DNSQuestionType
4343from ._services .info import ServiceInfo , instance_name_from_service_info
4444from ._services .registry import ServiceRegistry
4545from ._updates import RecordUpdate , RecordUpdateListener
46- from ._utils .asyncio import get_running_loop , shutdown_loop , wait_event_or_timeout
46+ from ._utils .asyncio import await_awaitable , get_running_loop , shutdown_loop , wait_event_or_timeout
4747from ._utils .name import service_type_name
4848from ._utils .net import (
4949 IPVersion ,
7474
7575_TC_DELAY_RANDOM_INTERVAL = (400 , 500 )
7676_CLOSE_TIMEOUT = 3
77+ _REGISTER_BROADCASTS = 3
7778
7879
7980class AsyncEngine :
@@ -478,6 +479,27 @@ def register_service(
478479 allow_name_change : bool = False ,
479480 cooperating_responders : bool = False ,
480481 ) -> None :
482+ """Registers service information to the network with a default TTL.
483+ Zeroconf will then respond to requests for information for that
484+ service. The name of the service may be changed if needed to make
485+ it unique on the network. Additionally multiple cooperating responders
486+ can register the same service on the network for resilience
487+ (if you want this behavior set `cooperating_responders` to `True`)."""
488+ assert self .loop is not None
489+ asyncio .run_coroutine_threadsafe (
490+ await_awaitable (
491+ self .async_register_service (info , ttl , allow_name_change , cooperating_responders )
492+ ),
493+ self .loop ,
494+ ).result (millis_to_seconds (_REGISTER_TIME * _REGISTER_BROADCASTS ) + _LOADED_SYSTEM_TIMEOUT )
495+
496+ async def async_register_service (
497+ self ,
498+ info : ServiceInfo ,
499+ ttl : Optional [int ] = None ,
500+ allow_name_change : bool = False ,
501+ cooperating_responders : bool = False ,
502+ ) -> Awaitable :
481503 """Registers service information to the network with a default TTL.
482504 Zeroconf will then respond to requests for information for that
483505 service. The name of the service may be changed if needed to make
@@ -489,47 +511,41 @@ def register_service(
489511 # Setting TTLs via ServiceInfo is preferred
490512 info .host_ttl = ttl
491513 info .other_ttl = ttl
492- self .check_service (info , allow_name_change , cooperating_responders )
514+
515+ await self .async_wait_for_start ()
516+ await self .async_check_service (info , allow_name_change , cooperating_responders )
493517 self .registry .add (info )
494- self ._broadcast_service (info , _REGISTER_TIME , None )
518+ return asyncio . ensure_future ( self ._async_broadcast_service (info , _REGISTER_TIME , None ) )
495519
496520 def update_service (self , info : ServiceInfo ) -> None :
497521 """Registers service information to the network with a default TTL.
498522 Zeroconf will then respond to requests for information for that
499523 service."""
524+ assert self .loop is not None
525+ asyncio .run_coroutine_threadsafe (await_awaitable (self .async_update_service (info )), self .loop ).result (
526+ millis_to_seconds (_REGISTER_TIME * _REGISTER_BROADCASTS ) + _LOADED_SYSTEM_TIMEOUT
527+ )
500528
529+ async def async_update_service (self , info : ServiceInfo ) -> Awaitable :
530+ """Registers service information to the network with a default TTL.
531+ Zeroconf will then respond to requests for information for that
532+ service."""
501533 self .registry .update (info )
502- self ._broadcast_service (info , _REGISTER_TIME , None )
534+ return asyncio . ensure_future ( self ._async_broadcast_service (info , _REGISTER_TIME , None ) )
503535
504- def _broadcast_service (self , info : ServiceInfo , interval : int , ttl : Optional [int ]) -> None :
536+ async def _async_broadcast_service (self , info : ServiceInfo , interval : int , ttl : Optional [int ]) -> None :
505537 """Send a broadcasts to announce a service at intervals."""
506- now = current_time_millis ()
507- next_time = now
508- i = 0
509- while i < 3 :
510- if now < next_time :
511- self .wait (next_time - now )
512- now = current_time_millis ()
513- continue
514-
515- self .send_service_broadcast (info , ttl )
516- i += 1
517- next_time += interval
518-
519- def send_service_broadcast (self , info : ServiceInfo , ttl : Optional [int ]) -> None :
520- """Send a broadcast to announce a service."""
521- self .send (self .generate_service_broadcast (info , ttl ))
538+ for i in range (_REGISTER_BROADCASTS ):
539+ if i != 0 :
540+ await asyncio .sleep (millis_to_seconds (interval ))
541+ self .async_send (self .generate_service_broadcast (info , ttl ))
522542
523543 def generate_service_broadcast (self , info : ServiceInfo , ttl : Optional [int ]) -> DNSOutgoing :
524544 """Generate a broadcast to announce a service."""
525545 out = DNSOutgoing (_FLAGS_QR_RESPONSE | _FLAGS_AA )
526546 self ._add_broadcast_answer (out , info , ttl )
527547 return out
528548
529- def send_service_query (self , info : ServiceInfo ) -> None :
530- """Send a query to lookup a service."""
531- self .send (self .generate_service_query (info ))
532-
533549 def generate_service_query (self , info : ServiceInfo ) -> DNSOutgoing : # pylint: disable=no-self-use
534550 """Generate a query to lookup a service."""
535551 out = DNSOutgoing (_FLAGS_QR_QUERY | _FLAGS_AA )
@@ -559,9 +575,16 @@ def _add_broadcast_answer( # pylint: disable=no-self-use
559575 out .add_answer_at_time (dns_address , 0 )
560576
561577 def unregister_service (self , info : ServiceInfo ) -> None :
578+ """Unregister a service."""
579+ assert self .loop is not None
580+ asyncio .run_coroutine_threadsafe (
581+ await_awaitable (self .async_unregister_service (info )), self .loop
582+ ).result (millis_to_seconds (_UNREGISTER_TIME * _REGISTER_BROADCASTS ) + _LOADED_SYSTEM_TIMEOUT )
583+
584+ async def async_unregister_service (self , info : ServiceInfo ) -> Awaitable :
562585 """Unregister a service."""
563586 self .registry .remove (info )
564- self ._broadcast_service (info , _UNREGISTER_TIME , 0 )
587+ return asyncio . ensure_future ( self ._async_broadcast_service (info , _UNREGISTER_TIME , 0 ) )
565588
566589 def generate_unregister_all_services (self ) -> Optional [DNSOutgoing ]:
567590 """Generate a DNSOutgoing goodbye for all services and remove them from the registry."""
@@ -574,6 +597,22 @@ def generate_unregister_all_services(self) -> Optional[DNSOutgoing]:
574597 self .registry .remove (service_infos )
575598 return out
576599
600+ async def async_unregister_all_services (self ) -> None :
601+ """Unregister all registered services.
602+
603+ Unlike async_register_service and async_unregister_service, this
604+ method does not return a future and is always expected to be
605+ awaited since its only called at shutdown.
606+ """
607+ # Send Goodbye packets https://datatracker.ietf.org/doc/html/rfc6762#section-10.1
608+ out = self .generate_unregister_all_services ()
609+ if not out :
610+ return
611+ for i in range (_REGISTER_BROADCASTS ):
612+ if i != 0 :
613+ await asyncio .sleep (millis_to_seconds (_UNREGISTER_TIME ))
614+ self .async_send (out )
615+
577616 def unregister_all_services (self ) -> None :
578617 """Unregister all registered services."""
579618 # Send Goodbye packets https://datatracker.ietf.org/doc/html/rfc6762#section-10.1
@@ -592,7 +631,7 @@ def unregister_all_services(self) -> None:
592631 i += 1
593632 next_time += _UNREGISTER_TIME
594633
595- def check_service (
634+ async def async_check_service (
596635 self , info : ServiceInfo , allow_name_change : bool , cooperating_responders : bool = False
597636 ) -> None :
598637 """Checks the network for a unique service name, modifying the
@@ -603,7 +642,7 @@ def check_service(
603642 next_instance_number = 2
604643 next_time = now = current_time_millis ()
605644 i = 0
606- while i < 3 :
645+ while i < _REGISTER_BROADCASTS :
607646 # check for a name conflict
608647 while self .cache .current_entry_with_name_and_alias (info .type , info .name ):
609648 if not allow_name_change :
@@ -617,11 +656,11 @@ def check_service(
617656 i = 0
618657
619658 if now < next_time :
620- self .wait (next_time - now )
659+ await self .async_wait (next_time - now )
621660 now = current_time_millis ()
622661 continue
623662
624- self .send_service_query ( info )
663+ self .async_send ( self . generate_service_query ( info ) )
625664 i += 1
626665 next_time += _CHECK_TIME
627666
0 commit comments