|
29 | 29 | from cassandra.concurrent import execute_concurrent |
30 | 30 | from cassandra.policies import (RoundRobinPolicy, ExponentialReconnectionPolicy, |
31 | 31 | RetryPolicy, SimpleConvictionPolicy, HostDistance, |
32 | | - WhiteListRoundRobinPolicy, AddressTranslator) |
| 32 | + WhiteListRoundRobinPolicy, AddressTranslator, TokenAwarePolicy, HostFilterPolicy) |
33 | 33 | from cassandra.pool import Host |
34 | 34 | from cassandra.query import SimpleStatement, TraceUnavailable, tuple_factory |
35 | 35 |
|
36 | 36 |
|
37 | | -from tests.integration import use_singledc, PROTOCOL_VERSION, get_server_versions, CASSANDRA_VERSION, DSE_VERSION, execute_until_pass, execute_with_long_wait_retry, get_node,\ |
38 | | - MockLoggingHandler, get_unsupported_lower_protocol, get_unsupported_upper_protocol, protocolv5, local, CASSANDRA_IP |
| 37 | +from tests.integration import use_singledc, PROTOCOL_VERSION, get_server_versions, CASSANDRA_VERSION, \ |
| 38 | + execute_until_pass, execute_with_long_wait_retry, get_node, MockLoggingHandler, get_unsupported_lower_protocol, \ |
| 39 | + get_unsupported_upper_protocol, protocolv5, local, CASSANDRA_IP |
39 | 40 | from tests.integration.util import assert_quiescent_pool_state |
40 | 41 | import sys |
41 | 42 |
|
@@ -974,6 +975,64 @@ def test_add_profile_timeout(self): |
974 | 975 | else: |
975 | 976 | raise Exception("add_execution_profile didn't timeout after {0} retries".format(max_retry_count)) |
976 | 977 |
|
| 978 | + def test_replicas_are_queried(self): |
| 979 | + """ |
| 980 | + Test that replicas are queried first for TokenAwarePolicy. A table with RF 1 |
| 981 | + is created. All the queries should go to that replica when TokenAwarePolicy |
| 982 | + is used. |
| 983 | + Then using HostFilterPolicy the replica is excluded from the considered hosts. |
| 984 | + By checking the trace we verify that there are no more replicas. |
| 985 | +
|
| 986 | + @since 3.5 |
| 987 | + @jira_ticket PYTHON-653 |
| 988 | + @expected_result the replicas are queried for HostFilterPolicy |
| 989 | +
|
| 990 | + @test_category metadata |
| 991 | + """ |
| 992 | + queried_hosts = set() |
| 993 | + with Cluster(protocol_version=PROTOCOL_VERSION, |
| 994 | + load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy())) as cluster: |
| 995 | + session = cluster.connect() |
| 996 | + session.execute(''' |
| 997 | + CREATE TABLE test1rf.table_with_big_key ( |
| 998 | + k1 int, |
| 999 | + k2 int, |
| 1000 | + k3 int, |
| 1001 | + k4 int, |
| 1002 | + PRIMARY KEY((k1, k2, k3), k4))''') |
| 1003 | + prepared = session.prepare("""SELECT * from test1rf.table_with_big_key |
| 1004 | + WHERE k1 = ? AND k2 = ? AND k3 = ? AND k4 = ?""") |
| 1005 | + for i in range(10): |
| 1006 | + result = session.execute(prepared, (i, i, i, i), trace=True) |
| 1007 | + queried_hosts = self._assert_replica_queried(result.get_query_trace(), only_replicas=True) |
| 1008 | + last_i = i |
| 1009 | + |
| 1010 | + only_replica = queried_hosts.pop() |
| 1011 | + available_hosts = [host for host in ["127.0.0.1", "127.0.0.2", "127.0.0.3"] if host != only_replica] |
| 1012 | + with Cluster(contact_points=available_hosts, |
| 1013 | + protocol_version=PROTOCOL_VERSION, |
| 1014 | + load_balancing_policy=HostFilterPolicy(RoundRobinPolicy(), |
| 1015 | + predicate=lambda host: host.address != only_replica)) as cluster: |
| 1016 | + |
| 1017 | + session = cluster.connect() |
| 1018 | + prepared = session.prepare("""SELECT * from test1rf.table_with_big_key |
| 1019 | + WHERE k1 = ? AND k2 = ? AND k3 = ? AND k4 = ?""") |
| 1020 | + for _ in range(10): |
| 1021 | + result = session.execute(prepared, (last_i, last_i, last_i, last_i), trace=True) |
| 1022 | + self._assert_replica_queried(result.get_query_trace(), only_replicas=False) |
| 1023 | + |
| 1024 | + session.execute('''DROP TABLE test1rf.table_with_big_key''') |
| 1025 | + |
| 1026 | + def _assert_replica_queried(self, trace, only_replicas=True): |
| 1027 | + queried_hosts = set() |
| 1028 | + for row in trace.events: |
| 1029 | + queried_hosts.add(row.source) |
| 1030 | + if only_replicas: |
| 1031 | + self.assertEqual(len(queried_hosts), 1, "The hosts queried where {}".format(queried_hosts)) |
| 1032 | + else: |
| 1033 | + self.assertGreater(len(queried_hosts), 1, "The host queried was {}".format(queried_hosts)) |
| 1034 | + return queried_hosts |
| 1035 | + |
977 | 1036 |
|
978 | 1037 | class LocalHostAdressTranslator(AddressTranslator): |
979 | 1038 |
|
|
0 commit comments