Skip to content

Commit bcfa812

Browse files
rootvipjml
authored andcommitted
improve ExponentialReconnectionPolicy,now can custom max attempts time
1 parent 201d2c6 commit bcfa812

2 files changed

Lines changed: 23 additions & 3 deletions

File tree

cassandra/policies.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -528,22 +528,33 @@ class ExponentialReconnectionPolicy(ReconnectionPolicy):
528528
a set maximum delay.
529529
"""
530530

531-
def __init__(self, base_delay, max_delay):
531+
def __init__(self, base_delay, max_delay, max_attempts=64):
532532
"""
533533
`base_delay` and `max_delay` should be in floating point units of
534534
seconds.
535+
536+
`max_attempts` should be a total number of attempts to be made before
537+
giving up, or :const:`None` to continue reconnection attempts forever.
538+
The default is 64.
535539
"""
536540
if base_delay < 0 or max_delay < 0:
537541
raise ValueError("Delays may not be negative")
538542

539543
if max_delay < base_delay:
540544
raise ValueError("Max delay must be greater than base delay")
541545

546+
if max_attempts is not None and max_attempts < 0:
547+
raise ValueError("max_attempts must not be negative")
548+
542549
self.base_delay = base_delay
543550
self.max_delay = max_delay
551+
self.max_attempts = max_attempts
544552

545553
def new_schedule(self):
546-
return (min(self.base_delay * (2 ** i), self.max_delay) for i in range(64))
554+
i=0
555+
while self.max_attempts == None or i < self.max_attempts:
556+
yield min(self.base_delay * (2 ** i), self.max_delay)
557+
i += 1
547558

548559

549560
class WriteType(object):

tests/unit/test_policies.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -814,9 +814,18 @@ def test_bad_vals(self):
814814
self.assertRaises(ValueError, ExponentialReconnectionPolicy, -1, 0)
815815
self.assertRaises(ValueError, ExponentialReconnectionPolicy, 0, -1)
816816
self.assertRaises(ValueError, ExponentialReconnectionPolicy, 9000, 1)
817+
self.assertRaises(ValueError, ExponentialReconnectionPolicy, 1, 2,-1)
817818

818819
def test_schedule(self):
819-
policy = ExponentialReconnectionPolicy(base_delay=2, max_delay=100)
820+
policy = ExponentialReconnectionPolicy(base_delay=2, max_delay=100, max_attempts=None)
821+
i=0;
822+
for delay in policy.new_schedule():
823+
i += 1
824+
if i > 10000:
825+
break;
826+
self.assertEqual(i, 10001)
827+
828+
policy = ExponentialReconnectionPolicy(base_delay=2, max_delay=100, max_attempts=64)
820829
schedule = list(policy.new_schedule())
821830
self.assertEqual(len(schedule), 64)
822831
for i, delay in enumerate(schedule):

0 commit comments

Comments
 (0)