From 9700f31fa4662646040015aa260536b9ca9db4dc Mon Sep 17 00:00:00 2001 From: Jeffrey Bosboom Date: Sun, 17 May 2026 13:23:44 -0700 Subject: [PATCH] gh-84943: Add support for 'directonly' and 'innocuous' flags for user-defined functions Co-authored-by: Erlend E. Aasland --- Doc/library/sqlite3.rst | 73 +++++++- .../pycore_global_objects_fini_generated.h | 2 + Include/internal/pycore_global_strings.h | 2 + .../internal/pycore_runtime_init_generated.h | 2 + .../internal/pycore_unicodeobject_generated.h | 8 + Lib/test/test_sqlite3/test_userfunctions.py | 107 +++++++++++ ...6-05-17-11-55-52.gh-issue-89483.l6ShbK.rst | 2 + Modules/_sqlite/clinic/connection.c.h | 168 +++++++++++++++--- Modules/_sqlite/connection.c | 95 +++++++++- 9 files changed, 423 insertions(+), 36 deletions(-) create mode 100644 Misc/NEWS.d/next/Library/2026-05-17-11-55-52.gh-issue-89483.l6ShbK.rst diff --git a/Doc/library/sqlite3.rst b/Doc/library/sqlite3.rst index 484260e63dd5f2..902412e26bedbc 100644 --- a/Doc/library/sqlite3.rst +++ b/Doc/library/sqlite3.rst @@ -699,7 +699,7 @@ Connection objects :meth:`~Cursor.executescript` on it with the given *sql_script*. Return the new cursor object. - .. method:: create_function(name, narg, func, /, *, deterministic=False) + .. method:: create_function(name, narg, func, /, *, deterministic=False, innocuous=False, directonly=False) Create or remove a user-defined SQL function. @@ -722,12 +722,31 @@ Connection objects `deterministic `_, which allows SQLite to perform additional optimizations. + :param bool innocuous: + If ``True``, the created SQL function is marked as + `innocuous `__, + making it usable in views, triggers and schema structures even when + the ``trusted_schema`` pragma is disabled. + + :param bool directonly: + If ``True``, the created SQL function is marked as + `directonly `__, + restricting its use to top-level SQL statements regardless of the + value of the ``trusted_schema`` pragma. + + :raises NotSupportedError: + If called with *innocuous* or *directonly* equal to True on a version + of SQLite older than 3.31.0. + .. versionchanged:: 3.8 Added the *deterministic* parameter. .. versionchanged:: 3.15 The first three parameters are now positional-only. + .. versionchanged:: next + Added the *innocuous* and *directonly* parameters. + Example: .. doctest:: @@ -743,7 +762,7 @@ Connection objects >>> con.close() - .. method:: create_aggregate(name, n_arg, aggregate_class, /) + .. method:: create_aggregate(name, n_arg, aggregate_class, /, *, deterministic=False, innocuous=False, directonly=False) Create or remove a user-defined SQL aggregate function. @@ -767,9 +786,33 @@ Connection objects Set to ``None`` to remove an existing SQL aggregate function. :type aggregate_class: :term:`class` | None + :param bool deterministic: + If ``True``, the created SQL function is marked as + `deterministic `__, + which allows SQLite to perform additional optimizations. + + :param bool innocuous: + If ``True``, the created SQL function is marked as + `innocuous `__, + making it usable in views, triggers and schema structures even when + the ``trusted_schema`` pragma is disabled. + + :param bool directonly: + If ``True``, the created SQL function is marked as + `directonly `__, + restricting its use to top-level SQL statements regardless of the + value of the ``trusted_schema`` pragma. + + :raises NotSupportedError: + If called with *innocuous* or *directonly* equal to True on a version + of SQLite older than 3.31.0. + .. versionchanged:: 3.15 All three parameters are now positional-only. + .. versionchanged:: next + Added the *deterministic*, *innocuous* and *directonly* parameters. + Example: .. testcode:: @@ -800,7 +843,7 @@ Connection objects 3 - .. method:: create_window_function(name, num_params, aggregate_class, /) + .. method:: create_window_function(name, num_params, aggregate_class, /, *, deterministic=False, innocuous=False, directonly=False) Create or remove a user-defined aggregate window function. @@ -825,14 +868,38 @@ Connection objects Set to ``None`` to remove an existing SQL aggregate window function. + :param bool deterministic: + If ``True``, the created SQL function is marked as + `deterministic `__, + which allows SQLite to perform additional optimizations. + + :param bool innocuous: + If ``True``, the created SQL function is marked as + `innocuous `__, + making it usable in views, triggers and schema structures even when + the ``trusted_schema`` pragma is disabled. + + :param bool directonly: + If ``True``, the created SQL function is marked as + `directonly `__, + restricting its use to top-level SQL statements regardless of the + value of the ``trusted_schema`` pragma. + :raises NotSupportedError: If used with a version of SQLite older than 3.25.0, which does not support aggregate window functions. + :raises NotSupportedError: + If called with *innocuous* or *directonly* equal to True on a version + of SQLite older than 3.31.0. + :type aggregate_class: :term:`class` | None .. versionadded:: 3.11 + .. versionchanged:: next + Added the *deterministic*, *innocuous* and *directonly* parameters. + Example: .. testcode:: diff --git a/Include/internal/pycore_global_objects_fini_generated.h b/Include/internal/pycore_global_objects_fini_generated.h index f7d3dcd440aaf1..702c08a273a6ce 100644 --- a/Include/internal/pycore_global_objects_fini_generated.h +++ b/Include/internal/pycore_global_objects_fini_generated.h @@ -1712,6 +1712,7 @@ _PyStaticObjects_CheckRefcnt(PyInterpreterState *interp) { _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(digest_size)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(digestmod)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(dir_fd)); + _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(directonly)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(discard)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(dispatch_table)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(displayhook)); @@ -1836,6 +1837,7 @@ _PyStaticObjects_CheckRefcnt(PyInterpreterState *interp) { _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(initial_value)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(initval)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(inner_size)); + _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(innocuous)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(input)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(insert_comments)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(insert_pis)); diff --git a/Include/internal/pycore_global_strings.h b/Include/internal/pycore_global_strings.h index 22494b1798cc53..cf143dad36c26b 100644 --- a/Include/internal/pycore_global_strings.h +++ b/Include/internal/pycore_global_strings.h @@ -435,6 +435,7 @@ struct _Py_global_strings { STRUCT_FOR_ID(digest_size) STRUCT_FOR_ID(digestmod) STRUCT_FOR_ID(dir_fd) + STRUCT_FOR_ID(directonly) STRUCT_FOR_ID(discard) STRUCT_FOR_ID(dispatch_table) STRUCT_FOR_ID(displayhook) @@ -559,6 +560,7 @@ struct _Py_global_strings { STRUCT_FOR_ID(initial_value) STRUCT_FOR_ID(initval) STRUCT_FOR_ID(inner_size) + STRUCT_FOR_ID(innocuous) STRUCT_FOR_ID(input) STRUCT_FOR_ID(insert_comments) STRUCT_FOR_ID(insert_pis) diff --git a/Include/internal/pycore_runtime_init_generated.h b/Include/internal/pycore_runtime_init_generated.h index 892c3cdd9623a2..a343bd52a8796d 100644 --- a/Include/internal/pycore_runtime_init_generated.h +++ b/Include/internal/pycore_runtime_init_generated.h @@ -1710,6 +1710,7 @@ extern "C" { INIT_ID(digest_size), \ INIT_ID(digestmod), \ INIT_ID(dir_fd), \ + INIT_ID(directonly), \ INIT_ID(discard), \ INIT_ID(dispatch_table), \ INIT_ID(displayhook), \ @@ -1834,6 +1835,7 @@ extern "C" { INIT_ID(initial_value), \ INIT_ID(initval), \ INIT_ID(inner_size), \ + INIT_ID(innocuous), \ INIT_ID(input), \ INIT_ID(insert_comments), \ INIT_ID(insert_pis), \ diff --git a/Include/internal/pycore_unicodeobject_generated.h b/Include/internal/pycore_unicodeobject_generated.h index f0fc3c4f5b0900..08aa8f9d8a5cb6 100644 --- a/Include/internal/pycore_unicodeobject_generated.h +++ b/Include/internal/pycore_unicodeobject_generated.h @@ -1520,6 +1520,10 @@ _PyUnicode_InitStaticStrings(PyInterpreterState *interp) { _PyUnicode_InternStatic(interp, &string); assert(_PyUnicode_CheckConsistency(string, 1)); assert(PyUnicode_GET_LENGTH(string) != 1); + string = &_Py_ID(directonly); + _PyUnicode_InternStatic(interp, &string); + assert(_PyUnicode_CheckConsistency(string, 1)); + assert(PyUnicode_GET_LENGTH(string) != 1); string = &_Py_ID(discard); _PyUnicode_InternStatic(interp, &string); assert(_PyUnicode_CheckConsistency(string, 1)); @@ -2016,6 +2020,10 @@ _PyUnicode_InitStaticStrings(PyInterpreterState *interp) { _PyUnicode_InternStatic(interp, &string); assert(_PyUnicode_CheckConsistency(string, 1)); assert(PyUnicode_GET_LENGTH(string) != 1); + string = &_Py_ID(innocuous); + _PyUnicode_InternStatic(interp, &string); + assert(_PyUnicode_CheckConsistency(string, 1)); + assert(PyUnicode_GET_LENGTH(string) != 1); string = &_Py_ID(input); _PyUnicode_InternStatic(interp, &string); assert(_PyUnicode_CheckConsistency(string, 1)); diff --git a/Lib/test/test_sqlite3/test_userfunctions.py b/Lib/test/test_sqlite3/test_userfunctions.py index 11cf877a011c78..674c63a5803f84 100644 --- a/Lib/test/test_sqlite3/test_userfunctions.py +++ b/Lib/test/test_sqlite3/test_userfunctions.py @@ -363,6 +363,40 @@ def test_func_deterministic_keyword_only(self): with self.assertRaises(TypeError): self.con.create_function("deterministic", 0, int, True) + @unittest.skipIf(sqlite.sqlite_version_info < (3, 31, 0), + "Requires SQLite 3.31.0 or higher") + def test_func_non_innocuous_in_trusted_env(self): + mock = Mock(return_value=None) + self.con.create_function("noninnocuous", 0, mock, innocuous=False) + self.con.execute("pragma trusted_schema = 0") + self.con.execute("create view notallowed as select noninnocuous() = noninnocuous()") + with self.assertRaises(sqlite.OperationalError) as cm: + self.con.execute("select * from notallowed") + self.assertEqual(str(cm.exception), 'unsafe use of noninnocuous()') + + @unittest.skipIf(sqlite.sqlite_version_info < (3, 31, 0), + "Requires SQLite 3.31.0 or higher") + def test_func_innocuous_in_trusted_env(self): + mock = Mock(return_value=None) + self.con.create_function("innocuous", 0, mock, innocuous=True) + self.con.execute("pragma trusted_schema = 0") + self.con.execute("create view allowed as select innocuous() = innocuous()") + self.con.execute("select * from allowed") + self.assertEqual(mock.call_count, 2) + + @unittest.skipIf(sqlite.sqlite_version_info < (3, 31, 0), + "Requires SQLite 3.31.0 or higher") + def test_func_direct_only(self): + mock = Mock(return_value=None) + self.con.create_function("directonly", 0, mock, directonly=True) + self.con.execute("pragma trusted_schema = 1") + self.con.execute("select directonly() = directonly()") + self.assertEqual(mock.call_count, 2) + self.con.execute("create view notallowed as select directonly() = directonly()") + with self.assertRaises(sqlite.OperationalError) as cm: + self.con.execute("select * from notallowed") + self.assertEqual(str(cm.exception), 'unsafe use of directonly()') + def test_function_destructor_via_gc(self): # See bpo-44304: The destructor of the user function can # crash if is called without the GIL from the gc functions @@ -479,6 +513,9 @@ def setUp(self): from test order by x """ self.con.create_window_function("sumint", 1, WindowSumInt) + if sqlite.sqlite_version_info >= (3, 31, 0): + self.con.create_window_function("sumintInnocuous", 1, WindowSumInt, innocuous=True) + self.con.create_window_function("sumintDirectOnly", 1, WindowSumInt, directonly=True) def tearDown(self): self.cur.close() @@ -488,6 +525,34 @@ def test_win_sum_int(self): self.cur.execute(self.query % "sumint") self.assertEqual(self.cur.fetchall(), self.expected) + @unittest.skipIf(sqlite.sqlite_version_info < (3, 31, 0), + "Requires SQLite 3.31.0 or newer") + def test_win_non_innocuous(self): + self.cur.execute("pragma trusted_schema = 0") + self.cur.execute("create view notallowed as " + self.query % "sumint") + with self.assertRaises(sqlite.OperationalError) as cm: + self.cur.execute("select * from notallowed") + self.assertEqual(str(cm.exception), 'unsafe use of sumint()') + + @unittest.skipIf(sqlite.sqlite_version_info < (3, 31, 0), + "Requires SQLite 3.31.0 or newer") + def test_win_innocuous(self): + self.cur.execute("pragma trusted_schema = 0") + self.cur.execute("create view allowed as " + self.query % "sumintInnocuous") + self.cur.execute("select * from allowed") + self.assertEqual(self.cur.fetchall(), self.expected) + + @unittest.skipIf(sqlite.sqlite_version_info < (3, 31, 0), + "Requires SQLite 3.31.0 or newer") + def test_win_directonly(self): + self.cur.execute("pragma trusted_schema = 1") + self.cur.execute("create view notallowed as " + self.query % "sumintDirectOnly") + with self.assertRaises(sqlite.OperationalError) as cm: + self.cur.execute("select * from notallowed") + self.assertEqual(str(cm.exception), 'unsafe use of sumintDirectOnly()') + self.cur.execute(self.query % "sumintDirectOnly") + self.assertEqual(self.cur.fetchall(), self.expected) + def test_win_error_on_create(self): with self.assertRaisesRegex(sqlite.ProgrammingError, "not -100"): self.con.create_window_function("shouldfail", -100, WindowSumInt) @@ -614,6 +679,9 @@ def setUp(self): self.con.create_aggregate("checkTypes", -1, AggrCheckTypes) self.con.create_aggregate("mysum", 1, AggrSum) self.con.create_aggregate("aggtxt", 1, AggrText) + if sqlite.sqlite_version_info >= (3, 31, 0): + self.con.create_aggregate("mysumInnocuous", 1, AggrSum, innocuous=True) + self.con.create_aggregate("mysumDirectOnly", 1, AggrSum, directonly=True) def tearDown(self): self.con.close() @@ -705,6 +773,45 @@ def test_aggr_check_aggr_sum(self): val = cur.fetchone()[0] self.assertEqual(val, 60) + @unittest.skipIf(sqlite.sqlite_version_info < (3, 31, 0), + "Requires SQLite 3.31.0 or newer") + def test_aggr_non_innocuous(self): + cur = self.con.cursor() + cur.execute("pragma trusted_schema = 0") + cur.execute("delete from test") + cur.execute("insert into test(i) values (?)", (10,)) + cur.execute("create view notallowed as select mysum(i) from test") + with self.assertRaises(sqlite.OperationalError) as cm: + cur.execute("select * from notallowed") + self.assertEqual(str(cm.exception), 'unsafe use of mysum()') + + @unittest.skipIf(sqlite.sqlite_version_info < (3, 31, 0), + "Requires SQLite 3.31.0 or newer") + def test_aggr_innocuous(self): + cur = self.con.cursor() + cur.execute("pragma trusted_schema = 0") + cur.execute("delete from test") + cur.executemany("insert into test(i) values (?)", [(10,), (20,), (30,)]) + cur.execute("create view allowed as select mysumInnocuous(i) from test") + cur.execute("select * from allowed") + val = cur.fetchone()[0] + self.assertEqual(val, 60) + + @unittest.skipIf(sqlite.sqlite_version_info < (3, 31, 0), + "Requires SQLite 3.31.0 or newer") + def test_aggr_directonly(self): + cur = self.con.cursor() + cur.execute("pragma trusted_schema = 1") + cur.execute("delete from test") + cur.executemany("insert into test(i) values (?)", [(10,), (20,), (30,)]) + cur.execute("create view notallowed as select mysumDirectOnly(i) from test") + with self.assertRaises(sqlite.OperationalError) as cm: + cur.execute("select * from notallowed") + self.assertEqual(str(cm.exception), 'unsafe use of mysumDirectOnly()') + cur.execute("select mysumDirectOnly(i) from test") + val = cur.fetchone()[0] + self.assertEqual(val, 60) + def test_aggr_no_match(self): cur = self.con.execute("select mysum(i) from (select 1 as i) where i == 0") val = cur.fetchone()[0] diff --git a/Misc/NEWS.d/next/Library/2026-05-17-11-55-52.gh-issue-89483.l6ShbK.rst b/Misc/NEWS.d/next/Library/2026-05-17-11-55-52.gh-issue-89483.l6ShbK.rst new file mode 100644 index 00000000000000..3d9b1310c0d693 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2026-05-17-11-55-52.gh-issue-89483.l6ShbK.rst @@ -0,0 +1,2 @@ +Add support for ``SQLITE_INNOCUOUS`` and ``SQLITE_DIRECTONLY`` flags in +:mod:`sqlite3`. diff --git a/Modules/_sqlite/clinic/connection.c.h b/Modules/_sqlite/clinic/connection.c.h index abb864eb030757..208911b0c51600 100644 --- a/Modules/_sqlite/clinic/connection.c.h +++ b/Modules/_sqlite/clinic/connection.c.h @@ -398,7 +398,8 @@ pysqlite_connection_rollback(PyObject *self, PyObject *Py_UNUSED(ignored)) } PyDoc_STRVAR(pysqlite_connection_create_function__doc__, -"create_function($self, name, narg, func, /, *, deterministic=False)\n" +"create_function($self, name, narg, func, /, *, deterministic=False,\n" +" innocuous=False, directonly=False)\n" "--\n" "\n" "Creates a new function."); @@ -410,7 +411,8 @@ static PyObject * pysqlite_connection_create_function_impl(pysqlite_Connection *self, PyTypeObject *cls, const char *name, int narg, PyObject *func, - int deterministic); + int deterministic, int innocuous, + int directonly); static PyObject * pysqlite_connection_create_function(PyObject *self, PyTypeObject *cls, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) @@ -418,7 +420,7 @@ pysqlite_connection_create_function(PyObject *self, PyTypeObject *cls, PyObject PyObject *return_value = NULL; #if defined(Py_BUILD_CORE) && !defined(Py_BUILD_CORE_MODULE) - #define NUM_KEYWORDS 1 + #define NUM_KEYWORDS 3 static struct { PyGC_Head _this_is_not_used; PyObject_VAR_HEAD @@ -427,7 +429,7 @@ pysqlite_connection_create_function(PyObject *self, PyTypeObject *cls, PyObject } _kwtuple = { .ob_base = PyVarObject_HEAD_INIT(&PyTuple_Type, NUM_KEYWORDS) .ob_hash = -1, - .ob_item = { &_Py_ID(deterministic), }, + .ob_item = { &_Py_ID(deterministic), &_Py_ID(innocuous), &_Py_ID(directonly), }, }; #undef NUM_KEYWORDS #define KWTUPLE (&_kwtuple.ob_base.ob_base) @@ -436,19 +438,21 @@ pysqlite_connection_create_function(PyObject *self, PyTypeObject *cls, PyObject # define KWTUPLE NULL #endif // !Py_BUILD_CORE - static const char * const _keywords[] = {"", "", "", "deterministic", NULL}; + static const char * const _keywords[] = {"", "", "", "deterministic", "innocuous", "directonly", NULL}; static _PyArg_Parser _parser = { .keywords = _keywords, .fname = "create_function", .kwtuple = KWTUPLE, }; #undef KWTUPLE - PyObject *argsbuf[4]; + PyObject *argsbuf[6]; Py_ssize_t noptargs = nargs + (kwnames ? PyTuple_GET_SIZE(kwnames) : 0) - 3; const char *name; int narg; PyObject *func; int deterministic = 0; + int innocuous = 0; + int directonly = 0; args = _PyArg_UnpackKeywords(args, nargs, NULL, kwnames, &_parser, /*minpos*/ 3, /*maxpos*/ 3, /*minkw*/ 0, /*varpos*/ 0, argsbuf); @@ -476,12 +480,30 @@ pysqlite_connection_create_function(PyObject *self, PyTypeObject *cls, PyObject if (!noptargs) { goto skip_optional_kwonly; } - deterministic = PyObject_IsTrue(args[3]); - if (deterministic < 0) { + if (args[3]) { + deterministic = PyObject_IsTrue(args[3]); + if (deterministic < 0) { + goto exit; + } + if (!--noptargs) { + goto skip_optional_kwonly; + } + } + if (args[4]) { + innocuous = PyObject_IsTrue(args[4]); + if (innocuous < 0) { + goto exit; + } + if (!--noptargs) { + goto skip_optional_kwonly; + } + } + directonly = PyObject_IsTrue(args[5]); + if (directonly < 0) { goto exit; } skip_optional_kwonly: - return_value = pysqlite_connection_create_function_impl((pysqlite_Connection *)self, cls, name, narg, func, deterministic); + return_value = pysqlite_connection_create_function_impl((pysqlite_Connection *)self, cls, name, narg, func, deterministic, innocuous, directonly); exit: return return_value; @@ -490,7 +512,9 @@ pysqlite_connection_create_function(PyObject *self, PyTypeObject *cls, PyObject #if defined(HAVE_WINDOW_FUNCTIONS) PyDoc_STRVAR(create_window_function__doc__, -"create_window_function($self, name, num_params, aggregate_class, /)\n" +"create_window_function($self, name, num_params, aggregate_class, /, *,\n" +" deterministic=False, innocuous=False,\n" +" directonly=False)\n" "--\n" "\n" "Creates or redefines an aggregate window function. Non-standard.\n" @@ -510,29 +534,48 @@ PyDoc_STRVAR(create_window_function__doc__, static PyObject * create_window_function_impl(pysqlite_Connection *self, PyTypeObject *cls, const char *name, int num_params, - PyObject *aggregate_class); + PyObject *aggregate_class, int deterministic, + int innocuous, int directonly); static PyObject * create_window_function(PyObject *self, PyTypeObject *cls, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) { PyObject *return_value = NULL; #if defined(Py_BUILD_CORE) && !defined(Py_BUILD_CORE_MODULE) - # define KWTUPLE (PyObject *)&_Py_SINGLETON(tuple_empty) - #else + + #define NUM_KEYWORDS 3 + static struct { + PyGC_Head _this_is_not_used; + PyObject_VAR_HEAD + Py_hash_t ob_hash; + PyObject *ob_item[NUM_KEYWORDS]; + } _kwtuple = { + .ob_base = PyVarObject_HEAD_INIT(&PyTuple_Type, NUM_KEYWORDS) + .ob_hash = -1, + .ob_item = { &_Py_ID(deterministic), &_Py_ID(innocuous), &_Py_ID(directonly), }, + }; + #undef NUM_KEYWORDS + #define KWTUPLE (&_kwtuple.ob_base.ob_base) + + #else // !Py_BUILD_CORE # define KWTUPLE NULL - #endif + #endif // !Py_BUILD_CORE - static const char * const _keywords[] = {"", "", "", NULL}; + static const char * const _keywords[] = {"", "", "", "deterministic", "innocuous", "directonly", NULL}; static _PyArg_Parser _parser = { .keywords = _keywords, .fname = "create_window_function", .kwtuple = KWTUPLE, }; #undef KWTUPLE - PyObject *argsbuf[3]; + PyObject *argsbuf[6]; + Py_ssize_t noptargs = nargs + (kwnames ? PyTuple_GET_SIZE(kwnames) : 0) - 3; const char *name; int num_params; PyObject *aggregate_class; + int deterministic = 0; + int innocuous = 0; + int directonly = 0; args = _PyArg_UnpackKeywords(args, nargs, NULL, kwnames, &_parser, /*minpos*/ 3, /*maxpos*/ 3, /*minkw*/ 0, /*varpos*/ 0, argsbuf); @@ -557,7 +600,33 @@ create_window_function(PyObject *self, PyTypeObject *cls, PyObject *const *args, goto exit; } aggregate_class = args[2]; - return_value = create_window_function_impl((pysqlite_Connection *)self, cls, name, num_params, aggregate_class); + if (!noptargs) { + goto skip_optional_kwonly; + } + if (args[3]) { + deterministic = PyObject_IsTrue(args[3]); + if (deterministic < 0) { + goto exit; + } + if (!--noptargs) { + goto skip_optional_kwonly; + } + } + if (args[4]) { + innocuous = PyObject_IsTrue(args[4]); + if (innocuous < 0) { + goto exit; + } + if (!--noptargs) { + goto skip_optional_kwonly; + } + } + directonly = PyObject_IsTrue(args[5]); + if (directonly < 0) { + goto exit; + } +skip_optional_kwonly: + return_value = create_window_function_impl((pysqlite_Connection *)self, cls, name, num_params, aggregate_class, deterministic, innocuous, directonly); exit: return return_value; @@ -566,7 +635,8 @@ create_window_function(PyObject *self, PyTypeObject *cls, PyObject *const *args, #endif /* defined(HAVE_WINDOW_FUNCTIONS) */ PyDoc_STRVAR(pysqlite_connection_create_aggregate__doc__, -"create_aggregate($self, name, n_arg, aggregate_class, /)\n" +"create_aggregate($self, name, n_arg, aggregate_class, /, *,\n" +" deterministic=False, innocuous=False, directonly=False)\n" "--\n" "\n" "Creates a new aggregate."); @@ -578,29 +648,49 @@ static PyObject * pysqlite_connection_create_aggregate_impl(pysqlite_Connection *self, PyTypeObject *cls, const char *name, int n_arg, - PyObject *aggregate_class); + PyObject *aggregate_class, + int deterministic, int innocuous, + int directonly); static PyObject * pysqlite_connection_create_aggregate(PyObject *self, PyTypeObject *cls, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) { PyObject *return_value = NULL; #if defined(Py_BUILD_CORE) && !defined(Py_BUILD_CORE_MODULE) - # define KWTUPLE (PyObject *)&_Py_SINGLETON(tuple_empty) - #else + + #define NUM_KEYWORDS 3 + static struct { + PyGC_Head _this_is_not_used; + PyObject_VAR_HEAD + Py_hash_t ob_hash; + PyObject *ob_item[NUM_KEYWORDS]; + } _kwtuple = { + .ob_base = PyVarObject_HEAD_INIT(&PyTuple_Type, NUM_KEYWORDS) + .ob_hash = -1, + .ob_item = { &_Py_ID(deterministic), &_Py_ID(innocuous), &_Py_ID(directonly), }, + }; + #undef NUM_KEYWORDS + #define KWTUPLE (&_kwtuple.ob_base.ob_base) + + #else // !Py_BUILD_CORE # define KWTUPLE NULL - #endif + #endif // !Py_BUILD_CORE - static const char * const _keywords[] = {"", "", "", NULL}; + static const char * const _keywords[] = {"", "", "", "deterministic", "innocuous", "directonly", NULL}; static _PyArg_Parser _parser = { .keywords = _keywords, .fname = "create_aggregate", .kwtuple = KWTUPLE, }; #undef KWTUPLE - PyObject *argsbuf[3]; + PyObject *argsbuf[6]; + Py_ssize_t noptargs = nargs + (kwnames ? PyTuple_GET_SIZE(kwnames) : 0) - 3; const char *name; int n_arg; PyObject *aggregate_class; + int deterministic = 0; + int innocuous = 0; + int directonly = 0; args = _PyArg_UnpackKeywords(args, nargs, NULL, kwnames, &_parser, /*minpos*/ 3, /*maxpos*/ 3, /*minkw*/ 0, /*varpos*/ 0, argsbuf); @@ -625,7 +715,33 @@ pysqlite_connection_create_aggregate(PyObject *self, PyTypeObject *cls, PyObject goto exit; } aggregate_class = args[2]; - return_value = pysqlite_connection_create_aggregate_impl((pysqlite_Connection *)self, cls, name, n_arg, aggregate_class); + if (!noptargs) { + goto skip_optional_kwonly; + } + if (args[3]) { + deterministic = PyObject_IsTrue(args[3]); + if (deterministic < 0) { + goto exit; + } + if (!--noptargs) { + goto skip_optional_kwonly; + } + } + if (args[4]) { + innocuous = PyObject_IsTrue(args[4]); + if (innocuous < 0) { + goto exit; + } + if (!--noptargs) { + goto skip_optional_kwonly; + } + } + directonly = PyObject_IsTrue(args[5]); + if (directonly < 0) { + goto exit; + } +skip_optional_kwonly: + return_value = pysqlite_connection_create_aggregate_impl((pysqlite_Connection *)self, cls, name, n_arg, aggregate_class, deterministic, innocuous, directonly); exit: return return_value; @@ -1722,4 +1838,4 @@ getconfig(PyObject *self, PyObject *arg) #ifndef DESERIALIZE_METHODDEF #define DESERIALIZE_METHODDEF #endif /* !defined(DESERIALIZE_METHODDEF) */ -/*[clinic end generated code: output=16d44c1d8a45e622 input=a9049054013a1b77]*/ +/*[clinic end generated code: output=bf2679e8e0c88e60 input=a9049054013a1b77]*/ diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c index bd44ff31b87c67..d8b1bba309d03d 100644 --- a/Modules/_sqlite/connection.c +++ b/Modules/_sqlite/connection.c @@ -1140,6 +1140,48 @@ check_num_params(pysqlite_Connection *self, const int n, const char *name) return 0; } +static int +apply_innocuous_flag_if_supported(pysqlite_Connection *self, + int *flags, int is_set) +{ + if (is_set) { +#if SQLITE_VERSION_NUMBER < 3031000 + PyErr_SetString(self->NotSupportedError, + "innocuous=True requires SQLite 3.31.0 or higher"); + return -1; +#else + if (sqlite3_libversion_number() < 3031000) { + PyErr_SetString(self->NotSupportedError, + "innocuous=True requires SQLite 3.31.0 or higher"); + return -1; + } + *flags |= SQLITE_INNOCUOUS; +#endif + } + return 0; +} + +static int +apply_directonly_flag_if_supported(pysqlite_Connection *self, + int *flags, int is_set) +{ + if (is_set) { +#if SQLITE_VERSION_NUMBER < 3031000 + PyErr_SetString(self->NotSupportedError, + "directonly=True requires SQLite 3.31.0 or higher"); + return -1; +#else + if (sqlite3_libversion_number() < 3031000) { + PyErr_SetString(self->NotSupportedError, + "directonly=True requires SQLite 3.31.0 or higher"); + return -1; + } + *flags |= SQLITE_DIRECTONLY; +#endif + } + return 0; +} + /*[clinic input] _sqlite3.Connection.create_function as pysqlite_connection_create_function @@ -1150,6 +1192,8 @@ _sqlite3.Connection.create_function as pysqlite_connection_create_function / * deterministic: bool = False + innocuous: bool = False + directonly: bool = False Creates a new function. [clinic start generated code]*/ @@ -1158,8 +1202,9 @@ static PyObject * pysqlite_connection_create_function_impl(pysqlite_Connection *self, PyTypeObject *cls, const char *name, int narg, PyObject *func, - int deterministic) -/*[clinic end generated code: output=8a811529287ad240 input=a896096ed5390ae1]*/ + int deterministic, int innocuous, + int directonly) +/*[clinic end generated code: output=f38fe8fd62fe1331 input=474f56fa8e971f7f]*/ { int rc; int flags = SQLITE_UTF8; @@ -1174,6 +1219,12 @@ pysqlite_connection_create_function_impl(pysqlite_Connection *self, if (deterministic) { flags |= SQLITE_DETERMINISTIC; } + if (apply_innocuous_flag_if_supported(self, &flags, innocuous) < 0) { + return NULL; + } + if (apply_directonly_flag_if_supported(self, &flags, directonly) < 0) { + return NULL; + } callback_context *ctx = create_callback_context(cls, func); if (ctx == NULL) { return NULL; @@ -1297,6 +1348,10 @@ _sqlite3.Connection.create_window_function as create_window_function A class with step(), finalize(), value(), and inverse() methods. Set to None to clear the window function. / + * + deterministic: bool = False + innocuous: bool = False + directonly: bool = False Creates or redefines an aggregate window function. Non-standard. [clinic start generated code]*/ @@ -1304,8 +1359,9 @@ Creates or redefines an aggregate window function. Non-standard. static PyObject * create_window_function_impl(pysqlite_Connection *self, PyTypeObject *cls, const char *name, int num_params, - PyObject *aggregate_class) -/*[clinic end generated code: output=5332cd9464522235 input=46d57a54225b5228]*/ + PyObject *aggregate_class, int deterministic, + int innocuous, int directonly) +/*[clinic end generated code: output=9d3081a3b22b83d3 input=ac82d0db9fd8d774]*/ { if (sqlite3_libversion_number() < 3025000) { PyErr_SetString(self->NotSupportedError, @@ -1321,6 +1377,15 @@ create_window_function_impl(pysqlite_Connection *self, PyTypeObject *cls, } int flags = SQLITE_UTF8; + if (deterministic) { + flags |= SQLITE_DETERMINISTIC; + } + if (apply_innocuous_flag_if_supported(self, &flags, innocuous) < 0) { + return NULL; + } + if (apply_directonly_flag_if_supported(self, &flags, directonly) < 0) { + return NULL; + } int rc; if (Py_IsNone(aggregate_class)) { rc = sqlite3_create_window_function(self->db, name, num_params, flags, @@ -1358,6 +1423,10 @@ _sqlite3.Connection.create_aggregate as pysqlite_connection_create_aggregate n_arg: int aggregate_class: object / + * + deterministic: bool = False + innocuous: bool = False + directonly: bool = False Creates a new aggregate. [clinic start generated code]*/ @@ -1366,8 +1435,10 @@ static PyObject * pysqlite_connection_create_aggregate_impl(pysqlite_Connection *self, PyTypeObject *cls, const char *name, int n_arg, - PyObject *aggregate_class) -/*[clinic end generated code: output=1b02d0f0aec7ff96 input=aa2773f6a42f7e17]*/ + PyObject *aggregate_class, + int deterministic, int innocuous, + int directonly) +/*[clinic end generated code: output=9058dcf3da395d17 input=5e3bd0a8266575cd]*/ { int rc; @@ -1378,11 +1449,21 @@ pysqlite_connection_create_aggregate_impl(pysqlite_Connection *self, return NULL; } + int flags = SQLITE_UTF8; + if (deterministic) { + flags |= SQLITE_DETERMINISTIC; + } + if (apply_innocuous_flag_if_supported(self, &flags, innocuous) < 0) { + return NULL; + } + if (apply_directonly_flag_if_supported(self, &flags, directonly) < 0) { + return NULL; + } callback_context *ctx = create_callback_context(cls, aggregate_class); if (ctx == NULL) { return NULL; } - rc = sqlite3_create_function_v2(self->db, name, n_arg, SQLITE_UTF8, ctx, + rc = sqlite3_create_function_v2(self->db, name, n_arg, flags, ctx, 0, &step_callback, &final_callback,