Skip to content
Prev Previous commit
Next Next commit
gh-142830: prevent crashes when replacing sqlite3 callbacks
  • Loading branch information
picnixz committed Dec 28, 2025
commit 4dd06525f881d8a19a5b1b6a1334db1d60d2dd32
99 changes: 99 additions & 0 deletions Lib/test/test_sqlite3/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2029,5 +2029,104 @@ def test_row_is_a_sequence(self):
self.assertIsInstance(row, Sequence)


class CallbackTests(unittest.TestCase):

def setUp(self):
super().setUp()
self.cx = sqlite.connect(":memory:")
self.addCleanup(self.cx.close)
self.cu = self.cx.cursor()
self.cu.execute("create table test(a number)")

class Handler:
cx = self.cx

self.handler_class = Handler

def assert_not_authorized(self, func, /, *args, **kwargs):
with self.assertRaisesRegex(sqlite.DatabaseError, "not authorized"):
func(*args, **kwargs)

def assert_interrupted(self, func, /, *args, **kwargs):
with self.assertRaisesRegex(sqlite.OperationalError, "interrupted"):
func(*args, **kwargs)

def assert_invalid_trace(self, func, /, *args, **kwargs):
# Exception in trace callbacks are entirely suppressed.
pass

# When a handler has an invalid signature, the exception raised is
# the same that would be raised if the handler "negatively" replied.

def test_authorizer_invalid_signature(self):
self.cx.set_authorizer(lambda: None)
self.assert_not_authorized(self.cx.execute, "select * from test")

def test_progress_handler_invalid_signature(self):
self.cx.set_progress_handler(lambda x: None, 1)
self.assert_interrupted(self.cx.execute, "select * from test")

def test_trace_callback_invalid_signature_traceback(self):
self.cx.set_trace_callback(lambda: None)
self.assert_invalid_trace(self.cx.execute, "select * from test")

# Tests for checking that callback context mutations do not crash.
# Regression tests for https://github.com/python/cpython/issues/142830.

def test_authorizer_concurrent_mutation_in_call(self):
class Handler(self.handler_class):
def __call__(self, *a, **kw):
self.cx.set_authorizer(None)
raise ValueError

self.cx.set_authorizer(Handler())
self.assert_not_authorized(self.cx.execute, "select * from test")

def test_authorizer_concurrent_mutation_with_overflown_value(self):
_testcapi = import_helper.import_module("_testcapi")

class Handler(self.handler_class):
def __call__(self, *a, **kw):
self.cx.set_authorizer(None)
# We expect 'int' at the C level, so this one will raise
# when converting via PyLong_Int().
return _testcapi.INT_MAX + 1

self.cx.set_authorizer(Handler())
self.assert_not_authorized(self.cx.execute, "select * from test")

def test_progress_handler_concurrent_mutation_in_call(self):
class Handler(self.handler_class):
def __call__(self, *a, **kw):
self.cx.set_authorizer(None)
raise ValueError

self.cx.set_progress_handler(Handler(), 1)
self.assert_interrupted(self.cx.execute, "select * from test")

def test_progress_handler_concurrent_mutation_in_conversion(self):
class Handler(self.handler_class):
def __bool__(self):
# clear the progress handler
self.cx.set_progress_handler(None, 1)
raise ValueError # force PyObject_True() to fail

self.cx.set_progress_handler(Handler.__init__, 1)
self.assert_interrupted(self.cx.execute, "select * from test")

def test_trace_callback_concurrent_mutation_in_call(self):
class Handler:
def __call__(self, statement):
# clear the progress handler
self.cx.set_progress_handler(None, 1)
Comment thread
picnixz marked this conversation as resolved.
Outdated
raise ValueError

self.cx.set_trace_callback(Handler())
self.assert_invalid_trace(self.cx.execute, "select * from test")

# TODO(picnixz): increase test coverage for other callbacks
# such as 'func', 'step', 'finalize', and 'collation'.


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
:mod:`sqlite3`: fix use-after-free crashes when the connection's callbacks
are mutated during a callback execution. Patch by Bénédikt Tran.
41 changes: 34 additions & 7 deletions Modules/_sqlite/connection.c
Original file line number Diff line number Diff line change
Expand Up @@ -914,7 +914,9 @@ func_callback(sqlite3_context *context, int argc, sqlite3_value **argv)
if (args) {
pysqlite_CallbackContext *ctx = sqlite3_user_data(context);
assert(ctx != NULL);
Comment thread
picnixz marked this conversation as resolved.
Outdated
Py_INCREF(ctx);
py_retval = PyObject_CallObject(ctx->callable, args);
Py_DECREF(ctx);
Py_DECREF(args);
}

Expand Down Expand Up @@ -942,6 +944,8 @@ step_callback(sqlite3_context *context, int argc, sqlite3_value **params)

pysqlite_CallbackContext *ctx = sqlite3_user_data(context);
assert(ctx != NULL);
// Hold a reference to 'ctx' to prevent concurrent mutations.
Py_INCREF(ctx);

aggregate_instance = (PyObject**)sqlite3_aggregate_context(context, sizeof(PyObject*));
if (aggregate_instance == NULL) {
Expand Down Expand Up @@ -971,6 +975,7 @@ step_callback(sqlite3_context *context, int argc, sqlite3_value **params)
}

function_result = PyObject_CallObject(stepmethod, args);
Py_CLEAR(ctx);
Py_DECREF(args);

if (!function_result) {
Expand All @@ -979,6 +984,7 @@ step_callback(sqlite3_context *context, int argc, sqlite3_value **params)
}

