Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/zeroconf/_dns.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ cdef class DNSNsec(DNSRecord):

cdef class DNSRRSet:

cdef cython.list _record_sets
cdef cython.list _records
cdef cython.dict _lookup

@cython.locals(other=DNSRecord)
Expand Down
25 changes: 11 additions & 14 deletions src/zeroconf/_dns.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def __eq__(self, other: Any) -> bool: # pylint: disable=no-self-use
def suppressed_by(self, msg: 'DNSIncoming') -> bool:
"""Returns true if any answer in a message can suffice for the
information held in this record."""
answers = msg.answers
answers = msg.answers()
for record in answers:
if self._suppressed_by_answer(record):
return True
Expand Down Expand Up @@ -521,37 +521,34 @@ def __repr__(self) -> str:
class DNSRRSet:
"""A set of dns records with a lookup to get the ttl."""

__slots__ = ('_record_sets', '_lookup')
__slots__ = ('_records', '_lookup')

def __init__(self, record_sets: List[List[DNSRecord]]) -> None:
def __init__(self, records: List[DNSRecord]) -> None:
"""Create an RRset from records sets."""
self._record_sets = record_sets
self._lookup: Optional[Dict[DNSRecord, float]] = None
self._records = records
self._lookup: Optional[Dict[DNSRecord, DNSRecord]] = None

@property
def lookup(self) -> Dict[DNSRecord, float]:
def lookup(self) -> Dict[DNSRecord, DNSRecord]:
"""Return the lookup table."""
return self._get_lookup()

def lookup_set(self) -> Set[DNSRecord]:
"""Return the lookup table as aset."""
return set(self._get_lookup())

def _get_lookup(self) -> Dict[DNSRecord, float]:
def _get_lookup(self) -> Dict[DNSRecord, DNSRecord]:
"""Return the lookup table, building it if needed."""
if self._lookup is None:
# Build the hash table so we can lookup the record ttl
self._lookup = {}
for record_sets in self._record_sets:
for record in record_sets:
self._lookup[record] = record.ttl
self._lookup = {record: record for record in self._records}
return self._lookup

def suppresses(self, record: _DNSRecord) -> bool:
"""Returns true if any answer in the rrset can suffice for the
information held in this record."""
lookup = self._get_lookup()
other_ttl = lookup.get(record)
if other_ttl is None:
other = lookup.get(record)
if other is None:
return False
return other_ttl > (record.ttl / 2)
return other.ttl > (record.ttl / 2)
14 changes: 8 additions & 6 deletions src/zeroconf/_handlers/query_handler.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ cdef object _TYPE_PTR, _CLASS_IN, _DNS_OTHER_TTL
cdef class _QueryResponse:

cdef bint _is_probe
cdef DNSIncoming _msg
cdef cython.list _questions
cdef float _now
cdef DNSCache _cache
cdef cython.dict _additionals
Expand All @@ -31,20 +31,20 @@ cdef class _QueryResponse:
cdef cython.set _mcast_aggregate_last_second

@cython.locals(record=DNSRecord)
cpdef add_qu_question_response(self, cython.dict answers)
cdef add_qu_question_response(self, cython.dict answers)

cpdef add_ucast_question_response(self, cython.dict answers)
cdef add_ucast_question_response(self, cython.dict answers)

@cython.locals(answer=DNSRecord)
cpdef add_mcast_question_response(self, cython.dict answers)
@cython.locals(answer=DNSRecord, question=DNSQuestion)
cdef add_mcast_question_response(self, cython.dict answers)

@cython.locals(maybe_entry=DNSRecord)
cdef bint _has_mcast_within_one_quarter_ttl(self, DNSRecord record)

@cython.locals(maybe_entry=DNSRecord)
cdef bint _has_mcast_record_in_last_second(self, DNSRecord record)

cpdef answers(self)
cdef QuestionAnswers answers(self)

cdef class QueryHandler:

Expand All @@ -70,5 +70,7 @@ cdef class QueryHandler:
answer_set=cython.dict,
known_answers=DNSRRSet,
known_answers_set=cython.set,
is_probe=object,
now=object
)
cpdef async_response(self, cython.list msgs, cython.bint unicast_source)
43 changes: 27 additions & 16 deletions src/zeroconf/_handlers/query_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class _QueryResponse:

