Skip to content

Commit d992e81

Browse files
committed
Avoid using 2 io buffers when checksumming is not used
1 parent bd05fe6 commit d992e81

9 files changed

Lines changed: 114 additions & 55 deletions

File tree

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Not released
55
Features
66
--------
77
* Ensure the driver can connect when invalid peer hosts are in system.peers (PYTHON-1260)
8+
* Implement protocol v5 checksumming (PYTHON-1258)
89

910
Bug Fixes
1011
---------

cassandra/connection.py

Lines changed: 77 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from threading import Thread, Event, RLock, Condition
2828
import time
2929
import ssl
30+
import weakref
31+
3032

3133
if 'gevent.monkey' in sys.modules:
3234
from gevent.queue import Queue, Empty
@@ -610,6 +612,55 @@ def int_from_buf_item(i):
610612
int_from_buf_item = ord
611613

612614

615+
class _ConnectionIOBuffer(object):
616+
"""
617+
Abstraction class to ease the use of the different connection io buffers. With
618+
protocol V5 and checksumming, the data is read, validated and copied to another
619+
cql frame buffer.
620+
"""
621+
_io_buffer = None
622+
_cql_frame_buffer = None
623+
_connection = None
624+
625+
def __init__(self, connection):
626+
self._io_buffer = io.BytesIO()
627+
self._connection = weakref.proxy(connection)
628+
629+
@property
630+
def io_buffer(self):
631+
return self._io_buffer
632+
633+
@property
634+
def cql_frame_buffer(self):
635+
return self._cql_frame_buffer if self.is_checksumming_enabled else \
636+
self._io_buffer
637+
638+
def set_checksumming_buffer(self):
639+
self.reset_io_buffer()
640+
self._cql_frame_buffer = io.BytesIO()
641+
642+
@property
643+
def is_checksumming_enabled(self):
644+
return self._connection._is_checksumming_enabled
645+
646+
def readable_io_bytes(self):
647+
return self.io_buffer.tell()
648+
649+
def readable_cql_frame_bytes(self):
650+
return self.cql_frame_buffer.tell()
651+
652+
def reset_io_buffer(self):
653+
self._io_buffer = io.BytesIO(self._io_buffer.read())
654+
self._io_buffer.seek(0, 2) # 2 == SEEK_END
655+
656+
def reset_cql_frame_buffer(self):
657+
if self.is_checksumming_enabled:
658+
self._cql_frame_buffer = io.BytesIO(self._cql_frame_buffer.read())
659+
self._cql_frame_buffer.seek(0, 2) # 2 == SEEK_END
660+
else:
661+
self.reset_io_buffer()
662+
663+
613664
class Connection(object):
614665

615666
CALLBACK_ERR_THREAD_THRESHOLD = 100
@@ -665,8 +716,6 @@ class Connection(object):
665716

666717
allow_beta_protocol_version = False
667718

668-
_iobuf = None
669-
_frame_iobuf = None
670719
_current_frame = None
671720

672721
_socket = None
@@ -679,6 +728,11 @@ class Connection(object):
679728

680729
_is_checksumming_enabled = False
681730

731+
@property
732+
def _iobuf(self):
733+
# backward compatibility, to avoid any change in the reactors
734+
return self._io_buffer.io_buffer
735+
682736
def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
683737
ssl_options=None, sockopts=None, compression=True,
684738
cql_version=None, protocol_version=ProtocolVersion.MAX_SUPPORTED, is_control_connection=False,
@@ -702,8 +756,7 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
702756
self.no_compact = no_compact
703757
self._push_watchers = defaultdict(set)
704758
self._requests = {}
705-
self._iobuf = io.BytesIO()
706-
self._frame_iobuf = io.BytesIO()
759+
self._io_buffer = _ConnectionIOBuffer(self)
707760
self._continuous_paging_sessions = {}
708761
self._socket_writable = True
709762

@@ -844,6 +897,12 @@ def _connect_socket(self):
844897
for args in self.sockopts:
845898
self._socket.setsockopt(*args)
846899

900+
def _enable_checksumming(self):
901+
self._io_buffer.set_checksumming_buffer()
902+
self._is_checksumming_enabled = True
903+
self._segment_codec = segment_codec_lz4 if self.compressor else segment_codec_no_compression
904+
log.debug("Enabling protocol checksumming on connection (%s).", id(self))
905+
847906
def close(self):
848907
raise NotImplementedError()
849908

@@ -1032,7 +1091,7 @@ def control_conn_disposed(self):
10321091