error:
Py_XDECREF(ctx);
Py_XDECREF(stepmethod);
Py_XDECREF(function_result);

Expand Down Expand Up @@ -1011,8 +1017,10 @@ final_callback(sqlite3_context *context)

pysqlite_CallbackContext *ctx = sqlite3_user_data(context);
assert(ctx != NULL);
Py_INCREF(ctx);
function_result = PyObject_CallMethodNoArgs(*aggregate_instance,
ctx->state->str_finalize);
Py_DECREF(ctx);
Py_DECREF(*aggregate_instance);

ok = 0;
Expand Down Expand Up @@ -1163,6 +1171,8 @@ inverse_callback(sqlite3_context *context, int argc, sqlite3_value **params)

pysqlite_CallbackContext *ctx = sqlite3_user_data(context);
assert(ctx != NULL);
// Hold a reference to 'ctx' to prevent concurrent mutations.
Py_INCREF(ctx);

int size = sizeof(PyObject *);
PyObject **cls = (PyObject **)sqlite3_aggregate_context(context, size);
Expand Down Expand Up @@ -1191,9 +1201,11 @@ inverse_callback(sqlite3_context *context, int argc, sqlite3_value **params)
"user-defined aggregate's 'inverse' method raised error");
goto exit;
}
Py_CLEAR(ctx);
Py_DECREF(res);

exit:
Py_XDECREF(ctx);
Py_XDECREF(method);
PyGILState_Release(gilstate);
}
Expand All @@ -1217,7 +1229,10 @@ value_callback(sqlite3_context *context)
assert(cls != NULL);
assert(*cls != NULL);

Py_INCREF(ctx);
PyObject *res = PyObject_CallMethodNoArgs(*cls, ctx->state->str_value);
Py_DECREF(ctx);

if (res == NULL) {
int attr_err = PyErr_ExceptionMatches(PyExc_AttributeError);
set_sqlite_error(context, attr_err
Expand Down Expand Up @@ -1360,10 +1375,11 @@ authorizer_callback(void *ctx_vp, int action, const char *arg1,

assert(ctx_vp != NULL);
pysqlite_CallbackContext *ctx = pysqlite_CallbackContext_CAST(ctx_vp);
PyObject *callable = ctx->callable;
ret = PyObject_CallFunction(callable, "issss", action, arg1, arg2, dbname,
access_attempt_source);
// Hold a reference to 'ctx' to prevent concurrent mutations.
Py_INCREF(ctx);

ret = PyObject_CallFunction(ctx->callable, "issss", action, arg1, arg2,
dbname, access_attempt_source);
if (ret == NULL) {
print_or_clear_traceback(ctx);
rc = SQLITE_DENY;
Expand All @@ -1381,6 +1397,7 @@ authorizer_callback(void *ctx_vp, int action, const char *arg1,
}
Py_DECREF(ret);
}
Py_DECREF(ctx);

PyGILState_Release(gilstate);
return rc;
Expand All @@ -1396,8 +1413,10 @@ progress_callback(void *ctx_vp)

assert(ctx_vp != NULL);
pysqlite_CallbackContext *ctx = pysqlite_CallbackContext_CAST(ctx_vp);
PyObject *callable = ctx->callable;
ret = PyObject_CallNoArgs(callable);
// Hold a reference to 'ctx' to prevent concurrent mutations.
Py_INCREF(ctx);

ret = PyObject_CallNoArgs(ctx->callable);
if (!ret) {
/* abort query if error occurred */
rc = -1;
Expand All @@ -1409,7 +1428,7 @@ progress_callback(void *ctx_vp)
if (rc < 0) {
print_or_clear_traceback(ctx);
}

Py_DECREF(ctx);
PyGILState_Release(gilstate);
return rc;
}
Expand Down Expand Up @@ -1455,7 +1474,9 @@ trace_callback(unsigned int type, void *ctx_vp, void *stmt, void *sql)
sqlite3_free((void *)expanded_sql);
}
if (py_statement) {
Py_INCREF(ctx);
PyObject *ret = PyObject_CallOneArg(ctx->callable, py_statement);
Py_DECREF(ctx);
Py_DECREF(py_statement);
Py_XDECREF(ret);
}
Expand Down Expand Up @@ -1889,6 +1910,7 @@ collation_callback(void *context, int text1_length, const void *text1_data,
{
PyGILState_STATE gilstate = PyGILState_Ensure();

pysqlite_CallbackContext *ctx = NULL;
PyObject* string1 = 0;
PyObject* string2 = 0;
PyObject* retval = NULL;
Expand All @@ -1910,8 +1932,11 @@ collation_callback(void *context, int text1_length, const void *text1_data,
goto finally;
}

pysqlite_CallbackContext *ctx = pysqlite_CallbackContext_CAST(context);
ctx = pysqlite_CallbackContext_CAST(context);
assert(ctx != NULL);
// Hold a reference to 'ctx' to prevent concurrent mutations.
Py_INCREF(ctx);

PyObject *args[] = { NULL, string1, string2 }; // Borrowed refs.
size_t nargsf = 2 | PY_VECTORCALL_ARGUMENTS_OFFSET;
retval = PyObject_Vectorcall(ctx->callable, args + 1, nargsf, NULL);
Expand All @@ -1931,8 +1956,10 @@ collation_callback(void *context, int text1_length, const void *text1_data,
else if (longval < 0)
result = -1;
}
Py_CLEAR(ctx);

finally:
Py_XDECREF(ctx);
Py_XDECREF(string1);
Py_XDECREF(string2);
Py_XDECREF(retval);
Expand Down
Loading