Skip to content

Commit 22b92df

Browse files
committed
Add protocol_hander_class to allow extension, serdes specialization
1 parent a327ab1 commit 22b92df

6 files changed

Lines changed: 151 additions & 82 deletions

File tree

cassandra/cluster.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
IsBootstrappingErrorMessage,
6161
BatchMessage, RESULT_KIND_PREPARED,
6262
RESULT_KIND_SET_KEYSPACE, RESULT_KIND_ROWS,
63-
RESULT_KIND_SCHEMA_CHANGE, MIN_SUPPORTED_VERSION)
63+
RESULT_KIND_SCHEMA_CHANGE, MIN_SUPPORTED_VERSION, ProtocolHandler)
6464
from cassandra.metadata import Metadata, protect_name, murmur3
6565
from cassandra.policies import (TokenAwarePolicy, DCAwareRoundRobinPolicy, SimpleConvictionPolicy,
6666
ExponentialReconnectionPolicy, HostDistance,
@@ -419,6 +419,15 @@ def auth_provider(self, value):
419419
GeventConnection will be used automatically.
420420
"""
421421

422+
protocol_handler_class = ProtocolHandler
423+
"""
424+
Specifies a protocol handler class, which can be used to override or extend features
425+
such as message or type deserialization.
426+
427+
The class must conform to the public classmethod interface defined in the default
428+
implementation, :class:`cassandra.protocol.ProtocolHandler`
429+
"""
430+
422431
control_connection_timeout = 2.0
423432
"""
424433
A timeout, in seconds, for queries made by the control connection, such
@@ -515,7 +524,8 @@ def __init__(self,
515524
idle_heartbeat_interval=30,
516525
schema_event_refresh_window=2,
517526
topology_event_refresh_window=10,
518-
connect_timeout=5):
527+
connect_timeout=5,
528+
protocol_handler_class=None):
519529
"""
520530
Any of the mutable Cluster attributes may be set as keyword arguments
521531
to the constructor.
@@ -559,6 +569,9 @@ def __init__(self,
559569
if connection_class is not None:
560570
self.connection_class = connection_class
561571

572+
if protocol_handler_class is not None:
573+
self.protocol_handler_class = protocol_handler_class
574+
562575
self.metrics_enabled = metrics_enabled
563576
self.ssl_options = ssl_options
564577
self.sockopts = sockopts
@@ -798,6 +811,7 @@ def _make_connection_kwargs(self, address, kwargs_dict):
798811
kwargs_dict['cql_version'] = self.cql_version
799812
kwargs_dict['protocol_version'] = self.protocol_version
800813
kwargs_dict['user_type_map'] = self._user_types
814+
kwargs_dict['protocol_handler_class'] = self.protocol_handler_class
801815

802816
return kwargs_dict
803817

cassandra/connection.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from cassandra.marshal import int32_pack, uint8_unpack
4343
from cassandra.protocol import (ReadyMessage, AuthenticateMessage, OptionsMessage,
4444
StartupMessage, ErrorMessage, CredentialsMessage,
45-
QueryMessage, ResultMessage, decode_response,
45+
QueryMessage, ResultMessage, ProtocolHandler,
4646
InvalidRequestException, SupportedMessage,
4747
AuthResponseMessage, AuthChallengeMessage,
4848
AuthSuccessMessage, ProtocolException,
@@ -209,7 +209,7 @@ class Connection(object):
209209
def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
210210
ssl_options=None, sockopts=None, compression=True,
211211
cql_version=None, protocol_version=MAX_SUPPORTED_VERSION, is_control_connection=False,
212-
user_type_map=None):
212+
user_type_map=None, protocol_handler_class=ProtocolHandler):
213213
self.host = host
214214
self.port = port
215215
self.authenticator = authenticator
@@ -220,6 +220,8 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
220220
self.protocol_version = protocol_version
221221
self.is_control_connection = is_control_connection
222222
self.user_type_map = user_type_map
223+
self.decoder = protocol_handler_class.decode_message
224+
self.encoder = protocol_handler_class.encode_message
223225
self._push_watchers = defaultdict(set)
224226
self._callbacks = {}
225227
self._iobuf = io.BytesIO()
@@ -362,7 +364,7 @@ def send_msg(self, msg, request_id, cb):
362364
raise ConnectionShutdown("Connection to %s is closed" % self.host)
363365

364366
self._callbacks[request_id] = cb
365-
self.push(msg.to_binary(request_id, self.protocol_version, compression=self.compressor))
367+
self.push(self.encoder(msg, request_id, self.protocol_version, compressor=self.compressor))
366368
return request_id
367369

368370
def wait_for_response(self, msg, timeout=None):
@@ -498,8 +500,8 @@ def process_msg(self, header, body):
498500
self.msg_received = True
499501

500502
try:
501-
response = decode_response(header.version, self.user_type_map, stream_id,
502-
header.flags, header.opcode, body, self.decompressor)
503+
response = self.decoder(header.version, self.user_type_map, stream_id,
504+
header.flags, header.opcode, body, self.decompressor)
503505
except Exception as exc:
504506
log.exception("Error decoding response from Cassandra. "
505507
"opcode: %04x; message contents: %r", header.opcode, body)

cassandra/protocol.py

Lines changed: 116 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -83,29 +83,6 @@ class _MessageType(object):
8383
custom_payload = None
8484
warnings = None
8585

86-
def to_binary(self, stream_id, protocol_version, compression=None):
87-
flags = 0
88-
body = io.BytesIO()
89-
if self.custom_payload:
90-
if protocol_version < 4:
91-
raise UnsupportedOperation("Custom key/value payloads can only be used with protocol version 4 or higher")
92-
flags |= CUSTOM_PAYLOAD_FLAG
93-
write_bytesmap(body, self.custom_payload)
94-
self.send_body(body, protocol_version)
95-
body = body.getvalue()
96-
97-
if compression and len(body) > 0:
98-
body = compression(body)
99-
flags |= COMPRESSED_FLAG
100-
if self.tracing:
101-
flags |= TRACING_FLAG
102-
103-
msg = io.BytesIO()
104-
write_header(msg, protocol_version, flags, stream_id, self.opcode, len(body))
105-
msg.write(body)
106-
107-
return msg.getvalue()
108-
10986
def update_custom_payload(self, other):
11087
if other:
11188
if not self.custom_payload:
@@ -126,50 +103,6 @@ def _get_params(message_obj):
126103
)
127104

128105

129-
def decode_response(protocol_version, user_type_map, stream_id, flags, opcode, body,
130-
decompressor=None):
131-
if flags & COMPRESSED_FLAG:
132-
if decompressor is None:
133-
raise Exception("No de-compressor available for compressed frame!")
134-
body = decompressor(body)
135-
flags ^= COMPRESSED_FLAG
136-
137-
body = io.BytesIO(body)
138-
if flags & TRACING_FLAG:
139-
trace_id = UUID(bytes=body.read(16))
140-
flags ^= TRACING_FLAG
141-
else:
142-
trace_id = None
143-
144-
if flags & WARNING_FLAG:
145-
warnings = read_stringlist(body)
146-
flags ^= WARNING_FLAG
147-
else:
148-
warnings = None
149-
150-
if flags & CUSTOM_PAYLOAD_FLAG:
151-
custom_payload = read_bytesmap(body)
152-
flags ^= CUSTOM_PAYLOAD_FLAG
153-
else:
154-
custom_payload = None
155-
156-
if flags:
157-
log.warning("Unknown protocol flags set: %02x. May cause problems.", flags)
158-
159-
msg_class = _message_types_by_opcode[opcode]
160-
msg = msg_class.recv_body(body, protocol_version, user_type_map)
161-
msg.stream_id = stream_id
162-
msg.trace_id = trace_id
163-
msg.custom_payload = custom_payload
164-
msg.warnings = warnings
165-
166-
if msg.warnings:
167-
for w in msg.warnings:
168-
log.warning("Server warning: %s", w)
169-
170-
return msg
171-
172-
173106
error_classes = {}
174107

175108

@@ -609,7 +542,7 @@ class ResultMessage(_MessageType):
609542
results = None
610543
paging_state = None
611544

612-
_type_codes = {
545+
type_codes = {
613546
0x0000: CUSTOM_TYPE,
614547
0x0001: AsciiType,
615548
0x0002: LongType,
@@ -744,7 +677,7 @@ def recv_results_schema_change(cls, f, protocol_version):
744677
def read_type(cls, f, user_type_map):
745678
optid = read_short(f)
746679
try:
747-
typeclass = cls._type_codes[optid]
680+
typeclass = cls.type_codes[optid]
748681
except KeyError:
749682
raise NotSupportedError("Unknown data type code 0x%04x. Have to skip"
750683
" entire result set." % (optid,))
@@ -964,13 +897,122 @@ def recv_schema_change(cls, f, protocol_version):
964897
return event
965898

966899

967-
def write_header(f, version, flags, stream_id, opcode, length):
900+
class ProtocolHandler(object):
901+
"""
902+
ProtocolHander handles encoding and decoding messages.
903+
904+
This class can be specialized to compose Handlers which implement alternative
905+
result decoding or type deserialization. Class definitions are passed to :class:`cassandra.cluster.Cluster`
906+
on initialization.
907+
908+
Contracted class methods are :meth:`ProtocolHandler.encode_message` and :meth:`ProtocolHandler.decode_message`.
909+
"""
910+
911+
message_types_by_opcode = _message_types_by_opcode.copy()
968912
"""
969-
Write a CQL protocol frame header.
913+
Default mapping of opcode to Message implementation. The default ``decode_message`` implementation uses
914+
this to instantiate a message and populate using ``recv_body``. This mapping can be updated to inject specialized
915+
result decoding implementations.
970916
"""
971-
pack = v3_header_pack if version >= 3 else header_pack
972-
f.write(pack(version, flags, stream_id, opcode))
973-
write_int(f, length)
917+
918+
@classmethod
919+
def encode_message(cls, msg, stream_id, protocol_version, compressor):
920+
"""
921+
Encodes a message using the specified frame parameters, and compressor
922+
923+
:param msg: the message, typically of cassandra.protocol._MessageType, generated by the driver
924+
:param stream_id: protocol stream id for the frame header
925+
:param protocol_version: version for the frame header, and used encoding contents
926+
:param compressor: optional compression function to be used on the body
927+
:return:
928+
"""
929+
flags = 0
930+
body = io.BytesIO()
931+
if msg.custom_payload:
932+
if protocol_version < 4:
933+
raise UnsupportedOperation("Custom key/value payloads can only be used with protocol version 4 or higher")
934+
flags |= CUSTOM_PAYLOAD_FLAG
935+
write_bytesmap(body, msg.custom_payload)
936+
msg.send_body(body, protocol_version)
937+
body = body.getvalue()
938+
939+
if compressor and len(body) > 0:
940+
body = compressor(body)
941+
flags |= COMPRESSED_FLAG
942+
943+
if msg.tracing:
944+
flags |= TRACING_FLAG
945+
946+
buff = io.BytesIO()
947+
cls._write_header(buff, protocol_version, flags, stream_id, msg.opcode, len(body))
948+
buff.write(body)
949+
950+
return buff.getvalue()
951+
952+
@staticmethod
953+
def _write_header(f, version, flags, stream_id, opcode, length):
954+
"""
955+
Write a CQL protocol frame header.
956+
"""
957+
pack = v3_header_pack if version >= 3 else header_pack
958+
f.write(pack(version, flags, stream_id, opcode))
959+
write_int(f, length)
960+
961+
@classmethod
962+
def decode_message(cls, protocol_version, user_type_map, stream_id, flags, opcode, body,
963+
decompressor):
964+
"""
965+
Decodes a native protocol message body
966+
967+
:param protocol_version: version to use decoding contents
968+
:param user_type_map: map[keyspace name] = map[type name] = custom type to instantiate when deserializing this type
969+
:param stream_id: native protocol stream id from the frame header
970+
:param flags: native protocol flags bitmap from the header
971+
:param opcode: native protocol opcode from the header
972+
:param body: frame body
973+
:param decompressor: optional decompression function to inflate the body
974+
:return: a message decoded from the body and frame attributes
975+
"""
976+
if flags & COMPRESSED_FLAG:
977+
if decompressor is None:
978+
raise Exception("No de-compressor available for compressed frame!")
979+
body = decompressor(body)
980+
flags ^= COMPRESSED_FLAG
981+
982+
body = io.BytesIO(body)
983+
if flags & TRACING_FLAG:
984+
trace_id = UUID(bytes=body.read(16))
985+
flags ^= TRACING_FLAG
986+
else:
987+
trace_id = None
988+
989+
if flags & WARNING_FLAG:
990+
warnings = read_stringlist(body)
991+
flags ^= WARNING_FLAG
992+
else:
993+
warnings = None
994+
995+
if flags & CUSTOM_PAYLOAD_FLAG:
996+
custom_payload = read_bytesmap(body)
997+
flags ^= CUSTOM_PAYLOAD_FLAG
998+
else:
999+
custom_payload = None
1000+
1001+
if flags:
1002+
log.warning("Unknown protocol flags set: %02x. May cause problems.", flags)
1003+
1004+
msg_class = cls.message_types_by_opcode[opcode]
1005+
msg = msg_class.recv_body(body, protocol_version, user_type_map)
1006+
msg.stream_id = stream_id
1007+
msg.trace_id = trace_id
1008+
msg.custom_payload = custom_payload
1009+
msg.warnings = warnings
1010+
1011+
if msg.warnings:
1012+
for w in msg.warnings:
1013+
log.warning("Server warning: %s", w)
1014+
1015+
return msg
9741016

9751017

9761018
def read_byte(f):

docs/api/cassandra/cluster.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727

2828
.. autoattribute:: connection_class
2929

30+
.. autoattribute:: protocol_handler_class
31+
3032
.. autoattribute:: metrics_enabled
3133

3234
.. autoattribute:: metrics

docs/api/cassandra/protocol.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,12 @@ By default these are ignored by the server. They can be useful for servers imple
1515
a custom QueryHandler.
1616

1717
See :meth:`.Session.execute`, ::meth:`.Session.execute_async`, :attr:`.ResponseFuture.custom_payload`.
18+
19+
.. autoclass:: ProtocolHandler
20+
21+
.. autoattribute:: message_types_by_opcode
22+
:annotation: = {default mapping}
23+
24+
.. automethod:: encode_message
25+
26+
.. automethod:: decode_message

tests/integration/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def _tuple_version(version_string):
8080

8181
USE_CASS_EXTERNAL = bool(os.getenv('USE_CASS_EXTERNAL', False))
8282

83-
default_cassandra_version = '2.1.5'
83+
default_cassandra_version = '2.1.6'
8484

8585
if USE_CASS_EXTERNAL:
8686
if CCMClusterFactory:

0 commit comments

Comments
 (0)