10331092
@defunct_on_error
10341093
def _read_frame_header(self):
1035-
buf = self._frame_iobuf.getvalue()
1094+
buf = self._io_buffer.cql_frame_buffer.getvalue()
10361095
pos = len(buf)
10371096
if pos:
10381097
version = int_from_buf_item(buf[0]) & PROTOCOL_VERSION_MASK
@@ -1048,28 +1107,19 @@ def _read_frame_header(self):
10481107
self._current_frame = _Frame(version, flags, stream, op, header_size, body_len + header_size)
10491108
return pos
10501109

1051-
def _reset_frame(self):
1052-
self._frame_iobuf = io.BytesIO(self._frame_iobuf.read())
1053-
self._frame_iobuf.seek(0, 2) # 2 == SEEK_END
1054-
self._current_frame = None
1055-
1056-
def _reset_io_buffer(self):
1057-
self._iobuf = io.BytesIO(self._iobuf.read())
1058-
self._iobuf.seek(0, 2) # 2 == SEEK_END
1059-
10601110
@defunct_on_error
10611111
def _process_segment_buffer(self):
1062-
readable_bytes = self._iobuf.tell()
1112+
readable_bytes = self._io_buffer.readable_io_bytes()
10631113
if readable_bytes >= self._segment_codec.header_length_with_crc:
10641114
try:
1065-
self._iobuf.seek(0)
1066-
segment_header = self._segment_codec.decode_header(self._iobuf)
1115+
self._io_buffer.io_buffer.seek(0)
1116+
segment_header = self._segment_codec.decode_header(self._io_buffer.io_buffer)
10671117
if readable_bytes >= segment_header.segment_length:
10681118
segment = self._segment_codec.decode(self._iobuf, segment_header)
1069-
self._frame_iobuf.write(segment.payload)
1119+
self._io_buffer.cql_frame_buffer.write(segment.payload)
10701120
else:
10711121
# not enough data to read the segment
1072-
self._iobuf.seek(0, 2)
1122+
self._io_buffer.io_buffer.seek(0, 2)
10731123
except CrcException as exc:
10741124
# re-raise an exception that inherits from ConnectionException
10751125
raise CrcMismatchException(str(exc), self.endpoint)
@@ -1078,21 +1128,15 @@ def process_io_buffer(self):
10781128
while True:
10791129
if self._is_checksumming_enabled:
10801130
self._process_segment_buffer()
1081-
else:
1082-
# We should probably refactor the IO buffering stuff out of the Connection
1083-
# class to handle this in a better way. That would make the segment and frame
1084-
# decoding code clearer.
1085-
self._frame_iobuf.write(self._iobuf.getvalue())
1086-
1087-
self._reset_io_buffer()
1131+
self._io_buffer.reset_io_buffer()
10881132

10891133
if not self._current_frame:
10901134
pos = self._read_frame_header()
10911135
else:
1092-
pos = self._frame_iobuf.tell()
1136+
pos = self._io_buffer.readable_cql_frame_bytes()
10931137

10941138
if not self._current_frame or pos < self._current_frame.end_pos:
1095-
if self._is_checksumming_enabled and self._iobuf.tell():
1139+
if self._is_checksumming_enabled and self._io_buffer.readable_io_bytes():
10961140
# We have a multi-segments message and we need to read more
10971141
# data to complete the current cql frame
10981142
continue
@@ -1103,10 +1147,11 @@ def process_io_buffer(self):
11031147
return
11041148
else:
11051149
frame = self._current_frame
1106-
self._frame_iobuf.seek(frame.body_offset)
1107-
msg = self._frame_iobuf.read(frame.end_pos - frame.body_offset)
1150+
self._io_buffer.cql_frame_buffer.seek(frame.body_offset)
1151+
msg = self._io_buffer.cql_frame_buffer.read(frame.end_pos - frame.body_offset)
11081152
self.process_msg(frame, msg)
1109-
self._reset_frame()
1153+
self._io_buffer.reset_cql_frame_buffer()
1154+
self._current_frame = None
11101155

11111156
@defunct_on_error
11121157
def process_msg(self, header, body):
@@ -1287,9 +1332,7 @@ def _handle_startup_response(self, startup_response, did_authenticate=False):
12871332
self.compressor = self._compressor
12881333

12891334
if ProtocolVersion.has_checksumming_support(self.protocol_version):
1290-
self._is_checksumming_enabled = True
1291-
self._segment_codec = segment_codec_lz4 if self.compressor else segment_codec_no_compression
1292-
log.debug("Enabling protocol checksumming on connection (%s).", id(self))
1335+
self._enable_checksumming()
12931336

12941337
self.connected_event.set()
12951338
elif isinstance(startup_response, AuthenticateMessage):

