@@ -751,42 +751,72 @@ class DNSOutgoing(object):
751751
752752 """Object representation of an outgoing packet"""
753753
754- def __init__ (self , flags , multicast = True ):
754+ def __init__ (self , flags , multicast = True , build_on_fly = False ):
755755 self .finished = False
756756 self .id = 0
757757 self .multicast = multicast
758758 self .flags = flags
759759 self .names = {}
760760 self .data = []
761761 self .size = 12
762+ self .build_on_fly = build_on_fly
763+ self .state = self .State .init
762764
763765 self .questions = []
764766 self .answers = []
765767 self .authorities = []
766768 self .additionals = []
767769
770+ class State (enum .Enum ):
771+ init = 0
772+ adding_questions = 1
773+ adding_answers = 2
774+ adding_authoratives = 3
775+ adding_additionals = 4
776+ finished = 4
777+
778+ def set_state (self , state ):
779+ if self .state != state :
780+ if self .state .value > state .value :
781+ raise Error ('Out of order DNSOutgoing build %s -> %s' % (
782+ self .state .name , state .name ))
783+ self .state = state
784+ return self .state != self .State .finished
785+
768786 def add_question (self , record ):
769787 """Adds a question"""
770788 self .questions .append (record )
789+ if self .build_on_fly :
790+ if self .set_state (self .State .adding_questions ):
791+ self .write_question (record )
771792
772793 def add_answer (self , inp , record ):
773794 """Adds an answer"""
774795 if not record .suppressed_by (inp ):
775796 self .add_answer_at_time (record , 0 )
776797
777798 def add_answer_at_time (self , record , now ):
778- """Adds an answer if if does not expire by a certain time"""
799+ """Adds an answer if it does not expire by a certain time"""
779800 if record is not None :
780801 if now == 0 or not record .is_expired (now ):
781802 self .answers .append ((record , now ))
803+ if self .build_on_fly :
804+ if self .set_state (self .State .adding_answers ):
805+ self .write_record (record , now )
782806
783807 def add_authorative_answer (self , record ):
784808 """Adds an authoritative answer"""
785809 self .authorities .append (record )
810+ if self .build_on_fly :
811+ if self .set_state (self .State .adding_authoratives ):
812+ self .write_record (record , 0 )
786813
787814 def add_additional_answer (self , record ):
788815 """Adds an additional answer"""
789816 self .additionals .append (record )
817+ if self .build_on_fly :
818+ if self .set_state (self .State .adding_additionals ):
819+ self .write_record (record , 0 )
790820
791821 def pack (self , format_ , value ):
792822 self .data .append (struct .pack (format_ , value ))
@@ -887,6 +917,7 @@ def write_question(self, question):
887917 def write_record (self , record , now ):
888918 """Writes a record (answer, authoritative answer, additional) to
889919 the packet"""
920+ start_data_length , start_size = len (self .data ), self .size
890921 self .write_name (record .name )
891922 self .write_short (record .type )
892923 if record .unique and self .multicast :
@@ -898,30 +929,42 @@ def write_record(self, record, now):
898929 else :
899930 self .write_int (record .get_remaining_ttl (now ))
900931 index = len (self .data )
932+
901933 # Adjust size for the short we will write before this record
902- #
903934 self .size += 2
904935 record .write (self )
905936 self .size -= 2
906937
907- length = len (b'' .join (self .data [index :]))
908- self .insert_short (index , length ) # Here is the short we adjusted for
938+ length = sum ((len (d ) for d in self .data [index :]))
939+ # Here is the short we adjusted for
940+ self .insert_short (index , length )
941+
942+ # if we go over, then rollback and quit
943+ if self .size > _MAX_MSG_ABSOLUTE :
944+ while len (self .data ) > start_data_length :
945+ self .data .pop ()
946+ self .size = start_size
947+ self .state = self .State .finished
909948
910949 def packet (self ):
911950 """Returns a string containing the packet's bytes
912951
913952 No further parts should be added to the packet once this
914953 is done."""
915- if not self .finished :
916- self .finished = True
917- for question in self .questions :
918- self .write_question (question )
919- for answer , time_ in self .answers :
920- self .write_record (answer , time_ )
921- for authority in self .authorities :
922- self .write_record (authority , 0 )
923- for additional in self .additionals :
924- self .write_record (additional , 0 )
954+ if self .state != self .State .finished :
955+ if not self .build_on_fly :
956+ for question in self .questions :
957+ self .write_question (question )
958+ for answer , time_ in self .answers :
959+ if self .state != self .State .finished :
960+ self .write_record (answer , time_ )
961+ for authority in self .authorities :
962+ if self .state != self .State .finished :
963+ self .write_record (authority , 0 )
964+ for additional in self .additionals :
965+ if self .state != self .State .finished :
966+ self .write_record (additional , 0 )
967+ self .state = self .State .finished
925968
926969 self .insert_short (0 , len (self .additionals ))
927970 self .insert_short (0 , len (self .authorities ))
@@ -1240,13 +1283,15 @@ def run(self):
12401283 if self .zc .done or self .done :
12411284 return
12421285 now = current_time_millis ()
1243-
12441286 if self .next_time <= now :
1245- out = DNSOutgoing (_FLAGS_QR_QUERY )
1287+ out = DNSOutgoing (_FLAGS_QR_QUERY , build_on_fly = True )
12461288 out .add_question (DNSQuestion (self .type , _TYPE_PTR , _CLASS_IN ))
12471289 for record in self .services .values ():
12481290 if not record .is_expired (now ):
12491291 out .add_answer_at_time (record , now )
1292+ if out .state == out .State .finished :
1293+ break
1294+
12501295 self .zc .send (out )
12511296 self .next_time = now + self .delay
12521297 self .delay = min (20 * 1000 , self .delay * 2 )
0 commit comments