Skip to content

Commit 3a58ead

Browse files
committed
Rewrite sqlite3 UTF8 validation
1 parent a4fab0d commit 3a58ead

1 file changed

Lines changed: 14 additions & 15 deletions

File tree

stdlib/src/sqlite.rs

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ mod _sqlite {
5959
builtins::{
6060
PyBaseException, PyBaseExceptionRef, PyByteArray, PyBytes, PyDict, PyDictRef, PyFloat,
6161
PyInt, PyIntRef, PySlice, PyStr, PyStrRef, PyTuple, PyTupleRef, PyType, PyTypeRef,
62+
PyUtf8Str, PyUtf8StrRef,
6263
},
6364
convert::IntoObject,
6465
function::{
@@ -851,7 +852,7 @@ mod _sqlite {
851852
}
852853

853854
impl Callable for Connection {
854-
type Args = (PyStrRef,);
855+
type Args = (PyUtf8StrRef,);
855856

856857
fn call(zelf: &Py<Self>, args: Self::Args, vm: &VirtualMachine) -> PyResult {
857858
if let Some(stmt) = Statement::new(zelf, args.0, vm)? {
@@ -986,7 +987,7 @@ mod _sqlite {
986987
#[pymethod]
987988
fn execute(
988989
zelf: PyRef<Self>,
989-
sql: PyStrRef,
990+
sql: PyUtf8StrRef,
990991
parameters: OptionalArg<PyObjectRef>,
991992
vm: &VirtualMachine,
992993
) -> PyResult<PyRef<Cursor>> {
@@ -998,7 +999,7 @@ mod _sqlite {
998999
#[pymethod]
9991000
fn executemany(
10001001
zelf: PyRef<Self>,
1001-
sql: PyStrRef,
1002+
sql: PyUtf8StrRef,
10021003
seq_of_params: ArgIterable,
10031004
vm: &VirtualMachine,
10041005
) -> PyResult<PyRef<Cursor>> {
@@ -1010,7 +1011,7 @@ mod _sqlite {
10101011
#[pymethod]
10111012
fn executescript(
10121013
zelf: PyRef<Self>,
1013-
script: PyStrRef,
1014+
script: PyUtf8StrRef,
10141015
vm: &VirtualMachine,
10151016
) -> PyResult<PyRef<Cursor>> {
10161017
let row_factory = zelf.row_factory.to_owned();
@@ -1159,11 +1160,10 @@ mod _sqlite {
11591160
#[pymethod]
11601161
fn create_collation(
11611162
&self,
1162-
name: PyStrRef,
1163+
name: PyUtf8StrRef,
11631164
callable: PyObjectRef,
11641165
vm: &VirtualMachine,
11651166
) -> PyResult<()> {
1166-
name.ensure_valid_utf8(vm)?;
11671167
let name = name.to_cstring(vm)?;
11681168
let db = self.db_lock(vm)?;
11691169
let Some(data) = CallbackData::new(callable.clone(), vm) else {
@@ -1491,7 +1491,7 @@ mod _sqlite {
14911491
#[pymethod]
14921492
fn execute(
14931493
zelf: PyRef<Self>,
1494-
sql: PyStrRef,
1494+
sql: PyUtf8StrRef,
14951495
parameters: OptionalArg<PyObjectRef>,
14961496
vm: &VirtualMachine,
14971497
) -> PyResult<PyRef<Self>> {
@@ -1563,7 +1563,7 @@ mod _sqlite {
15631563
#[pymethod]
15641564
fn executemany(
15651565
zelf: PyRef<Self>,
1566-
sql: PyStrRef,
1566+
sql: PyUtf8StrRef,
15671567
seq_of_params: ArgIterable,
15681568
vm: &VirtualMachine,
15691569
) -> PyResult<PyRef<Self>> {
@@ -1637,11 +1637,9 @@ mod _sqlite {
16371637
#[pymethod]
16381638
fn executescript(
16391639
zelf: PyRef<Self>,
1640-
script: PyStrRef,
1640+
script: PyUtf8StrRef,
16411641
vm: &VirtualMachine,
16421642
) -> PyResult<PyRef<Self>> {
1643-
script.ensure_valid_utf8(vm)?;
1644-
16451643
let db = zelf.connection.db_lock(vm)?;
16461644

16471645
db.sql_limit(script.byte_len(), vm)?;
@@ -2375,10 +2373,9 @@ mod _sqlite {
23752373
impl Statement {
23762374
fn new(
23772375
connection: &Connection,
2378-
sql: PyStrRef,
2376+
sql: PyUtf8StrRef,
23792377
vm: &VirtualMachine,
23802378
) -> PyResult<Option<Self>> {
2381-
let sql = sql.try_into_utf8(vm)?;
23822379
if sql.as_str().contains('\0') {
23832380
return Err(new_programming_error(
23842381
vm,
@@ -2731,6 +2728,7 @@ mod _sqlite {
27312728
let val = val.to_f64();
27322729
unsafe { sqlite3_bind_double(self.st, pos, val) }
27332730
} else if let Some(val) = obj.downcast_ref::<PyStr>() {
2731+
let val = val.try_as_utf8(vm)?;
27342732
let (ptr, len) = str_to_ptr_len(val, vm)?;
27352733
unsafe { sqlite3_bind_text(self.st, pos, ptr, len, SQLITE_TRANSIENT()) }
27362734
} else if let Ok(buffer) = PyBuffer::try_from_borrowed_object(vm, obj) {
@@ -2990,6 +2988,7 @@ mod _sqlite {
29902988
} else if let Some(val) = val.downcast_ref::<PyFloat>() {
29912989
sqlite3_result_double(self.ctx, val.to_f64())
29922990
} else if let Some(val) = val.downcast_ref::<PyStr>() {
2991+
let val = val.try_as_utf8(vm)?;
29932992
let (ptr, len) = str_to_ptr_len(val, vm)?;
29942993
sqlite3_result_text(self.ctx, ptr, len, SQLITE_TRANSIENT())
29952994
} else if let Ok(buffer) = PyBuffer::try_from_borrowed_object(vm, val) {
@@ -3070,8 +3069,8 @@ mod _sqlite {
30703069
}
30713070
}
30723071

3073-
fn str_to_ptr_len(s: &PyStr, vm: &VirtualMachine) -> PyResult<(*const libc::c_char, i32)> {
3074-
let s_str = s.try_to_str(vm)?;
3072+
fn str_to_ptr_len(s: &PyUtf8Str, vm: &VirtualMachine) -> PyResult<(*const libc::c_char, i32)> {
3073+
let s_str = s.as_str();
30753074
let len = c_int::try_from(s_str.len())
30763075
.map_err(|_| vm.new_overflow_error("TEXT longer than INT_MAX bytes"))?;
30773076
let ptr = s_str.as_ptr().cast();

0 commit comments

Comments
 (0)