1515"""
1616Connection pooling and host management.
1717"""
18-
18+ from concurrent . futures import Future
1919from functools import total_ordering
2020import logging
2121import socket
@@ -411,6 +411,7 @@ def __init__(self, host, host_distance, session):
411411 # and are waiting until all requests time out or complete
412412 # so that we can dispose of them.
413413 self ._trash = set ()
414+ self ._shard_connections_futures = []
414415
415416 if host_distance == HostDistance .IGNORED :
416417 log .debug ("Not opening connection to ignored host %s" , self .host )
@@ -537,9 +538,9 @@ def return_connection(self, connection, stream_was_orphaned=False):
537538 if is_down :
538539 self .shutdown ()
539540 else :
540- connection .close ()
541- del self ._connections [connection .shard_id ]
542541 with self ._lock :
542+ connection .close ()
543+ self ._connections .pop (connection .shard_id , None )
543544 if self ._is_replacing :
544545 return
545546 self ._is_replacing = True
@@ -568,23 +569,22 @@ def _replace(self, connection):
568569 if self .is_shutdown :
569570 return
570571
571- log .debug ("Replacing connection (%s) to %s" , id (connection ), self .host )
572- try :
573- if connection .shard_id in self ._connections .keys ():
574- del self ._connections [connection .shard_id ]
575- if self .host .sharding_info :
576- self ._connecting .add (connection .shard_id )
577- self ._open_connection_to_missing_shard (connection .shard_id )
572+ log .debug ("Replacing connection (%s) to %s" , id (connection ), self .host )
573+ try :
574+ if connection .shard_id in self ._connections .keys ():
575+ del self ._connections [connection .shard_id ]
576+ if self .host .sharding_info :
577+ self ._connecting .add (connection .shard_id )
578+ self ._session .submit (self ._open_connection_to_missing_shard , connection .shard_id )
579+ else :
580+ connection = self ._session .cluster .connection_factory (self .host .endpoint , owning_pool = self )
581+ if self ._keyspace :
582+ connection .set_keyspace_blocking (self ._keyspace )
583+ self ._connections [connection .shard_id ] = connection
584+ except Exception :
585+ log .warning ("Failed reconnecting %s. Retrying." % (self .host .endpoint ,))
586+ self ._session .submit (self ._replace , connection )
578587 else :
579- connection = self ._session .cluster .connection_factory (self .host .endpoint , owning_pool = self )
580- if self ._keyspace :
581- connection .set_keyspace_blocking (self ._keyspace )
582- self ._connections [connection .shard_id ] = connection
583- except Exception :
584- log .warning ("Failed reconnecting %s. Retrying." % (self .host .endpoint ,))
585- self ._session .submit (self ._replace , connection )
586- else :
587- with self ._lock :
588588 self ._is_replacing = False
589589 self ._stream_available_condition .notify ()
590590
@@ -597,11 +597,14 @@ def shutdown(self):
597597 self .is_shutdown = True
598598 self ._stream_available_condition .notify_all ()
599599
600- if self ._connections :
601- for c in self ._connections .values ():
602- log .debug ("Closing connection (%s) to %s" , id (c ), self .host )
603- c .close ()
604- self ._connections = {}
600+ for future in self ._shard_connections_futures :
601+ future .cancel ()
602+
603+ if self ._connections :
604+ for connection in self ._connections .values ():
605+ log .debug ("Closing connection (%s) to %s" , id (connection ), self .host )
606+ connection .close ()
607+ self ._connections .clear ()
605608
606609 self ._close_excess_connections ()
607610
@@ -620,7 +623,7 @@ def _close_excess_connections(self):
620623 if not self ._excess_connections :
621624 return
622625 conns = self ._excess_connections
623- self ._excess_connections = set ()
626+ self ._excess_connections . clear ()
624627
625628 for c in conns :
626629 log .debug ("Closing excess connection (%s) to %s" , id (c ), self .host )
@@ -653,7 +656,9 @@ def _open_connection_to_missing_shard(self, shard_id):
653656 if self .is_shutdown :
654657 log .debug ("Pool for host %s is in shutdown, closing the new connection (%s)" , id (conn ), self .host )
655658 conn .close ()
656- elif conn .shard_id not in self ._connections .keys () or self ._connections [conn .shard_id ].orphaned_threshold_reached :
659+ return
660+ old_conn = self ._connections .get (conn .shard_id )
661+ if old_conn is None or old_conn .orphaned_threshold_reached :
657662 log .debug (
658663 "New connection (%s) created to shard_id=%i on host %s" ,
659664 id (conn ),
@@ -698,7 +703,8 @@ def _open_connection_to_missing_shard(self, shard_id):
698703 else :
699704 self ._trash .add (old_conn )
700705 if self ._keyspace :
701- self ._connections [conn .shard_id ].set_keyspace_blocking (self ._keyspace )
706+ if old_conn := self ._connections .get (conn .shard_id ):
707+ old_conn .set_keyspace_blocking (self ._keyspace )
702708 num_missing_or_needing_replacement = self .num_missing_or_needing_replacement
703709 log .debug (
704710 "Connected to %s/%i shards on host %s (%i missing or needs replacement)" ,
@@ -750,9 +756,11 @@ def _open_connections_for_all_shards(self):
750756 if self .is_shutdown :
751757 return
752758
753- for shard_id in range (self .host .sharding_info .shards_count ):
754- self ._connecting .add (shard_id )
755- self ._session .submit (self ._open_connection_to_missing_shard , shard_id )
759+ for shard_id in range (self .host .sharding_info .shards_count ):
760+ future = self ._session .submit (self ._open_connection_to_missing_shard , shard_id )
761+ if isinstance (future , Future ):
762+ self ._connecting .add (shard_id )
763+ self ._shard_connections_futures .append (future )
756764
757765 def _set_keyspace_for_all_conns (self , keyspace , callback ):
758766 """
@@ -779,15 +787,15 @@ def connection_finished_setting_keyspace(conn, error):
779787 callback (self , errors )
780788
781789 self ._keyspace = keyspace
782- for conn in self ._connections .values ():
790+ for conn in list ( self ._connections .values () ):
783791 conn .set_keyspace_async (keyspace , connection_finished_setting_keyspace )
784792
785793 def get_connections (self ):
786- c = self ._connections
787- return list (self . _connections . values ()) if c else []
794+ connections = self ._connections
795+ return list (connections . values ()) if connections else []
788796
789797 def get_state (self ):
790- in_flights = [c .in_flight for c in self ._connections .values ()]
798+ in_flights = [c .in_flight for c in list ( self ._connections .values () )]
791799 return {'shutdown' : self .is_shutdown , 'open_count' : self .open_count , 'in_flights' : in_flights }
792800
793801 @property
@@ -797,7 +805,7 @@ def num_missing_or_needing_replacement(self):
797805
798806 @property
799807 def open_count (self ):
800- return sum ([1 if c and not (c .is_closed or c .is_defunct ) else 0 for c in self ._connections .values ()])
808+ return sum ([1 if c and not (c .is_closed or c .is_defunct ) else 0 for c in list ( self ._connections .values () )])
801809
802810 @property
803811 def _excess_connection_limit (self ):
0 commit comments