Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Adapt create_*() functions
  • Loading branch information
Erlend E. Aasland committed Jul 29, 2021
commit 48a954d88e59bf4639f139ed9cfb4c741c4a0685
63 changes: 38 additions & 25 deletions Modules/_sqlite/connection.c
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,9 @@ _pysqlite_func_callback(sqlite3_context *context, int argc, sqlite3_value **argv

threadstate = PyGILState_Ensure();

py_func = (PyObject*)sqlite3_user_data(context);
callback_context *ctx = (callback_context *)sqlite3_user_data(context);
assert(ctx != NULL);
py_func = ctx->obj;

args = _pysqlite_build_py_params(context, argc, argv);
if (args) {
Expand All @@ -631,8 +633,8 @@ _pysqlite_func_callback(sqlite3_context *context, int argc, sqlite3_value **argv
Py_DECREF(py_retval);
}
if (!ok) {
pysqlite_state *state = pysqlite_get_state(NULL);
if (state->enable_callback_tracebacks) {
assert(ctx->state != NULL);
if (ctx->state->enable_callback_tracebacks) {
PyErr_Print();
}
else {
Expand All @@ -656,7 +658,9 @@ static void _pysqlite_step_callback(sqlite3_context *context, int argc, sqlite3_

threadstate = PyGILState_Ensure();

aggregate_class = (PyObject*)sqlite3_user_data(context);
callback_context *ctx = (callback_context *)sqlite3_user_data(context);
assert(ctx != NULL);
aggregate_class = ctx->obj;

aggregate_instance = (PyObject**)sqlite3_aggregate_context(context, sizeof(PyObject*));

Expand All @@ -666,8 +670,8 @@ static void _pysqlite_step_callback(sqlite3_context *context, int argc, sqlite3_
if (PyErr_Occurred()) {
*aggregate_instance = 0;

pysqlite_state *state = pysqlite_get_state(NULL);
if (state->enable_callback_tracebacks) {
assert(ctx->state != NULL);
if (ctx->state->enable_callback_tracebacks) {
PyErr_Print();
}
else {
Expand All @@ -692,8 +696,8 @@ static void _pysqlite_step_callback(sqlite3_context *context, int argc, sqlite3_
Py_DECREF(args);

if (!function_result) {
pysqlite_state *state = pysqlite_get_state(NULL);
if (state->enable_callback_tracebacks) {
assert(ctx->state != NULL);
if (ctx->state->enable_callback_tracebacks) {
PyErr_Print();
}
else {
Expand Down Expand Up @@ -747,8 +751,10 @@ _pysqlite_final_callback(sqlite3_context *context)
Py_DECREF(function_result);
}
if (!ok) {
pysqlite_state *state = pysqlite_get_state(NULL);
if (state->enable_callback_tracebacks) {
callback_context *ctx = (callback_context *)sqlite3_user_data(context);
assert(ctx != NULL);
assert(ctx->state != NULL);
if (ctx->state->enable_callback_tracebacks) {
PyErr_Print();
}
else {
Expand Down Expand Up @@ -853,7 +859,7 @@ static void _destructor(void* args)
// that we destroy 'args' with the GIL
PyGILState_STATE gstate;
gstate = PyGILState_Ensure();
Py_DECREF((PyObject*)args);
free_callback_context((callback_context *)args);
PyGILState_Release(gstate);
}

Expand Down Expand Up @@ -896,11 +902,11 @@ pysqlite_connection_create_function_impl(pysqlite_Connection *self,
flags |= SQLITE_DETERMINISTIC;
#endif
}
rc = sqlite3_create_function_v2(self->db,
name,
narg,
flags,
(void*)Py_NewRef(func),
callback_context *ctx = create_callback_context(self->state, func);
if (ctx == NULL) {
return NULL;
}
rc = sqlite3_create_function_v2(self->db, name, narg, flags, ctx,
_pysqlite_func_callback,
NULL,
NULL,
Expand Down Expand Up @@ -936,11 +942,12 @@ pysqlite_connection_create_aggregate_impl(pysqlite_Connection *self,
return NULL;
}

rc = sqlite3_create_function_v2(self->db,
name,
n_arg,
SQLITE_UTF8,
(void*)Py_NewRef(aggregate_class),
callback_context *ctx = create_callback_context(self->state,
aggregate_class);
if (ctx == NULL) {
return NULL;
}
rc = sqlite3_create_function_v2(self->db, name, n_arg, SQLITE_UTF8, ctx,
0,
&_pysqlite_step_callback,
&_pysqlite_final_callback,
Expand Down Expand Up @@ -1507,7 +1514,6 @@ pysqlite_collation_callback(
int text1_length, const void* text1_data,
int text2_length, const void* text2_data)
{
PyObject* callback = (PyObject*)context;
PyObject* string1 = 0;
PyObject* string2 = 0;
PyGILState_STATE gilstate;
Expand All @@ -1527,6 +1533,9 @@ pysqlite_collation_callback(
goto finally; /* failed to allocate strings */
}

callback_context *ctx = (callback_context *)context;
assert(ctx != NULL);
PyObject *callback = ctx->obj;
PyObject *args[] = { string1, string2 }; // Borrowed refs.
retval = PyObject_Vectorcall(callback, args, 2, NULL);
if (retval == NULL) {
Expand Down Expand Up @@ -1758,6 +1767,7 @@ pysqlite_connection_create_collation_impl(pysqlite_Connection *self,
return NULL;
}

callback_context *ctx = NULL;
int rc;
int flags = SQLITE_UTF8;
if (callable == Py_None) {
Expand All @@ -1769,8 +1779,11 @@ pysqlite_connection_create_collation_impl(pysqlite_Connection *self,
PyErr_SetString(PyExc_TypeError, "parameter must be callable");
return NULL;
}
rc = sqlite3_create_collation_v2(self->db, name, flags,
Py_NewRef(callable),
ctx = create_callback_context(self->state, callable);
if (ctx == NULL) {
return NULL;
}
rc = sqlite3_create_collation_v2(self->db, name, flags, ctx,
&pysqlite_collation_callback,
&_destructor);
}
Expand All @@ -1781,7 +1794,7 @@ pysqlite_connection_create_collation_impl(pysqlite_Connection *self,
* the context before returning.
*/
if (callable != Py_None) {
Py_DECREF(callable);
free_callback_context(ctx);
}
_pysqlite_seterror(self->state, self->db);
return NULL;
Expand Down