1-
21try :
32 import unittest2 as unittest
43except ImportError :
54 import unittest # noqa
65
76from mock import Mock
8- from cassandra import ProtocolVersion
9- from cassandra .protocol import PrepareMessage , QueryMessage , ExecuteMessage
7+ from cassandra import ProtocolVersion , UnsupportedOperation
8+ from cassandra .protocol import (PrepareMessage , QueryMessage , ExecuteMessage ,
9+ BatchMessage )
10+ from cassandra .query import SimpleStatement , BatchType
1011
1112class MessageTest (unittest .TestCase ):
1213
@@ -53,20 +54,21 @@ def test_query_message(self):
5354
5455 @test_category connection
5556 """
56- message = QueryMessage ("a" ,3 )
57+ message = QueryMessage ("a" , 3 )
5758 io = Mock ()
58-
59- message .send_body (io ,4 )
59+
60+ message .send_body (io , 4 )
6061 self ._check_calls (io , [(b'\x00 \x00 \x00 \x01 ' ,), (b'a' ,), (b'\x00 \x03 ' ,), (b'\x00 ' ,)])
6162
6263 io .reset_mock ()
63- message .send_body (io ,5 )
64+ message .send_body (io , 5 )
6465 self ._check_calls (io , [(b'\x00 \x00 \x00 \x01 ' ,), (b'a' ,), (b'\x00 \x03 ' ,), (b'\x00 \x00 \x00 \x00 ' ,)])
6566
6667 def _check_calls (self , io , expected ):
67- self .assertEqual (len (io .write .mock_calls ), len (expected ))
68- for call , expect in zip (io .write .mock_calls , expected ):
69- self .assertEqual (call [1 ], expect )
68+ self .assertEqual (
69+ tuple (c [1 ] for c in io .write .mock_calls ),
70+ tuple (expected )
71+ )
7072
7173 def test_prepare_flag (self ):
7274 """
@@ -83,9 +85,86 @@ def test_prepare_flag(self):
8385 for version in ProtocolVersion .SUPPORTED_VERSIONS :
8486 message .send_body (io , version )
8587 if ProtocolVersion .uses_prepare_flags (version ):
86- # This should pass after PYTHON-696
8788 self .assertEqual (len (io .write .mock_calls ), 3 )
88- # self.assertEqual(uint32_unpack(io.write.mock_calls[2][1][0]) & _WITH_SERIAL_CONSISTENCY_FLAG, 1)
8989 else :
9090 self .assertEqual (len (io .write .mock_calls ), 2 )
9191 io .reset_mock ()
92+
93+ def test_prepare_flag_with_keyspace (self ):
94+ message = PrepareMessage ("a" , keyspace = 'ks' )
95+ io = Mock ()
96+
97+ for version in ProtocolVersion .SUPPORTED_VERSIONS :
98+ if ProtocolVersion .uses_keyspace_flag (version ):
99+ message .send_body (io , version )
100+ self ._check_calls (io , [
101+ ('\x00 \x00 \x00 \x01 ' ,),
102+ ('a' ,),
103+ ('\x00 \x00 \x00 \x80 ' ,),
104+ (b'\x00 \x02 ' ,),
105+ (b'ks' ,),
106+ ])
107+ else :
108+ with self .assertRaises (UnsupportedOperation ):
109+ message .send_body (io , version )
110+ io .reset_mock ()
111+
112+ def test_keyspace_flag_raises_before_v5 (self ):
113+ keyspace_message = QueryMessage ('a' , consistency_level = 3 , keyspace = 'ks' )
114+ io = Mock (name = 'io' )
115+
116+ with self .assertRaisesRegex (UnsupportedOperation , 'Keyspaces.*set' ):
117+ keyspace_message .send_body (io , protocol_version = 4 )
118+ io .assert_not_called ()
119+
120+ def test_keyspace_written_with_length (self ):
121+ io = Mock (name = 'io' )
122+ base_expected = [
123+ (b'\x00 \x00 \x00 \x01 ' ,),
124+ (b'a' ,),
125+ (b'\x00 \x03 ' ,),
126+ (b'\x00 \x00 \x00 \x80 ' ,), # options w/ keyspace flag
127+ ]
128+
129+ QueryMessage ('a' , consistency_level = 3 , keyspace = 'ks' ).send_body (
130+ io , protocol_version = 5
131+ )
132+ self ._check_calls (io , base_expected + [
133+ (b'\x00 \x02 ' ,), # length of keyspace string
134+ (b'ks' ,),
135+ ])
136+
137+ io .reset_mock ()
138+
139+ QueryMessage ('a' , consistency_level = 3 , keyspace = 'keyspace' ).send_body (
140+ io , protocol_version = 5
141+ )
142+ self ._check_calls (io , base_expected + [
143+ (b'\x00 \x08 ' ,), # length of keyspace string
144+ (b'keyspace' ,),
145+ ])
146+
147+ def test_batch_message_with_keyspace (self ):
148+ self .maxDiff = None
149+ io = Mock (name = 'io' )
150+ batch = BatchMessage (
151+ batch_type = BatchType .LOGGED ,
152+ queries = ((False , 'stmt a' , ('param a' ,)),
153+ (False , 'stmt b' , ('param b' ,)),
154+ (False , 'stmt c' , ('param c' ,))
155+ ),
156+ consistency_level = 3 ,
157+ keyspace = 'ks'
158+ )
159+ batch .send_body (io , protocol_version = 5 )
160+ self ._check_calls (io ,
161+ (('\x00 ' ,), ('\x00 \x03 ' ,), ('\x00 ' ,),
162+ ('\x00 \x00 \x00 \x06 ' ,), ('stmt a' ,),
163+ ('\x00 \x01 ' ,), ('\x00 \x00 \x00 \x07 ' ,), ('param a' ,),
164+ ('\x00 ' ,), ('\x00 \x00 \x00 \x06 ' ,), ('stmt b' ,),
165+ ('\x00 \x01 ' ,), ('\x00 \x00 \x00 \x07 ' ,), ('param b' ,),
166+ ('\x00 ' ,), ('\x00 \x00 \x00 \x06 ' ,), ('stmt c' ,),
167+ ('\x00 \x01 ' ,), ('\x00 \x00 \x00 \x07 ' ,), ('param c' ,),
168+ ('\x00 \x03 ' ,),
169+ ('\x00 \x00 \x00 \x80 ' ,), ('\x00 \x02 ' ,), ('ks' ,))
170+ )
0 commit comments