Skip to content

Commit 0085a74

Browse files
author
bjmb
committed
Merge branch 'PYTHON-643'
2 parents 2a23582 + 1e3c90a commit 0085a74

3 files changed

Lines changed: 178 additions & 4 deletions

File tree

cassandra/policies.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from itertools import islice, cycle, groupby, repeat
1616
import logging
17-
from random import randint
17+
from random import randint, shuffle
1818
from threading import Lock
1919
import socket
2020

@@ -315,18 +315,25 @@ class TokenAwarePolicy(LoadBalancingPolicy):
315315
This alters the child policy's behavior so that it first attempts to
316316
send queries to :attr:`~.HostDistance.LOCAL` replicas (as determined
317317
by the child policy) based on the :class:`.Statement`'s
318-
:attr:`~.Statement.routing_key`. Once those hosts are exhausted, the
319-
remaining hosts in the child policy's query plan will be used.
318+
:attr:`~.Statement.routing_key`. If :attr:`.shuffle_replicas` is
319+
truthy, these replicas will be yielded in a random order. Once those
320+
hosts are exhausted, the remaining hosts in the child policy's query
321+
plan will be used in the order provided by the child policy.
320322
321323
If no :attr:`~.Statement.routing_key` is set on the query, the child
322324
policy's query plan will be used as is.
323325
"""
324326

325327
_child_policy = None
326328
_cluster_metadata = None
329+
shuffle_replicas = False
330+
"""
331+
Yield local replicas in a random order.
332+
"""
327333

328-
def __init__(self, child_policy):
334+
def __init__(self, child_policy, shuffle_replicas=False):
329335
self._child_policy = child_policy
336+
self.shuffle_replicas = shuffle_replicas
330337

331338
def populate(self, cluster, hosts):
332339
self._cluster_metadata = cluster.metadata
@@ -361,6 +368,8 @@ def make_query_plan(self, working_keyspace=None, query=None):
361368
yield host
362369
else:
363370
replicas = self._cluster_metadata.get_replicas(keyspace, routing_key)
371+
if self.shuffle_replicas:
372+
shuffle(replicas)
364373
for replica in replicas:
365374
if replica.is_up and \
366375
child.distance(replica) == HostDistance.LOCAL:

tests/integration/long/test_loadbalancingpolicies.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,104 @@ def test_token_aware_with_local_table(self):
520520

521521
cluster.shutdown()
522522

523+
def test_token_aware_with_shuffle_rf2(self):
524+
"""
525+
Test to validate the hosts are shuffled when the `shuffle_replicas` is truthy
526+
@since 3.8
527+
@jira_ticket PYTHON-676
528+
@expected_result the request are spread across the replicas,
529+
when one of them is down, the requests target the available one
530+
531+
@test_category policy
532+
"""
533+
keyspace = 'test_token_aware_with_rf_2'
534+
cluster, session = self._set_up_shuffle_test(keyspace, replication_factor=2)
535+
536+
self._check_query_order_changes(session=session, keyspace=keyspace)
537+
538+
#check TokenAwarePolicy still return the remaining replicas when one goes down
539+
self.coordinator_stats.reset_counts()
540+
stop(2)
541+
self._wait_for_nodes_down([2], cluster)
542+
543+
self._query(session, keyspace)
544+
545+
self.coordinator_stats.assert_query_count_equals(self, 1, 0)
546+
self.coordinator_stats.assert_query_count_equals(self, 2, 0)
547+
self.coordinator_stats.assert_query_count_equals(self, 3, 12)
548+
549+
cluster.shutdown()
550+
551+
def test_token_aware_with_shuffle_rf3(self):
552+
"""
553+
Test to validate the hosts are shuffled when the `shuffle_replicas` is truthy
554+
@since 3.8
555+
@jira_ticket PYTHON-676
556+
@expected_result the request are spread across the replicas,
557+
when one of them is down, the requests target the other available ones
558+
559+
@test_category policy
560+
"""
561+
keyspace = 'test_token_aware_with_rf_3'
562+
cluster, session = self._set_up_shuffle_test(keyspace, replication_factor=3)
563+
564+
self._check_query_order_changes(session=session, keyspace=keyspace)
565+
566+
# check TokenAwarePolicy still return the remaining replicas when one goes down
567+
self.coordinator_stats.reset_counts()
568+
stop(1)
569+
self._wait_for_nodes_down([1], cluster)
570+
571+
self._query(session, keyspace)
572+
573+
self.coordinator_stats.assert_query_count_equals(self, 1, 0)
574+
query_count_two = self.coordinator_stats.get_query_count(2)
575+
query_count_three = self.coordinator_stats.get_query_count(3)
576+
self.assertEqual(query_count_two + query_count_three, 12)
577+
578+
self.coordinator_stats.reset_counts()
579+
stop(2)
580+
self._wait_for_nodes_down([2], cluster)
581+
582+
self._query(session, keyspace)
583+
584+
self.coordinator_stats.assert_query_count_equals(self, 1, 0)
585+
self.coordinator_stats.assert_query_count_equals(self, 2, 0)
586+
self.coordinator_stats.assert_query_count_equals(self, 3, 12)
587+
588+
cluster.shutdown()
589+
590+
def _set_up_shuffle_test(self, keyspace, replication_factor):
591+
use_singledc()
592+
cluster, session = self._cluster_session_with_lbp(
593+
TokenAwarePolicy(RoundRobinPolicy(), shuffle_replicas=True)
594+
)
595+
self._wait_for_nodes_up(range(1, 4), cluster)
596+
597+
create_schema(cluster, session, keyspace, replication_factor=replication_factor)
598+
return cluster, session
599+
600+
def _check_query_order_changes(self, session, keyspace):
601+
LIMIT_TRIES, tried, query_counts = 20, 0, set()
602+
603+
while len(query_counts) <= 1:
604+
tried += 1
605+
if tried >= LIMIT_TRIES:
606+
raise Exception("After {0} tries shuffle returned the same output".format(LIMIT_TRIES))
607+
608+
self._insert(session, keyspace)
609+
self._query(session, keyspace)
610+
611+
loop_qcs = (self.coordinator_stats.get_query_count(1),
612+
self.coordinator_stats.get_query_count(2),
613+
self.coordinator_stats.get_query_count(3))
614+
615+
query_counts.add(loop_qcs)
616+
self.assertEqual(sum(loop_qcs), 12)
617+
618+
# end the loop if we get more than one query ordering
619+
self.coordinator_stats.reset_counts()
620+
523621
def test_white_list(self):
524622
use_singledc()
525623
keyspace = 'test_white_list'

tests/unit/test_policies.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,73 @@ def test_statement_keyspace(self):
731731
self.assertEqual(replicas + hosts[:2], qplan)
732732
cluster.metadata.get_replicas.assert_called_with(statement_keyspace, routing_key)
733733

734+
def test_shuffles_if_given_keyspace_and_routing_key(self):
735+
"""
736+
Test to validate the hosts are shuffled when `shuffle_replicas` is truthy
737+
@since 3.8
738+
@jira_ticket PYTHON-676
739+
@expected_result shuffle should be called, because the keyspace and the
740+
routing key are set
741+
742+
@test_category policy
743+
"""
744+
self._assert_shuffle(keyspace='keyspace', routing_key='routing_key')
745+
746+
def test_no_shuffle_if_given_no_keyspace(self):
747+
"""
748+
Test to validate the hosts are not shuffled when no keyspace is provided
749+
@since 3.8
750+
@jira_ticket PYTHON-676
751+
@expected_result shuffle should be called, because keyspace is None
752+
753+
@test_category policy
754+
"""
755+
self._assert_shuffle(keyspace=None, routing_key='routing_key')
756+
757+
def test_no_shuffle_if_given_no_routing_key(self):
758+
"""
759+
Test to validate the hosts are not shuffled when no routing_key is provided
760+
@since 3.8
761+
@jira_ticket PYTHON-676
762+
@expected_result shuffle should be called, because routing_key is None
763+
764+
@test_category policy
765+
"""
766+
self._assert_shuffle(keyspace='keyspace', routing_key=None)
767+
768+
@patch('cassandra.policies.shuffle')
769+
def _assert_shuffle(self, patched_shuffle, keyspace, routing_key):
770+
hosts = [Host(str(i), SimpleConvictionPolicy) for i in range(4)]
771+
for host in hosts:
772+
host.set_up()
773+
774+
cluster = Mock(spec=Cluster)
775+
cluster.metadata = Mock(spec=Metadata)
776+
replicas = hosts[2:]
777+
cluster.metadata.get_replicas.return_value = replicas
778+
779+
child_policy = Mock()
780+
child_policy.make_query_plan.return_value = hosts
781+
child_policy.distance.return_value = HostDistance.LOCAL
782+
783+
policy = TokenAwarePolicy(child_policy, shuffle_replicas=True)
784+
policy.populate(cluster, hosts)
785+
786+
cluster.metadata.get_replicas.reset_mock()
787+
child_policy.make_query_plan.reset_mock()
788+
query = Statement(routing_key=routing_key)
789+
qplan = list(policy.make_query_plan(keyspace, query))
790+
if keyspace is None or routing_key is None:
791+
self.assertEqual(hosts, qplan)
792+
self.assertEqual(cluster.metadata.get_replicas.call_count, 0)
793+
child_policy.make_query_plan.assert_called_once_with(keyspace, query)
794+
self.assertEqual(patched_shuffle.call_count, 0)
795+
else:
796+
self.assertEqual(set(replicas), set(qplan[:2]))
797+
self.assertEqual(hosts[:2], qplan[2:])
798+
child_policy.make_query_plan.assert_called_once_with(keyspace, query)
799+
self.assertEqual(patched_shuffle.call_count, 1)
800+
734801

735802
class ConvictionPolicyTest(unittest.TestCase):
736803
def test_not_implemented(self):

0 commit comments

Comments
 (0)