Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/zeroconf/_handlers/query_handler.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@ from .._cache cimport DNSCache
from .._dns cimport DNSPointer, DNSQuestion, DNSRecord, DNSRRSet
from .._history cimport QuestionHistory
from .._protocol.incoming cimport DNSIncoming
from .._services.info cimport ServiceInfo
from .._services.registry cimport ServiceRegistry


cdef object TYPE_CHECKING, QuestionAnswers
cdef cython.uint _ONE_SECOND, _TYPE_PTR, _TYPE_ANY, _TYPE_A, _TYPE_AAAA, _TYPE_SRV, _TYPE_TXT
cdef str _SERVICE_TYPE_ENUMERATION_NAME
cdef cython.set _RESPOND_IMMEDIATE_TYPES
cdef cython.set _ADDRESS_RECORD_TYPES
cdef object IPVersion
cdef object _TYPE_PTR, _CLASS_IN, _DNS_OTHER_TTL

cdef class _QueryResponse:

Expand Down Expand Up @@ -45,13 +49,16 @@ cdef class QueryHandler:
cdef DNSCache cache
cdef QuestionHistory question_history

@cython.locals(service=ServiceInfo)
cdef _add_service_type_enumeration_query_answers(self, cython.dict answer_set, DNSRRSet known_answers)

@cython.locals(service=ServiceInfo)
cdef _add_pointer_answers(self, str lower_name, cython.dict answer_set, DNSRRSet known_answers)

@cython.locals(service=ServiceInfo)
cdef _add_address_answers(self, str lower_name, cython.dict answer_set, DNSRRSet known_answers, cython.uint type_)

@cython.locals(question_lower_name=str, type_=cython.uint)
@cython.locals(question_lower_name=str, type_=cython.uint, service=ServiceInfo)
cdef _answer_question(self, DNSQuestion question, DNSRRSet known_answers)