tests/integration/standard/test_custom_protocol_handler.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
from tests.integration import use_singledc, drop_keyspace_shutdown_cluster, \
2727
greaterthanorequalcass30, execute_with_long_wait_retry, greaterthanorequaldse51, greaterthanorequalcass3_10, \
28-
greaterthanorequalcass31, TestCluster
28+
TestCluster, greaterthanorequalcass40, requirecassandra
2929
from tests.integration.datatype_utils import update_datatypes, PRIMITIVE_DATATYPES
3030
from tests.integration.standard.utils import create_table_with_all_types, get_all_primitive_params
3131
from six import binary_type
@@ -124,7 +124,8 @@ def test_custom_raw_row_results_all_types(self):
124124
self.assertEqual(len(CustomResultMessageTracked.checked_rev_row_set), len(PRIMITIVE_DATATYPES)-1)
125125
cluster.shutdown()
126126

127-
@greaterthanorequalcass31
127+
@requirecassandra
128+
@greaterthanorequalcass40
128129
def test_protocol_divergence_v5_fail_by_continuous_paging(self):
129130
"""
130131
Test to validate that V5 and DSE_V1 diverge. ContinuousPagingOptions is not supported by V5
@@ -170,7 +171,8 @@ def test_protocol_divergence_v4_fail_by_flag_uses_int(self):
170171
self._protocol_divergence_fail_by_flag_uses_int(ProtocolVersion.V4, uses_int_query_flag=False,
171172
int_flag=True)
172173

173-
@greaterthanorequalcass3_10
174+
@requirecassandra
175+
@greaterthanorequalcass40
174176
def test_protocol_v5_uses_flag_int(self):
175177
"""
176178
Test to validate that the _PAGE_SIZE_FLAG is treated correctly using write_uint for V5
@@ -196,7 +198,8 @@ def test_protocol_dsev1_uses_flag_int(self):
196198
self._protocol_divergence_fail_by_flag_uses_int(ProtocolVersion.DSE_V1, uses_int_query_flag=True,
197199
int_flag=True)
198200

199-
@greaterthanorequalcass3_10
201+
@requirecassandra
202+
@greaterthanorequalcass40
200203
def test_protocol_divergence_v5_fail_by_flag_uses_int(self):
201204
"""
202205
Test to validate that the _PAGE_SIZE_FLAG is treated correctly using write_uint for V5

tests/integration/standard/test_query.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from cassandra.policies import HostDistance, RoundRobinPolicy, WhiteListRoundRobinPolicy
2929
from tests.integration import use_singledc, PROTOCOL_VERSION, BasicSharedKeyspaceUnitTestCase, \
3030
greaterthanprotocolv3, MockLoggingHandler, get_supported_protocol_versions, local, get_cluster, setup_keyspace, \
31-
USE_CASS_EXTERNAL, greaterthanorequalcass40, DSE_VERSION, TestCluster
31+
USE_CASS_EXTERNAL, greaterthanorequalcass40, DSE_VERSION, TestCluster, requirecassandra
3232
from tests import notwindows
3333
from tests.integration import greaterthanorequalcass30, get_node
3434

@@ -1408,6 +1408,8 @@ def test_setting_keyspace(self):
14081408
"""
14091409
self._check_set_keyspace_in_statement(self.session)
14101410