__slots__ = (
"_is_probe",
"_msg",
"_questions",
"_now",
"_cache",
"_additionals",
Expand All @@ -65,15 +65,11 @@ class _QueryResponse:
"_mcast_aggregate_last_second",
)

def __init__(self, cache: DNSCache, msgs: List[DNSIncoming]) -> None:
def __init__(self, cache: DNSCache, questions: List[DNSQuestion], is_probe: bool, now: float) -> None:
"""Build a query response."""
self._is_probe = False
for msg in msgs:
if msg.is_probe:
self._is_probe = True
break
self._msg = msgs[0]
self._now = self._msg.now
self._is_probe = is_probe
self._questions = questions
self._now = now
self._cache = cache
self._additionals: _AnswerWithAdditionalsType = {}
self._ucast: Set[DNSRecord] = set()
Expand Down Expand Up @@ -107,10 +103,15 @@ def add_mcast_question_response(self, answers: _AnswerWithAdditionalsType) -> No

if self._has_mcast_record_in_last_second(answer):
self._mcast_aggregate_last_second.add(answer)
elif len(self._msg.questions) == 1 and self._msg.questions[0].type in _RESPOND_IMMEDIATE_TYPES:
self._mcast_now.add(answer)
else:
self._mcast_aggregate.add(answer)
continue

if len(self._questions) == 1:
question = self._questions[0]
if question.type in _RESPOND_IMMEDIATE_TYPES:
self._mcast_now.add(answer)
continue

self._mcast_aggregate.add(answer)

