Skip to content

Commit f5694e4

Browse files
committed
add and test keyspace-per-query
CASSANDRA-10145/PYTHON-678
1 parent fc8a2c1 commit f5694e4

3 files changed

Lines changed: 149 additions & 17 deletions

File tree

cassandra/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,10 @@ def uses_prepare_flags(cls, version):
204204
def uses_error_code_map(cls, version):
205205
return version >= cls.V5
206206

207+
@classmethod
208+
def uses_keyspace_flag(cls, version):
209+
return version >= cls.V5
210+
207211

208212
class SchemaChangeType(object):
209213
DROPPED = 'DROPPED'

cassandra/protocol.py

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -507,20 +507,22 @@ def recv_body(cls, f, *args):
507507
_WITH_PAGING_STATE_FLAG = 0x08
508508
_WITH_SERIAL_CONSISTENCY_FLAG = 0x10
509509
_PROTOCOL_TIMESTAMP = 0x20
510+
_WITH_KEYSPACE_FLAG = 0x80
510511

511512

512513
class QueryMessage(_MessageType):
513514
opcode = 0x07
514515
name = 'QUERY'
515516

516517
def __init__(self, query, consistency_level, serial_consistency_level=None,
517-
fetch_size=None, paging_state=None, timestamp=None):
518+
fetch_size=None, paging_state=None, timestamp=None, keyspace=None):
518519
self.query = query
519520
self.consistency_level = consistency_level
520521
self.serial_consistency_level = serial_consistency_level
521522
self.fetch_size = fetch_size
522523
self.paging_state = paging_state
523524
self.timestamp = timestamp
525+
self.keyspace = keyspace
524526
self._query_params = None # only used internally. May be set to a list of native-encoded values to have them sent with the request.
525527

526528
def send_body(self, f, protocol_version):
@@ -558,6 +560,14 @@ def send_body(self, f, protocol_version):
558560
if self.timestamp is not None:
559561
flags |= _PROTOCOL_TIMESTAMP
560562

563+
if self.keyspace is not None:
564+
if ProtocolVersion.uses_keyspace_flag(protocol_version):
565+
flags |= _WITH_KEYSPACE_FLAG
566+
else:
567+
raise UnsupportedOperation(
568+
"Keyspaces may only be set on queries with protocol version "
569+
"5 or higher. Consider setting Cluster.protocol_version to 5.")
570+
561571
if ProtocolVersion.uses_int_query_flags(protocol_version):
562572
write_uint(f, flags)
563573
else:
@@ -576,6 +586,8 @@ def send_body(self, f, protocol_version):
576586
write_consistency_level(f, self.serial_consistency_level)
577587
if self.timestamp is not None:
578588
write_long(f, self.timestamp)
589+
if self.keyspace is not None:
590+
write_string(f, self.keyspace)
579591

580592

581593
CUSTOM_TYPE = object()
@@ -768,14 +780,38 @@ class PrepareMessage(_MessageType):
768780
opcode = 0x09
769781
name = 'PREPARE'
770782

771-
def __init__(self, query):
783+
def __init__(self, query, keyspace=None):
772784
self.query = query
785+
self.keyspace = keyspace
773786

774787
def send_body(self, f, protocol_version):
775788
write_longstring(f, self.query)
789+
790+
flags = 0x00
791+
792+
if self.keyspace is not None:
793+
if ProtocolVersion.uses_keyspace_flag(protocol_version):
794+
flags |= _WITH_KEYSPACE_FLAG
795+
else:
796+
raise UnsupportedOperation(
797+
"Keyspaces may only be set on queries with protocol version "
798+
"5 or higher. Consider setting Cluster.protocol_version to 5.")
799+
776800
if ProtocolVersion.uses_prepare_flags(protocol_version):
777-
# Write the flags byte; with 0 value for now, but this should change in PYTHON-678
778-
write_uint(f, 0)
801+
write_uint(f, flags)
802+
else:
803+
# checks above should prevent this, but just to be safe...
804+
if flags:
805+
raise UnsupportedOperation(
806+
"Attempted to set flags with value {flags:0=#8x} on"
807+
"protocol version {pv}, which doesn't support flags"
808+
"in prepared statements."
809+
"Consider setting Cluster.protocol_version to 5."
810+
"".format(flags=flags, pv=protocol_version))
811+
812+
if ProtocolVersion.uses_keyspace_flag(protocol_version):
813+
if self.keyspace:
814+
write_string(f, self.keyspace)
779815

