Skip to content

Commit 64242af

Browse files
authored
Merge pull request scylladb#112 from k0machi/iss111-lock-races
Fix race conditions in HostConnection that happen during shutdown
2 parents d2ab67c + f3753a7 commit 64242af

3 files changed

Lines changed: 103 additions & 39 deletions

File tree

cassandra/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def emit(self, record):
2222

2323
logging.getLogger('cassandra').addHandler(NullHandler())
2424

25-
__version_info__ = (3, 24, 5)
25+
__version_info__ = (3, 24, 6)
2626
__version__ = '.'.join(map(str, __version_info__))
2727

2828

cassandra/pool.py

Lines changed: 43 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""
1616
Connection pooling and host management.
1717
"""
18-
18+
from concurrent.futures import Future
1919
from functools import total_ordering
2020
import logging
2121
import socket
@@ -411,6 +411,7 @@ def __init__(self, host, host_distance, session):
411411
# and are waiting until all requests time out or complete
412412
# so that we can dispose of them.
413413
self._trash = set()
414+
self._shard_connections_futures = []
414415

415416
if host_distance == HostDistance.IGNORED:
416417
log.debug("Not opening connection to ignored host %s", self.host)
@@ -537,9 +538,9 @@ def return_connection(self, connection, stream_was_orphaned=False):
537538
if is_down:
538539
self.shutdown()
539540
else:
540-
connection.close()
541-
del self._connections[connection.shard_id]
542541
with self._lock:
542+
connection.close()
543+
self._connections.pop(connection.shard_id, None)
543544
if self._is_replacing:
544545
return
545546
self._is_replacing = True
@@ -568,23 +569,22 @@ def _replace(self, connection):
568569
if self.is_shutdown:
569570
return
570571

571-
log.debug("Replacing connection (%s) to %s", id(connection), self.host)
572-
try:
573-
if connection.shard_id in self._connections.keys():
574-
del self._connections[connection.shard_id]
575-
if self.host.sharding_info:
576-
self._connecting.add(connection.shard_id)
577-
self._open_connection_to_missing_shard(connection.shard_id)
572+
log.debug("Replacing connection (%s) to %s", id(connection), self.host)
573+
try:
574+
if connection.shard_id in self._connections.keys():
575+
del self._connections[connection.shard_id]
576+
if self.host.sharding_info:
577+
self._connecting.add(connection.shard_id)
578+
self._session.submit(self._open_connection_to_missing_shard, connection.shard_id)
579+
else:
580+
connection = self._session.cluster.connection_factory(self.host.endpoint, owning_pool=self)
581+
if self._keyspace:
582+
connection.set_keyspace_blocking(self._keyspace)
583+
self._connections[connection.shard_id] = connection
584+
except Exception:
585+
log.warning("Failed reconnecting %s. Retrying." % (self.host.endpoint,))
586+
self._session.submit(self._replace, connection)
578587
else:
579-
connection = self._session.cluster.connection_factory(self.host.endpoint, owning_pool=self)
580-
if self._keyspace:
581-
connection.set_keyspace_blocking(self._keyspace)
582-
self._connections[connection.shard_id] = connection
583-
except Exception:
584-
log.warning("Failed reconnecting %s. Retrying." % (self.host.endpoint,))
585-
self._session.submit(self._replace, connection)
586-
else:
587-
with self._lock:
588588
self._is_replacing = False
589589
self._stream_available_condition.notify()
590590

@@ -597,11 +597,14 @@ def shutdown(self):
597597
self.is_shutdown = True
598598
self._stream_available_condition.notify_all()
599599

600-
if self._connections:
601-
for c in self._connections.values():
602-
log.debug("Closing connection (%s) to %s", id(c), self.host)
603-
c.close()
604-
self._connections = {}
600+
for future in self._shard_connections_futures:
601+
future.cancel()
602+
603+
if self._connections:
604+
for connection in self._connections.values():
605+
log.debug("Closing connection (%s) to %s", id(connection), self.host)
606+
connection.close()
607+
self._connections.clear()
605608

606609
self._close_excess_connections()
607610

@@ -620,7 +623,7 @@ def _close_excess_connections(self):
620623
if not self._excess_connections:
621624
return
622625
conns = self._excess_connections
623-
self._excess_connections = set()
626+
self._excess_connections.clear()
624627

625628
for c in conns:
626629
log.debug("Closing excess connection (%s) to %s", id(c), self.host)
@@ -653,7 +656,9 @@ def _open_connection_to_missing_shard(self, shard_id):
653656
if self.is_shutdown:
654657
log.debug("Pool for host %s is in shutdown, closing the new connection (%s)", id(conn), self.host)
655658
conn.close()
656-
elif conn.shard_id not in self._connections.keys() or self._connections[conn.shard_id].orphaned_threshold_reached:
659+
return
660+
old_conn = self._connections.get(conn.shard_id)
661+
if old_conn is None or old_conn.orphaned_threshold_reached:
657662
log.debug(
658663
"New connection (%s) created to shard_id=%i on host %s",
659664
id(conn),
@@ -698,7 +703,8 @@ def _open_connection_to_missing_shard(self, shard_id):
698703
else:
699704
self._trash.add(old_conn)
700705
if self._keyspace:
701-
self._connections[conn.shard_id].set_keyspace_blocking(self._keyspace)
706+
if old_conn := self._connections.get(conn.shard_id):
707+
old_conn.set_keyspace_blocking(self._keyspace)
702708
num_missing_or_needing_replacement = self.num_missing_or_needing_replacement
703709
log.debug(
704710
"Connected to %s/%i shards on host %s (%i missing or needs replacement)",
@@ -750,9 +756,11 @@ def _open_connections_for_all_shards(self):
750756
if self.is_shutdown:
751757
return
752758

753-
for shard_id in range(self.host.sharding_info.shards_count):
754-
self._connecting.add(shard_id)
755-
self._session.submit(self._open_connection_to_missing_shard, shard_id)
759+
for shard_id in range(self.host.sharding_info.shards_count):
760+
future = self._session.submit(self._open_connection_to_missing_shard, shard_id)
761+
if isinstance(future, Future):
762+
self._connecting.add(shard_id)
763+
self._shard_connections_futures.append(future)
756764

757765
def _set_keyspace_for_all_conns(self, keyspace, callback):
758766
"""
@@ -779,15 +787,15 @@ def connection_finished_setting_keyspace(conn, error):
779787
callback(self, errors)
780788

