@@ -83,29 +83,6 @@ class _MessageType(object):
8383 custom_payload = None
8484 warnings = None
8585
86- def to_binary (self , stream_id , protocol_version , compression = None ):
87- flags = 0
88- body = io .BytesIO ()
89- if self .custom_payload :
90- if protocol_version < 4 :
91- raise UnsupportedOperation ("Custom key/value payloads can only be used with protocol version 4 or higher" )
92- flags |= CUSTOM_PAYLOAD_FLAG
93- write_bytesmap (body , self .custom_payload )
94- self .send_body (body , protocol_version )
95- body = body .getvalue ()
96-
97- if compression and len (body ) > 0 :
98- body = compression (body )
99- flags |= COMPRESSED_FLAG
100- if self .tracing :
101- flags |= TRACING_FLAG
102-
103- msg = io .BytesIO ()
104- write_header (msg , protocol_version , flags , stream_id , self .opcode , len (body ))
105- msg .write (body )
106-
107- return msg .getvalue ()
108-
10986 def update_custom_payload (self , other ):
11087 if other :
11188 if not self .custom_payload :
@@ -126,50 +103,6 @@ def _get_params(message_obj):
126103 )
127104
128105
129- def decode_response (protocol_version , user_type_map , stream_id , flags , opcode , body ,
130- decompressor = None ):
131- if flags & COMPRESSED_FLAG :
132- if decompressor is None :
133- raise Exception ("No de-compressor available for compressed frame!" )
134- body = decompressor (body )
135- flags ^= COMPRESSED_FLAG
136-
137- body = io .BytesIO (body )
138- if flags & TRACING_FLAG :
139- trace_id = UUID (bytes = body .read (16 ))
140- flags ^= TRACING_FLAG
141- else :
142- trace_id = None
143-
144- if flags & WARNING_FLAG :
145- warnings = read_stringlist (body )
146- flags ^= WARNING_FLAG
147- else :
148- warnings = None
149-
150- if flags & CUSTOM_PAYLOAD_FLAG :
151- custom_payload = read_bytesmap (body )
152- flags ^= CUSTOM_PAYLOAD_FLAG
153- else :
154- custom_payload = None
155-
156- if flags :
157- log .warning ("Unknown protocol flags set: %02x. May cause problems." , flags )
158-
159- msg_class = _message_types_by_opcode [opcode ]
160- msg = msg_class .recv_body (body , protocol_version , user_type_map )
161- msg .stream_id = stream_id
162- msg .trace_id = trace_id
163- msg .custom_payload = custom_payload
164- msg .warnings = warnings
165-
166- if msg .warnings :
167- for w in msg .warnings :
168- log .warning ("Server warning: %s" , w )
169-
170- return msg
171-
172-
173106error_classes = {}
174107
175108
@@ -609,7 +542,7 @@ class ResultMessage(_MessageType):
609542 results = None
610543 paging_state = None
611544
612- _type_codes = {
545+ type_codes = {
613546 0x0000 : CUSTOM_TYPE ,
614547 0x0001 : AsciiType ,
615548 0x0002 : LongType ,
@@ -744,7 +677,7 @@ def recv_results_schema_change(cls, f, protocol_version):
744677 def read_type (cls , f , user_type_map ):
745678 optid = read_short (f )
746679 try :
747- typeclass = cls ._type_codes [optid ]
680+ typeclass = cls .type_codes [optid ]
748681 except KeyError :
749682 raise NotSupportedError ("Unknown data type code 0x%04x. Have to skip"
750683 " entire result set." % (optid ,))
@@ -964,13 +897,122 @@ def recv_schema_change(cls, f, protocol_version):
964897 return event
965898
966899
967- def write_header (f , version , flags , stream_id , opcode , length ):
900+ class ProtocolHandler (object ):
901+ """
902+ ProtocolHander handles encoding and decoding messages.
903+
904+ This class can be specialized to compose Handlers which implement alternative
905+ result decoding or type deserialization. Class definitions are passed to :class:`cassandra.cluster.Cluster`
906+ on initialization.
907+
908+ Contracted class methods are :meth:`ProtocolHandler.encode_message` and :meth:`ProtocolHandler.decode_message`.
909+ """
910+
911+ message_types_by_opcode = _message_types_by_opcode .copy ()
968912 """
969- Write a CQL protocol frame header.
913+ Default mapping of opcode to Message implementation. The default ``decode_message`` implementation uses
914+ this to instantiate a message and populate using ``recv_body``. This mapping can be updated to inject specialized
915+ result decoding implementations.
970916 """
971- pack = v3_header_pack if version >= 3 else header_pack
972- f .write (pack (version , flags , stream_id , opcode ))
973- write_int (f , length )
917+
918+ @classmethod
919+ def encode_message (cls , msg , stream_id , protocol_version , compressor ):
920+ """
921+ Encodes a message using the specified frame parameters, and compressor
922+
923+ :param msg: the message, typically of cassandra.protocol._MessageType, generated by the driver
924+ :param stream_id: protocol stream id for the frame header
925+ :param protocol_version: version for the frame header, and used encoding contents
926+ :param compressor: optional compression function to be used on the body
927+ :return:
928+ """
929+ flags = 0
930+ body = io .BytesIO ()
931+ if msg .custom_payload :
932+ if protocol_version < 4 :
933+ raise UnsupportedOperation ("Custom key/value payloads can only be used with protocol version 4 or higher" )
934+ flags |= CUSTOM_PAYLOAD_FLAG
935+ write_bytesmap (body , msg .custom_payload )
936+ msg .send_body (body , protocol_version )
937+ body = body .getvalue ()
938+
939+ if compressor and len (body ) > 0 :
940+ body = compressor (body )
941+ flags |= COMPRESSED_FLAG
942+
943+ if msg .tracing :
944+ flags |= TRACING_FLAG
945+
946+ buff = io .BytesIO ()
947+ cls ._write_header (buff , protocol_version , flags , stream_id , msg .opcode , len (body ))
948+ buff .write (body )
949+
950+ return buff .getvalue ()
951+
952+ @staticmethod
953+ def _write_header (f , version , flags , stream_id , opcode , length ):
954+ """
955+ Write a CQL protocol frame header.
956+ """
957+ pack = v3_header_pack if version >= 3 else header_pack
958+ f .write (pack (version , flags , stream_id , opcode ))
959+ write_int (f , length )
960+
961+ @classmethod
962+ def decode_message (cls , protocol_version , user_type_map , stream_id , flags , opcode , body ,
963+ decompressor ):
964+ """
965+ Decodes a native protocol message body
966+
967+ :param protocol_version: version to use decoding contents
968+ :param user_type_map: map[keyspace name] = map[type name] = custom type to instantiate when deserializing this type
969+ :param stream_id: native protocol stream id from the frame header
970+ :param flags: native protocol flags bitmap from the header
971+ :param opcode: native protocol opcode from the header
972+ :param body: frame body
973+ :param decompressor: optional decompression function to inflate the body
974+ :return: a message decoded from the body and frame attributes
975+ """
976+ if flags & COMPRESSED_FLAG :
977+ if decompressor is None :
978+ raise Exception ("No de-compressor available for compressed frame!" )
979+ body = decompressor (body )
980+ flags ^= COMPRESSED_FLAG
981+
982+ body = io .BytesIO (body )
983+ if flags & TRACING_FLAG :
984+ trace_id = UUID (bytes = body .read (16 ))
985+ flags ^= TRACING_FLAG
986+ else :
987+ trace_id = None
988+
989+ if flags & WARNING_FLAG :
990+ warnings = read_stringlist (body )
991+ flags ^= WARNING_FLAG
992+ else :
993+ warnings = None
994+
995+ if flags & CUSTOM_PAYLOAD_FLAG :
996+ custom_payload = read_bytesmap (body )
997+ flags ^= CUSTOM_PAYLOAD_FLAG
998+ else :
999+ custom_payload = None
1000+
1001+ if flags :
1002+ log .warning ("Unknown protocol flags set: %02x. May cause problems." , flags )
1003+
1004+ msg_class = cls .message_types_by_opcode [opcode ]
1005+ msg = msg_class .recv_body (body , protocol_version , user_type_map )
1006+ msg .stream_id = stream_id
1007+ msg .trace_id = trace_id
1008+ msg .custom_payload = custom_payload
1009+ msg .warnings = warnings
1010+
1011+ if msg .warnings :
1012+ for w in msg .warnings :
1013+ log .warning ("Server warning: %s" , w )
1014+
1015+ return msg
9741016
9751017
9761018def read_byte (f ):
0 commit comments