Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -2320,7 +2320,27 @@ def run_add_or_renew_pool():
return False

previous = self._pools.get(host)
self._pools[host] = new_pool
with self._lock:
while new_pool._keyspace != self.keyspace:
self._lock.release()
set_keyspace_event = Event()
errors_returned = []

def callback(pool, errors):
errors_returned.extend(errors)
set_keyspace_event.set()

new_pool._set_keyspace_for_all_conns(self.keyspace, callback)
set_keyspace_event.wait(self.cluster.connect_timeout)
if not set_keyspace_event.is_set() or errors_returned:
log.warning("Failed setting keyspace for pool after keyspace changed during connect: %s", errors_returned)
self.cluster.on_down(host, is_host_addition)
new_pool.shutdown()
self._lock.acquire()
return False
self._lock.acquire()
self._pools[host] = new_pool

log.debug("Added pool for host %s to session", host)
if previous:
previous.shutdown()
Expand Down Expand Up @@ -2397,9 +2417,9 @@ def _set_keyspace_for_all_pools(self, keyspace, callback):
called with a dictionary of all errors that occurred, keyed
by the `Host` that they occurred against.
"""
self.keyspace = keyspace

remaining_callbacks = set(self._pools.values())
with self._lock:
self.keyspace = keyspace
remaining_callbacks = set(self._pools.values())
errors = {}

if not remaining_callbacks:
Expand Down
20 changes: 13 additions & 7 deletions cassandra/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ class HostConnection(object):
_session = None
_connection = None
_lock = None
_keyspace = None

def __init__(self, host, host_distance, session):
self.host = host
Expand All @@ -326,8 +327,9 @@ def __init__(self, host, host_distance, session):

log.debug("Initializing connection for host %s", self.host)
self._connection = session.cluster.connection_factory(host.address)
if session.keyspace:
self._connection.set_keyspace_blocking(session.keyspace)
self._keyspace = session.keyspace
if self._keyspace:
self._connection.set_keyspace_blocking(self._keyspace)
log.debug("Finished initializing connection for host %s", self.host)

def borrow_connection(self, timeout):
Expand Down Expand Up @@ -381,8 +383,8 @@ def _replace(self, connection):
log.debug("Replacing connection (%s) to %s", id(connection), self.host)
try:
conn = self._session.cluster.connection_factory(self.host.address)
if self._session.keyspace:
conn.set_keyspace_blocking(self._session.keyspace)
if self._keyspace:
conn.set_keyspace_blocking(self._keyspace)
self._connection = conn
except Exception:
log.warning("Failed reconnecting %s. Retrying." % (self.host.address,))
Expand Down Expand Up @@ -412,6 +414,7 @@ def connection_finished_setting_keyspace(conn, error):
errors = [] if not error else [error]
callback(self, errors)

self._keyspace = keyspace
self._connection.set_keyspace_async(keyspace, connection_finished_setting_keyspace)

def get_connections(self):
Expand Down Expand Up @@ -445,6 +448,7 @@ class HostConnectionPool(object):
open_count = 0
_scheduled_for_creation = 0
_next_trash_allowed_at = 0
_keyspace = None

def __init__(self, host, host_distance, session):
self.host = host
Expand All @@ -459,9 +463,10 @@ def __init__(self, host, host_distance, session):
self._connections = [session.cluster.connection_factory(host.address)
for i in range(core_conns)]

if session.keyspace:
self._keyspace = session.keyspace
if self._keyspace:
for conn in self._connections:
conn.set_keyspace_blocking(session.keyspace)
conn.set_keyspace_blocking(self._keyspace)

self._trash = set()
self._next_trash_allowed_at = time.time()
Expand Down Expand Up @@ -560,7 +565,7 @@ def _add_conn_if_under_max(self):
log.debug("Going to open new connection to host %s", self.host)
try:
conn = self._session.cluster.connection_factory(self.host.address)
if self._session.keyspace:
if self._keyspace:
conn.set_keyspace_blocking(self._session.keyspace)
self._next_trash_allowed_at = time.time() + _MIN_TRASH_INTERVAL
with self._lock:
Expand Down Expand Up @@ -761,6 +766,7 @@ def connection_finished_setting_keyspace(conn, error):
if not remaining_callbacks:
callback(self, errors)

self._keyspace = keyspace
for conn in self._connections:
conn.set_keyspace_async(keyspace, connection_finished_setting_keyspace)

Expand Down