781789
self._keyspace = keyspace
782-
for conn in self._connections.values():
790+
for conn in list(self._connections.values()):
783791
conn.set_keyspace_async(keyspace, connection_finished_setting_keyspace)
784792

785793
def get_connections(self):
786-
c = self._connections
787-
return list(self._connections.values()) if c else []
794+
connections = self._connections
795+
return list(connections.values()) if connections else []
788796

789797
def get_state(self):
790-
in_flights = [c.in_flight for c in self._connections.values()]
798+
in_flights = [c.in_flight for c in list(self._connections.values())]
791799
return {'shutdown': self.is_shutdown, 'open_count': self.open_count, 'in_flights': in_flights}
792800

793801
@property
@@ -797,7 +805,7 @@ def num_missing_or_needing_replacement(self):
797805

798806
@property
799807
def open_count(self):
800-
return sum([1 if c and not (c.is_closed or c.is_defunct) else 0 for c in self._connections.values()])
808+
return sum([1 if c and not (c.is_closed or c.is_defunct) else 0 for c in list(self._connections.values())])
801809

802810
@property
803811
def _excess_connection_limit(self):

tests/unit/test_host_connection_pool.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,19 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from concurrent.futures import ThreadPoolExecutor
15+
import logging
16+
import time
17+
18+
from cassandra.shard_info import _ShardingInfo
1419

1520
try:
1621
import unittest2 as unittest
1722
except ImportError:
18-
import unittest # noqa
23+
import unittest # noqa
24+
import unittest.mock as mock
1925

20-
from mock import Mock, NonCallableMagicMock
26+
from mock import Mock, NonCallableMagicMock, MagicMock
2127
from threading import Thread, Event, Lock
2228

2329
from cassandra.cluster import Session
@@ -26,6 +32,8 @@
2632
from cassandra.pool import Host, NoConnectionsAvailable
2733
from cassandra.policies import HostDistance, SimpleConvictionPolicy
2834

35+
LOGGER = logging.getLogger(__name__)
36+
2937

3038
class _PoolTests(unittest.TestCase):
3139
__test__ = False
@@ -79,7 +87,8 @@ def test_failed_wait_for_connection(self):
7987
def test_successful_wait_for_connection(self):
8088
host = Mock(spec=Host, address='ip1')
8189
session = self.make_session()
82-
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100, lock=Lock())
90+
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100,
91+
lock=Lock())
8392
session.cluster.connection_factory.return_value = conn
8493

8594
pool = self.PoolImpl(host, HostDistance.LOCAL, session)
@@ -266,3 +275,50 @@ class HostConnectionTests(_PoolTests):
266275
PoolImpl = HostConnection
267276
uses_single_connection = True
268277

278+
def test_fast_shutdown(self):
279+
class MockSession(MagicMock):
280+
is_shutdown = False
281+
keyspace = "reprospace"
282+
283+
def __init__(self, *args, **kwargs):
284+
super().__init__(*args, **kwargs)
285+
self.cluster = MagicMock()
286+
self.cluster.executor = ThreadPoolExecutor(max_workers=2, initializer=self.executor_init)
287+
self.cluster.signal_connection_failure = lambda *args, **kwargs: False
288+
self.cluster.connection_factory = self.mock_connection_factory
289+
self.connection_counter = 0
290+
291+
def submit(self, fn, *args, **kwargs):
292+
LOGGER.info("Scheduling %s with args: %s, kwargs: %s", fn, args, kwargs)
293+
if not self.is_shutdown:
294+
return self.cluster.executor.submit(fn, *args, **kwargs)
295+
296+
def mock_connection_factory(self, *args, **kwargs):
297+
connection = MagicMock()
298+
connection.is_shutdown = False
299+
connection.is_defunct = False
300+
connection.is_closed = False
301+
connection.shard_id = self.connection_counter
302+
self.connection_counter += 1
303+
connection.sharding_info = _ShardingInfo(shard_id=1, shards_count=14,
304+
partitioner="", sharding_algorithm="", sharding_ignore_msb=0)
305+
306+
return connection
307+
308+
def executor_init(self, *args):
309+
time.sleep(0.5)
310+
LOGGER.info("Future start: %s", args)
311+
312+
for attempt_num in range(20):
313+
LOGGER.info("Testing fast shutdown %d / 20 times", attempt_num + 1)
314+
host = MagicMock()
315+
host.endpoint = "1.2.3.4"
316+
session = MockSession()
317+
318+
pool = HostConnection(host=host, host_distance=HostDistance.REMOTE, session=session)
319+
LOGGER.info("Initialized pool %s", pool)
320+
LOGGER.info("Connections: %s", pool._connections)
321+
time.sleep(0.5)
322+
pool.shutdown()
323+
time.sleep(3)
324+
session.cluster.executor.shutdown()

0 commit comments

Comments
 (0)