def answers(
self,
Expand Down Expand Up @@ -262,16 +263,26 @@ def async_response( # pylint: disable=unused-argument
This function must be run in the event loop as it is not
threadsafe.
"""
known_answers = DNSRRSet([msg.answers for msg in msgs if not msg.is_probe])
query_res = _QueryResponse(self.cache, msgs)
answers: List[DNSRecord] = []
is_probe = False
msg = msgs[0]
questions = msg.questions
now = msg.now
for msg in msgs:
if not msg.is_probe():
answers.extend(msg.answers())
else:
is_probe = True
known_answers = DNSRRSet(answers)
query_res = _QueryResponse(self.cache, questions, is_probe, now)
known_answers_set: Optional[Set[DNSRecord]] = None

for msg in msgs:
for question in msg.questions:
if not question.unique: # unique and unicast are the same flag
if not known_answers_set: # pragma: no branch
known_answers_set = known_answers.lookup_set()
self.question_history.add_question_at_time(question, msg.now, known_answers_set)
self.question_history.add_question_at_time(question, now, known_answers_set)
answer_set = self._answer_question(question, known_answers)
if not ucast_source and question.unique: # unique and unicast are the same flag
query_res.add_qu_question_response(answer_set)
Expand Down
2 changes: 1 addition & 1 deletion src/zeroconf/_handlers/record_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None:
now_float = now
unique_types: Set[Tuple[str, int, int]] = set()
cache = self.cache
answers = msg.answers
answers = msg.answers()

for record in answers:
# Protect zeroconf from records that can cause denial of service.
Expand Down
4 changes: 4 additions & 0 deletions src/zeroconf/_protocol/incoming.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ cdef class DNSIncoming:

cpdef is_query(self)

cpdef is_probe(self)

cpdef answers(self)

cpdef is_response(self)

@cython.locals(
Expand Down
4 changes: 1 addition & 3 deletions src/zeroconf/_protocol/incoming.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ def _log_exception_debug(cls, *logger_data: Any) -> None:
log_exc_info = True
log.debug(*(logger_data or ['Exception occurred']), exc_info=log_exc_info)

@property
def answers(self) -> List[DNSRecord]:
"""Answers in the packet."""
if not self._did_read_others:
Expand All @@ -187,7 +186,6 @@ def answers(self) -> List[DNSRecord]:
)
return self._answers

@property
def is_probe(self) -> bool:
"""Returns true if this is a probe."""
return self.num_authorities > 0
Expand All @@ -203,7 +201,7 @@ def __repr__(self) -> str:
'n_auth=%s' % self.num_authorities,
'n_add=%s' % self.num_additionals,
'questions=%s' % self.questions,
'answers=%s' % self.answers,
'answers=%s' % self.answers(),
]
)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,7 +997,7 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=()):
"""Sends an outgoing packet."""
pout = DNSIncoming(out.packets()[0])
nonlocal nbr_answers
for answer in pout.answers:
for answer in pout.answers():
nbr_answers += 1
if not answer.ttl > expected_ttl / 2:
unexpected_ttl.set()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_dns.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def test_rrset_does_not_consider_ttl():
longaaaarec = r.DNSAddress('irrelevant', const._TYPE_AAAA, const._CLASS_IN, 100, b'same')
shortaaaarec = r.DNSAddress('irrelevant', const._TYPE_AAAA, const._CLASS_IN, 10, b'same')

rrset = DNSRRSet([[longarec, shortaaaarec]])
rrset = DNSRRSet([longarec, shortaaaarec])

assert rrset.suppresses(longarec)
assert rrset.suppresses(shortarec)
Expand All @@ -404,7 +404,7 @@ def test_rrset_does_not_consider_ttl():
mediumarec = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 60, b'same')
shortarec = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 10, b'same')

rrset2 = DNSRRSet([[mediumarec]])
rrset2 = DNSRRSet([mediumarec])
assert not rrset2.suppresses(verylongarec)
assert rrset2.suppresses(longarec)
assert rrset2.suppresses(mediumarec)
Expand Down
14 changes: 7 additions & 7 deletions tests/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1425,8 +1425,8 @@ async def test_response_aggregation_timings(run_isolated):
outgoing = send_mock.call_args[0][0]
incoming = r.DNSIncoming(outgoing.packets()[0])
zc.record_manager.async_updates_from_response(incoming)
assert info.dns_pointer() in incoming.answers
assert info2.dns_pointer() in incoming.answers
assert info.dns_pointer() in incoming.answers()
assert info2.dns_pointer() in incoming.answers()
send_mock.reset_mock()

protocol.datagram_received(query3.packets()[0], ('127.0.0.1', const._MDNS_PORT))
Expand All @@ -1439,7 +1439,7 @@ async def test_response_aggregation_timings(run_isolated):
outgoing = send_mock.call_args[0][0]
incoming = r.DNSIncoming(outgoing.packets()[0])
zc.record_manager.async_updates_from_response(incoming)
assert info3.dns_pointer() in incoming.answers
assert info3.dns_pointer() in incoming.answers()
send_mock.reset_mock()

# Because the response was sent in the last second we need to make
Expand All @@ -1461,7 +1461,7 @@ async def test_response_aggregation_timings(run_isolated):
assert len(calls) == 1
outgoing = send_mock.call_args[0][0]
incoming = r.DNSIncoming(outgoing.packets()[0])
assert info.dns_pointer() in incoming.answers
assert info.dns_pointer() in incoming.answers()

await aiozc.async_close()

Expand Down Expand Up @@ -1501,7 +1501,7 @@ async def test_response_aggregation_timings_multiple(run_isolated, disable_dupli
outgoing = send_mock.call_args[0][0]
incoming = r.DNSIncoming(outgoing.packets()[0])
zc.record_manager.async_updates_from_response(incoming)
assert info2.dns_pointer() in incoming.answers
assert info2.dns_pointer() in incoming.answers()

send_mock.reset_mock()
protocol.datagram_received(query2.packets()[0], ('127.0.0.1', const._MDNS_PORT))
Expand All @@ -1511,7 +1511,7 @@ async def test_response_aggregation_timings_multiple(run_isolated, disable_dupli
outgoing = send_mock.call_args[0][0]
incoming = r.DNSIncoming(outgoing.packets()[0])
zc.record_manager.async_updates_from_response(incoming)
assert info2.dns_pointer() in incoming.answers
assert info2.dns_pointer() in incoming.answers()

send_mock.reset_mock()
protocol.datagram_received(query2.packets()[0], ('127.0.0.1', const._MDNS_PORT))
Expand All @@ -1534,7 +1534,7 @@ async def test_response_aggregation_timings_multiple(run_isolated, disable_dupli
outgoing = send_mock.call_args[0][0]
incoming = r.DNSIncoming(outgoing.packets()[0])
zc.record_manager.async_updates_from_response(incoming)
assert info2.dns_pointer() in incoming.answers
assert info2.dns_pointer() in incoming.answers()


@pytest.mark.asyncio
Expand Down
Loading