Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 38 additions & 16 deletions mssql_python/pybind/connection/connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,13 @@ Connection::Connection(const std::wstring& conn_str, bool use_pool)
}

Connection::~Connection() {
disconnect(); // fallback if user forgets to disconnect
try {
disconnect(); // fallback if user forgets to disconnect
} catch (...) {
Comment thread
subrata-ms marked this conversation as resolved.
Outdated
// Never throw from a destructor — doing so during stack unwinding
// causes std::terminate(). Log and swallow.
LOG_ERROR("Exception suppressed in ~Connection destructor");
}
Comment thread
subrata-ms marked this conversation as resolved.
Outdated
}

// Allocates connection handle
Expand Down Expand Up @@ -99,23 +105,22 @@ void Connection::disconnect() {
// When we free the DBC handle below, the ODBC driver will automatically free
// all child STMT handles. We need to tell the SqlHandle objects about this
// so they don't try to free the handles again during their destruction.

// THREAD-SAFETY: Lock mutex to safely access _childStatementHandles
// This protects against concurrent allocStatementHandle() calls or GC finalizers
{
std::lock_guard<std::mutex> lock(_childHandlesMutex);

// First compact: remove expired weak_ptrs (they're already destroyed)
size_t originalSize = _childStatementHandles.size();
_childStatementHandles.erase(
std::remove_if(_childStatementHandles.begin(), _childStatementHandles.end(),
[](const std::weak_ptr<SqlHandle>& wp) { return wp.expired(); }),
_childStatementHandles.end());

LOG("Compacted child handles: %zu -> %zu (removed %zu expired)",
originalSize, _childStatementHandles.size(),
originalSize - _childStatementHandles.size());


LOG("Compacted child handles: %zu -> %zu (removed %zu expired)", originalSize,
_childStatementHandles.size(), originalSize - _childStatementHandles.size());

LOG("Marking %zu child statement handles as implicitly freed",
_childStatementHandles.size());
for (auto& weakHandle : _childStatementHandles) {
Expand All @@ -124,8 +129,10 @@ void Connection::disconnect() {
// This is guaranteed by allocStatementHandle() which only creates STMT handles
// If this assertion fails, it indicates a serious bug in handle tracking
if (handle->type() != SQL_HANDLE_STMT) {
LOG_ERROR("CRITICAL: Non-STMT handle (type=%d) found in _childStatementHandles. "
"This will cause a handle leak!", handle->type());
LOG_ERROR(
"CRITICAL: Non-STMT handle (type=%d) found in _childStatementHandles. "
"This will cause a handle leak!",
handle->type());
continue; // Skip marking to prevent leak
}
handle->markImplicitlyFreed();
Expand All @@ -136,8 +143,24 @@ void Connection::disconnect() {
} // Release lock before potentially slow SQLDisconnect call

SQLRETURN ret = SQLDisconnect_ptr(_dbcHandle->get());
checkError(ret);
// triggers SQLFreeHandle via destructor, if last owner
if (!SQL_SUCCEEDED(ret)) {
// Log the error but do NOT throw — disconnect must be safe to call
// from destructors, reset() failure paths, and pool cleanup.
// Throwing here during stack unwinding causes std::terminate().
LOG_ERROR("SQLDisconnect failed (ret=%d), forcing handle cleanup", ret);
Comment thread
subrata-ms marked this conversation as resolved.
Outdated

// Best-effort: retrieve and log ODBC diagnostics for debuggability.
// This must not throw, to keep disconnect noexcept-safe.
try {
ErrorInfo err = SQLCheckError_Wrap(SQL_HANDLE_DBC, _dbcHandle, ret);
std::string diagMsg = WideToUTF8(err.ddbcErrorMsg);
LOG_ERROR("SQLDisconnect diagnostics: %s", diagMsg.c_str());
} catch (...) {
// Swallow all exceptions: cleanup paths must not throw.
LOG_ERROR("SQLDisconnect: failed to retrieve ODBC diagnostics");
}
}
// Always free the handle regardless of SQLDisconnect result
_dbcHandle.reset();
} else {
LOG("No connection handle to disconnect");
Expand Down Expand Up @@ -221,7 +244,7 @@ SqlHandlePtr Connection::allocStatementHandle() {
// or GC finalizers running from different threads
{
std::lock_guard<std::mutex> lock(_childHandlesMutex);

// Track this child handle so we can mark it as implicitly freed when connection closes
// Use weak_ptr to avoid circular references and allow normal cleanup
_childStatementHandles.push_back(stmtHandle);
Expand All @@ -237,9 +260,8 @@ SqlHandlePtr Connection::allocStatementHandle() {
[](const std::weak_ptr<SqlHandle>& wp) { return wp.expired(); }),
_childStatementHandles.end());
_allocationsSinceCompaction = 0;
LOG("Periodic compaction: %zu -> %zu handles (removed %zu expired)",
originalSize, _childStatementHandles.size(),
originalSize - _childStatementHandles.size());
LOG("Periodic compaction: %zu -> %zu handles (removed %zu expired)", originalSize,
_childStatementHandles.size(), originalSize - _childStatementHandles.size());
}
} // Release lock

Expand Down
30 changes: 17 additions & 13 deletions mssql_python/pybind/connection/connection_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,24 @@ std::shared_ptr<Connection> ConnectionPool::acquire(const std::wstring& connStr,
auto now = std::chrono::steady_clock::now();
size_t before = _pool.size();

LOG("ConnectionPool::acquire: pool_size=%zu, max_size=%zu, idle_timeout=%d", before,
_max_size, _idle_timeout_secs);

// Phase 1: Remove stale connections, collect for later disconnect
_pool.erase(std::remove_if(_pool.begin(), _pool.end(),
[&](const std::shared_ptr<Connection>& conn) {
auto idle_time =
std::chrono::duration_cast<std::chrono::seconds>(
now - conn->lastUsed())
.count();
if (idle_time > _idle_timeout_secs) {
to_disconnect.push_back(conn);
return true;
}
return false;
}),
_pool.end());
_pool.erase(
std::remove_if(
_pool.begin(), _pool.end(),
[&](const std::shared_ptr<Connection>& conn) {
auto idle_time =
std::chrono::duration_cast<std::chrono::seconds>(now - conn->lastUsed())
.count();
if (idle_time > _idle_timeout_secs) {
to_disconnect.push_back(conn);
return true;
}
return false;
}),
_pool.end());

size_t pruned = before - _pool.size();
_current_size = (_current_size >= pruned) ? (_current_size - pruned) : 0;
Expand Down
4 changes: 1 addition & 3 deletions mssql_python/pybind/ddbc_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4514,10 +4514,8 @@ SQLLEN SQLRowCount_wrap(SqlHandlePtr StatementHandle) {
return rowCount;
}

static std::once_flag pooling_init_flag;
void enable_pooling(int maxSize, int idleTimeout) {
std::call_once(pooling_init_flag,
Comment thread
subrata-ms marked this conversation as resolved.
[&]() { ConnectionPoolManager::getInstance().configure(maxSize, idleTimeout); });
ConnectionPoolManager::getInstance().configure(maxSize, idleTimeout);
}

// Thread-safe decimal separator setting
Expand Down
132 changes: 67 additions & 65 deletions tests/test_009_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,20 +104,16 @@ def test_connection_pooling_isolation_level_reset(conn_str):
# Set isolation level to SERIALIZABLE (non-default)
conn1.set_attr(mssql_python.SQL_ATTR_TXN_ISOLATION, mssql_python.SQL_TXN_SERIALIZABLE)

# Verify the isolation level was set
# Verify the isolation level was set (use DBCC USEROPTIONS to avoid
# requiring VIEW SERVER PERFORMANCE STATE permission for sys.dm_exec_sessions)
cursor1 = conn1.cursor()
cursor1.execute(
"SELECT CASE transaction_isolation_level "
"WHEN 0 THEN 'Unspecified' "
"WHEN 1 THEN 'ReadUncommitted' "
"WHEN 2 THEN 'ReadCommitted' "
"WHEN 3 THEN 'RepeatableRead' "
"WHEN 4 THEN 'Serializable' "
"WHEN 5 THEN 'Snapshot' END AS isolation_level "
"FROM sys.dm_exec_sessions WHERE session_id = @@SPID"
)
isolation_level_1 = cursor1.fetchone()[0]
assert isolation_level_1 == "Serializable", f"Expected Serializable, got {isolation_level_1}"
cursor1.execute("DBCC USEROPTIONS WITH NO_INFOMSGS")
isolation_level_1 = None
for row in cursor1.fetchall():
if row[0] == "isolation level":
isolation_level_1 = row[1]
break
assert isolation_level_1 == "serializable", f"Expected serializable, got {isolation_level_1}"

# Get SPID for verification of connection reuse
cursor1.execute("SELECT @@SPID")
Expand All @@ -138,24 +134,20 @@ def test_connection_pooling_isolation_level_reset(conn_str):
# Verify connection was reused
assert spid1 == spid2, "Connection was not reused from pool"

# Check if isolation level is reset to default
cursor2.execute(
"SELECT CASE transaction_isolation_level "
"WHEN 0 THEN 'Unspecified' "
"WHEN 1 THEN 'ReadUncommitted' "
"WHEN 2 THEN 'ReadCommitted' "
"WHEN 3 THEN 'RepeatableRead' "
"WHEN 4 THEN 'Serializable' "
"WHEN 5 THEN 'Snapshot' END AS isolation_level "
"FROM sys.dm_exec_sessions WHERE session_id = @@SPID"
)
isolation_level_2 = cursor2.fetchone()[0]
# Check if isolation level is reset to default (use DBCC USEROPTIONS to avoid
# requiring VIEW SERVER PERFORMANCE STATE permission for sys.dm_exec_sessions)
cursor2.execute("DBCC USEROPTIONS WITH NO_INFOMSGS")
isolation_level_2 = None
for row in cursor2.fetchall():
if row[0] == "isolation level":
isolation_level_2 = row[1]
break

# Verify isolation level is reset to default (READ COMMITTED)
# This is the CORRECT behavior for connection pooling - we should reset
# session state to prevent settings from one usage affecting the next
assert isolation_level_2 == "ReadCommitted", (
f"Isolation level was not reset! Expected 'ReadCommitted', got '{isolation_level_2}'. "
assert isolation_level_2 == "read committed", (
f"Isolation level was not reset! Expected 'read committed', got '{isolation_level_2}'. "
f"This indicates session state leaked from the previous connection usage."
)

Expand Down Expand Up @@ -278,82 +270,92 @@ def try_overflow():
c.close()


@pytest.mark.skip("Flaky test - idle timeout behavior needs investigation")
def test_pool_idle_timeout_removes_connections(conn_str):
"""Test that idle_timeout removes connections from the pool after the timeout."""
pooling(max_size=2, idle_timeout=1)
conn1 = connect(conn_str)
spid_list = []
cursor1 = conn1.cursor()
# Use @@SPID to identify the connection without requiring
# VIEW SERVER PERFORMANCE STATE permission for sys.dm_exec_connections.
cursor1.execute("SELECT @@SPID")
spid1 = cursor1.fetchone()[0]
spid_list.append(spid1)
conn1.close()

# Wait for longer than idle_timeout
time.sleep(3)
# Wait well beyond the idle_timeout to account for slow CI and integer-second granularity
time.sleep(5)

# Get a new connection, which should not reuse the previous SPID
# Get a new connection — the idle one should have been evicted during acquire()
conn2 = connect(conn_str)
cursor2 = conn2.cursor()
cursor2.execute("SELECT @@SPID")
spid2 = cursor2.fetchone()[0]
spid_list.append(spid2)
conn2.close()

assert spid1 != spid2, "Idle timeout did not remove connection from pool"
assert spid1 != spid2, "Idle timeout did not remove connection from pool — same SPID reused"


# =============================================================================
# Error Handling and Recovery Tests
# =============================================================================


@pytest.mark.skip(
"Test causes fatal crash - forcibly closing underlying connection leads to undefined behavior"
)
def test_pool_removes_invalid_connections(conn_str):
"""Test that the pool removes connections that become invalid (simulate by closing underlying connection)."""
"""Test that the pool removes connections that become invalid and recovers gracefully.

This test simulates a connection being returned to the pool in a dirty state
(with an open transaction) by calling _conn.close() directly, bypassing the
normal Python close() which does a rollback. The pool's acquire() should detect
the bad connection during reset(), discard it, and create a fresh one.
"""
pooling(max_size=1, idle_timeout=30)
conn = connect(conn_str)
cursor = conn.cursor()
cursor.execute("SELECT 1")
# Simulate invalidation by forcibly closing the connection at the driver level
try:
# Try to access a private attribute or method to forcibly close the underlying connection
# This is implementation-specific; if not possible, skip
if hasattr(conn, "_conn") and hasattr(conn._conn, "close"):
conn._conn.close()
else:
pytest.skip("Cannot forcibly close underlying connection for this driver")
except Exception:
pass
# Safely close the connection, ignoring errors due to forced invalidation
cursor.fetchone()

# Record the SPID of the original connection (avoids requiring
# VIEW SERVER PERFORMANCE STATE permission for sys.dm_exec_connections)
cursor.execute("SELECT @@SPID")
original_spid = cursor.fetchone()[0]

# Force-return the connection to the pool WITHOUT rollback.
# This leaves the pooled connection in a dirty state (open implicit transaction)
# which will cause reset() to fail on next acquire().
conn._conn.close()

# Python close() will fail since the underlying handle is already gone
try:
conn.close()
except RuntimeError as e:
if "not initialized" not in str(e):
raise
# Now, get a new connection from the pool and ensure it works
except RuntimeError:
pass

# Now get a new connection the pool should discard the dirty one and create fresh
new_conn = connect(conn_str)
new_cursor = new_conn.cursor()
try:
new_cursor.execute("SELECT 1")
result = new_cursor.fetchone()
assert result is not None and result[0] == 1, "Pool did not remove invalid connection"
finally:
new_conn.close()
new_cursor.execute("SELECT 1")
result = new_cursor.fetchone()
assert result is not None and result[0] == 1, "Pool did not recover from invalid connection"

# Verify it's a different physical connection
new_cursor.execute("SELECT @@SPID")
new_spid = new_cursor.fetchone()[0]
assert (
original_spid != new_spid
), "Expected a new physical connection after pool discarded the dirty one"
Comment thread
subrata-ms marked this conversation as resolved.

new_conn.close()


def test_pool_recovery_after_failed_connection(conn_str):
"""Test that the pool recovers after a failed connection attempt."""
pooling(max_size=1, idle_timeout=30)
# First, try to connect with a bad password (should fail)
if "Pwd=" in conn_str:
bad_conn_str = conn_str.replace("Pwd=", "Pwd=wrongpassword")
elif "Password=" in conn_str:
bad_conn_str = conn_str.replace("Password=", "Password=wrongpassword")
else:
import re

# Replace the value of the first Pwd/Password key-value pair with "wrongpassword"
pattern = re.compile(r"(?i)(Pwd|Password\s*=\s*)([^;]*)")
Comment thread
subrata-ms marked this conversation as resolved.
Outdated
bad_conn_str, num_subs = pattern.subn(lambda m: m.group(1) + "wrongpassword", conn_str, count=1)
if num_subs == 0:
pytest.skip("No password found in connection string to modify")
with pytest.raises(Exception):
connect(bad_conn_str)
Expand Down
Loading