From a9116ac4be286c3dba858c87509c3861fbdfcf42 Mon Sep 17 00:00:00 2001 From: Adam Holmberg Date: Wed, 7 Oct 2015 16:34:03 -0500 Subject: [PATCH 01/14] always return ResultSet for query results --- cassandra/cluster.py | 67 +++++++++++++------ cassandra/concurrent.py | 6 +- docs/api/cassandra/cluster.rst | 2 +- .../long/test_loadbalancingpolicies.py | 2 - tests/integration/standard/test_concurrent.py | 4 +- .../standard/test_custom_protocol_handler.py | 11 ++- tests/integration/standard/test_metrics.py | 6 +- tests/integration/standard/test_query.py | 12 ++-- .../integration/standard/test_query_paging.py | 32 ++++----- .../standard/test_row_factories.py | 10 +-- tests/integration/standard/test_udts.py | 10 --- 11 files changed, 86 insertions(+), 76 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 0f3a554faa..42328f85f9 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -2720,6 +2720,7 @@ class ResponseFuture(object): _start_time = None _metrics = None _paging_state = None + _is_result_kind_rows = False _custom_payload = None _warnings = None _timer = None @@ -2930,7 +2931,8 @@ def _set_result(self, response): self, **response.results) else: results = getattr(response, 'results', None) - if results is not None and response.kind == RESULT_KIND_ROWS: + self._is_result_kind_rows = response.kind ==RESULT_KIND_ROWS + if results is not None and self._is_result_kind_rows: self._paging_state = response.paging_state results = self.row_factory(*results) self._set_final_result(results) @@ -3197,10 +3199,10 @@ def result(self, timeout=_NOT_SET): if not self._event.is_set(): self._on_timeout() if self._final_result is not _NOT_SET: - if self._paging_state is None: - return self._final_result + if self._is_result_kind_rows: + return ResultSet(self, self._final_result) else: - return PagedResult(self, self._final_result) + return self._final_result else: raise self._final_exception @@ -3333,51 +3335,74 @@ class QueryExhausted(Exception): pass -class PagedResult(object): +class ResultSet(object): """ - An iterator over the rows from a paged query result. Whenever the number - of result rows for a query exceed the :attr:`~.query.Statement.fetch_size` - (or :attr:`~.Session.default_fetch_size`, if not set) an instance of this - class will be returned. + An iterator over the rows from a query result. Also supplies basic equality + and indexing methods for backward-compatability. These methods materialize + the entire result set (loading all pages), and should only be used if the + total result size is understood. Warnings are emitted when paged results + are materialized in this fashion. You can treat this as a normal iterator over rows:: >>> from cassandra.query import SimpleStatement >>> statement = SimpleStatement("SELECT * FROM users", fetch_size=10) >>> for user_row in session.execute(statement): - ... process_user(user_row) + ... process_user(user_rowt Whenever there are no more rows in the current page, the next page will be fetched transparently. However, note that it *is* possible for an :class:`Exception` to be raised while fetching the next page, just like you might see on a normal call to ``session.execute()``. - - .. versionadded: 2.0.0 """ - response_future = None - def __init__(self, response_future, initial_response): self.response_future = response_future - self.current_response = iter(initial_response) + self._had_pages = response_future.has_more_pages + self._current_rows = initial_response + self._page_iter = None + + @property + def has_more_pages(self): + return self.response_future.has_more_pages def __iter__(self): + self._page_iter = iter(self._current_rows) return self def next(self): try: - return next(self.current_response) + return next(self._page_iter) except StopIteration: if not self.response_future.has_more_pages: raise self.response_future.start_fetching_next_page() result = self.response_future.result() - if self.response_future.has_more_pages: - self.current_response = result.current_response - else: - self.current_response = iter(result) + self._current_rows = result._current_rows + self._page_iter = iter(self._current_rows) - return next(self.current_response) + return next(self._page_iter) __next__ = next + + def __eq__(self, other): + if self._page_iter and self._had_pages: + raise RuntimeError("Cannot test equality when paged results have been consumed.") + if self.response_future.has_more_pages: + log.warning("Using equality operator on paged results causes entire result set to be materialized.") + self._current_rows = list(self) + return self._current_rows == other + + def __getitem__(self, i): + if self._page_iter and self._had_pages: + raise RuntimeError("Cannot index when paged results have been consumed.") + if self.response_future.has_more_pages: + log.warning("Using indexing on paged results causes entire result set to be materialized.") + self._current_rows = list(self) + return self._current_rows[i] + + def __nonzero__(self): + return bool(self._current_rows) + + __bool__ = __nonzero__ diff --git a/cassandra/concurrent.py b/cassandra/concurrent.py index d29ff84bd1..3391fdeaa2 100644 --- a/cassandra/concurrent.py +++ b/cassandra/concurrent.py @@ -20,7 +20,7 @@ from threading import Condition import sys -from cassandra.cluster import PagedResult +from cassandra.cluster import ResultSet import logging log = logging.getLogger(__name__) @@ -134,8 +134,8 @@ def _execute(self, idx, statement, params): self._put_result(e, idx, False) def _on_success(self, result, future, idx): - if future.has_more_pages: - result = PagedResult(future, result) + if future._is_result_kind_rows: + result = ResultSet(future, result) future.clear_callbacks() self._put_result(result, idx, True) diff --git a/docs/api/cassandra/cluster.rst b/docs/api/cassandra/cluster.rst index 1ccd9282b7..3ea0d23091 100644 --- a/docs/api/cassandra/cluster.rst +++ b/docs/api/cassandra/cluster.rst @@ -140,7 +140,7 @@ .. automethod:: add_callbacks(callback, errback, callback_args=(), callback_kwargs=None, errback_args=(), errback_args=None) -.. autoclass:: PagedResult () +.. autoclass:: ResultSet () :members: .. autoexception:: QueryExhausted () diff --git a/tests/integration/long/test_loadbalancingpolicies.py b/tests/integration/long/test_loadbalancingpolicies.py index b1a8d4d251..40f0893944 100644 --- a/tests/integration/long/test_loadbalancingpolicies.py +++ b/tests/integration/long/test_loadbalancingpolicies.py @@ -493,7 +493,6 @@ def test_token_aware_composite_key(self): session.execute(prepared.bind((1, 2, 3))) results = session.execute('SELECT * FROM %s WHERE k1 = 1 AND k2 = 2' % table) - self.assertTrue(len(results) == 1) self.assertTrue(results[0].i) cluster.shutdown() @@ -539,7 +538,6 @@ def test_token_aware_with_local_table(self): p = session.prepare("SELECT * FROM system.local WHERE key=?") # this would blow up prior to 61b4fad r = session.execute(p, ('local',)) - self.assertEqual(len(r), 1) self.assertEqual(r[0].key, 'local') cluster.shutdown() diff --git a/tests/integration/standard/test_concurrent.py b/tests/integration/standard/test_concurrent.py index 4d9ce3aeac..23b810c535 100644 --- a/tests/integration/standard/test_concurrent.py +++ b/tests/integration/standard/test_concurrent.py @@ -18,7 +18,7 @@ from cassandra import InvalidRequest, ConsistencyLevel, ReadTimeout, WriteTimeout, OperationTimedOut, \ ReadFailure, WriteFailure -from cassandra.cluster import Cluster, PagedResult +from cassandra.cluster import Cluster from cassandra.concurrent import execute_concurrent, execute_concurrent_with_args from cassandra.policies import HostDistance from cassandra.query import tuple_factory, SimpleStatement @@ -184,7 +184,7 @@ def test_execute_concurrent_paged_result(self): self.assertEqual(1, len(results)) self.assertTrue(results[0][0]) result = results[0][1] - self.assertIsInstance(result, PagedResult) + self.assertTrue(result.has_more_pages) self.assertEqual(num_statements, sum(1 for _ in result)) def test_execute_concurrent_paged_result_generator(self): diff --git a/tests/integration/standard/test_custom_protocol_handler.py b/tests/integration/standard/test_custom_protocol_handler.py index 36965a36ff..ce16e78c79 100644 --- a/tests/integration/standard/test_custom_protocol_handler.py +++ b/tests/integration/standard/test_custom_protocol_handler.py @@ -65,9 +65,8 @@ def test_custom_raw_uuid_row_results(self): # Ensure that we get normal uuid back first session = Cluster(protocol_version=PROTOCOL_VERSION).connect(keyspace="custserdes") session.row_factory = tuple_factory - result_set = session.execute("SELECT schema_version FROM system.local") - result = result_set.pop() - uuid_type = result[0] + result = session.execute("SELECT schema_version FROM system.local") + uuid_type = result[0][0] self.assertEqual(type(uuid_type), uuid.UUID) # use our custom protocol handlder @@ -75,16 +74,14 @@ def test_custom_raw_uuid_row_results(self): session.client_protocol_handler = CustomTestRawRowType session.row_factory = tuple_factory result_set = session.execute("SELECT schema_version FROM system.local") - result = result_set.pop() - raw_value = result.pop() + raw_value = result_set[0][0] self.assertTrue(isinstance(raw_value, binary_type)) self.assertEqual(len(raw_value), 16) # Ensure that we get normal uuid back when we re-connect session.client_protocol_handler = ProtocolHandler result_set = session.execute("SELECT schema_version FROM system.local") - result = result_set.pop() - uuid_type = result[0] + uuid_type = result_set[0][0] self.assertEqual(type(uuid_type), uuid.UUID) session.shutdown() diff --git a/tests/integration/standard/test_metrics.py b/tests/integration/standard/test_metrics.py index 6731e9073e..b2a046bae6 100644 --- a/tests/integration/standard/test_metrics.py +++ b/tests/integration/standard/test_metrics.py @@ -77,7 +77,7 @@ def test_write_timeout(self): # Assert read query = SimpleStatement("SELECT * FROM test WHERE k=1", consistency_level=ConsistencyLevel.ALL) results = execute_until_pass(session, query) - self.assertEqual(1, len(results)) + self.assertTrue(results) # Pause node so it shows as unreachable to coordinator get_node(1).pause() @@ -110,7 +110,7 @@ def test_read_timeout(self): # Assert read query = SimpleStatement("SELECT * FROM test WHERE k=1", consistency_level=ConsistencyLevel.ALL) results = execute_until_pass(session, query) - self.assertEqual(1, len(results)) + self.assertTrue(results) # Pause node so it shows as unreachable to coordinator get_node(1).pause() @@ -143,7 +143,7 @@ def test_unavailable(self): # Assert read query = SimpleStatement("SELECT * FROM test WHERE k=1", consistency_level=ConsistencyLevel.ALL) results = execute_until_pass(session, query) - self.assertEqual(1, len(results)) + self.assertTrue(results) # Stop node gracefully get_node(1).stop(wait=True, wait_other_notice=True) diff --git a/tests/integration/standard/test_query.py b/tests/integration/standard/test_query.py index 92ffbf68f3..2364a49f71 100644 --- a/tests/integration/standard/test_query.py +++ b/tests/integration/standard/test_query.py @@ -402,7 +402,7 @@ def test_conditional_update(self): future = self.session.execute_async(statement) result = future.result() self.assertEqual(future.message.serial_consistency_level, ConsistencyLevel.SERIAL) - self.assertEqual(1, len(result)) + self.assertTrue(result) self.assertFalse(result[0].applied) statement = SimpleStatement( @@ -412,7 +412,7 @@ def test_conditional_update(self): future = self.session.execute_async(statement) result = future.result() self.assertEqual(future.message.serial_consistency_level, ConsistencyLevel.LOCAL_SERIAL) - self.assertEqual(1, len(result)) + self.assertTrue(result) self.assertTrue(result[0].applied) def test_conditional_update_with_prepared_statements(self): @@ -424,7 +424,7 @@ def test_conditional_update_with_prepared_statements(self): future = self.session.execute_async(statement) result = future.result() self.assertEqual(future.message.serial_consistency_level, ConsistencyLevel.SERIAL) - self.assertEqual(1, len(result)) + self.assertTrue(result) self.assertFalse(result[0].applied) statement = self.session.prepare( @@ -434,7 +434,7 @@ def test_conditional_update_with_prepared_statements(self): future = self.session.execute_async(bound) result = future.result() self.assertEqual(future.message.serial_consistency_level, ConsistencyLevel.LOCAL_SERIAL) - self.assertEqual(1, len(result)) + self.assertTrue(result) self.assertTrue(result[0].applied) def test_conditional_update_with_batch_statements(self): @@ -445,7 +445,7 @@ def test_conditional_update_with_batch_statements(self): future = self.session.execute_async(statement) result = future.result() self.assertEqual(future.message.serial_consistency_level, ConsistencyLevel.SERIAL) - self.assertEqual(1, len(result)) + self.assertTrue(result) self.assertFalse(result[0].applied) statement = BatchStatement(serial_consistency_level=ConsistencyLevel.LOCAL_SERIAL) @@ -454,7 +454,7 @@ def test_conditional_update_with_batch_statements(self): future = self.session.execute_async(statement) result = future.result() self.assertEqual(future.message.serial_consistency_level, ConsistencyLevel.LOCAL_SERIAL) - self.assertEqual(1, len(result)) + self.assertTrue(result) self.assertTrue(result[0].applied) def test_bad_consistency_level(self): diff --git a/tests/integration/standard/test_query_paging.py b/tests/integration/standard/test_query_paging.py index eeb55e05be..f9e943f93d 100644 --- a/tests/integration/standard/test_query_paging.py +++ b/tests/integration/standard/test_query_paging.py @@ -26,7 +26,7 @@ from six.moves import range from threading import Event -from cassandra.cluster import Cluster, PagedResult +from cassandra.cluster import Cluster from cassandra.concurrent import execute_concurrent, execute_concurrent_with_args from cassandra.policies import HostDistance from cassandra.query import SimpleStatement @@ -301,66 +301,66 @@ def test_fetch_size(self): self.session.default_fetch_size = 10 result = self.session.execute(prepared, []) - self.assertIsInstance(result, PagedResult) + self.assertTrue(result.has_more_pages) self.session.default_fetch_size = 2000 result = self.session.execute(prepared, []) - self.assertIsInstance(result, list) + self.assertFalse(result.has_more_pages) self.session.default_fetch_size = None result = self.session.execute(prepared, []) - self.assertIsInstance(result, list) + self.assertFalse(result.has_more_pages) self.session.default_fetch_size = 10 prepared.fetch_size = 2000 result = self.session.execute(prepared, []) - self.assertIsInstance(result, list) + self.assertFalse(result.has_more_pages) prepared.fetch_size = None result = self.session.execute(prepared, []) - self.assertIsInstance(result, list) + self.assertFalse(result.has_more_pages) prepared.fetch_size = 10 result = self.session.execute(prepared, []) - self.assertIsInstance(result, PagedResult) + self.assertTrue(result.has_more_pages) prepared.fetch_size = 2000 bound = prepared.bind([]) result = self.session.execute(bound, []) - self.assertIsInstance(result, list) + self.assertFalse(result.has_more_pages) prepared.fetch_size = None bound = prepared.bind([]) result = self.session.execute(bound, []) - self.assertIsInstance(result, list) + self.assertFalse(result.has_more_pages) prepared.fetch_size = 10 bound = prepared.bind([]) result = self.session.execute(bound, []) - self.assertIsInstance(result, PagedResult) + self.assertTrue(result.has_more_pages) bound.fetch_size = 2000 result = self.session.execute(bound, []) - self.assertIsInstance(result, list) + self.assertFalse(result.has_more_pages) bound.fetch_size = None result = self.session.execute(bound, []) - self.assertIsInstance(result, list) + self.assertFalse(result.has_more_pages) bound.fetch_size = 10 result = self.session.execute(bound, []) - self.assertIsInstance(result, PagedResult) + self.assertTrue(result.has_more_pages) s = SimpleStatement("SELECT * FROM test3rf.test", fetch_size=None) result = self.session.execute(s, []) - self.assertIsInstance(result, list) + self.assertFalse(result.has_more_pages) s = SimpleStatement("SELECT * FROM test3rf.test") result = self.session.execute(s, []) - self.assertIsInstance(result, PagedResult) + self.assertTrue(result.has_more_pages) s = SimpleStatement("SELECT * FROM test3rf.test") s.fetch_size = None result = self.session.execute(s, []) - self.assertIsInstance(result, list) + self.assertFalse(result.has_more_pages) diff --git a/tests/integration/standard/test_row_factories.py b/tests/integration/standard/test_row_factories.py index 4fe5cf3916..d4391daffb 100644 --- a/tests/integration/standard/test_row_factories.py +++ b/tests/integration/standard/test_row_factories.py @@ -19,7 +19,7 @@ except ImportError: import unittest # noqa -from cassandra.cluster import Cluster +from cassandra.cluster import Cluster, ResultSet from cassandra.query import tuple_factory, named_tuple_factory, dict_factory, ordered_dict_factory from cassandra.util import OrderedDict @@ -72,7 +72,7 @@ def test_tuple_factory(self): result = session.execute(self.select) - self.assertIsInstance(result, list) + self.assertIsInstance(result, ResultSet) self.assertIsInstance(result[0], tuple) for row in result: @@ -93,7 +93,7 @@ def test_named_tuple_factory(self): result = session.execute(self.select) - self.assertIsInstance(result, list) + self.assertIsInstance(result, ResultSet) for row in result: self.assertEqual(row.k, row.v) @@ -113,7 +113,7 @@ def test_dict_factory(self): result = session.execute(self.select) - self.assertIsInstance(result, list) + self.assertIsInstance(result, ResultSet) self.assertIsInstance(result[0], dict) for row in result: @@ -134,7 +134,7 @@ def test_ordered_dict_factory(self): result = session.execute(self.select) - self.assertIsInstance(result, list) + self.assertIsInstance(result, ResultSet) self.assertIsInstance(result[0], OrderedDict) for row in result: diff --git a/tests/integration/standard/test_udts.py b/tests/integration/standard/test_udts.py index 72aa658165..15ed9dcebb 100644 --- a/tests/integration/standard/test_udts.py +++ b/tests/integration/standard/test_udts.py @@ -75,7 +75,6 @@ def test_can_insert_unprepared_registered_udts(self): s.execute("INSERT INTO mytable (a, b) VALUES (%s, %s)", (0, User(42, 'bob'))) result = s.execute("SELECT b FROM mytable WHERE a=0") - self.assertEqual(1, len(result)) row = result[0] self.assertEqual(42, row.b.age) self.assertEqual('bob', row.b.name) @@ -95,7 +94,6 @@ def test_can_insert_unprepared_registered_udts(self): s.execute("INSERT INTO mytable (a, b) VALUES (%s, %s)", (0, User('Texas', True))) result = s.execute("SELECT b FROM mytable WHERE a=0") - self.assertEqual(1, len(result)) row = result[0] self.assertEqual('Texas', row.b.state) self.assertEqual(True, row.b.is_cool) @@ -142,7 +140,6 @@ def test_can_register_udt_before_connecting(self): s.set_keyspace("udt_test_register_before_connecting") s.execute("INSERT INTO mytable (a, b) VALUES (%s, %s)", (0, User1(42, 'bob'))) result = s.execute("SELECT b FROM mytable WHERE a=0") - self.assertEqual(1, len(result)) row = result[0] self.assertEqual(42, row.b.age) self.assertEqual('bob', row.b.name) @@ -152,7 +149,6 @@ def test_can_register_udt_before_connecting(self): s.set_keyspace("udt_test_register_before_connecting2") s.execute("INSERT INTO mytable (a, b) VALUES (%s, %s)", (0, User2('Texas', True))) result = s.execute("SELECT b FROM mytable WHERE a=0") - self.assertEqual(1, len(result)) row = result[0] self.assertEqual('Texas', row.b.state) self.assertEqual(True, row.b.is_cool) @@ -177,7 +173,6 @@ def test_can_insert_prepared_unregistered_udts(self): select = s.prepare("SELECT b FROM mytable WHERE a=?") result = s.execute(select, (0,)) - self.assertEqual(1, len(result)) row = result[0] self.assertEqual(42, row.b.age) self.assertEqual('bob', row.b.name) @@ -197,7 +192,6 @@ def test_can_insert_prepared_unregistered_udts(self): select = s.prepare("SELECT b FROM mytable WHERE a=?") result = s.execute(select, (0,)) - self.assertEqual(1, len(result)) row = result[0] self.assertEqual('Texas', row.b.state) self.assertEqual(True, row.b.is_cool) @@ -223,7 +217,6 @@ def test_can_insert_prepared_registered_udts(self): select = s.prepare("SELECT b FROM mytable WHERE a=?") result = s.execute(select, (0,)) - self.assertEqual(1, len(result)) row = result[0] self.assertEqual(42, row.b.age) self.assertEqual('bob', row.b.name) @@ -246,7 +239,6 @@ def test_can_insert_prepared_registered_udts(self): select = s.prepare("SELECT b FROM mytable WHERE a=?") result = s.execute(select, (0,)) - self.assertEqual(1, len(result)) row = result[0] self.assertEqual('Texas', row.b.state) self.assertEqual(True, row.b.is_cool) @@ -528,7 +520,6 @@ def test_can_insert_udt_all_datatypes(self): # retrieve and verify data results = s.execute("SELECT * FROM mytable") - self.assertEqual(1, len(results)) row = results[0].b for expected, actual in zip(params, row): @@ -587,7 +578,6 @@ def test_can_insert_udt_all_collection_datatypes(self): # retrieve and verify data results = s.execute("SELECT * FROM mytable") - self.assertEqual(1, len(results)) row = results[0].b for expected, actual in zip(params, row): From 6402f71547cb2774263f37746b670b2c0c6337c2 Mon Sep 17 00:00:00 2001 From: Adam Holmberg Date: Thu, 8 Oct 2015 16:56:59 -0500 Subject: [PATCH 02/14] Make ResultSet list mode explicit PYTHON-368 --- cassandra/cluster.py | 32 +++++++++++++++++++++----------- tests/unit/test_concurrent.py | 2 ++ 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 42328f85f9..ef41b58235 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -3358,15 +3358,17 @@ class ResultSet(object): def __init__(self, response_future, initial_response): self.response_future = response_future - self._had_pages = response_future.has_more_pages self._current_rows = initial_response self._page_iter = None + self._list_mode = False @property def has_more_pages(self): return self.response_future.has_more_pages def __iter__(self): + if self._list_mode: + return iter(self._current_rows) self._page_iter = iter(self._current_rows) return self @@ -3375,6 +3377,8 @@ def next(self): return next(self._page_iter) except StopIteration: if not self.response_future.has_more_pages: + if not self._list_mode: + self._current_rows = [] raise self.response_future.start_fetching_next_page() @@ -3386,20 +3390,26 @@ def next(self): __next__ = next - def __eq__(self, other): - if self._page_iter and self._had_pages: - raise RuntimeError("Cannot test equality when paged results have been consumed.") + def _fetch_all(self): + self._current_rows = list(self) + self._page_iter = None + + def _enter_list_mode(self, operator): + if self._list_mode: + return + if self._page_iter: + raise RuntimeError("Cannot use %s when results have been iterated." % operator) if self.response_future.has_more_pages: - log.warning("Using equality operator on paged results causes entire result set to be materialized.") - self._current_rows = list(self) + log.warning("Using %s on paged results causes entire result set to be materialized.", operator) + self._fetch_all() + self._list_mode = True + + def __eq__(self, other): + self._enter_list_mode("equality operator") return self._current_rows == other def __getitem__(self, i): - if self._page_iter and self._had_pages: - raise RuntimeError("Cannot index when paged results have been consumed.") - if self.response_future.has_more_pages: - log.warning("Using indexing on paged results causes entire result set to be materialized.") - self._current_rows = list(self) + self._enter_list_mode("index operator") return self._current_rows[i] def __nonzero__(self): diff --git a/tests/unit/test_concurrent.py b/tests/unit/test_concurrent.py index 0bdb1f9e4d..7465f685c7 100644 --- a/tests/unit/test_concurrent.py +++ b/tests/unit/test_concurrent.py @@ -32,6 +32,8 @@ class MockResponseResponseFuture(): and invoke callback with various timing. """ + _is_result_kind_rows = False + # a list pending callbacks, these will be prioritized in reverse or normal orderd pending_callbacks = PriorityQueue() From 1862c7a0c4bf1447095977b08abfa738f045b8eb Mon Sep 17 00:00:00 2001 From: Adam Holmberg Date: Thu, 8 Oct 2015 16:59:51 -0500 Subject: [PATCH 03/14] Add unit test for cluster.ResultSet --- tests/unit/test_resultset.py | 160 +++++++++++++++++++++++++++++++++++ 1 file changed, 160 insertions(+) create mode 100644 tests/unit/test_resultset.py diff --git a/tests/unit/test_resultset.py b/tests/unit/test_resultset.py new file mode 100644 index 0000000000..bb85de47ca --- /dev/null +++ b/tests/unit/test_resultset.py @@ -0,0 +1,160 @@ +# Copyright 2013-2015 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from mock import Mock, PropertyMock +import warnings + +from cassandra.cluster import ResultSet + + +class ResultSetTests(unittest.TestCase): + + def test_iter_non_paged(self): + expected = list(range(10)) + rs = ResultSet(Mock(has_more_pages=False), expected) + itr = iter(rs) + self.assertListEqual(list(itr), expected) + + def test_iter_paged(self): + expected = list(range(10)) + response_future = Mock(has_more_pages=True) + response_future.result.side_effect = (ResultSet(Mock(), expected[-5:]), ) # ResultSet is iterable, so it must be protected in order to be returned whole by the Mock + rs = ResultSet(response_future, expected[:5]) + itr = iter(rs) + type(response_future).has_more_pages = PropertyMock(side_effect=(True, False)) # after init to avoid side effects being consumed by init + self.assertListEqual(list(itr), expected) + + def test_list_non_paged(self): + # list access on RS for backwards-compatibility + expected = list(range(10)) + rs = ResultSet(Mock(has_more_pages=False), expected) + for i in range(10): + self.assertEqual(rs[i], expected[i]) + self.assertEqual(list(rs), expected) + + def test_list_paged(self): + # list access on RS for backwards-compatibility + expected = list(range(10)) + response_future = Mock(has_more_pages=True) + response_future.result.side_effect = (ResultSet(Mock(), expected[-5:]), ) # ResultSet is iterable, so it must be protected in order to be returned whole by the Mock + rs = ResultSet(response_future, expected[:5]) + type(response_future).has_more_pages = PropertyMock(side_effect=(True, True, False)) # one True for getitem check/warn, then True, False for two pages + self.assertEqual(rs[9], expected[9]) + self.assertEqual(list(rs), expected) + + def test_has_more_pages(self): + response_future = Mock() + response_future.has_more_pages.side_effect = PropertyMock(side_effect=(True, False)) + rs = ResultSet(response_future, []) + type(response_future).has_more_pages = PropertyMock(side_effect=(True, False)) # after init to avoid side effects being consumed by init + self.assertTrue(rs.has_more_pages) + self.assertFalse(rs.has_more_pages) + + def test_iterate_then_index(self): + # RuntimeError if indexing with no pages + expected = list(range(10)) + rs = ResultSet(Mock(has_more_pages=False), expected) + itr = iter(rs) + # before consuming + with self.assertRaises(RuntimeError): + rs[0] + list(itr) + # after consuming + with self.assertRaises(RuntimeError): + rs[0] + + self.assertFalse(rs) + self.assertFalse(list(rs)) + + # RuntimeError if indexing during or after pages + response_future = Mock(has_more_pages=True) + response_future.result.side_effect = (ResultSet(Mock(), expected[-5:]), ) # ResultSet is iterable, so it must be protected in order to be returned whole by the Mock + rs = ResultSet(response_future, expected[:5]) + type(response_future).has_more_pages = PropertyMock(side_effect=(True, False)) + itr = iter(rs) + # before consuming + with self.assertRaises(RuntimeError): + rs[0] + for row in itr: + # while consuming + with self.assertRaises(RuntimeError): + rs[0] + # after consuming + with self.assertRaises(RuntimeError): + rs[0] + self.assertFalse(rs) + self.assertFalse(list(rs)) + + def test_index_list_mode(self): + # no pages + expected = list(range(10)) + rs = ResultSet(Mock(has_more_pages=False), expected) + + # index access before iteration causes list to be materialized + self.assertEqual(rs[0], expected[0]) + + # resusable iteration + self.assertListEqual(list(rs), expected) + self.assertListEqual(list(rs), expected) + + self.assertTrue(rs) + + # pages + response_future = Mock(has_more_pages=True) + response_future.result.side_effect = (ResultSet(Mock(), expected[-5:]), ) # ResultSet is iterable, so it must be protected in order to be returned whole by the Mock + rs = ResultSet(response_future, expected[:5]) + type(response_future).has_more_pages = PropertyMock(side_effect=(True, True, False)) # First True is consumed on check entering list mode + # index access before iteration causes list to be materialized + self.assertEqual(rs[0], expected[0]) + self.assertEqual(rs[9], expected[9]) + # resusable iteration + self.assertListEqual(list(rs), expected) + self.assertListEqual(list(rs), expected) + + self.assertTrue(rs) + + def test_eq(self): + # no pages + expected = list(range(10)) + rs = ResultSet(Mock(has_more_pages=False), expected) + + # eq before iteration causes list to be materialized + self.assertEqual(rs, expected) + + # results can be iterated or indexed once we're materialized + self.assertListEqual(list(rs), expected) + self.assertEqual(rs[9], expected[9]) + self.assertTrue(rs) + + # pages + response_future = Mock(has_more_pages=True) + response_future.result.side_effect = (ResultSet(Mock(), expected[-5:]), ) # ResultSet is iterable, so it must be protected in order to be returned whole by the Mock + rs = ResultSet(response_future, expected[:5]) + type(response_future).has_more_pages = PropertyMock(side_effect=(True, True, False)) + # eq before iteration causes list to be materialized + self.assertEqual(rs, expected) + + # results can be iterated or indexed once we're materialized + self.assertListEqual(list(rs), expected) + self.assertEqual(rs[9], expected[9]) + self.assertTrue(rs) + + def test_bool(self): + self.assertFalse(ResultSet(Mock(has_more_pages=False), [])) + self.assertTrue(ResultSet(Mock(has_more_pages=False), [1])) From 97adeca31f40ec73546bde5e41a20b739f34f5bc Mon Sep 17 00:00:00 2001 From: Adam Holmberg Date: Mon, 12 Oct 2015 15:02:30 -0500 Subject: [PATCH 04/14] surface query trace in ResultSet API PYTHON-318 --- cassandra/cluster.py | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index ef41b58235..9dae0370db 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -3218,7 +3218,8 @@ def get_query_trace(self, max_wait=None): if not self._query_trace: return None - self._query_trace.populate(max_wait) + if not self._query_trace.events: + self._query_trace.populate(max_wait) return self._query_trace def add_callback(self, fn, *args, **kwargs): @@ -3361,9 +3362,13 @@ def __init__(self, response_future, initial_response): self._current_rows = initial_response self._page_iter = None self._list_mode = False + self._traces = [response_future._query_trace] if response_future._query_trace else [] @property def has_more_pages(self): + """ + True if the last response indicated more pages; False otherwise + """ return self.response_future.has_more_pages def __iter__(self): @@ -3383,6 +3388,8 @@ def next(self): self.response_future.start_fetching_next_page() result = self.response_future.result() + if self.response_future._query_trace: + self._traces.append(self.response_future._query_trace) self._current_rows = result._current_rows self._page_iter = iter(self._current_rows) @@ -3416,3 +3423,29 @@ def __nonzero__(self): return bool(self._current_rows) __bool__ = __nonzero__ + + def get_query_trace(self, max_wait_sec=None): + """ + Fetches and returns the query trace of the last response, or `None` if tracing was + not enabled. + + Note that this may raise an exception if there are problems retrieving the trace + details from Cassandra. If the trace is not available after `max_wait_sec`, + :exc:`cassandra.query.TraceUnavailable` will be raised. + """ + if self._traces: + self._get_trace(0, max_wait_sec) + + def get_all_query_traces(self, max_wait_sec=None): + """ + Fetches and returns the query traces for all query pages, if tracing was enabled. + + See note in :meth:`~.get_current_query_trace` regarding possible exceptions. + """ + return [self._get_trace(i, max_wait_sec) for i in range(len(self._traces))] + + def _get_trace(self, i, max_wait): + trace = self._traces[i] + if not trace.events: + trace.populate(max_wait=max_wait) + return trace From 7f7b8cc280fed6332ad73a85a51d3594b8ce3433 Mon Sep 17 00:00:00 2001 From: Adam Holmberg Date: Mon, 12 Oct 2015 15:21:49 -0500 Subject: [PATCH 05/14] always return ResultSet, even for non-row responses --- cassandra/cluster.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 9dae0370db..6749085fb2 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -2720,7 +2720,6 @@ class ResponseFuture(object): _start_time = None _metrics = None _paging_state = None - _is_result_kind_rows = False _custom_payload = None _warnings = None _timer = None @@ -2931,8 +2930,7 @@ def _set_result(self, response): self, **response.results) else: results = getattr(response, 'results', None) - self._is_result_kind_rows = response.kind ==RESULT_KIND_ROWS - if results is not None and self._is_result_kind_rows: + if results is not None and response.kind ==RESULT_KIND_ROWS: self._paging_state = response.paging_state results = self.row_factory(*results) self._set_final_result(results) @@ -3199,10 +3197,7 @@ def result(self, timeout=_NOT_SET): if not self._event.is_set(): self._on_timeout() if self._final_result is not _NOT_SET: - if self._is_result_kind_rows: - return ResultSet(self, self._final_result) - else: - return self._final_result + return ResultSet(self, self._final_result) else: raise self._final_exception @@ -3359,7 +3354,7 @@ class ResultSet(object): def __init__(self, response_future, initial_response): self.response_future = response_future - self._current_rows = initial_response + self._current_rows = initial_response or [] self._page_iter = None self._list_mode = False self._traces = [response_future._query_trace] if response_future._query_trace else [] From 9ca9985f703a872cf6bada1ed4949bf70f04d9c0 Mon Sep 17 00:00:00 2001 From: Adam Holmberg Date: Tue, 13 Oct 2015 11:42:35 -0500 Subject: [PATCH 06/14] Trace info is now attached to ResultSet --- cassandra/cluster.py | 37 ++++--------------- cassandra/concurrent.py | 6 +-- tests/integration/standard/test_cluster.py | 23 ++++++------ tests/integration/standard/test_concurrent.py | 21 +++++++---- tests/integration/standard/test_query.py | 37 ++++++++++++------- .../standard/test_row_factories.py | 1 + tests/unit/test_concurrent.py | 10 ++--- tests/unit/test_response_future.py | 8 ++-- 8 files changed, 69 insertions(+), 74 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 6749085fb2..8ef7b6610f 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -1601,34 +1601,14 @@ def execute(self, query, parameters=None, timeout=_NOT_SET, trace=False, custom_ no timeout. Please see :meth:`.ResponseFuture.result` for details on the scope and effect of this timeout. - If `trace` is set to :const:`True`, an attempt will be made to - fetch the trace details and attach them to the `query`'s - :attr:`~.Statement.trace` attribute in the form of a :class:`.QueryTrace` - instance. This requires that `query` be a :class:`.Statement` subclass - instance and not just a string. If there is an error fetching the - trace details, the :attr:`~.Statement.trace` attribute will be left as - :const:`None`. + If `trace` is set to :const:`True`, the query will be sent with tracing enabled. + The trace details can be obtained using the returned :class:`.ResultSet` object. `custom_payload` is a :ref:`custom_payload` dict to be passed to the server. If `query` is a Statement with its own custom_payload. The message payload will be a union of the two, with the values specified here taking precedence. """ - if trace and not isinstance(query, Statement): - raise TypeError( - "The query argument must be an instance of a subclass of " - "cassandra.query.Statement when trace=True") - - future = self.execute_async(query, parameters, trace, custom_payload, timeout) - try: - result = future.result() - finally: - if trace: - try: - query.trace = future.get_query_trace(self.max_trace_wait) - except Exception: - log.exception("Unable to fetch query trace:") - - return result + return self.execute_async(query, parameters, trace, custom_payload, timeout).result() def execute_async(self, query, parameters=None, trace=False, custom_payload=None, timeout=_NOT_SET): """ @@ -1638,9 +1618,9 @@ def execute_async(self, query, parameters=None, trace=False, custom_payload=None on the :class:`.ResponseFuture` to syncronously block for results at any time. - If `trace` is set to :const:`True`, you may call - :meth:`.ResponseFuture.get_query_trace()` after the request - completes to retrieve a :class:`.QueryTrace` instance. + If `trace` is set to :const:`True`, you may get the query trace descriptors using + :meth:`.ResultSet.get_query_trace()` or :meth:`.ResultSet.get_all_query_traces()` + on the future result. `custom_payload` is a :ref:`custom_payload` dict to be passed to the server. If `query` is a Statement with its own custom_payload. The message payload @@ -1730,8 +1710,7 @@ def _create_response_future(self, query, parameters, trace, custom_payload, time query.batch_type, query._statements_and_parameters, cl, query.serial_consistency_level, timestamp) - if trace: - message.tracing = True + message.tracing = trace message.update_custom_payload(query.custom_payload) message.update_custom_payload(custom_payload) @@ -3429,7 +3408,7 @@ def get_query_trace(self, max_wait_sec=None): :exc:`cassandra.query.TraceUnavailable` will be raised. """ if self._traces: - self._get_trace(0, max_wait_sec) + return self._get_trace(0, max_wait_sec) def get_all_query_traces(self, max_wait_sec=None): """ diff --git a/cassandra/concurrent.py b/cassandra/concurrent.py index 3391fdeaa2..a360ffaaca 100644 --- a/cassandra/concurrent.py +++ b/cassandra/concurrent.py @@ -134,10 +134,8 @@ def _execute(self, idx, statement, params): self._put_result(e, idx, False) def _on_success(self, result, future, idx): - if future._is_result_kind_rows: - result = ResultSet(future, result) - future.clear_callbacks() - self._put_result(result, idx, True) + future.clear_callbacks() + self._put_result(ResultSet(future, result), idx, True) def _on_error(self, result, future, idx): self._put_result(result, idx, False) diff --git a/tests/integration/standard/test_cluster.py b/tests/integration/standard/test_cluster.py index 5fe978868c..97adb45f52 100644 --- a/tests/integration/standard/test_cluster.py +++ b/tests/integration/standard/test_cluster.py @@ -78,7 +78,7 @@ def test_basic(self): CREATE KEYSPACE clustertests WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'} """) - self.assertEqual(None, result) + self.assertFalse(result) result = execute_until_pass(session, """ @@ -89,13 +89,13 @@ def test_basic(self): PRIMARY KEY (a, b) ) """) - self.assertEqual(None, result) + self.assertFalse(result) result = session.execute( """ INSERT INTO clustertests.cf0 (a, b, c) VALUES ('a', 'b', 'c') """) - self.assertEqual(None, result) + self.assertFalse(result) result = session.execute("SELECT * FROM clustertests.cf0") self.assertEqual([('a', 'b', 'c')], result) @@ -152,7 +152,7 @@ def test_connect_on_keyspace(self): """ INSERT INTO test3rf.test (k, v) VALUES (8889, 8889) """) - self.assertEqual(None, result) + self.assertFalse(result) result = session.execute("SELECT * FROM test3rf.test") self.assertEqual([(8889, 8889)], result) @@ -437,8 +437,6 @@ def test_trace(self): cluster = Cluster(protocol_version=PROTOCOL_VERSION) session = cluster.connect() - self.assertRaises(TypeError, session.execute, "SELECT * FROM system.local", trace=True) - def check_trace(trace): self.assertIsNot(None, trace.request_type) self.assertIsNot(None, trace.duration) @@ -446,15 +444,18 @@ def check_trace(trace): self.assertIsNot(None, trace.coordinator) self.assertIsNot(None, trace.events) + result = session.execute( "SELECT * FROM system.local", trace=True) + check_trace(result.get_query_trace()) + query = "SELECT * FROM system.local" statement = SimpleStatement(query) - session.execute(statement, trace=True) - check_trace(statement.trace) + result = session.execute(statement, trace=True) + check_trace(result.get_query_trace()) query = "SELECT * FROM system.local" statement = SimpleStatement(query) - session.execute(statement) - self.assertEqual(None, statement.trace) + result = session.execute(statement) + self.assertIsNone(result.get_query_trace()) statement2 = SimpleStatement(query) future = session.execute_async(statement2, trace=True) @@ -464,7 +465,7 @@ def check_trace(trace): statement2 = SimpleStatement(query) future = session.execute_async(statement2) future.result() - self.assertEqual(None, future.get_query_trace()) + self.assertIsNone(future.get_query_trace()) prepared = session.prepare("SELECT * FROM system.local") future = session.execute_async(prepared, parameters=(), trace=True) diff --git a/tests/integration/standard/test_concurrent.py b/tests/integration/standard/test_concurrent.py index 23b810c535..a0cf79fab8 100644 --- a/tests/integration/standard/test_concurrent.py +++ b/tests/integration/standard/test_concurrent.py @@ -89,7 +89,9 @@ def test_execute_concurrent(self): results = self.execute_concurrent_helper(self.session, list(zip(statements, parameters))) self.assertEqual(num_statements, len(results)) - self.assertEqual([(True, None)] * num_statements, results) + for success, result in results: + self.assertTrue(success) + self.assertFalse(result) # read statement = SimpleStatement( @@ -111,7 +113,9 @@ def test_execute_concurrent_with_args(self): results = self.execute_concurrent_args_helper(self.session, statement, parameters) self.assertEqual(num_statements, len(results)) - self.assertEqual([(True, None)] * num_statements, results) + for success, result in results: + self.assertTrue(success) + self.assertFalse(result) # read statement = SimpleStatement( @@ -143,8 +147,9 @@ def test_execute_concurrent_with_args_generator(self): parameters = [(i, i) for i in range(num_statements)] results = self.execute_concurrent_args_helper(self.session, statement, parameters, results_generator=True) - for result in results: - self.assertEqual((True, None), result) + for success, result in results: + self.assertTrue(success) + self.assertFalse(result) # read statement = SimpleStatement( @@ -172,7 +177,9 @@ def test_execute_concurrent_paged_result(self): results = self.execute_concurrent_args_helper(self.session, statement, parameters) self.assertEqual(num_statements, len(results)) - self.assertEqual([(True, None)] * num_statements, results) + for success, result in results: + self.assertTrue(success) + self.assertFalse(result) # read statement = SimpleStatement( @@ -273,7 +280,7 @@ def test_no_raise_on_first_failure(self): self.assertIsInstance(result, InvalidRequest) else: self.assertTrue(success) - self.assertEqual(None, result) + self.assertFalse(result) def test_no_raise_on_first_failure_client_side(self): statement = SimpleStatement( @@ -292,4 +299,4 @@ def test_no_raise_on_first_failure_client_side(self): self.assertIsInstance(result, TypeError) else: self.assertTrue(success) - self.assertEqual(None, result) + self.assertFalse(result) diff --git a/tests/integration/standard/test_query.py b/tests/integration/standard/test_query.py index 2364a49f71..fbb206b14a 100644 --- a/tests/integration/standard/test_query.py +++ b/tests/integration/standard/test_query.py @@ -65,27 +65,34 @@ def test_trace_prints_okay(self): query = "SELECT * FROM system.local" statement = SimpleStatement(query) - session.execute(statement, trace=True) + rs = session.execute(statement, trace=True) # Ensure this does not throw an exception - str(statement.trace) - for event in statement.trace.events: + trace = rs.get_query_trace() + self.assertTrue(trace.events) + str(trace) + for event in trace.events: str(event) cluster.shutdown() - def test_trace_id_to_query(self): + def test_trace_id_to_resultset(self): cluster = Cluster(protocol_version=PROTOCOL_VERSION) session = cluster.connect() - query = "SELECT * FROM system.local" - statement = SimpleStatement(query) - self.assertIsNone(statement.trace_id) - future = session.execute_async(statement, trace=True) + future = session.execute_async("SELECT * FROM system.local", trace=True) + + # future should have the current trace + rs = future.result() + future_trace = future.get_query_trace() + self.assertIsNotNone(future_trace) + + rs_trace = rs.get_query_trace() + self.assertEqual(rs_trace, future_trace) + self.assertTrue(rs_trace.events) + self.assertEqual(len(rs_trace.events), len(future_trace.events)) - # query should have trace_id, even before trace is obtained - future.result() - self.assertIsNotNone(statement.trace_id) + self.assertListEqual([rs_trace], rs.get_all_query_traces()) cluster.shutdown() @@ -96,11 +103,13 @@ def test_trace_ignores_row_factory(self): query = "SELECT * FROM system.local" statement = SimpleStatement(query) - session.execute(statement, trace=True) + rs = session.execute(statement, trace=True) # Ensure this does not throw an exception - str(statement.trace) - for event in statement.trace.events: + trace = rs.get_query_trace() + self.assertTrue(trace.events) + str(trace) + for event in trace.events: str(event) cluster.shutdown() diff --git a/tests/integration/standard/test_row_factories.py b/tests/integration/standard/test_row_factories.py index d4391daffb..c43fe57e02 100644 --- a/tests/integration/standard/test_row_factories.py +++ b/tests/integration/standard/test_row_factories.py @@ -94,6 +94,7 @@ def test_named_tuple_factory(self): result = session.execute(self.select) self.assertIsInstance(result, ResultSet) + result = list(result) for row in result: self.assertEqual(row.k, row.v) diff --git a/tests/unit/test_concurrent.py b/tests/unit/test_concurrent.py index 7465f685c7..3c2734e415 100644 --- a/tests/unit/test_concurrent.py +++ b/tests/unit/test_concurrent.py @@ -32,7 +32,7 @@ class MockResponseResponseFuture(): and invoke callback with various timing. """ - _is_result_kind_rows = False + _query_trace = None # a list pending callbacks, these will be prioritized in reverse or normal orderd pending_callbacks = PriorityQueue() @@ -106,7 +106,7 @@ def run(self): self._stopper.wait(.1) callback_args = pending_callback[1] fn, args, kwargs, time_added = callback_args - fn(time_added, *args, **kwargs) + fn([time_added], *args, **kwargs) self._stopper.wait(.001) return @@ -222,8 +222,8 @@ def validate_result_ordering(self, results): :param results: """ last_time_added = 0 - for result in results: - current_time_added = result[1] + for success, result in results: + self.assertTrue(success) + current_time_added = list(result)[0] self.assertLess(last_time_added, current_time_added) last_time_added = current_time_added - diff --git a/tests/unit/test_response_future.py b/tests/unit/test_response_future.py index 3c41c53c4e..5c59259be3 100644 --- a/tests/unit/test_response_future.py +++ b/tests/unit/test_response_future.py @@ -94,7 +94,7 @@ def test_set_keyspace_result(self): results="keyspace1") rf._set_result(result) rf._set_keyspace_completed({}) - self.assertEqual(None, rf.result()) + self.assertFalse(rf.result()) def test_schema_change_result(self): session = self.make_session() @@ -113,9 +113,9 @@ def test_other_result_message_kind(self): session = self.make_session() rf = self.make_response_future(session) rf.send_request() - result = object() + result = [1, 2, 3] rf._set_result(Mock(spec=ResultMessage, kind=999, results=result)) - self.assertIs(result, rf.result()) + self.assertListEqual(list(rf.result()), result) def test_read_timeout_error_message(self): session = self.make_session() @@ -172,7 +172,7 @@ def test_retry_policy_says_ignore(self): result = Mock(spec=UnavailableErrorMessage, info={}) rf._set_result(result) - self.assertEqual(None, rf.result()) + self.assertFalse(rf.result()) def test_retry_policy_says_retry(self): session = self.make_session() From ead4f2fc8a478bc7bd85f2fcb82699fb7aaff593 Mon Sep 17 00:00:00 2001 From: Adam Holmberg Date: Tue, 13 Oct 2015 11:43:05 -0500 Subject: [PATCH 07/14] ignore generated test artifact --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index a874ebcec1..5c9cbec957 100644 --- a/.gitignore +++ b/.gitignore @@ -21,6 +21,7 @@ setuptools*.egg cassandra/*.c !cassandra/cmurmur3.c cassandra/*.html +tests/unit/cython/bytesio_testhelper.c # OSX .DS_Store From 932a1d5272d075f945ea5659a2a642e3e3cdcb58 Mon Sep 17 00:00:00 2001 From: Adam Holmberg Date: Tue, 13 Oct 2015 11:43:23 -0500 Subject: [PATCH 08/14] Remove trace, trace_id from Statement --- cassandra/cluster.py | 2 -- cassandra/query.py | 12 ------------ 2 files changed, 14 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 8ef7b6610f..208d80de0e 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -2880,8 +2880,6 @@ def _set_result(self, response): trace_id = getattr(response, 'trace_id', None) if trace_id: - if self.query: - self.query.trace_id = trace_id self._query_trace = QueryTrace(trace_id, self.session) self._warnings = getattr(response, 'warnings', None) diff --git a/cassandra/query.py b/cassandra/query.py index 0ba6c5947d..f1e0fd8592 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -175,18 +175,6 @@ class Statement(object): will be retried. """ - trace = None - """ - If :meth:`.Session.execute()` is run with `trace` set to :const:`True`, - this will be set to a :class:`.QueryTrace` instance. - """ - - trace_id = None - """ - If :meth:`.Session.execute()` is run with `trace` set to :const:`True`, - this will be set to the tracing ID from the server. - """ - consistency_level = None """ The :class:`.ConsistencyLevel` to be used for this operation. Defaults From 183c8edb16f7509d956e63b13d4c321e104f0438 Mon Sep 17 00:00:00 2001 From: Adam Holmberg Date: Tue, 13 Oct 2015 17:03:23 -0500 Subject: [PATCH 09/14] query trace tracing all in ResponseFuture also add manual page fetching for ResultSet --- cassandra/cluster.py | 83 ++++++++++++++++++++------------------------ 1 file changed, 38 insertions(+), 45 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 208d80de0e..c236353331 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -1619,7 +1619,7 @@ def execute_async(self, query, parameters=None, trace=False, custom_payload=None any time. If `trace` is set to :const:`True`, you may get the query trace descriptors using - :meth:`.ResultSet.get_query_trace()` or :meth:`.ResultSet.get_all_query_traces()` + :meth:`.ResponseFuture.get_query_trace()` or :meth:`.ResponseFuture.get_all_query_traces()` on the future result. `custom_payload` is a :ref:`custom_payload` dict to be passed to the server. @@ -2689,7 +2689,7 @@ class ResponseFuture(object): _req_id = None _final_result = _NOT_SET _final_exception = None - _query_trace = None + _query_traces = None _callbacks = None _errbacks = None _current_host = None @@ -2880,7 +2880,9 @@ def _set_result(self, response): trace_id = getattr(response, 'trace_id', None) if trace_id: - self._query_trace = QueryTrace(trace_id, self.session) + if not self._query_traces: + self._query_traces = [] + self._query_traces.append(QueryTrace(trace_id, self.session)) self._warnings = getattr(response, 'warnings', None) self._custom_payload = getattr(response, 'custom_payload', None) @@ -3180,19 +3182,31 @@ def result(self, timeout=_NOT_SET): def get_query_trace(self, max_wait=None): """ - Returns the :class:`~.query.QueryTrace` instance representing a trace - of the last attempt for this operation, or :const:`None` if tracing was - not enabled for this query. Note that this may raise an exception if - there are problems retrieving the trace details from Cassandra. If the - trace is not available after `max_wait` seconds, + Fetches and returns the query trace of the last response, or `None` if tracing was + not enabled. + + Note that this may raise an exception if there are problems retrieving the trace + details from Cassandra. If the trace is not available after `max_wait_sec`, :exc:`cassandra.query.TraceUnavailable` will be raised. """ - if not self._query_trace: - return None + if self._query_traces: + return self._get_query_trace(len(self._query_traces) - 1, max_wait) - if not self._query_trace.events: - self._query_trace.populate(max_wait) - return self._query_trace + def get_all_query_traces(self, max_wait_per=None): + """ + Fetches and returns the query traces for all query pages, if tracing was enabled. + + See note in :meth:`~.get_current_query_trace` regarding possible exceptions. + """ + if self._query_traces: + return [self._get_query_trace(i, max_wait_per) for i in range(len(self._query_traces))] + return [] + + def _get_query_trace(self, i, max_wait): + trace = self._query_traces[i] + if not trace.events: + trace.populate(max_wait=max_wait) + return trace def add_callback(self, fn, *args, **kwargs): """ @@ -3334,7 +3348,6 @@ def __init__(self, response_future, initial_response): self._current_rows = initial_response or [] self._page_iter = None self._list_mode = False - self._traces = [response_future._query_trace] if response_future._query_trace else [] @property def has_more_pages(self): @@ -3343,6 +3356,10 @@ def has_more_pages(self): """ return self.response_future.has_more_pages + @property + def current_rows(self): + return self._current_rows or [] + def __iter__(self): if self._list_mode: return iter(self._current_rows) @@ -3358,17 +3375,19 @@ def next(self): self._current_rows = [] raise - self.response_future.start_fetching_next_page() - result = self.response_future.result() - if self.response_future._query_trace: - self._traces.append(self.response_future._query_trace) - self._current_rows = result._current_rows + self.fetch_next_page() self._page_iter = iter(self._current_rows) return next(self._page_iter) __next__ = next + def fetch_next_page(self): + if self.response_future.has_more_pages: + self.response_future.start_fetching_next_page() + result = self.response_future.result() + self._current_rows = result._current_rows + def _fetch_all(self): self._current_rows = list(self) self._page_iter = None @@ -3395,29 +3414,3 @@ def __nonzero__(self): return bool(self._current_rows) __bool__ = __nonzero__ - - def get_query_trace(self, max_wait_sec=None): - """ - Fetches and returns the query trace of the last response, or `None` if tracing was - not enabled. - - Note that this may raise an exception if there are problems retrieving the trace - details from Cassandra. If the trace is not available after `max_wait_sec`, - :exc:`cassandra.query.TraceUnavailable` will be raised. - """ - if self._traces: - return self._get_trace(0, max_wait_sec) - - def get_all_query_traces(self, max_wait_sec=None): - """ - Fetches and returns the query traces for all query pages, if tracing was enabled. - - See note in :meth:`~.get_current_query_trace` regarding possible exceptions. - """ - return [self._get_trace(i, max_wait_sec) for i in range(len(self._traces))] - - def _get_trace(self, i, max_wait): - trace = self._traces[i] - if not trace.events: - trace.populate(max_wait=max_wait) - return trace From 88f0afe9b4c1774f3caf9173ef40412da10fcd10 Mon Sep 17 00:00:00 2001 From: Adam Holmberg Date: Wed, 14 Oct 2015 10:15:33 -0500 Subject: [PATCH 10/14] Add ResultSet get_*query*trace* pass-through to future --- cassandra/cluster.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index c236353331..83307208a0 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -3196,7 +3196,7 @@ def get_all_query_traces(self, max_wait_per=None): """ Fetches and returns the query traces for all query pages, if tracing was enabled. - See note in :meth:`~.get_current_query_trace` regarding possible exceptions. + See note in :meth:`~.get_query_trace` regarding possible exceptions. """ if self._query_traces: return [self._get_query_trace(i, max_wait_per) for i in range(len(self._query_traces))] @@ -3387,6 +3387,8 @@ def fetch_next_page(self): self.response_future.start_fetching_next_page() result = self.response_future.result() self._current_rows = result._current_rows + else: + self._current_rows = [] def _fetch_all(self): self._current_rows = list(self) @@ -3414,3 +3416,17 @@ def __nonzero__(self): return bool(self._current_rows) __bool__ = __nonzero__ + + def get_query_trace(self, max_wait_sec=None): + """ + Gets the last query trace from the associated future. + See :meth:`.ResponseFuture.get_query_trace` for details. + """ + return self.response_future.get_query_trace(max_wait_sec) + + def get_all_query_traces(self, max_wait_sec_per=None): + """ + Gets all query traces from the associated future. + See :meth:`.ResponseFuture.get_all_query_traces` for details. + """ + return self.response_future.get_all_query_traces(max_wait_sec_per) From 5957a65ec12167cd83eda10a06ff87c24f3db7ae Mon Sep 17 00:00:00 2001 From: Adam Holmberg Date: Wed, 14 Oct 2015 10:18:05 -0500 Subject: [PATCH 11/14] ResponseFuture doc updates --- docs/api/cassandra/cluster.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/api/cassandra/cluster.rst b/docs/api/cassandra/cluster.rst index 3ea0d23091..8075476b8a 100644 --- a/docs/api/cassandra/cluster.rst +++ b/docs/api/cassandra/cluster.rst @@ -126,6 +126,8 @@ .. automethod:: get_query_trace() + .. automethod:: get_all_query_traces() + .. autoattribute:: custom_payload() .. autoattribute:: has_more_pages From 385baa1659df81d95b81a6582a10c455f25a181e Mon Sep 17 00:00:00 2001 From: Adam Holmberg Date: Wed, 14 Oct 2015 10:18:23 -0500 Subject: [PATCH 12/14] Trace test updates for PYTHON-318 --- tests/integration/standard/test_client_warnings.py | 4 ++-- tests/integration/standard/test_cluster.py | 10 +++++----- tests/unit/test_resultset.py | 11 +++++++---- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/tests/integration/standard/test_client_warnings.py b/tests/integration/standard/test_client_warnings.py index 4316c9398a..90f224825a 100644 --- a/tests/integration/standard/test_client_warnings.py +++ b/tests/integration/standard/test_client_warnings.py @@ -91,7 +91,7 @@ def test_warning_with_trace(self): future.result() self.assertEqual(len(future.warnings), 1) self.assertRegexpMatches(future.warnings[0], 'Batch.*exceeding.*') - self.assertIsNotNone(future._query_trace) + self.assertIsNotNone(future.get_query_trace()) def test_warning_with_custom_payload(self): """ @@ -127,5 +127,5 @@ def test_warning_with_trace_and_custom_payload(self): future.result() self.assertEqual(len(future.warnings), 1) self.assertRegexpMatches(future.warnings[0], 'Batch.*exceeding.*') - self.assertIsNotNone(future._query_trace) + self.assertIsNotNone(future.get_query_trace()) self.assertDictEqual(future.custom_payload, payload) diff --git a/tests/integration/standard/test_cluster.py b/tests/integration/standard/test_cluster.py index 97adb45f52..d6d1394f7d 100644 --- a/tests/integration/standard/test_cluster.py +++ b/tests/integration/standard/test_cluster.py @@ -438,11 +438,11 @@ def test_trace(self): session = cluster.connect() def check_trace(trace): - self.assertIsNot(None, trace.request_type) - self.assertIsNot(None, trace.duration) - self.assertIsNot(None, trace.started_at) - self.assertIsNot(None, trace.coordinator) - self.assertIsNot(None, trace.events) + self.assertIsNotNone(trace.request_type) + self.assertIsNotNone(trace.duration) + self.assertIsNotNone(trace.started_at) + self.assertIsNotNone(trace.coordinator) + self.assertIsNotNone(trace.events) result = session.execute( "SELECT * FROM system.local", trace=True) check_trace(result.get_query_trace()) diff --git a/tests/unit/test_resultset.py b/tests/unit/test_resultset.py index bb85de47ca..2a68376767 100644 --- a/tests/unit/test_resultset.py +++ b/tests/unit/test_resultset.py @@ -37,7 +37,8 @@ def test_iter_paged(self): response_future.result.side_effect = (ResultSet(Mock(), expected[-5:]), ) # ResultSet is iterable, so it must be protected in order to be returned whole by the Mock rs = ResultSet(response_future, expected[:5]) itr = iter(rs) - type(response_future).has_more_pages = PropertyMock(side_effect=(True, False)) # after init to avoid side effects being consumed by init + # this is brittle, depends on internal impl details. Would like to find a better way + type(response_future).has_more_pages = PropertyMock(side_effect=(True, True, False)) # after init to avoid side effects being consumed by init self.assertListEqual(list(itr), expected) def test_list_non_paged(self): @@ -54,7 +55,8 @@ def test_list_paged(self): response_future = Mock(has_more_pages=True) response_future.result.side_effect = (ResultSet(Mock(), expected[-5:]), ) # ResultSet is iterable, so it must be protected in order to be returned whole by the Mock rs = ResultSet(response_future, expected[:5]) - type(response_future).has_more_pages = PropertyMock(side_effect=(True, True, False)) # one True for getitem check/warn, then True, False for two pages + # this is brittle, depends on internal impl details. Would like to find a better way + type(response_future).has_more_pages = PropertyMock(side_effect=(True, True, True, False)) # First two True are consumed on check entering list mode self.assertEqual(rs[9], expected[9]) self.assertEqual(list(rs), expected) @@ -119,7 +121,8 @@ def test_index_list_mode(self): response_future = Mock(has_more_pages=True) response_future.result.side_effect = (ResultSet(Mock(), expected[-5:]), ) # ResultSet is iterable, so it must be protected in order to be returned whole by the Mock rs = ResultSet(response_future, expected[:5]) - type(response_future).has_more_pages = PropertyMock(side_effect=(True, True, False)) # First True is consumed on check entering list mode + # this is brittle, depends on internal impl details. Would like to find a better way + type(response_future).has_more_pages = PropertyMock(side_effect=(True, True, True, False)) # First two True are consumed on check entering list mode # index access before iteration causes list to be materialized self.assertEqual(rs[0], expected[0]) self.assertEqual(rs[9], expected[9]) @@ -146,7 +149,7 @@ def test_eq(self): response_future = Mock(has_more_pages=True) response_future.result.side_effect = (ResultSet(Mock(), expected[-5:]), ) # ResultSet is iterable, so it must be protected in order to be returned whole by the Mock rs = ResultSet(response_future, expected[:5]) - type(response_future).has_more_pages = PropertyMock(side_effect=(True, True, False)) + type(response_future).has_more_pages = PropertyMock(side_effect=(True, True, True, False)) # eq before iteration causes list to be materialized self.assertEqual(rs, expected) From d4c380cb010fc3396e6217597ead1ef6024be723 Mon Sep 17 00:00:00 2001 From: Adam Holmberg Date: Fri, 16 Oct 2015 13:56:36 -0500 Subject: [PATCH 13/14] Make ResultSet handle more than just list results for use with custom ProtocolHandlers providing different ResultMessage deserialization PYTHON-430 --- cassandra/cluster.py | 12 +++++++++--- cassandra/obj_parser.pyx | 4 ++++ cassandra/protocol.py | 4 ++++ 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 83307208a0..9d3427b869 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -61,7 +61,7 @@ BatchMessage, RESULT_KIND_PREPARED, RESULT_KIND_SET_KEYSPACE, RESULT_KIND_ROWS, RESULT_KIND_SCHEMA_CHANGE, MIN_SUPPORTED_VERSION, - ProtocolHandler) + ProtocolHandler, _RESULT_SEQUENCE_TYPES) from cassandra.metadata import Metadata, protect_name, murmur3 from cassandra.policies import (TokenAwarePolicy, DCAwareRoundRobinPolicy, SimpleConvictionPolicy, ExponentialReconnectionPolicy, HostDistance, @@ -3345,7 +3345,7 @@ class ResultSet(object): def __init__(self, response_future, initial_response): self.response_future = response_future - self._current_rows = initial_response or [] + self._set_current_rows(initial_response) self._page_iter = None self._list_mode = False @@ -3386,10 +3386,16 @@ def fetch_next_page(self): if self.response_future.has_more_pages: self.response_future.start_fetching_next_page() result = self.response_future.result() - self._current_rows = result._current_rows + self._current_rows = result._current_rows # ResultSet has already _set_current_rows to the appropriate form else: self._current_rows = [] + def _set_current_rows(self, result): + if isinstance(result, _RESULT_SEQUENCE_TYPES): + self._current_rows = result + else: + self._current_rows = [result] if result else [] + def _fetch_all(self): self._current_rows = list(self) self._page_iter = None diff --git a/cassandra/obj_parser.pyx b/cassandra/obj_parser.pyx index 8aa5b3940f..21ce95e0bd 100644 --- a/cassandra/obj_parser.pyx +++ b/cassandra/obj_parser.pyx @@ -38,6 +38,8 @@ cdef class LazyParser(ColumnParser): # supported in cpdef methods return parse_rows_lazy(reader, desc) + cpdef get_cython_generator_type(self): + return get_cython_generator_type() def parse_rows_lazy(BytesIOReader reader, ParseDesc desc): cdef Py_ssize_t i, rowcount @@ -45,6 +47,8 @@ def parse_rows_lazy(BytesIOReader reader, ParseDesc desc): cdef RowParser rowparser = TupleRowParser() return (rowparser.unpack_row(reader, desc) for i in range(rowcount)) +def get_cython_generator_type(): + return type((i for i in range(0))) cdef class TupleRowParser(RowParser): """ diff --git a/cassandra/protocol.py b/cassandra/protocol.py index 90fdcd091c..1c6ea85cf1 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -1000,6 +1000,7 @@ def decode_message(cls, protocol_version, user_type_map, stream_id, flags, opcod return msg +_RESULT_SEQUENCE_TYPES = (list, tuple) # types retuned by ResultMessages def cython_protocol_handler(colparser): """ @@ -1045,6 +1046,9 @@ class CythonProtocolHandler(ProtocolHandler): if HAVE_CYTHON: from cassandra.obj_parser import ListParser, LazyParser ProtocolHandler = cython_protocol_handler(ListParser()) + + lazy_parser = LazyParser() + _RESULT_SEQUENCE_TYPES += (lazy_parser.get_cython_generator_type(),) LazyProtocolHandler = cython_protocol_handler(LazyParser()) else: # Use Python-based ProtocolHandler From 795acbd254987c1f11ff41a78f42ecb26a3fb0f4 Mon Sep 17 00:00:00 2001 From: Adam Holmberg Date: Fri, 16 Oct 2015 13:57:36 -0500 Subject: [PATCH 14/14] test: new tests covering paging for Lazy and NumPy parsers PYTHON-430 --- .../standard/test_cython_protocol_handlers.py | 72 +++++++++++++++++-- 1 file changed, 65 insertions(+), 7 deletions(-) diff --git a/tests/integration/standard/test_cython_protocol_handlers.py b/tests/integration/standard/test_cython_protocol_handlers.py index 985b795302..888f73fed6 100644 --- a/tests/integration/standard/test_cython_protocol_handlers.py +++ b/tests/integration/standard/test_cython_protocol_handlers.py @@ -55,21 +55,79 @@ def test_cython_lazy_parser(self): """ verify_iterator_data(self.assertEqual, get_data(LazyProtocolHandler)) + @numpytest + def test_cython_lazy_results_paged(self): + """ + Test Cython-based parser that returns an iterator, over multiple pages + """ + # arrays = { 'a': arr1, 'b': arr2, ... } + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + session = cluster.connect(keyspace="testspace") + session.row_factory = tuple_factory + session.client_protocol_handler = LazyProtocolHandler + session.default_fetch_size = 2 + + self.assertLess(session.default_fetch_size, self.N_ITEMS) + + results = session.execute("SELECT * FROM test_table") + + self.assertTrue(results.has_more_pages) + self.assertEqual(verify_iterator_data(self.assertEqual, results), self.N_ITEMS) # make sure we see all rows + + cluster.shutdown() + @numpytest def test_numpy_parser(self): """ Test Numpy-based parser that returns a NumPy array """ # arrays = { 'a': arr1, 'b': arr2, ... } - arrays = get_data(NumpyProtocolHandler) + result = get_data(NumpyProtocolHandler) + self.assertFalse(result.has_more_pages) + self._verify_numpy_page(result[0]) + @numpytest + def test_numpy_results_paged(self): + """ + Test Numpy-based parser that returns a NumPy array + """ + # arrays = { 'a': arr1, 'b': arr2, ... } + cluster = Cluster(protocol_version=PROTOCOL_VERSION) + session = cluster.connect(keyspace="testspace") + session.row_factory = tuple_factory + session.client_protocol_handler = NumpyProtocolHandler + session.default_fetch_size = 2 + + expected_pages = (self.N_ITEMS + session.default_fetch_size - 1) // session.default_fetch_size + + self.assertLess(session.default_fetch_size, self.N_ITEMS) + + results = session.execute("SELECT * FROM test_table") + + self.assertTrue(results.has_more_pages) + for count, page in enumerate(results, 1): + self.assertIsInstance(page, dict) + for colname, arr in page.items(): + if count <= expected_pages: + self.assertGreater(len(arr), 0, "page count: %d" % (count,)) + self.assertLessEqual(len(arr), session.default_fetch_size) + else: + # we get one extra item out of this iteration because of the way NumpyParser returns results + # The last page is returned as a dict with zero-length arrays + self.assertEqual(len(arr), 0) + self.assertEqual(self._verify_numpy_page(page), len(arr)) + self.assertEqual(count, expected_pages + 1) # see note about extra 'page' above + + cluster.shutdown() + + def _verify_numpy_page(self, page): colnames = self.colnames datatypes = get_primitive_datatypes() for colname, datatype in zip(colnames, datatypes): - arr = arrays[colname] + arr = page[colname] self.match_dtype(datatype, arr.dtype) - verify_iterator_data(self.assertEqual, arrays_to_list_of_tuples(arrays, colnames)) + return verify_iterator_data(self.assertEqual, arrays_to_list_of_tuples(page, colnames)) def match_dtype(self, datatype, dtype): """Match a string cqltype (e.g. 'int' or 'blob') with a numpy dtype""" @@ -100,9 +158,7 @@ def arrays_to_list_of_tuples(arrays, colnames): def get_data(protocol_handler): """ - Get some data from the test table. - - :param key: if None, get all results (100.000 results), otherwise get only one result + Get data from the test table. """ cluster = Cluster(protocol_version=PROTOCOL_VERSION) session = cluster.connect(keyspace="testspace") @@ -121,9 +177,11 @@ def verify_iterator_data(assertEqual, results): Check the result of get_data() when this is a list or iterator of tuples """ - for result in results: + count = 0 + for count, result in enumerate(results, 1): params = get_all_primitive_params(result[0]) assertEqual(len(params), len(result), msg="Not the right number of columns?") for expected, actual in zip(params, result): assertEqual(actual, expected) + return count