Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 89 additions & 8 deletions tests/services/test_browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

""" Unit tests for zeroconf._services.browser. """

import asyncio
import logging
import socket
import time
Expand Down Expand Up @@ -476,13 +477,7 @@ def on_service_state_change(zeroconf, service_type, state_change, name):
expected_query_time = 0.0
while True:
sleep_count += 1
for _ in range(2):
# If the browser thread is starting up
# its possible we notify before the initial sleep
# which means the test will fail so we need to d
# this twice to eliminate the race condition
zeroconf_browser.notify_all()
got_query.wait(0.05)
got_query.wait(0.1)
if time_offset == expected_query_time:
assert got_query.is_set()
got_query.clear()
Expand All @@ -501,6 +496,7 @@ def on_service_state_change(zeroconf, service_type, state_change, name):
else:
assert not got_query.is_set()
time_offset += initial_query_interval
zeroconf_browser.loop.call_soon_threadsafe(browser.query_scheduler.set_schedule_changed)

finally:
browser.cancel()
Expand Down Expand Up @@ -726,7 +722,7 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=()):
while nbr_answers < test_iterations:
# Increase simulated time shift by 1/4 of the TTL in seconds
time_offset += expected_ttl / 4
zeroconf_browser.notify_all()
zeroconf_browser.loop.call_soon_threadsafe(browser.query_scheduler.set_schedule_changed)
sleep_count += 1
got_query.wait(0.5)
# Prevent the test running indefinitely in an error condition
Expand Down Expand Up @@ -1067,3 +1063,88 @@ async def test_generate_service_query_suppress_duplicate_questions():
assert not outs

await aiozc.async_close()


@pytest.mark.asyncio
async def test_query_scheduler():
delay = const._BROWSER_TIME
types_ = set(["_hap._tcp.local.", "_http._tcp.local."])
query_scheduler = _services_browser.QueryScheduler(types_, delay, (0, 0))

now = current_time_millis()
query_scheduler.start(now)

# Test query interval is increasing
assert query_scheduler.millis_to_wait(now - 1) == 1
assert query_scheduler.millis_to_wait(now) is None
assert query_scheduler.millis_to_wait(now + 1) is None

assert set(query_scheduler.process_ready_types(now)) == types_
assert set(query_scheduler.process_ready_types(now)) == set()
assert query_scheduler.millis_to_wait(now) == delay

assert set(query_scheduler.process_ready_types(now + delay)) == types_
assert set(query_scheduler.process_ready_types(now + delay)) == set()
assert query_scheduler.millis_to_wait(now) == delay * 3

assert set(query_scheduler.process_ready_types(now + delay * 3)) == types_
assert set(query_scheduler.process_ready_types(now + delay * 3)) == set()
assert query_scheduler.millis_to_wait(now) == delay * 7

assert set(query_scheduler.process_ready_types(now + delay * 7)) == types_
assert set(query_scheduler.process_ready_types(now + delay * 7)) == set()
assert query_scheduler.millis_to_wait(now) == delay * 15

assert set(query_scheduler.process_ready_types(now + delay * 15)) == types_
assert set(query_scheduler.process_ready_types(now + delay * 15)) == set()

# Test if we reschedule 1 second later, the millis_to_wait goes up by 1
query_scheduler.reschedule_type("_hap._tcp.local.", now + delay * 16)
assert query_scheduler.millis_to_wait(now) == delay * 16

assert set(query_scheduler.process_ready_types(now + delay * 15)) == set()

# Test if we reschedule 1 second later... and its ready for processing
assert set(query_scheduler.process_ready_types(now + delay * 16)) == set(["_hap._tcp.local."])
assert query_scheduler.millis_to_wait(now) == delay * 31
assert set(query_scheduler.process_ready_types(now + delay * 20)) == set()

assert set(query_scheduler.process_ready_types(now + delay * 31)) == set(["_http._tcp.local."])


@pytest.mark.asyncio
async def test_query_scheduler_triggers_async_wait_ready_on_reschedule():
"""Test that a reschedule wakes up the async_wait_ready."""
delay = const._BROWSER_TIME
types_ = set(["_hap._tcp.local.", "_http._tcp.local."])
query_scheduler = _services_browser.QueryScheduler(types_, delay, (0, 0))

now = current_time_millis()
query_scheduler.start(now)
assert set(query_scheduler.process_ready_types(now)) == types_
assert query_scheduler.millis_to_wait(now) == delay

