Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions zeroconf/_services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,21 @@ def _group_ptr_queries_with_known_answers(
return [query_bucket.out for query_bucket in query_buckets]


def generate_service_query(
zc: 'Zeroconf', now: float, types_: List[str], multicast: bool = True
) -> List[DNSOutgoing]:
"""Generate a service query for sending with zeroconf.send."""
questions_with_known_answers: _QuestionWithKnownAnswers = {}
for type_ in types_:
question = DNSQuestion(type_, _TYPE_PTR, _CLASS_IN)
questions_with_known_answers[question] = set(
cast(DNSPointer, record)
for record in zc.cache.get_all_by_details(type_, _TYPE_PTR, _CLASS_IN)
if not record.is_stale(now)
)
return _group_ptr_queries_with_known_answers(now, multicast, questions_with_known_answers)


def _service_state_changed_from_listener(listener: ServiceListener) -> Callable[..., None]:
"""Generate a service_state_changed handlers from a listener."""

Expand Down Expand Up @@ -271,7 +286,6 @@ def __init__(
self.addr = addr
self.port = port
self.multicast = self.addr in (None, _MDNS_ADDR, _MDNS_ADDR6)
self._services: Dict[str, Dict[str, DNSPointer]] = {check_type_: {} for check_type_ in self.types}
current_time = current_time_millis()
self._next_time = {check_type_: current_time for check_type_ in self.types}
self._delay = {check_type_: delay for check_type_ in self.types}
Expand Down Expand Up @@ -327,17 +341,14 @@ def _async_process_record_update(self, now: float, record: DNSRecord) -> None:
if isinstance(record, DNSPointer):
if record.name not in self.types:
return
service_key = record.alias.lower()
services_by_type = self._services[record.name]
old_record = services_by_type.get(service_key)
old_record = self.zc.cache.async_get_unique(
DNSPointer(record.name, _TYPE_PTR, _CLASS_IN, 0, record.alias)
)
if old_record is None:
services_by_type[service_key] = record
self._enqueue_callback(ServiceStateChange.Added, record.name, record.alias)
elif expired:
del services_by_type[service_key]
self._enqueue_callback(ServiceStateChange.Removed, record.name, record.alias)
else:
old_record.reset_ttl(record)
expires = record.get_expiration_time(_EXPIRE_REFRESH_TIME_PERCENT)
if expires < self._next_time[record.name]:
self._next_time[record.name] = expires
Expand Down Expand Up @@ -407,18 +418,17 @@ def generate_ready_queries(self) -> List[DNSOutgoing]:
if min(self._next_time.values()) > now:
return []

questions_with_known_answers: _QuestionWithKnownAnswers = {}
ready_types = []

for type_, due in self._next_time.items():
if due > now:
continue
questions_with_known_answers[DNSQuestion(type_, _TYPE_PTR, _CLASS_IN)] = set(
record for record in self._services[type_].values() if not record.is_stale(now)
)

ready_types.append(type_)
self._next_time[type_] = now + self._delay[type_]
self._delay[type_] = min(_BROWSER_BACKOFF_LIMIT * 1000, self._delay[type_] * 2)

return _group_ptr_queries_with_known_answers(now, self.multicast, questions_with_known_answers)
return generate_service_query(self.zc, now, ready_types, self.multicast)

def _seconds_to_wait(self) -> Optional[float]:
"""Returns the number of seconds to wait for the next event."""
Expand Down