2727from threading import Thread , Event , RLock , Condition
2828import time
2929import ssl
30+ import weakref
31+
3032
3133if 'gevent.monkey' in sys .modules :
3234 from gevent .queue import Queue , Empty
@@ -610,6 +612,55 @@ def int_from_buf_item(i):
610612 int_from_buf_item = ord
611613
612614
615+ class _ConnectionIOBuffer (object ):
616+ """
617+ Abstraction class to ease the use of the different connection io buffers. With
618+ protocol V5 and checksumming, the data is read, validated and copied to another
619+ cql frame buffer.
620+ """
621+ _io_buffer = None
622+ _cql_frame_buffer = None
623+ _connection = None
624+
625+ def __init__ (self , connection ):
626+ self ._io_buffer = io .BytesIO ()
627+ self ._connection = weakref .proxy (connection )
628+
629+ @property
630+ def io_buffer (self ):
631+ return self ._io_buffer
632+
633+ @property
634+ def cql_frame_buffer (self ):
635+ return self ._cql_frame_buffer if self .is_checksumming_enabled else \
636+ self ._io_buffer
637+
638+ def set_checksumming_buffer (self ):
639+ self .reset_io_buffer ()
640+ self ._cql_frame_buffer = io .BytesIO ()
641+
642+ @property
643+ def is_checksumming_enabled (self ):
644+ return self ._connection ._is_checksumming_enabled
645+
646+ def readable_io_bytes (self ):
647+ return self .io_buffer .tell ()
648+
649+ def readable_cql_frame_bytes (self ):
650+ return self .cql_frame_buffer .tell ()
651+
652+ def reset_io_buffer (self ):
653+ self ._io_buffer = io .BytesIO (self ._io_buffer .read ())
654+ self ._io_buffer .seek (0 , 2 ) # 2 == SEEK_END
655+
656+ def reset_cql_frame_buffer (self ):
657+ if self .is_checksumming_enabled :
658+ self ._cql_frame_buffer = io .BytesIO (self ._cql_frame_buffer .read ())
659+ self ._cql_frame_buffer .seek (0 , 2 ) # 2 == SEEK_END
660+ else :
661+ self .reset_io_buffer ()
662+
663+
613664class Connection (object ):
614665
615666 CALLBACK_ERR_THREAD_THRESHOLD = 100
@@ -665,8 +716,6 @@ class Connection(object):
665716
666717 allow_beta_protocol_version = False
667718
668- _iobuf = None
669- _frame_iobuf = None
670719 _current_frame = None
671720
672721 _socket = None
@@ -679,6 +728,11 @@ class Connection(object):
679728
680729 _is_checksumming_enabled = False
681730
731+ @property
732+ def _iobuf (self ):
733+ # backward compatibility, to avoid any change in the reactors
734+ return self ._io_buffer .io_buffer
735+
682736 def __init__ (self , host = '127.0.0.1' , port = 9042 , authenticator = None ,
683737 ssl_options = None , sockopts = None , compression = True ,
684738 cql_version = None , protocol_version = ProtocolVersion .MAX_SUPPORTED , is_control_connection = False ,
@@ -702,8 +756,7 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
702756 self .no_compact = no_compact
703757 self ._push_watchers = defaultdict (set )
704758 self ._requests = {}
705- self ._iobuf = io .BytesIO ()
706- self ._frame_iobuf = io .BytesIO ()
759+ self ._io_buffer = _ConnectionIOBuffer (self )
707760 self ._continuous_paging_sessions = {}
708761 self ._socket_writable = True
709762
@@ -844,6 +897,12 @@ def _connect_socket(self):
844897 for args in self .sockopts :
845898 self ._socket .setsockopt (* args )
846899
900+ def _enable_checksumming (self ):
901+ self ._io_buffer .set_checksumming_buffer ()
902+ self ._is_checksumming_enabled = True
903+ self ._segment_codec = segment_codec_lz4 if self .compressor else segment_codec_no_compression
904+ log .debug ("Enabling protocol checksumming on connection (%s)." , id (self ))
905+
847906 def close (self ):
848907 raise NotImplementedError ()
849908
@@ -1032,7 +1091,7 @@ def control_conn_disposed(self):
10321091
10331092 @defunct_on_error
10341093 def _read_frame_header (self ):
1035- buf = self ._frame_iobuf .getvalue ()
1094+ buf = self ._io_buffer . cql_frame_buffer .getvalue ()
10361095 pos = len (buf )
10371096 if pos :
10381097 version = int_from_buf_item (buf [0 ]) & PROTOCOL_VERSION_MASK
@@ -1048,28 +1107,19 @@ def _read_frame_header(self):
10481107 self ._current_frame = _Frame (version , flags , stream , op , header_size , body_len + header_size )
10491108 return pos
10501109
1051- def _reset_frame (self ):
1052- self ._frame_iobuf = io .BytesIO (self ._frame_iobuf .read ())
1053- self ._frame_iobuf .seek (0 , 2 ) # 2 == SEEK_END
1054- self ._current_frame = None
1055-
1056- def _reset_io_buffer (self ):
1057- self ._iobuf = io .BytesIO (self ._iobuf .read ())
1058- self ._iobuf .seek (0 , 2 ) # 2 == SEEK_END
1059-
10601110 @defunct_on_error
10611111 def _process_segment_buffer (self ):
1062- readable_bytes = self ._iobuf . tell ()
1112+ readable_bytes = self ._io_buffer . readable_io_bytes ()
10631113 if readable_bytes >= self ._segment_codec .header_length_with_crc :
10641114 try :
1065- self ._iobuf .seek (0 )
1066- segment_header = self ._segment_codec .decode_header (self ._iobuf )
1115+ self ._io_buffer . io_buffer .seek (0 )
1116+ segment_header = self ._segment_codec .decode_header (self ._io_buffer . io_buffer )
10671117 if readable_bytes >= segment_header .segment_length :
10681118 segment = self ._segment_codec .decode (self ._iobuf , segment_header )
1069- self ._frame_iobuf .write (segment .payload )
1119+ self ._io_buffer . cql_frame_buffer .write (segment .payload )
10701120 else :
10711121 # not enough data to read the segment
1072- self ._iobuf .seek (0 , 2 )
1122+ self ._io_buffer . io_buffer .seek (0 , 2 )
10731123 except CrcException as exc :
10741124 # re-raise an exception that inherits from ConnectionException
10751125 raise CrcMismatchException (str (exc ), self .endpoint )
@@ -1078,21 +1128,15 @@ def process_io_buffer(self):
10781128 while True :
10791129 if self ._is_checksumming_enabled :
10801130 self ._process_segment_buffer ()
1081- else :
1082- # We should probably refactor the IO buffering stuff out of the Connection
1083- # class to handle this in a better way. That would make the segment and frame
1084- # decoding code clearer.
1085- self ._frame_iobuf .write (self ._iobuf .getvalue ())
1086-
1087- self ._reset_io_buffer ()
1131+ self ._io_buffer .reset_io_buffer ()
10881132
10891133 if not self ._current_frame :
10901134 pos = self ._read_frame_header ()
10911135 else :
1092- pos = self ._frame_iobuf . tell ()
1136+ pos = self ._io_buffer . readable_cql_frame_bytes ()
10931137
10941138 if not self ._current_frame or pos < self ._current_frame .end_pos :
1095- if self ._is_checksumming_enabled and self ._iobuf . tell ():
1139+ if self ._is_checksumming_enabled and self ._io_buffer . readable_io_bytes ():
10961140 # We have a multi-segments message and we need to read more
10971141 # data to complete the current cql frame
10981142 continue
@@ -1103,10 +1147,11 @@ def process_io_buffer(self):
11031147 return
11041148 else :
11051149 frame = self ._current_frame
1106- self ._frame_iobuf .seek (frame .body_offset )
1107- msg = self ._frame_iobuf .read (frame .end_pos - frame .body_offset )
1150+ self ._io_buffer . cql_frame_buffer .seek (frame .body_offset )
1151+ msg = self ._io_buffer . cql_frame_buffer .read (frame .end_pos - frame .body_offset )
11081152 self .process_msg (frame , msg )
1109- self ._reset_frame ()
1153+ self ._io_buffer .reset_cql_frame_buffer ()
1154+ self ._current_frame = None
11101155
11111156 @defunct_on_error
11121157 def process_msg (self , header , body ):
@@ -1287,9 +1332,7 @@ def _handle_startup_response(self, startup_response, did_authenticate=False):
12871332 self .compressor = self ._compressor
12881333
12891334 if ProtocolVersion .has_checksumming_support (self .protocol_version ):
1290- self ._is_checksumming_enabled = True
1291- self ._segment_codec = segment_codec_lz4 if self .compressor else segment_codec_no_compression
1292- log .debug ("Enabling protocol checksumming on connection (%s)." , id (self ))
1335+ self ._enable_checksumming ()
12931336
12941337 self .connected_event .set ()
12951338 elif isinstance (startup_response , AuthenticateMessage ):
0 commit comments