780816

781817
class ExecuteMessage(_MessageType):
@@ -852,12 +888,14 @@ class BatchMessage(_MessageType):
852888
name = 'BATCH'
853889

854890
def __init__(self, batch_type, queries, consistency_level,
855-
serial_consistency_level=None, timestamp=None):
891+
serial_consistency_level=None, timestamp=None,
892+
keyspace=None):
856893
self.batch_type = batch_type
857894
self.queries = queries
858895
self.consistency_level = consistency_level
859896
self.serial_consistency_level = serial_consistency_level
860897
self.timestamp = timestamp
898+
self.keyspace = keyspace
861899

862900
def send_body(self, f, protocol_version):
863901
write_byte(f, self.batch_type.value)
@@ -881,6 +919,13 @@ def send_body(self, f, protocol_version):
881919
flags |= _WITH_SERIAL_CONSISTENCY_FLAG
882920
if self.timestamp is not None:
883921
flags |= _PROTOCOL_TIMESTAMP
922+
if self.keyspace:
923+
if ProtocolVersion.uses_keyspace_flag(protocol_version):
924+
flags |= _WITH_KEYSPACE_FLAG
925+
else:
926+
raise UnsupportedOperation(
927+
"Keyspaces may only be set on queries with protocol version "
928+
"5 or higher. Consider setting Cluster.protocol_version to 5.")
884929

885930
if ProtocolVersion.uses_int_query_flags(protocol_version):
886931
write_int(f, flags)
@@ -892,6 +937,10 @@ def send_body(self, f, protocol_version):
892937
if self.timestamp is not None:
893938
write_long(f, self.timestamp)
894939

940+
if ProtocolVersion.uses_keyspace_flag(protocol_version):
941+
if self.keyspace is not None:
942+
write_string(f, self.keyspace)
943+
895944

