@@ -204,11 +204,14 @@ class DCAwareRoundRobinPolicy(LoadBalancingPolicy):
204204 local_dc = None
205205 used_hosts_per_remote_dc = 0
206206
207- def __init__ (self , local_dc , used_hosts_per_remote_dc = 0 ):
207+ def __init__ (self , local_dc = '' , used_hosts_per_remote_dc = 0 ):
208208 """
209209 The `local_dc` parameter should be the name of the datacenter
210210 (such as is reported by ``nodetool ring``) that should
211- be considered local.
211+ be considered local. If not specified, the driver will choose
212+ a local_dc based on the first host among :attr:`.Cluster.contact_points`
213+ having a valid DC. If relying on this mechanism, all specified
214+ contact points should be nodes in a single, local DC.
212215
213216 `used_hosts_per_remote_dc` controls how many nodes in
214217 each remote datacenter will have connections opened
@@ -220,6 +223,8 @@ def __init__(self, local_dc, used_hosts_per_remote_dc=0):
220223 self .local_dc = local_dc
221224 self .used_hosts_per_remote_dc = used_hosts_per_remote_dc
222225 self ._dc_live_hosts = {}
226+ self ._position = 0
227+ self ._contact_points = []
223228 LoadBalancingPolicy .__init__ (self )
224229
225230 def _dc (self , host ):
@@ -229,14 +234,10 @@ def populate(self, cluster, hosts):
229234 for dc , dc_hosts in groupby (hosts , lambda h : self ._dc (h )):
230235 self ._dc_live_hosts [dc ] = tuple (set (dc_hosts ))
231236
232- # position is currently only used for local hosts
233- local_live = self ._dc_live_hosts .get (self .local_dc )
234- if not local_live :
235- self ._position = 0
236- elif len (local_live ) == 1 :
237- self ._position = 0
238- else :
239- self ._position = randint (0 , len (local_live ) - 1 )
237+ if not self .local_dc :
238+ self ._contact_points = cluster .contact_points
239+
240+ self ._position = randint (0 , len (hosts ) - 1 ) if hosts else 0
240241
241242 def distance (self , host ):
242243 dc = self ._dc (host )
@@ -274,32 +275,40 @@ def make_query_plan(self, working_keyspace=None, query=None):
274275 yield host
275276
276277 def on_up (self , host ):
278+
279+ # not worrying about threads because this will happen during
280+ # control connection startup/refresh
281+ if not self .local_dc and host .datacenter :
282+ if host .address in self ._contact_points :
283+ self .local_dc = host .datacenter
284+ log .info ("Using datacenter '%s' for DCAwareRoundRobinPolicy (via host '%s'); "
285+ "if incorrect, please specify a local_dc to the constructor, "
286+ "or limit contact points to local cluster nodes" %
287+ (self .local_dc , host .address ))
288+ del self ._contact_points
289+
277290 dc = self ._dc (host )
278291 with self ._hosts_lock :
279- current_hosts = self ._dc_live_hosts .setdefault (dc , ())
292+ current_hosts = self ._dc_live_hosts .get (dc , ())
280293 if host not in current_hosts :
281294 self ._dc_live_hosts [dc ] = current_hosts + (host , )
282295
283296 def on_down (self , host ):
284297 dc = self ._dc (host )
285298 with self ._hosts_lock :
286- current_hosts = self ._dc_live_hosts .setdefault (dc , ())
299+ current_hosts = self ._dc_live_hosts .get (dc , ())
287300 if host in current_hosts :
288- self ._dc_live_hosts [dc ] = tuple (h for h in current_hosts if h != host )
301+ hosts = tuple (h for h in current_hosts if h != host )
302+ if hosts :
303+ self ._dc_live_hosts [dc ] = tuple (h for h in current_hosts if h != host )
304+ else :
305+ self ._dc_live_hosts .pop (dc , None )
289306
290307 def on_add (self , host ):
291- dc = self ._dc (host )
292- with self ._hosts_lock :
293- current_hosts = self ._dc_live_hosts .setdefault (dc , ())
294- if host not in current_hosts :
295- self ._dc_live_hosts [dc ] = current_hosts + (host , )
308+ self .on_up (host )
296309
297310 def on_remove (self , host ):
298- dc = self ._dc (host )
299- with self ._hosts_lock :
300- current_hosts = self ._dc_live_hosts .setdefault (dc , ())
301- if host in current_hosts :
302- self ._dc_live_hosts [dc ] = tuple (h for h in current_hosts if h != host )
311+ self .on_down (host )
303312
304313
305314class TokenAwarePolicy (LoadBalancingPolicy ):
@@ -333,8 +342,8 @@ def check_supported(self):
333342 '%s cannot be used with the cluster partitioner (%s) because '
334343 'the relevant C extension for this driver was not compiled. '
335344 'See the installation instructions for details on building '
336- 'and installing the C extensions.' % ( self . __class__ . __name__ ,
337- self ._cluster_metadata .partitioner ))
345+ 'and installing the C extensions.' %
346+ ( self . __class__ . __name__ , self ._cluster_metadata .partitioner ))
338347
339348 def distance (self , * args , ** kwargs ):
340349 return self ._child_policy .distance (* args , ** kwargs )
0 commit comments