@cython.locals(
Expand Down
21 changes: 11 additions & 10 deletions src/zeroconf/_handlers/query_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .._history import QuestionHistory
from .._protocol.incoming import DNSIncoming
from .._services.registry import ServiceRegistry
from .._utils.net import IPVersion
from ..const import (
_ADDRESS_RECORD_TYPES,
_CLASS_IN,
Expand Down Expand Up @@ -180,13 +181,13 @@ def _add_pointer_answers(
for service in self.registry.async_get_infos_type(lower_name):
# Add recommended additional answers according to
# https://tools.ietf.org/html/rfc6763#section-12.1.
dns_pointer = service.dns_pointer()
dns_pointer = service._dns_pointer(None)
if known_answers.suppresses(dns_pointer):
continue
answer_set[dns_pointer] = {
service.dns_service(),
service.dns_text(),
} | service.get_address_and_nsec_records()
service._dns_service(None),
service._dns_text(None),
} | service._get_address_and_nsec_records(None)

def _add_address_answers(
self,
Expand All @@ -200,7 +201,7 @@ def _add_address_answers(
answers: List[DNSAddress] = []
additionals: Set[DNSRecord] = set()
seen_types: Set[int] = set()
for dns_address in service.dns_addresses():
for dns_address in service._dns_addresses(None, IPVersion.All):
seen_types.add(dns_address.type)
if dns_address.type != type_:
additionals.add(dns_address)
Expand All @@ -210,12 +211,12 @@ def _add_address_answers(
if answers:
if missing_types:
assert service.server is not None, "Service server must be set for NSEC record."
additionals.add(service.dns_nsec(list(missing_types)))
additionals.add(service._dns_nsec(list(missing_types), None))
for answer in answers:
answer_set[answer] = additionals
elif type_ in missing_types:
assert service.server is not None, "Service server must be set for NSEC record."
answer_set[service.dns_nsec(list(missing_types))] = set()
answer_set[service._dns_nsec(list(missing_types), None)] = set()

def _answer_question(
self,
Expand Down Expand Up @@ -243,11 +244,11 @@ def _answer_question(
if type_ in (_TYPE_SRV, _TYPE_ANY):
# Add recommended additional answers according to
# https://tools.ietf.org/html/rfc6763#section-12.2.
dns_service = service.dns_service()
dns_service = service._dns_service(None)
if not known_answers.suppresses(dns_service):
answer_set[dns_service] = service.get_address_and_nsec_records()
answer_set[dns_service] = service._get_address_and_nsec_records(None)
if type_ in (_TYPE_TXT, _TYPE_ANY):
dns_text = service.dns_text()
dns_text = service._dns_text(None)
if not known_answers.suppresses(dns_text):
answer_set[dns_text] = set()

Expand Down
21 changes: 20 additions & 1 deletion src/zeroconf/_services/info.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import cython

from .._cache cimport DNSCache
from .._dns cimport DNSPointer, DNSRecord, DNSService, DNSText
from .._dns cimport DNSNsec, DNSPointer, DNSRecord, DNSService, DNSText
from .._protocol.outgoing cimport DNSOutgoing
from .._updates cimport RecordUpdateListener
from .._utils.time cimport current_time_millis
Expand All @@ -27,6 +27,8 @@ cdef object DNS_QUESTION_TYPE_QM
cdef object _IPVersion_All_value
cdef object _IPVersion_V4Only_value

cdef cython.set _ADDRESS_RECORD_TYPES

cdef object TYPE_CHECKING

cdef class ServiceInfo(RecordUpdateListener):
Expand Down Expand Up @@ -85,3 +87,20 @@ cdef class ServiceInfo(RecordUpdateListener):
cdef cython.list _ip_addresses_by_version_value(self, object version_value)

cdef addresses_by_version(self, object version)

@cython.locals(cacheable=cython.bint)
cdef cython.list _dns_addresses(self, object override_ttls, object version)

@cython.locals(cacheable=cython.bint)
cdef DNSPointer _dns_pointer(self, object override_ttl)

@cython.locals(cacheable=cython.bint)
cdef DNSService _dns_service(self, object override_ttl)

@cython.locals(cacheable=cython.bint)
cdef DNSText _dns_text(self, object override_ttl)

cdef DNSNsec _dns_nsec(self, cython.list missing_types, object override_ttl)

@cython.locals(cacheable=cython.bint)
cdef cython.set _get_address_and_nsec_records(self, object override_ttl)
32 changes: 30 additions & 2 deletions src/zeroconf/_services/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,14 @@ def dns_addresses(
self,
override_ttl: Optional[int] = None,
version: IPVersion = IPVersion.All,
) -> List[DNSAddress]:
"""Return matching DNSAddress from ServiceInfo."""
return self._dns_addresses(override_ttl, version)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would have preferred not to add the _ functions but I don't want this to be a breaking change


def _dns_addresses(
self,
override_ttl: Optional[int],
version: IPVersion,
) -> List[DNSAddress]:
"""Return matching DNSAddress from ServiceInfo."""
cacheable = version is IPVersion.All and override_ttl is None
Expand All @@ -544,6 +552,10 @@ def dns_addresses(
return records

def dns_pointer(self, override_ttl: Optional[int] = None) -> DNSPointer:
"""Return DNSPointer from ServiceInfo."""
return self._dns_pointer(override_ttl)

def _dns_pointer(self, override_ttl: Optional[int]) -> DNSPointer:
"""Return DNSPointer from ServiceInfo."""
cacheable = override_ttl is None
if self._dns_pointer_cache is not None and cacheable:
Expand All @@ -561,6 +573,10 @@ def dns_pointer(self, override_ttl: Optional[int] = None) -> DNSPointer:
return record

def dns_service(self, override_ttl: Optional[int] = None) -> DNSService:
"""Return DNSService from ServiceInfo."""
return self._dns_service(override_ttl)

def _dns_service(self, override_ttl: Optional[int]) -> DNSService:
"""Return DNSService from ServiceInfo."""
cacheable = override_ttl is None
if self._dns_service_cache is not None and cacheable:
Expand All @@ -584,6 +600,10 @@ def dns_service(self, override_ttl: Optional[int] = None) -> DNSService:
return record

def dns_text(self, override_ttl: Optional[int] = None) -> DNSText:
"""Return DNSText from ServiceInfo."""
return self._dns_text(override_ttl)

def _dns_text(self, override_ttl: Optional[int]) -> DNSText:
"""Return DNSText from ServiceInfo."""
cacheable = override_ttl is None
if self._dns_text_cache is not None and cacheable:
Expand All @@ -601,6 +621,10 @@ def dns_text(self, override_ttl: Optional[int] = None) -> DNSText:
return record

def dns_nsec(self, missing_types: List[int], override_ttl: Optional[int] = None) -> DNSNsec:
"""Return DNSNsec from ServiceInfo."""
return self._dns_nsec(missing_types, override_ttl)

def _dns_nsec(self, missing_types: List[int], override_ttl: Optional[int]) -> DNSNsec:
"""Return DNSNsec from ServiceInfo."""
return DNSNsec(
self._name,
Expand All @@ -613,18 +637,22 @@ def dns_nsec(self, missing_types: List[int], override_ttl: Optional[int] = None)
)

def get_address_and_nsec_records(self, override_ttl: Optional[int] = None) -> Set[DNSRecord]:
"""Build a set of address records and NSEC records for non-present record types."""
return self._get_address_and_nsec_records(override_ttl)

def _get_address_and_nsec_records(self, override_ttl: Optional[int]) -> Set[DNSRecord]:
"""Build a set of address records and NSEC records for non-present record types."""
cacheable = override_ttl is None
if self._get_address_and_nsec_records_cache is not None and cacheable:
return self._get_address_and_nsec_records_cache
missing_types: Set[int] = _ADDRESS_RECORD_TYPES.copy()
records: Set[DNSRecord] = set()
for dns_address in self.dns_addresses(override_ttl, IPVersion.All):
for dns_address in self._dns_addresses(override_ttl, IPVersion.All):
missing_types.discard(dns_address.type)
records.add(dns_address)
if missing_types:
assert self.server is not None, "Service server must be set for NSEC record."
records.add(self.dns_nsec(list(missing_types), override_ttl))
records.add(self._dns_nsec(list(missing_types), override_ttl))
if cacheable:
self._get_address_and_nsec_records_cache = records
return records
Expand Down
14 changes: 8 additions & 6 deletions src/zeroconf/_services/registry.pxd
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@

import cython

from .info cimport ServiceInfo


cdef class ServiceRegistry:

Expand All @@ -11,16 +13,16 @@ cdef class ServiceRegistry:
@cython.locals(
record_list=cython.list,
)
cdef _async_get_by_index(self, cython.dict records, str key)
cdef cython.list _async_get_by_index(self, cython.dict records, str key)

cdef _add(self, object info)
cdef _add(self, ServiceInfo info)

cdef _remove(self, cython.list infos)

cpdef async_get_info_name(self, str name)
cpdef ServiceInfo async_get_info_name(self, str name)

cpdef async_get_types(self)
cpdef cython.list async_get_types(self)

cpdef async_get_infos_type(self, str type_)
cpdef cython.list async_get_infos_type(self, str type_)

cpdef async_get_infos_server(self, str server)
cpdef cython.list async_get_infos_server(self, str server)
15 changes: 15 additions & 0 deletions tests/services/test_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -1535,3 +1535,18 @@ async def test_release_wait_when_new_recorded_added_concurrency():
assert not pending
assert info.addresses == [b'\x7f\x00\x00\x01']
await aiozc.async_close()


@pytest.mark.asyncio
async def test_service_info_nsec_records():
"""Test we can generate nsec records from ServiceInfo."""
type_ = "_http._tcp.local."
registration_name = "multiareccon.%s" % type_
desc = {'path': '/~paulsm/'}
host = "multahostcon.local."
info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, host)
nsec_record = info.dns_nsec([const._TYPE_A, const._TYPE_AAAA], 50)
assert nsec_record.name == registration_name
assert nsec_record.type == const._TYPE_NSEC
assert nsec_record.ttl == 50
assert nsec_record.rdtypes == [const._TYPE_A, const._TYPE_AAAA]