Skip to content

Commit 8355c85

Browse files
committed
Limit the size of the packet that can be built
1 parent 5d9f40d commit 8355c85

2 files changed

Lines changed: 63 additions & 20 deletions

File tree

test_zeroconf.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT):
179179

180180
# create a bunch of servers
181181
type_ = "_my-service._tcp.local."
182-
server_count = 200
182+
server_count = 300
183183
records_per_server = 2
184184
for i in range(int(server_count / 10)):
185185
self.generate_many_hosts(zeroconf, type_, 10)
@@ -200,8 +200,6 @@ def on_service_state_change(zeroconf, service_type, state_change, name):
200200
zeroconf.close()
201201

202202
# now the browser has sent at least one request, verify the size
203-
# this assertion is not currently super useful, but the code above
204-
# exercise several code paths.
205203
assert longest_packet[0] < r._MAX_MSG_ABSOLUTE
206204

207205
def generate_many_hosts(self, zc, type_, number_hosts):

zeroconf.py

Lines changed: 62 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -751,42 +751,72 @@ class DNSOutgoing(object):
751751

752752
"""Object representation of an outgoing packet"""
753753

754-
def __init__(self, flags, multicast=True):
754+
def __init__(self, flags, multicast=True, build_on_fly=False):
755755
self.finished = False
756756
self.id = 0
757757
self.multicast = multicast
758758
self.flags = flags
759759
self.names = {}
760760
self.data = []
761761
self.size = 12
762+
self.build_on_fly = build_on_fly
763+
self.state = self.State.init
762764

763765
self.questions = []
764766
self.answers = []
765767
self.authorities = []
766768
self.additionals = []
767769

770+
class State(enum.Enum):
771+
init = 0
772+
adding_questions = 1
773+
adding_answers = 2
774+
adding_authoratives = 3
775+
adding_additionals = 4
776+
finished = 4
777+
778+
def set_state(self, state):
779+
if self.state != state:
780+
if self.state.value > state.value:
781+
raise Error('Out of order DNSOutgoing build %s -> %s' % (
782+
self.state.name, state.name))
783+
self.state = state
784+
return self.state != self.State.finished
785+
768786
def add_question(self, record):
769787
"""Adds a question"""
770788
self.questions.append(record)
789+
if self.build_on_fly:
790+
if self.set_state(self.State.adding_questions):
791+
self.write_question(record)
771792

772793
def add_answer(self, inp, record):
773794
"""Adds an answer"""
774795
if not record.suppressed_by(inp):
775796
self.add_answer_at_time(record, 0)
776797

777798
def add_answer_at_time(self, record, now):
778-
"""Adds an answer if if does not expire by a certain time"""
799+
"""Adds an answer if it does not expire by a certain time"""
779800
if record is not None:
780801
if now == 0 or not record.is_expired(now):
781802
self.answers.append((record, now))
803+
if self.build_on_fly:
804+
if self.set_state(self.State.adding_answers):
805+
self.write_record(record, now)
782806

783807
def add_authorative_answer(self, record):
784808
"""Adds an authoritative answer"""
785809
self.authorities.append(record)
810+
if self.build_on_fly:
811+
if self.set_state(self.State.adding_authoratives):
812+
self.write_record(record, 0)
786813

787814
def add_additional_answer(self, record):
788815
"""Adds an additional answer"""
789816
self.additionals.append(record)
817+
if self.build_on_fly:
818+
if self.set_state(self.State.adding_additionals):
819+
self.write_record(record, 0)
790820

791821
def pack(self, format_, value):
792822
self.data.append(struct.pack(format_, value))
@@ -887,6 +917,7 @@ def write_question(self, question):
887917
def write_record(self, record, now):
888918
"""Writes a record (answer, authoritative answer, additional) to
889919
the packet"""
920+
start_data_length, start_size = len(self.data), self.size
890921
self.write_name(record.name)
891922
self.write_short(record.type)
892923
if record.unique and self.multicast:
@@ -898,30 +929,42 @@ def write_record(self, record, now):
898929
else:
899930
self.write_int(record.get_remaining_ttl(now))
900931
index = len(self.data)
932+
901933
# Adjust size for the short we will write before this record
902-
#
903934
self.size += 2
904935
record.write(self)
905936
self.size -= 2
906937

907-
length = len(b''.join(self.data[index:]))
908-
self.insert_short(index, length) # Here is the short we adjusted for
938+
length = sum((len(d) for d in self.data[index:]))
939+
# Here is the short we adjusted for
940+
self.insert_short(index, length)
941+
942+
# if we go over, then rollback and quit
943+
if self.size > _MAX_MSG_ABSOLUTE:
944+
while len(self.data) > start_data_length:
945+
self.data.pop()
946+
self.size = start_size
947+
self.state = self.State.finished
909948

910949
def packet(self):
911950
"""Returns a string containing the packet's bytes
912951
913952
No further parts should be added to the packet once this
914953
is done."""
915-
if not self.finished:
916-
self.finished = True
917-
for question in self.questions:
918-
self.write_question(question)
919-
for answer, time_ in self.answers:
920-
self.write_record(answer, time_)
921-
for authority in self.authorities:
922-
self.write_record(authority, 0)
923-
for additional in self.additionals:
924-
self.write_record(additional, 0)
954+
if self.state != self.State.finished:
955+
if not self.build_on_fly:
956+
for question in self.questions:
957+
self.write_question(question)
958+
for answer, time_ in self.answers:
959+
if self.state != self.State.finished:
960+
self.write_record(answer, time_)
961+
for authority in self.authorities:
962+
if self.state != self.State.finished:
963+
self.write_record(authority, 0)
964+
for additional in self.additionals:
965+
if self.state != self.State.finished:
966+
self.write_record(additional, 0)
967+
self.state = self.State.finished
925968

926969
self.insert_short(0, len(self.additionals))
927970
self.insert_short(0, len(self.authorities))
@@ -1240,13 +1283,15 @@ def run(self):
12401283
if self.zc.done or self.done:
12411284
return
12421285
now = current_time_millis()
1243-
12441286
if self.next_time <= now:
1245-
out = DNSOutgoing(_FLAGS_QR_QUERY)
1287+
out = DNSOutgoing(_FLAGS_QR_QUERY, build_on_fly=True)
12461288
out.add_question(DNSQuestion(self.type, _TYPE_PTR, _CLASS_IN))
12471289
for record in self.services.values():
12481290
if not record.is_expired(now):
12491291
out.add_answer_at_time(record, now)
1292+
if out.state == out.State.finished:
1293+
break
1294+
12501295
self.zc.send(out)
12511296
self.next_time = now + self.delay
12521297
self.delay = min(20 * 1000, self.delay * 2)

0 commit comments

Comments
 (0)