Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/zeroconf/_listener.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import cython

from ._handlers.record_manager cimport RecordManager
from ._protocol.incoming cimport DNSIncoming
from ._services.registry cimport ServiceRegistry
from ._utils.time cimport current_time_millis, millis_to_seconds


Expand All @@ -18,6 +19,7 @@ cdef cython.uint _DUPLICATE_PACKET_SUPPRESSION_INTERVAL
cdef class AsyncListener:

cdef public object zc
cdef ServiceRegistry _registry
cdef RecordManager _record_manager
cdef public cython.bytes data
cdef public cython.float last_time
Expand All @@ -34,3 +36,12 @@ cdef class AsyncListener:
cpdef _process_datagram_at_time(self, bint debug, cython.uint data_len, cython.float now, bytes data, cython.tuple addrs)

cdef _cancel_any_timers_for_addr(self, object addr)

cpdef handle_query_or_defer(
self,
DNSIncoming msg,
object addr,
object port,
object transport,
tuple v6_flow_scope
)
12 changes: 9 additions & 3 deletions src/zeroconf/_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class AsyncListener:

__slots__ = (
'zc',
'_registry',
'_record_manager',
'data',
'last_time',
Expand All @@ -69,6 +70,7 @@ class AsyncListener:

def __init__(self, zc: 'Zeroconf') -> None:
self.zc = zc
self._registry = zc.registry
self._record_manager = zc.record_manager
self.data: Optional[bytes] = None
self.last_time: float = 0
Expand Down Expand Up @@ -171,17 +173,21 @@ def _process_datagram_at_time(
self._record_manager.async_updates_from_response(msg)
return

if not self._registry.has_entries:
# If the registry is empty, we have no answers to give.
return

if TYPE_CHECKING:
assert self.transport is not None
self.handle_query_or_defer(msg, addr, port, self.transport, v6_flow_scope)

def handle_query_or_defer(
self,
msg: DNSIncoming,
addr: str,
port: int,
addr: _str,
port: _int,
transport: _WrappedTransport,
v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (),
v6_flow_scope: Union[Tuple[()], Tuple[int, int]],
) -> None:
"""Deal with incoming query packets. Provides a response if
possible."""
Expand Down
5 changes: 5 additions & 0 deletions src/zeroconf/_services/registry.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ cdef class ServiceRegistry:
cdef cython.dict _services
cdef public cython.dict types
cdef public cython.dict servers
cdef public bint has_entries

@cython.locals(
record_list=cython.list,
Expand All @@ -17,6 +18,10 @@ cdef class ServiceRegistry:

cdef _add(self, ServiceInfo info)

@cython.locals(
info=ServiceInfo,
old_service_info=ServiceInfo
)
cdef _remove(self, cython.list infos)

cpdef ServiceInfo async_get_info_name(self, str name)
Expand Down
10 changes: 7 additions & 3 deletions src/zeroconf/_services/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class ServiceRegistry:
the event loop as it is not thread safe.
"""

__slots__ = ("_services", "types", "servers")
__slots__ = ("_services", "types", "servers", "has_entries")

def __init__(
self,
Expand All @@ -44,6 +44,7 @@ def __init__(
self._services: Dict[str, ServiceInfo] = {}
self.types: Dict[str, List] = {}
self.servers: Dict[str, List] = {}
self.has_entries: bool = False

def async_add(self, info: ServiceInfo) -> None:
"""Add a new service to the registry."""
Expand Down Expand Up @@ -95,14 +96,17 @@ def _add(self, info: ServiceInfo) -> None:
self._services[info.key] = info
self.types.setdefault(info.type.lower(), []).append(info.key)
self.servers.setdefault(info.server_key, []).append(info.key)
self.has_entries = True

def _remove(self, infos: List[ServiceInfo]) -> None:
"""Remove a services under the lock."""
for info in infos:
if info.key not in self._services:
old_service_info = self._services.get(info.key)
if old_service_info is None:
continue
old_service_info = self._services[info.key]
assert old_service_info.server_key is not None
self.types[old_service_info.type.lower()].remove(info.key)
self.servers[old_service_info.server_key].remove(info.key)
del self._services[info.key]

self.has_entries = bool(self._services)
39 changes: 24 additions & 15 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,10 @@
import time
import unittest
import unittest.mock
from typing import cast
from unittest.mock import patch
from typing import Tuple, Union, cast
from unittest.mock import Mock, patch

if sys.version_info[:3][1] < 8:
from unittest.mock import Mock

AsyncMock = Mock
else:
from unittest.mock import AsyncMock
Expand All @@ -26,6 +24,8 @@

import zeroconf as r
from zeroconf import NotRunningException, Zeroconf, const, current_time_millis
from zeroconf._listener import AsyncListener, _WrappedTransport
from zeroconf._protocol.incoming import DNSIncoming
from zeroconf.asyncio import AsyncZeroconf

from . import _clear_cache, _inject_response, _wait_for_start, has_working_ipv6
Expand All @@ -45,10 +45,19 @@ def teardown_module():
log.setLevel(original_logging_level)


def threadsafe_query(zc, protocol, *args):
def threadsafe_query(
zc: 'Zeroconf',
protocol: 'AsyncListener',
msg: DNSIncoming,
addr: str,
port: int,
transport: _WrappedTransport,
v6_flow_scope: Union[Tuple[()], Tuple[int, int]],
) -> None:
async def make_query():
protocol.handle_query_or_defer(*args)
protocol.handle_query_or_defer(msg, addr, port, transport, v6_flow_scope)

assert zc.loop is not None
asyncio.run_coroutine_threadsafe(make_query(), zc.loop).result()


Expand Down Expand Up @@ -476,28 +485,28 @@ def test_tc_bit_defers():

next_packet = r.DNSIncoming(packets.pop(0))
expected_deferred.append(next_packet)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ())
assert protocol._deferred[source_ip] == expected_deferred
assert source_ip in protocol._timers

next_packet = r.DNSIncoming(packets.pop(0))
expected_deferred.append(next_packet)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ())
assert protocol._deferred[source_ip] == expected_deferred
assert source_ip in protocol._timers
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ())
assert protocol._deferred[source_ip] == expected_deferred
assert source_ip in protocol._timers

next_packet = r.DNSIncoming(packets.pop(0))
expected_deferred.append(next_packet)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ())
assert protocol._deferred[source_ip] == expected_deferred
assert source_ip in protocol._timers

next_packet = r.DNSIncoming(packets.pop(0))
expected_deferred.append(next_packet)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ())
assert source_ip not in protocol._deferred
assert source_ip not in protocol._timers

Expand Down Expand Up @@ -555,20 +564,20 @@ def test_tc_bit_defers_last_response_missing():

next_packet = r.DNSIncoming(packets.pop(0))
expected_deferred.append(next_packet)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ())
assert protocol._deferred[source_ip] == expected_deferred
timer1 = protocol._timers[source_ip]

next_packet = r.DNSIncoming(packets.pop(0))
expected_deferred.append(next_packet)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ())
assert protocol._deferred[source_ip] == expected_deferred
timer2 = protocol._timers[source_ip]
assert timer1.cancelled()
assert timer2 != timer1

# Send the same packet again to similar multi interfaces
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ())
assert protocol._deferred[source_ip] == expected_deferred
assert source_ip in protocol._timers
timer3 = protocol._timers[source_ip]
Expand All @@ -577,7 +586,7 @@ def test_tc_bit_defers_last_response_missing():

next_packet = r.DNSIncoming(packets.pop(0))
expected_deferred.append(next_packet)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ())
assert protocol._deferred[source_ip] == expected_deferred
assert source_ip in protocol._timers
timer4 = protocol._timers[source_ip]
Expand Down
12 changes: 11 additions & 1 deletion tests/test_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@
from unittest.mock import MagicMock, patch

import zeroconf as r
from zeroconf import Zeroconf, _engine, _listener, const, current_time_millis
from zeroconf import (
ServiceInfo,
Zeroconf,
_engine,
_listener,
const,
current_time_millis,
)
from zeroconf._protocol import outgoing
from zeroconf._protocol.incoming import DNSIncoming

Expand Down Expand Up @@ -125,6 +132,9 @@ def test_guard_against_duplicate_packets():
These packets can quickly overwhelm the system.
"""
zc = Zeroconf(interfaces=['127.0.0.1'])
zc.registry.async_add(
ServiceInfo("_http._tcp.local.", "Test._http._tcp.local.", server="Test._http._tcp.local.", port=4)
)
zc.question_history = QuestionHistoryWithoutSuppression()

class SubListener(_listener.AsyncListener):
Expand Down