@@ -55,21 +55,79 @@ def test_cython_lazy_parser(self):
5555 """
5656 verify_iterator_data (self .assertEqual , get_data (LazyProtocolHandler ))
5757
58+ @numpytest
59+ def test_cython_lazy_results_paged (self ):
60+ """
61+ Test Cython-based parser that returns an iterator, over multiple pages
62+ """
63+ # arrays = { 'a': arr1, 'b': arr2, ... }
64+ cluster = Cluster (protocol_version = PROTOCOL_VERSION )
65+ session = cluster .connect (keyspace = "testspace" )
66+ session .row_factory = tuple_factory
67+ session .client_protocol_handler = LazyProtocolHandler
68+ session .default_fetch_size = 2
69+
70+ self .assertLess (session .default_fetch_size , self .N_ITEMS )
71+
72+ results = session .execute ("SELECT * FROM test_table" )
73+
74+ self .assertTrue (results .has_more_pages )
75+ self .assertEqual (verify_iterator_data (self .assertEqual , results ), self .N_ITEMS ) # make sure we see all rows
76+
77+ cluster .shutdown ()
78+
5879 @numpytest
5980 def test_numpy_parser (self ):
6081 """
6182 Test Numpy-based parser that returns a NumPy array
6283 """
6384 # arrays = { 'a': arr1, 'b': arr2, ... }
64- arrays = get_data (NumpyProtocolHandler )
85+ result = get_data (NumpyProtocolHandler )
86+ self .assertFalse (result .has_more_pages )
87+ self ._verify_numpy_page (result [0 ])
6588
89+ @numpytest
90+ def test_numpy_results_paged (self ):
91+ """
92+ Test Numpy-based parser that returns a NumPy array
93+ """
94+ # arrays = { 'a': arr1, 'b': arr2, ... }
95+ cluster = Cluster (protocol_version = PROTOCOL_VERSION )
96+ session = cluster .connect (keyspace = "testspace" )
97+ session .row_factory = tuple_factory
98+ session .client_protocol_handler = NumpyProtocolHandler
99+ session .default_fetch_size = 2
100+
101+ expected_pages = (self .N_ITEMS + session .default_fetch_size - 1 ) // session .default_fetch_size
102+
103+ self .assertLess (session .default_fetch_size , self .N_ITEMS )
104+
105+ results = session .execute ("SELECT * FROM test_table" )
106+
107+ self .assertTrue (results .has_more_pages )
108+ for count , page in enumerate (results , 1 ):
109+ self .assertIsInstance (page , dict )
110+ for colname , arr in page .items ():
111+ if count <= expected_pages :
112+ self .assertGreater (len (arr ), 0 , "page count: %d" % (count ,))
113+ self .assertLessEqual (len (arr ), session .default_fetch_size )
114+ else :
115+ # we get one extra item out of this iteration because of the way NumpyParser returns results
116+ # The last page is returned as a dict with zero-length arrays
117+ self .assertEqual (len (arr ), 0 )
118+ self .assertEqual (self ._verify_numpy_page (page ), len (arr ))
119+ self .assertEqual (count , expected_pages + 1 ) # see note about extra 'page' above
120+
121+ cluster .shutdown ()
122+
123+ def _verify_numpy_page (self , page ):
66124 colnames = self .colnames
67125 datatypes = get_primitive_datatypes ()
68126 for colname , datatype in zip (colnames , datatypes ):
69- arr = arrays [colname ]
127+ arr = page [colname ]
70128 self .match_dtype (datatype , arr .dtype )
71129
72- verify_iterator_data (self .assertEqual , arrays_to_list_of_tuples (arrays , colnames ))
130+ return verify_iterator_data (self .assertEqual , arrays_to_list_of_tuples (page , colnames ))
73131
74132 def match_dtype (self , datatype , dtype ):
75133 """Match a string cqltype (e.g. 'int' or 'blob') with a numpy dtype"""
@@ -100,9 +158,7 @@ def arrays_to_list_of_tuples(arrays, colnames):
100158
101159def get_data (protocol_handler ):
102160 """
103- Get some data from the test table.
104-
105- :param key: if None, get all results (100.000 results), otherwise get only one result
161+ Get data from the test table.
106162 """
107163 cluster = Cluster (protocol_version = PROTOCOL_VERSION )
108164 session = cluster .connect (keyspace = "testspace" )
@@ -121,9 +177,11 @@ def verify_iterator_data(assertEqual, results):
121177 Check the result of get_data() when this is a list or
122178 iterator of tuples
123179 """
124- for result in results :
180+ count = 0
181+ for count , result in enumerate (results , 1 ):
125182 params = get_all_primitive_params (result [0 ])
126183 assertEqual (len (params ), len (result ),
127184 msg = "Not the right number of columns?" )
128185 for expected , actual in zip (params , result ):
129186 assertEqual (actual , expected )
187+ return count
0 commit comments