task = asyncio.ensure_future(query_scheduler.async_wait_ready(now))
await asyncio.sleep(0) # Start the task
await asyncio.sleep(0) # Make sure its waiting
assert not task.done()
assert query_scheduler.millis_to_wait(now + 1) == delay - 1
query_scheduler.reschedule_type("_hap._tcp.local.", now + 1)
assert query_scheduler.millis_to_wait(now + 1) is None
await asyncio.wait_for(task, timeout=0.1)
assert task.done()

task2 = asyncio.ensure_future(query_scheduler.async_wait_ready(now + 10000))
assert set(query_scheduler.process_ready_types(now + 1)) == set(["_hap._tcp.local."])
assert not task2.done()
assert query_scheduler.millis_to_wait(now + 2) == delay - 2
query_scheduler.reschedule_type("_hap._tcp.local.", now + 2)
assert query_scheduler.millis_to_wait(now + 2) is None
await asyncio.wait_for(task2, timeout=0.1)
assert task2.done()
assert set(query_scheduler.process_ready_types(now + 10000)) == types_
assert query_scheduler.millis_to_wait(now + 10000) == delay * 2

task3 = asyncio.ensure_future(query_scheduler.async_wait_ready(now + 10000))
with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(task3, timeout=0.1)
145 changes: 98 additions & 47 deletions zeroconf/_services/browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
SignalRegistrationInterface,
)
from .._updates import RecordUpdate, RecordUpdateListener
from .._utils.aio import get_best_available_queue
from .._utils.aio import get_best_available_queue, wait_event_or_timeout
from .._utils.name import service_type_name
from .._utils.time import current_time_millis, millis_to_seconds
from ..const import (
Expand Down Expand Up @@ -183,6 +183,89 @@ def on_change(
return on_change


class QueryScheduler:
"""Schedule outgoing PTR queries for Continuous Multicast DNS Querying

https://datatracker.ietf.org/doc/html/rfc6762#section-5.2

"""

def __init__(
self,
types: Set[str],
delay: int,
first_random_delay_interval: Tuple[int, int],
):
self._schedule_changed_event: Optional[asyncio.Event] = None
self._types = types
self._next_time: Dict[str, float] = {}
self._first_random_delay_interval = first_random_delay_interval
self._delay: Dict[str, float] = {check_type_: delay for check_type_ in self._types}

def start(self, now: float) -> None:
"""Start the scheduler."""
self._schedule_changed_event = asyncio.Event()
self._generate_first_next_time(now)

def _generate_first_next_time(self, now: float) -> None:
"""Generate the initial next query times.

https://datatracker.ietf.org/doc/html/rfc6762#section-5.2
To avoid accidental synchronization when, for some reason, multiple
clients begin querying at exactly the same moment (e.g., because of
some common external trigger event), a Multicast DNS querier SHOULD
also delay the first query of the series by a randomly chosen amount
in the range 20-120 ms.
"""
delay = millis_to_seconds(random.randint(*self._first_random_delay_interval))
next_time = now + delay
self._next_time = {check_type_: next_time for check_type_ in self._types}

def millis_to_wait(self, now: float) -> Optional[float]:
"""Returns the number of milliseconds to wait for the next event."""
# Wait for the type has the smallest next time
next_time = min(self._next_time.values())
return None if next_time <= now else next_time - now

def reschedule_type(self, type_: str, next_time: float) -> None:
"""Reschedule the query for a type to happen sooner."""
if next_time >= self._next_time[type_]:
return

self._next_time[type_] = next_time
self.set_schedule_changed()

def set_schedule_changed(self) -> None:
"""Set the event to unblock async_wait_ready to make sure the adjusted next time is seen."""
assert self._schedule_changed_event is not None
self._schedule_changed_event.set()
self._schedule_changed_event.clear()

def process_ready_types(self, now: float) -> List[str]:
"""Generate a list of ready types that is due and schedule the next time."""
if self.millis_to_wait(now):
return []

ready_types: List[str] = []

for type_, due in self._next_time.items():
if due > now:
continue

ready_types.append(type_)
self._next_time[type_] = now + self._delay[type_]
self._delay[type_] = min(_BROWSER_BACKOFF_LIMIT * 1000, self._delay[type_] * 2)

return ready_types

async def async_wait_ready(self, now: float) -> None:
"""Wait for at least one query to be ready."""
timeout = self.millis_to_wait(now)
if timeout:
assert self._schedule_changed_event is not None
await wait_event_or_timeout(self._schedule_changed_event, timeout=millis_to_seconds(timeout))


class _ServiceBrowserBase(RecordUpdateListener):
"""Base class for ServiceBrowser."""

Expand Down Expand Up @@ -225,10 +308,9 @@ def __init__(
self.port = port
self.multicast = self.addr in (None, _MDNS_ADDR, _MDNS_ADDR6)
self.question_type = question_type
self._next_time: Dict[str, float] = {}
self._delay: Dict[str, float] = {check_type_: delay for check_type_ in self.types}
self._pending_handlers: OrderedDict[Tuple[str, str], ServiceStateChange] = OrderedDict()
self._service_state_changed = Signal()
self.query_scheduler = QueryScheduler(self.types, delay, _FIRST_QUERY_DELAY_RANDOM_INTERVAL)
self.queue: Optional[queue.Queue] = None
self.done = False

Expand All @@ -250,25 +332,11 @@ def _async_start(self) -> None:
Must be called by uses of this base class after they
have finished setting their properties.
"""
self._generate_first_next_time()
self.query_scheduler.start(current_time_millis())
self.zc.async_add_listener(self, [DNSQuestion(type_, _TYPE_PTR, _CLASS_IN) for type_ in self.types])
# Only start queries after the listener is installed
self._browser_task = cast(asyncio.Task, asyncio.ensure_future(self.async_browser_task()))

def _generate_first_next_time(self) -> None:
"""Generate the initial next query times.

https://datatracker.ietf.org/doc/html/rfc6762#section-5.2
To avoid accidental synchronization when, for some reason, multiple
clients begin querying at exactly the same moment (e.g., because of
some common external trigger event), a Multicast DNS querier SHOULD
also delay the first query of the series by a randomly chosen amount
in the range 20-120 ms.
"""
delay = millis_to_seconds(random.randint(*_FIRST_QUERY_DELAY_RANDOM_INTERVAL))
next_time = current_time_millis() + delay
self._next_time = {check_type_: next_time for check_type_ in self.types}

@property
def service_state_changed(self) -> SignalRegistrationInterface:
return self._service_state_changed.registration_interface
Expand Down Expand Up @@ -310,9 +378,9 @@ def _async_process_record_update(
elif expired:
self._enqueue_callback(ServiceStateChange.Removed, record.name, record.alias)
else:
expires = record.get_expiration_time(_EXPIRE_REFRESH_TIME_PERCENT)
if expires < self._next_time[record.name]:
self._next_time[record.name] = expires
self.query_scheduler.reschedule_type(
record.name, record.get_expiration_time(_EXPIRE_REFRESH_TIME_PERCENT)
)
return

# If its expired or already exists in the cache it cannot be updated.
Expand Down Expand Up @@ -385,47 +453,30 @@ def _async_cancel(self) -> None:
def _generate_ready_queries(self, first_request: bool) -> List[DNSOutgoing]:
"""Generate the service browser query for any type that is due."""
now = current_time_millis()
if self._millis_to_wait(current_time_millis()):
ready_types = self.query_scheduler.process_ready_types(now)
if not ready_types:
return []

ready_types = []

for type_, due in self._next_time.items():
if due > now:
continue

ready_types.append(type_)
self._next_time[type_] = now + self._delay[type_]
self._delay[type_] = min(_BROWSER_BACKOFF_LIMIT * 1000, self._delay[type_] * 2)

# If they did not specify and this is the first request, ask QU questions
# https://datatracker.ietf.org/doc/html/rfc6762#section-5.4 since we are
# just starting up and we know our cache is likely empty. This ensures
# the next outgoing will be sent with the known answers list.
question_type = DNSQuestionType.QU if not self.question_type and first_request else self.question_type

return generate_service_query(self.zc, now, ready_types, self.multicast, question_type)

def _millis_to_wait(self, now: float) -> Optional[float]:
"""Returns the number of milliseconds to wait for the next event."""
# Wait for the type has the smallest next time
next_time = min(self._next_time.values())
return None if next_time <= now else next_time - now

async def async_browser_task(self) -> None:
"""Run the browser task."""
await self.zc.async_wait_for_start()
first_request = True
while True:
timeout = self._millis_to_wait(current_time_millis())
if timeout:
await self.zc.async_wait(timeout)

await self.query_scheduler.async_wait_ready(current_time_millis())
outs = self._generate_ready_queries(first_request)
if outs:
first_request = False
for out in outs:
self.zc.async_send(out, addr=self.addr, port=self.port)
if not outs:
continue

first_request = False
for out in outs:
self.zc.async_send(out, addr=self.addr, port=self.port)

async def _async_cancel_browser(self) -> None:
"""Cancel the browser."""
Expand Down