1411+
@requirecassandra
1412+
@greaterthanorequalcass40
14111413
def test_setting_keyspace_and_session(self):
14121414
"""
14131415
Test we can still send the keyspace independently even the session

tests/unit/io/test_twistedreactor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,12 +148,12 @@ def test_handle_read__incomplete(self):
148148
# incomplete header
149149
self.obj_ut._iobuf.write(b'\x84\x00\x00\x00\x00')
150150
self.obj_ut.handle_read()
151-
self.assertEqual(self.obj_ut._frame_iobuf.getvalue(), b'\x84\x00\x00\x00\x00')
151+
self.assertEqual(self.obj_ut._io_buffer.cql_frame_buffer.getvalue(), b'\x84\x00\x00\x00\x00')
152152

153153
# full header, but incomplete body
154154
self.obj_ut._iobuf.write(b'\x00\x00\x00\x15')
155155
self.obj_ut.handle_read()
156-
self.assertEqual(self.obj_ut._frame_iobuf.getvalue(),
156+
self.assertEqual(self.obj_ut._io_buffer.cql_frame_buffer.getvalue(),
157157
b'\x84\x00\x00\x00\x00\x00\x00\x00\x15')
158158
self.assertEqual(self.obj_ut._current_frame.end_pos, 30)
159159

@@ -174,7 +174,7 @@ def test_handle_read__fullmessage(self):
174174
self.obj_ut._iobuf.write(
175175
b'\x84\x01\x00\x02\x03\x00\x00\x00\x15' + body + extra)
176176
self.obj_ut.handle_read()
177-
self.assertEqual(self.obj_ut._frame_iobuf.getvalue(), extra)
177+
self.assertEqual(self.obj_ut._io_buffer.cql_frame_buffer.getvalue(), extra)
178178
self.obj_ut.process_msg.assert_called_with(
179179
_Frame(version=4, flags=1, stream=2, opcode=3, body_offset=9, end_pos=9 + len(body)), body)
180180

tests/unit/io/utils.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -309,14 +309,14 @@ def chunk(size):
309309

310310
for message, expected_size in messages:
311311
message_chunks = message
312-
c._iobuf = io.BytesIO()
312+
c._io_buffer._io_buffer = io.BytesIO()
313313
c.process_io_buffer.reset_mock()
314314
c.handle_read(*self.null_handle_function_args)
315-
c._iobuf.seek(0, os.SEEK_END)
315+
c._io_buffer.io_buffer.seek(0, os.SEEK_END)
316316

317317
# Ensure the message size is the good one and that the
318318
# message has been processed if it is non-empty
319-
self.assertEqual(c._iobuf.tell(), expected_size)
319+
self.assertEqual(c._io_buffer.io_buffer.tell(), expected_size)
320320
if expected_size == 0:
321321
c.process_io_buffer.assert_not_called()
322322
else:
@@ -435,11 +435,11 @@ def test_partial_header_read(self):
435435

436436
self.get_socket(c).recv.return_value = message[0:1]
437437
c.handle_read(*self.null_handle_function_args)
438-
self.assertEqual(c._frame_iobuf.getvalue(), message[0:1])
438+
self.assertEqual(c._io_buffer.cql_frame_buffer.getvalue(), message[0:1])
439439

440440
self.get_socket(c).recv.return_value = message[1:]
441441
c.handle_read(*self.null_handle_function_args)
442-
self.assertEqual(six.binary_type(), c._iobuf.getvalue())
442+
self.assertEqual(six.binary_type(), c._io_buffer.io_buffer.getvalue())
443443

444444
# let it write out a StartupMessage
445445
c.handle_write(*self.null_handle_function_args)
@@ -461,12 +461,12 @@ def test_partial_message_read(self):
461461
# read in the first nine bytes
462462
self.get_socket(c).recv.return_value = message[:9]
463463
c.handle_read(*self.null_handle_function_args)
464-
self.assertEqual(c._frame_iobuf.getvalue(), message[:9])
464+
self.assertEqual(c._io_buffer.cql_frame_buffer.getvalue(), message[:9])
465465

466466
# ... then read in the rest
467467
self.get_socket(c).recv.return_value = message[9:]
468468
c.handle_read(*self.null_handle_function_args)
469-
self.assertEqual(six.binary_type(), c._iobuf.getvalue())
469+
self.assertEqual(six.binary_type(), c._io_buffer.io_buffer.getvalue())
470470

471471
# let it write out a StartupMessage
472472
c.handle_write(*self.null_handle_function_args)
@@ -501,7 +501,7 @@ def test_mixed_message_and_buffer_sizes(self):
501501

502502
for i in range(1, 15):
503503
c.process_io_buffer.reset_mock()
504-
c._iobuf = io.BytesIO()
504+
c._io_buffer._io_buffer = io.BytesIO()
505505
message = io.BytesIO(six.b('a') * (2**i))
506506

507507
def recv_side_effect(*args):
@@ -511,7 +511,7 @@ def recv_side_effect(*args):
511511

512512
self.get_socket(c).recv.side_effect = recv_side_effect
513513
c.handle_read(*self.null_handle_function_args)
514-
if c._iobuf.tell():
514+
if c._io_buffer.io_buffer.tell():
515515
c.process_io_buffer.assert_called_once()
516516
else:
517517
c.process_io_buffer.assert_not_called()

tests/unit/test_connection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def test_bad_protocol_version(self, *args):
100100
header = self.make_header_prefix(SupportedMessage, version=0x7f)
101101
options = self.make_options_body()
102102
message = self.make_msg(header, options)
103-
c._iobuf = BytesIO()
103+
c._iobuf._io_buffer = BytesIO()
104104
c._iobuf.write(message)
105105
c.process_io_buffer()
106106

@@ -117,7 +117,7 @@ def test_negative_body_length(self, *args):
117117
# read in a SupportedMessage response
118118
header = self.make_header_prefix(SupportedMessage)
119119
message = header + int32_pack(-13)
120-
c._iobuf = BytesIO()
120+
c._iobuf._io_buffer = BytesIO()
121121
c._iobuf.write(message)
122122
c.process_io_buffer()
123123

0 commit comments

Comments
 (0)