Skip to content

Commit 263481f

Browse files
committed
Allow DCAware LB to default local_dc based on contact_points
PYTHON-148
1 parent ad2f61f commit 263481f

2 files changed

Lines changed: 49 additions & 28 deletions

File tree

cassandra/cluster.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,18 @@ class Cluster(object):
166166
167167
"""
168168

169+
contact_points = ['127.0.0.1']
170+
"""
171+
The list of contact points to try connecting for cluster discovery.
172+
173+
Defaults to loopback interface.
174+
175+
Note: When using :class:`.DCAwareLoadBalancingPolicy` with no explicit
176+
local_dc set, the DC is chosen from an arbitrary host in contact_points.
177+
In this case, contact_points should contain only nodes from a single,
178+
local DC.
179+
"""
180+
169181
port = 9042
170182
"""
171183
The server-side port to open connections to. Defaults to 9042.
@@ -377,7 +389,7 @@ def auth_provider(self, value):
377389
_listener_lock = None
378390

379391
def __init__(self,
380-
contact_points=("127.0.0.1",),
392+
contact_points=["127.0.0.1"],
381393
port=9042,
382394
compression=True,
383395
auth_provider=None,
@@ -1946,8 +1958,8 @@ def _handle_results(success, result):
19461958

19471959
responses = connection.wait_for_responses(*queries, fail_on_error=False)
19481960
(ks_success, ks_result), (cf_success, cf_result), \
1949-
(col_success, col_result), (types_success, types_result), \
1950-
(trigger_success, triggers_result) = responses
1961+
(col_success, col_result), (types_success, types_result), \
1962+
(trigger_success, triggers_result) = responses
19511963

19521964
if ks_success:
19531965
ks_result = dict_factory(*ks_result.results)

cassandra/policies.py

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -204,11 +204,14 @@ class DCAwareRoundRobinPolicy(LoadBalancingPolicy):
204204
local_dc = None
205205
used_hosts_per_remote_dc = 0
206206

207-
def __init__(self, local_dc, used_hosts_per_remote_dc=0):
207+
def __init__(self, local_dc='', used_hosts_per_remote_dc=0):
208208
"""
209209
The `local_dc` parameter should be the name of the datacenter
210210
(such as is reported by ``nodetool ring``) that should
211-
be considered local.
211+
be considered local. If not specified, the driver will choose
212+
a local_dc based on the first host among :attr:`.Cluster.contact_points`
213+
having a valid DC. If relying on this mechanism, all specified
214+
contact points should be nodes in a single, local DC.
212215
213216
`used_hosts_per_remote_dc` controls how many nodes in
214217
each remote datacenter will have connections opened
@@ -220,6 +223,8 @@ def __init__(self, local_dc, used_hosts_per_remote_dc=0):
220223
self.local_dc = local_dc
221224
self.used_hosts_per_remote_dc = used_hosts_per_remote_dc
222225
self._dc_live_hosts = {}
226+
self._position = 0
227+
self._contact_points = []
223228
LoadBalancingPolicy.__init__(self)
224229

225230
def _dc(self, host):
@@ -229,14 +234,10 @@ def populate(self, cluster, hosts):
229234
for dc, dc_hosts in groupby(hosts, lambda h: self._dc(h)):
230235
self._dc_live_hosts[dc] = tuple(set(dc_hosts))
231236

232-
# position is currently only used for local hosts
233-
local_live = self._dc_live_hosts.get(self.local_dc)
234-
if not local_live:
235-
self._position = 0
236-
elif len(local_live) == 1:
237-
self._position = 0
238-
else:
239-
self._position = randint(0, len(local_live) - 1)
237+
if not self.local_dc:
238+
self._contact_points = cluster.contact_points
239+
240+
self._position = randint(0, len(hosts) - 1) if hosts else 0
240241

241242
def distance(self, host):
242243
dc = self._dc(host)
@@ -274,32 +275,40 @@ def make_query_plan(self, working_keyspace=None, query=None):
274275
yield host
275276

276277
def on_up(self, host):
278+
279+
# not worrying about threads because this will happen during
280+
# control connection startup/refresh
281+
if not self.local_dc and host.datacenter:
282+
if host.address in self._contact_points:
283+
self.local_dc = host.datacenter
284+
log.info("Using datacenter '%s' for DCAwareRoundRobinPolicy (via host '%s'); "
285+
"if incorrect, please specify a local_dc to the constructor, "
286+
"or limit contact points to local cluster nodes" %
287+
(self.local_dc, host.address))
288+
del self._contact_points
289+
277290
dc = self._dc(host)
278291
with self._hosts_lock:
279-
current_hosts = self._dc_live_hosts.setdefault(dc, ())
292+
current_hosts = self._dc_live_hosts.get(dc, ())
280293
if host not in current_hosts:
281294
self._dc_live_hosts[dc] = current_hosts + (host, )
282295

283296
def on_down(self, host):
284297
dc = self._dc(host)
285298
with self._hosts_lock:
286-
current_hosts = self._dc_live_hosts.setdefault(dc, ())
299+
current_hosts = self._dc_live_hosts.get(dc, ())
287300
if host in current_hosts:
288-
self._dc_live_hosts[dc] = tuple(h for h in current_hosts if h != host)
301+
hosts = tuple(h for h in current_hosts if h != host)
302+
if hosts:
303+
self._dc_live_hosts[dc] = tuple(h for h in current_hosts if h != host)
304+
else:
305+
self._dc_live_hosts.pop(dc, None)
289306

290307
def on_add(self, host):
291-
dc = self._dc(host)
292-
with self._hosts_lock:
293-
current_hosts = self._dc_live_hosts.setdefault(dc, ())
294-
if host not in current_hosts:
295-
self._dc_live_hosts[dc] = current_hosts + (host, )
308+
self.on_up(host)
296309

297310
def on_remove(self, host):
298-
dc = self._dc(host)
299-
with self._hosts_lock:
300-
current_hosts = self._dc_live_hosts.setdefault(dc, ())
301-
if host in current_hosts:
302-
self._dc_live_hosts[dc] = tuple(h for h in current_hosts if h != host)
311+
self.on_down(host)
303312

304313

305314
class TokenAwarePolicy(LoadBalancingPolicy):
@@ -333,8 +342,8 @@ def check_supported(self):
333342
'%s cannot be used with the cluster partitioner (%s) because '
334343
'the relevant C extension for this driver was not compiled. '
335344
'See the installation instructions for details on building '
336-
'and installing the C extensions.' % (self.__class__.__name__,
337-
self._cluster_metadata.partitioner))
345+
'and installing the C extensions.' %
346+
(self.__class__.__name__, self._cluster_metadata.partitioner))
338347

339348
def distance(self, *args, **kwargs):
340349
return self._child_policy.distance(*args, **kwargs)

0 commit comments

Comments
 (0)