Skip to content

Commit 32407d7

Browse files
committed
Merge pull request apache#425 from datastax/430
PYTHON-430 - fix paged results for non-standard protocol handlers
2 parents 253bdcf + 795acbd commit 32407d7

4 files changed

Lines changed: 82 additions & 10 deletions

File tree

cassandra/cluster.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
BatchMessage, RESULT_KIND_PREPARED,
6262
RESULT_KIND_SET_KEYSPACE, RESULT_KIND_ROWS,
6363
RESULT_KIND_SCHEMA_CHANGE, MIN_SUPPORTED_VERSION,
64-
ProtocolHandler)
64+
ProtocolHandler, _RESULT_SEQUENCE_TYPES)
6565
from cassandra.metadata import Metadata, protect_name, murmur3
6666
from cassandra.policies import (TokenAwarePolicy, DCAwareRoundRobinPolicy, SimpleConvictionPolicy,
6767
ExponentialReconnectionPolicy, HostDistance,
@@ -3349,7 +3349,7 @@ class ResultSet(object):
33493349

33503350
def __init__(self, response_future, initial_response):
33513351
self.response_future = response_future
3352-
self._current_rows = initial_response or []
3352+
self._set_current_rows(initial_response)
33533353
self._page_iter = None
33543354
self._list_mode = False
33553355

@@ -3390,10 +3390,16 @@ def fetch_next_page(self):
33903390
if self.response_future.has_more_pages:
33913391
self.response_future.start_fetching_next_page()
33923392
result = self.response_future.result()
3393-
self._current_rows = result._current_rows
3393+
self._current_rows = result._current_rows # ResultSet has already _set_current_rows to the appropriate form
33943394
else:
33953395
self._current_rows = []
33963396

3397+
def _set_current_rows(self, result):
3398+
if isinstance(result, _RESULT_SEQUENCE_TYPES):
3399+
self._current_rows = result
3400+
else:
3401+
self._current_rows = [result] if result else []
3402+
33973403
def _fetch_all(self):
33983404
self._current_rows = list(self)
33993405
self._page_iter = None

cassandra/obj_parser.pyx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,17 @@ cdef class LazyParser(ColumnParser):
3838
# supported in cpdef methods
3939
return parse_rows_lazy(reader, desc)
4040

41+
cpdef get_cython_generator_type(self):
42+
return get_cython_generator_type()
4143

4244
def parse_rows_lazy(BytesIOReader reader, ParseDesc desc):
4345
cdef Py_ssize_t i, rowcount
4446
rowcount = read_int(reader)
4547
cdef RowParser rowparser = TupleRowParser()
4648
return (rowparser.unpack_row(reader, desc) for i in range(rowcount))
4749

50+
def get_cython_generator_type():
51+
return type((i for i in range(0)))
4852

4953
cdef class TupleRowParser(RowParser):
5054
"""

cassandra/protocol.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,6 +1000,7 @@ def decode_message(cls, protocol_version, user_type_map, stream_id, flags, opcod
10001000

10011001
return msg
10021002

1003+
_RESULT_SEQUENCE_TYPES = (list, tuple) # types retuned by ResultMessages
10031004

10041005
def cython_protocol_handler(colparser):
10051006
"""
@@ -1045,6 +1046,9 @@ class CythonProtocolHandler(ProtocolHandler):
10451046
if HAVE_CYTHON:
10461047
from cassandra.obj_parser import ListParser, LazyParser
10471048
ProtocolHandler = cython_protocol_handler(ListParser())
1049+
1050+
lazy_parser = LazyParser()
1051+
_RESULT_SEQUENCE_TYPES += (lazy_parser.get_cython_generator_type(),)
10481052
LazyProtocolHandler = cython_protocol_handler(LazyParser())
10491053
else:
10501054
# Use Python-based ProtocolHandler

tests/integration/standard/test_cython_protocol_handlers.py

Lines changed: 65 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,21 +55,79 @@ def test_cython_lazy_parser(self):
5555
"""
5656
verify_iterator_data(self.assertEqual, get_data(LazyProtocolHandler))
5757

58+
@numpytest
59+
def test_cython_lazy_results_paged(self):
60+
"""
61+
Test Cython-based parser that returns an iterator, over multiple pages
62+
"""
63+
# arrays = { 'a': arr1, 'b': arr2, ... }
64+
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
65+
session = cluster.connect(keyspace="testspace")
66+
session.row_factory = tuple_factory
67+
session.client_protocol_handler = LazyProtocolHandler
68+
session.default_fetch_size = 2
69+
70+
self.assertLess(session.default_fetch_size, self.N_ITEMS)
71+
72+
results = session.execute("SELECT * FROM test_table")
73+
74+
self.assertTrue(results.has_more_pages)
75+
self.assertEqual(verify_iterator_data(self.assertEqual, results), self.N_ITEMS) # make sure we see all rows
76+
77+
cluster.shutdown()
78+
5879
@numpytest
5980
def test_numpy_parser(self):
6081
"""
6182
Test Numpy-based parser that returns a NumPy array
6283
"""
6384
# arrays = { 'a': arr1, 'b': arr2, ... }
64-
arrays = get_data(NumpyProtocolHandler)
85+
result = get_data(NumpyProtocolHandler)
86+
self.assertFalse(result.has_more_pages)
87+
self._verify_numpy_page(result[0])
6588

89+
@numpytest
90+
def test_numpy_results_paged(self):
91+
"""
92+
Test Numpy-based parser that returns a NumPy array
93+
"""
94+
# arrays = { 'a': arr1, 'b': arr2, ... }
95+
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
96+
session = cluster.connect(keyspace="testspace")
97+
session.row_factory = tuple_factory
98+
session.client_protocol_handler = NumpyProtocolHandler
99+
session.default_fetch_size = 2
100+
101+
expected_pages = (self.N_ITEMS + session.default_fetch_size - 1) // session.default_fetch_size
102+
103+
self.assertLess(session.default_fetch_size, self.N_ITEMS)
104+
105+
results = session.execute("SELECT * FROM test_table")
106+
107+
self.assertTrue(results.has_more_pages)
108+
for count, page in enumerate(results, 1):
109+
self.assertIsInstance(page, dict)
110+
for colname, arr in page.items():
111+
if count <= expected_pages:
112+
self.assertGreater(len(arr), 0, "page count: %d" % (count,))
113+
self.assertLessEqual(len(arr), session.default_fetch_size)
114+
else:
115+
# we get one extra item out of this iteration because of the way NumpyParser returns results
116+
# The last page is returned as a dict with zero-length arrays
117+
self.assertEqual(len(arr), 0)
118+
self.assertEqual(self._verify_numpy_page(page), len(arr))
119+
self.assertEqual(count, expected_pages + 1) # see note about extra 'page' above
120+
121+
cluster.shutdown()
122+
123+
def _verify_numpy_page(self, page):
66124
colnames = self.colnames
67125
datatypes = get_primitive_datatypes()
68126
for colname, datatype in zip(colnames, datatypes):
69-
arr = arrays[colname]
127+
arr = page[colname]
70128
self.match_dtype(datatype, arr.dtype)
71129

72-
verify_iterator_data(self.assertEqual, arrays_to_list_of_tuples(arrays, colnames))
130+
return verify_iterator_data(self.assertEqual, arrays_to_list_of_tuples(page, colnames))
73131

74132
def match_dtype(self, datatype, dtype):
75133
"""Match a string cqltype (e.g. 'int' or 'blob') with a numpy dtype"""
@@ -100,9 +158,7 @@ def arrays_to_list_of_tuples(arrays, colnames):
100158

101159
def get_data(protocol_handler):
102160
"""
103-
Get some data from the test table.
104-
105-
:param key: if None, get all results (100.000 results), otherwise get only one result
161+
Get data from the test table.
106162
"""
107163
cluster = Cluster(protocol_version=PROTOCOL_VERSION)
108164
session = cluster.connect(keyspace="testspace")
@@ -121,9 +177,11 @@ def verify_iterator_data(assertEqual, results):
121177
Check the result of get_data() when this is a list or
122178
iterator of tuples
123179
"""
124-
for result in results:
180+
count = 0
181+
for count, result in enumerate(results, 1):
125182
params = get_all_primitive_params(result[0])
126183
assertEqual(len(params), len(result),
127184
msg="Not the right number of columns?")
128185
for expected, actual in zip(params, result):
129186
assertEqual(actual, expected)
187+
return count

0 commit comments

Comments
 (0)