896945
known_event_types = frozenset((
897946
'TOPOLOGY_CHANGE',

tests/unit/test_protocol.py

Lines changed: 91 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
21
try:
32
import unittest2 as unittest
43
except ImportError:
54
import unittest # noqa
65

76
from mock import Mock
8-
from cassandra import ProtocolVersion
9-
from cassandra.protocol import PrepareMessage, QueryMessage, ExecuteMessage
7+
from cassandra import ProtocolVersion, UnsupportedOperation
8+
from cassandra.protocol import (PrepareMessage, QueryMessage, ExecuteMessage,
9+
BatchMessage)
10+
from cassandra.query import SimpleStatement, BatchType
1011

1112
class MessageTest(unittest.TestCase):
1213

@@ -53,20 +54,21 @@ def test_query_message(self):
5354
5455
@test_category connection
5556
"""
56-
message = QueryMessage("a",3)
57+
message = QueryMessage("a", 3)
5758
io = Mock()
58-
59-
message.send_body(io,4)
59+
60+
message.send_body(io, 4)
6061
self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',), (b'\x00\x03',), (b'\x00',)])
6162

6263
io.reset_mock()
63-
message.send_body(io,5)
64+
message.send_body(io, 5)
6465
self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',), (b'\x00\x03',), (b'\x00\x00\x00\x00',)])
6566

6667
def _check_calls(self, io, expected):
67-
self.assertEqual(len(io.write.mock_calls), len(expected))
68-
for call, expect in zip(io.write.mock_calls, expected):
69-
self.assertEqual(call[1], expect)
68+
self.assertEqual(
69+
tuple(c[1] for c in io.write.mock_calls),
70+
tuple(expected)
71+
)
7072

7173
def test_prepare_flag(self):
7274
"""
@@ -83,9 +85,86 @@ def test_prepare_flag(self):
8385
for version in ProtocolVersion.SUPPORTED_VERSIONS:
8486
message.send_body(io, version)
8587
if ProtocolVersion.uses_prepare_flags(version):
86-
# This should pass after PYTHON-696
8788
self.assertEqual(len(io.write.mock_calls), 3)
88-
# self.assertEqual(uint32_unpack(io.write.mock_calls[2][1][0]) & _WITH_SERIAL_CONSISTENCY_FLAG, 1)
8989
else:
9090
self.assertEqual(len(io.write.mock_calls), 2)
9191
io.reset_mock()
92+
93+
def test_prepare_flag_with_keyspace(self):
94+
message = PrepareMessage("a", keyspace='ks')
95+
io = Mock()
96+
97+
for version in ProtocolVersion.SUPPORTED_VERSIONS:
98+
if ProtocolVersion.uses_keyspace_flag(version):
99+
message.send_body(io, version)
100+
self._check_calls(io, [
101+
('\x00\x00\x00\x01',),
102+
('a',),
103+
('\x00\x00\x00\x80',),
104+
(b'\x00\x02',),
105+
(b'ks',),
106+
])
107+
else:
108+
with self.assertRaises(UnsupportedOperation):
109+
message.send_body(io, version)
110+
io.reset_mock()
111+
112+
def test_keyspace_flag_raises_before_v5(self):
113+
keyspace_message = QueryMessage('a', consistency_level=3, keyspace='ks')
114+
io = Mock(name='io')
115+
116+
with self.assertRaisesRegex(UnsupportedOperation, 'Keyspaces.*set'):
117+
keyspace_message.send_body(io, protocol_version=4)
118+
io.assert_not_called()
119+
120+
def test_keyspace_written_with_length(self):
121+
io = Mock(name='io')
122+
base_expected = [
123+
(b'\x00\x00\x00\x01',),
124+
(b'a',),
125+
(b'\x00\x03',),
126+
(b'\x00\x00\x00\x80',), # options w/ keyspace flag
127+
]
128+
129+
QueryMessage('a', consistency_level=3, keyspace='ks').send_body(
130+
io, protocol_version=5
131+
)
132+
self._check_calls(io, base_expected + [
133+
(b'\x00\x02',), # length of keyspace string
134+
(b'ks',),
135+
])
136+
137+
io.reset_mock()
138+
139+
QueryMessage('a', consistency_level=3, keyspace='keyspace').send_body(
140+
io, protocol_version=5
141+
)
142+
self._check_calls(io, base_expected + [
143+
(b'\x00\x08',), # length of keyspace string
144+
(b'keyspace',),
145+
])
146+
147+
def test_batch_message_with_keyspace(self):
148+
self.maxDiff = None
149+
io = Mock(name='io')
150+
batch = BatchMessage(
151+
batch_type=BatchType.LOGGED,
152+
queries=((False, 'stmt a', ('param a',)),
153+
(False, 'stmt b', ('param b',)),
154+
(False, 'stmt c', ('param c',))
155+
),
156+
consistency_level=3,
157+
keyspace='ks'
158+
)
159+
batch.send_body(io, protocol_version=5)
160+
self._check_calls(io,
161+
(('\x00',), ('\x00\x03',), ('\x00',),
162+
('\x00\x00\x00\x06',), ('stmt a',),
163+
('\x00\x01',), ('\x00\x00\x00\x07',), ('param a',),
164+
('\x00',), ('\x00\x00\x00\x06',), ('stmt b',),
165+
('\x00\x01',), ('\x00\x00\x00\x07',), ('param b',),
166+
('\x00',), ('\x00\x00\x00\x06',), ('stmt c',),
167+
('\x00\x01',), ('\x00\x00\x00\x07',), ('param c',),
168+
('\x00\x03',),
169+
('\x00\x00\x00\x80',), ('\x00\x02',), ('ks',))
170+
)

0 commit comments

Comments
 (0)