From db3c73d8e950dc44159040ab1e221d91d98c5fa6 Mon Sep 17 00:00:00 2001 From: Sergey Fedoseev Date: Thu, 5 Jul 2018 15:35:02 +0500 Subject: [PATCH 1/6] bpo-34052: Prevent SQLite functions from setting callbacks on exceptions. --- Modules/_sqlite/connection.c | 34 ++++++++++++++++------------------ 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c index 89a87518978003..f970d90ba38f8c 100644 --- a/Modules/_sqlite/connection.c +++ b/Modules/_sqlite/connection.c @@ -843,7 +843,9 @@ PyObject* pysqlite_connection_create_function(pysqlite_Connection* self, PyObjec flags |= SQLITE_DETERMINISTIC; #endif } - + if (PyDict_SetItem(self->function_pinboard, func, Py_None) == -1) { + return NULL; + } rc = sqlite3_create_function(self->db, name, narg, @@ -854,15 +856,12 @@ PyObject* pysqlite_connection_create_function(pysqlite_Connection* self, PyObjec NULL); if (rc != SQLITE_OK) { + PyDict_DelItem(self->function_pinboard, func); /* Workaround for SQLite bug: no error code or string is available here */ PyErr_SetString(pysqlite_OperationalError, "Error creating function"); return NULL; - } else { - if (PyDict_SetItem(self->function_pinboard, func, Py_None) == -1) - return NULL; - - Py_RETURN_NONE; } + Py_RETURN_NONE; } PyObject* pysqlite_connection_create_aggregate(pysqlite_Connection* self, PyObject* args, PyObject* kwargs) @@ -883,17 +882,17 @@ PyObject* pysqlite_connection_create_aggregate(pysqlite_Connection* self, PyObje return NULL; } + if (PyDict_SetItem(self->function_pinboard, aggregate_class, Py_None) == -1) { + return NULL; + } rc = sqlite3_create_function(self->db, name, n_arg, SQLITE_UTF8, (void*)aggregate_class, 0, &_pysqlite_step_callback, &_pysqlite_final_callback); if (rc != SQLITE_OK) { + PyDict_DelItem(self->function_pinboard, aggregate_class); /* Workaround for SQLite bug: no error code or string is available here */ PyErr_SetString(pysqlite_OperationalError, "Error creating aggregate"); return NULL; - } else { - if (PyDict_SetItem(self->function_pinboard, aggregate_class, Py_None) == -1) - return NULL; - - Py_RETURN_NONE; } + Py_RETURN_NONE; } static int _authorizer_callback(void* user_arg, int action, const char* arg1, const char* arg2 , const char* dbname, const char* access_attempt_source) @@ -1006,17 +1005,16 @@ static PyObject* pysqlite_connection_set_authorizer(pysqlite_Connection* self, P return NULL; } + if (PyDict_SetItem(self->function_pinboard, authorizer_cb, Py_None) == -1) { + return NULL; + } rc = sqlite3_set_authorizer(self->db, _authorizer_callback, (void*)authorizer_cb); - if (rc != SQLITE_OK) { + PyDict_DelItem(self->function_pinboard, authorizer_cb); PyErr_SetString(pysqlite_OperationalError, "Error setting authorizer callback"); return NULL; - } else { - if (PyDict_SetItem(self->function_pinboard, authorizer_cb, Py_None) == -1) - return NULL; - - Py_RETURN_NONE; } + Py_RETURN_NONE; } static PyObject* pysqlite_connection_set_progress_handler(pysqlite_Connection* self, PyObject* args, PyObject* kwargs) @@ -1039,9 +1037,9 @@ static PyObject* pysqlite_connection_set_progress_handler(pysqlite_Connection* s /* None clears the progress handler previously set */ sqlite3_progress_handler(self->db, 0, 0, (void*)0); } else { - sqlite3_progress_handler(self->db, n, _progress_handler, progress_handler); if (PyDict_SetItem(self->function_pinboard, progress_handler, Py_None) == -1) return NULL; + sqlite3_progress_handler(self->db, n, _progress_handler, progress_handler); } Py_RETURN_NONE; From 1d00a99845510a1b88fc5a5d1ab5a4e8bd4f0ad1 Mon Sep 17 00:00:00 2001 From: Sergey Fedoseev Date: Mon, 16 Jul 2018 18:21:56 +0500 Subject: [PATCH 2/6] Do not remove functions from the dict on fail. --- Modules/_sqlite/connection.c | 3 --- 1 file changed, 3 deletions(-) diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c index f970d90ba38f8c..d25283677182fd 100644 --- a/Modules/_sqlite/connection.c +++ b/Modules/_sqlite/connection.c @@ -856,7 +856,6 @@ PyObject* pysqlite_connection_create_function(pysqlite_Connection* self, PyObjec NULL); if (rc != SQLITE_OK) { - PyDict_DelItem(self->function_pinboard, func); /* Workaround for SQLite bug: no error code or string is available here */ PyErr_SetString(pysqlite_OperationalError, "Error creating function"); return NULL; @@ -887,7 +886,6 @@ PyObject* pysqlite_connection_create_aggregate(pysqlite_Connection* self, PyObje } rc = sqlite3_create_function(self->db, name, n_arg, SQLITE_UTF8, (void*)aggregate_class, 0, &_pysqlite_step_callback, &_pysqlite_final_callback); if (rc != SQLITE_OK) { - PyDict_DelItem(self->function_pinboard, aggregate_class); /* Workaround for SQLite bug: no error code or string is available here */ PyErr_SetString(pysqlite_OperationalError, "Error creating aggregate"); return NULL; @@ -1010,7 +1008,6 @@ static PyObject* pysqlite_connection_set_authorizer(pysqlite_Connection* self, P } rc = sqlite3_set_authorizer(self->db, _authorizer_callback, (void*)authorizer_cb); if (rc != SQLITE_OK) { - PyDict_DelItem(self->function_pinboard, authorizer_cb); PyErr_SetString(pysqlite_OperationalError, "Error setting authorizer callback"); return NULL; } From efbeb1d4ff4f2b1e0ff5ed4127cf25607989409b Mon Sep 17 00:00:00 2001 From: Sergey Fedoseev Date: Mon, 23 Jul 2018 15:21:15 +0500 Subject: [PATCH 3/6] Added tests. --- Lib/sqlite3/test/hooks.py | 18 ++++++++++++++ Lib/sqlite3/test/userfunctions.py | 39 +++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/Lib/sqlite3/test/hooks.py b/Lib/sqlite3/test/hooks.py index d74e74bf272275..dc2b7c44a37d89 100644 --- a/Lib/sqlite3/test/hooks.py +++ b/Lib/sqlite3/test/hooks.py @@ -201,6 +201,24 @@ def progress(): con.execute("select 1 union select 2 union select 3").fetchall() self.assertEqual(action, 0, "progress handler was not cleared") + def CheckProgressHandlerUnhashable(self): + progress_calls = [] + class UnhashableFunc: + __hash__ = None + + def __call__(*args, **kwargs): + progress_calls.append(None) + return 0 + + con = sqlite.connect(":memory:") + with self.assertRaisesRegex(TypeError, "unhashable type"): + con.set_progress_handler(UnhashableFunc(), 1) + con.execute(""" + create table foo(a, b) + """) + self.assertFalse(progress_calls) + + class TraceCallbackTests(unittest.TestCase): def CheckTraceCallbackUsed(self): """ diff --git a/Lib/sqlite3/test/userfunctions.py b/Lib/sqlite3/test/userfunctions.py index 9501f535c49999..91f517c76b447e 100644 --- a/Lib/sqlite3/test/userfunctions.py +++ b/Lib/sqlite3/test/userfunctions.py @@ -59,6 +59,17 @@ def func_islonglong(v): def func(*args): return len(args) + +class UnhashableFunc: + __hash__ = None + + def __init__(self, return_value=None): + self.return_value = return_value + + def __call__(self, *args, **kwargs): + return self.return_value + + class AggrNoStep: def __init__(self): pass @@ -298,6 +309,14 @@ def CheckFuncDeterministicKeywordOnly(self): with self.assertRaises(TypeError): self.con.create_function("deterministic", 0, int, True) + def CheckFuncUnhashable(self): + func_name = "func_name" + with self.assertRaisesRegex(TypeError, "unhashable type"): + self.con.create_function(func_name, 0, UnhashableFunc()) + msg = "no such function: %s" % func_name + with self.assertRaisesRegex(sqlite.OperationalError, msg): + self.con.execute("SELECT %s()" % func_name) + class AggregateTests(unittest.TestCase): def setUp(self): @@ -411,6 +430,18 @@ def CheckAggrCheckAggrSum(self): val = cur.fetchone()[0] self.assertEqual(val, 60) + def CheckAggrUnhashable(self): + class UnhashableType(type): + __hash__ = None + + sqlite.enable_callback_tracebacks(True) + aggr_name = "aggr_name" + with self.assertRaisesRegex(TypeError, "unhashable type"): + self.con.create_aggregate(aggr_name, 0, UnhashableType('Aggr', (), {})) + msg = "no such function: %s" % aggr_name + with self.assertRaisesRegex(sqlite.OperationalError, msg): + self.con.execute("SELECT %s()" % aggr_name) + class AuthorizerTests(unittest.TestCase): @staticmethod def authorizer_cb(action, arg1, arg2, dbname, source): @@ -475,6 +506,13 @@ def authorizer_cb(action, arg1, arg2, dbname, source): return sqlite.SQLITE_OK +class AuthorizerUnhashable(AuthorizerTests): + def setUp(self): + super().setUp() + with self.assertRaisesRegex(TypeError, "unhashable type"): + self.con.set_authorizer(UnhashableFunc(sqlite.SQLITE_OK)) + + def suite(): function_suite = unittest.makeSuite(FunctionTests, "Check") aggregate_suite = unittest.makeSuite(AggregateTests, "Check") @@ -486,6 +524,7 @@ def suite(): unittest.makeSuite(AuthorizerRaiseExceptionTests), unittest.makeSuite(AuthorizerIllegalTypeTests), unittest.makeSuite(AuthorizerLargeIntegerTests), + unittest.makeSuite(AuthorizerUnhashable), )) def test(): From c607ffbb40c4bb6028003debc2f89cd2f07393fa Mon Sep 17 00:00:00 2001 From: Sergey Fedoseev Date: Tue, 24 Jul 2018 16:39:08 +0500 Subject: [PATCH 4/6] Added news entry. --- .../next/Library/2018-07-24-16-37-40.bpo-34052.VbbFAE.rst | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 Misc/NEWS.d/next/Library/2018-07-24-16-37-40.bpo-34052.VbbFAE.rst diff --git a/Misc/NEWS.d/next/Library/2018-07-24-16-37-40.bpo-34052.VbbFAE.rst b/Misc/NEWS.d/next/Library/2018-07-24-16-37-40.bpo-34052.VbbFAE.rst new file mode 100644 index 00000000000000..5aa3cc9a81d7f3 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2018-07-24-16-37-40.bpo-34052.VbbFAE.rst @@ -0,0 +1,7 @@ +:meth:`sqlite3.Connection.create_aggregate`, +:meth:`sqlite3.Connection.create_function`, +:meth:`sqlite3.Connection.set_authorizer`, +:meth:`sqlite3.Connection.set_progress_handler` methods raises TypeError +when unhashable objects are passed as callable. These methods now don't pass +such objects to SQLite API. Previous behavior could lead to segfaults. Patch +by Sergey Fedoseev. From 34da9f4145ecd97e688f67c8ed71c4b1b9dc3196 Mon Sep 17 00:00:00 2001 From: Sergey Fedoseev Date: Tue, 24 Jul 2018 17:11:55 +0500 Subject: [PATCH 5/6] Removed unintentionally committed line. --- Lib/sqlite3/test/userfunctions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/Lib/sqlite3/test/userfunctions.py b/Lib/sqlite3/test/userfunctions.py index 91f517c76b447e..aa651158e273b6 100644 --- a/Lib/sqlite3/test/userfunctions.py +++ b/Lib/sqlite3/test/userfunctions.py @@ -434,7 +434,6 @@ def CheckAggrUnhashable(self): class UnhashableType(type): __hash__ = None - sqlite.enable_callback_tracebacks(True) aggr_name = "aggr_name" with self.assertRaisesRegex(TypeError, "unhashable type"): self.con.create_aggregate(aggr_name, 0, UnhashableType('Aggr', (), {})) From 362433b0ad5859538182233954d718c4c59366f6 Mon Sep 17 00:00:00 2001 From: Sergey Fedoseev Date: Sat, 17 Nov 2018 00:18:04 +0500 Subject: [PATCH 6/6] Reorganized tests. --- Lib/sqlite3/test/hooks.py | 18 ------- Lib/sqlite3/test/regression.py | 83 ++++++++++++++++++++++++------- Lib/sqlite3/test/userfunctions.py | 38 -------------- 3 files changed, 64 insertions(+), 75 deletions(-) diff --git a/Lib/sqlite3/test/hooks.py b/Lib/sqlite3/test/hooks.py index dc2b7c44a37d89..d74e74bf272275 100644 --- a/Lib/sqlite3/test/hooks.py +++ b/Lib/sqlite3/test/hooks.py @@ -201,24 +201,6 @@ def progress(): con.execute("select 1 union select 2 union select 3").fetchall() self.assertEqual(action, 0, "progress handler was not cleared") - def CheckProgressHandlerUnhashable(self): - progress_calls = [] - class UnhashableFunc: - __hash__ = None - - def __call__(*args, **kwargs): - progress_calls.append(None) - return 0 - - con = sqlite.connect(":memory:") - with self.assertRaisesRegex(TypeError, "unhashable type"): - con.set_progress_handler(UnhashableFunc(), 1) - con.execute(""" - create table foo(a, b) - """) - self.assertFalse(progress_calls) - - class TraceCallbackTests(unittest.TestCase): def CheckTraceCallbackUsed(self): """ diff --git a/Lib/sqlite3/test/regression.py b/Lib/sqlite3/test/regression.py index 34cd233535dc16..1c59a3cd31c625 100644 --- a/Lib/sqlite3/test/regression.py +++ b/Lib/sqlite3/test/regression.py @@ -256,24 +256,6 @@ def CheckPragmaAutocommit(self): cur.execute("pragma page_size") row = cur.fetchone() - def CheckSetDict(self): - """ - See http://bugs.python.org/issue7478 - - It was possible to successfully register callbacks that could not be - hashed. Return codes of PyDict_SetItem were not checked properly. - """ - class NotHashable: - def __call__(self, *args, **kw): - pass - def __hash__(self): - raise TypeError() - var = NotHashable() - self.assertRaises(TypeError, self.con.create_function, var) - self.assertRaises(TypeError, self.con.create_aggregate, var) - self.assertRaises(TypeError, self.con.set_authorizer, var) - self.assertRaises(TypeError, self.con.set_progress_handler, var) - def CheckConnectionCall(self): """ Call a connection with a non-string SQL request: check error handling @@ -398,9 +380,72 @@ def callback(*args): support.gc_collect() +class UnhashableFunc: + __hash__ = None + + def __init__(self, return_value=None): + self.calls = 0 + self.return_value = return_value + + def __call__(self, *args, **kwargs): + self.calls += 1 + return self.return_value + + +class UnhashableCallbacksTestCase(unittest.TestCase): + """ + https://bugs.python.org/issue34052 + + Registering unhashable callbacks raises TypeError, callbacks are not + registered in SQLite after such registration attempt. + """ + def setUp(self): + self.con = sqlite.connect(':memory:') + + def tearDown(self): + self.con.close() + + def test_progress_handler(self): + f = UnhashableFunc(return_value=0) + with self.assertRaisesRegex(TypeError, 'unhashable type'): + self.con.set_progress_handler(f, 1) + self.con.execute('SELECT 1') + self.assertFalse(f.calls) + + def test_func(self): + func_name = 'func_name' + f = UnhashableFunc() + with self.assertRaisesRegex(TypeError, 'unhashable type'): + self.con.create_function(func_name, 0, f) + msg = 'no such function: %s' % func_name + with self.assertRaisesRegex(sqlite.OperationalError, msg): + self.con.execute('SELECT %s()' % func_name) + self.assertFalse(f.calls) + + def test_authorizer(self): + f = UnhashableFunc(return_value=sqlite.SQLITE_DENY) + with self.assertRaisesRegex(TypeError, 'unhashable type'): + self.con.set_authorizer(f) + self.con.execute('SELECT 1') + self.assertFalse(f.calls) + + def test_aggr(self): + class UnhashableType(type): + __hash__ = None + aggr_name = 'aggr_name' + with self.assertRaisesRegex(TypeError, 'unhashable type'): + self.con.create_aggregate(aggr_name, 0, UnhashableType('Aggr', (), {})) + msg = 'no such function: %s' % aggr_name + with self.assertRaisesRegex(sqlite.OperationalError, msg): + self.con.execute('SELECT %s()' % aggr_name) + + def suite(): regression_suite = unittest.makeSuite(RegressionTests, "Check") - return unittest.TestSuite((regression_suite,)) + return unittest.TestSuite(( + regression_suite, + unittest.makeSuite(UnhashableCallbacksTestCase), + )) def test(): runner = unittest.TextTestRunner() diff --git a/Lib/sqlite3/test/userfunctions.py b/Lib/sqlite3/test/userfunctions.py index aa651158e273b6..9501f535c49999 100644 --- a/Lib/sqlite3/test/userfunctions.py +++ b/Lib/sqlite3/test/userfunctions.py @@ -59,17 +59,6 @@ def func_islonglong(v): def func(*args): return len(args) - -class UnhashableFunc: - __hash__ = None - - def __init__(self, return_value=None): - self.return_value = return_value - - def __call__(self, *args, **kwargs): - return self.return_value - - class AggrNoStep: def __init__(self): pass @@ -309,14 +298,6 @@ def CheckFuncDeterministicKeywordOnly(self): with self.assertRaises(TypeError): self.con.create_function("deterministic", 0, int, True) - def CheckFuncUnhashable(self): - func_name = "func_name" - with self.assertRaisesRegex(TypeError, "unhashable type"): - self.con.create_function(func_name, 0, UnhashableFunc()) - msg = "no such function: %s" % func_name - with self.assertRaisesRegex(sqlite.OperationalError, msg): - self.con.execute("SELECT %s()" % func_name) - class AggregateTests(unittest.TestCase): def setUp(self): @@ -430,17 +411,6 @@ def CheckAggrCheckAggrSum(self): val = cur.fetchone()[0] self.assertEqual(val, 60) - def CheckAggrUnhashable(self): - class UnhashableType(type): - __hash__ = None - - aggr_name = "aggr_name" - with self.assertRaisesRegex(TypeError, "unhashable type"): - self.con.create_aggregate(aggr_name, 0, UnhashableType('Aggr', (), {})) - msg = "no such function: %s" % aggr_name - with self.assertRaisesRegex(sqlite.OperationalError, msg): - self.con.execute("SELECT %s()" % aggr_name) - class AuthorizerTests(unittest.TestCase): @staticmethod def authorizer_cb(action, arg1, arg2, dbname, source): @@ -505,13 +475,6 @@ def authorizer_cb(action, arg1, arg2, dbname, source): return sqlite.SQLITE_OK -class AuthorizerUnhashable(AuthorizerTests): - def setUp(self): - super().setUp() - with self.assertRaisesRegex(TypeError, "unhashable type"): - self.con.set_authorizer(UnhashableFunc(sqlite.SQLITE_OK)) - - def suite(): function_suite = unittest.makeSuite(FunctionTests, "Check") aggregate_suite = unittest.makeSuite(AggregateTests, "Check") @@ -523,7 +486,6 @@ def suite(): unittest.makeSuite(AuthorizerRaiseExceptionTests), unittest.makeSuite(AuthorizerIllegalTypeTests), unittest.makeSuite(AuthorizerLargeIntegerTests), - unittest.makeSuite(AuthorizerUnhashable), )) def test():