Skip to content

Commit 050db47

Browse files
authored
sqlite3: fix Connection.cursor() factory argument handling (#6783)
Fix test_is_instance in CursorFactoryTests by properly handling the factory argument in Connection.cursor() method. Now the factory can be passed as both positional and keyword argument, and returns the correct subclass type instead of always returning PyRef<Cursor>. - Use FromArgs derive macro with CursorArgs struct for argument parsing - Return PyObjectRef instead of PyRef<Cursor> to allow subclasses - Use fast_issubclass to validate returned cursor is a Cursor subclass - Properly differentiate between 'no argument' and 'None passed'
1 parent 9301ae2 commit 050db47

File tree

2 files changed

+25
-14
lines changed

2 files changed

+25
-14
lines changed

Lib/test/test_sqlite3/test_factory.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,6 @@ def setUp(self):
8080
def tearDown(self):
8181
self.con.close()
8282

83-
# TODO: RUSTPYTHON
84-
@unittest.expectedFailure
8583
def test_is_instance(self):
8684
cur = self.con.cursor()
8785
self.assertIsInstance(cur, sqlite.Cursor)

crates/stdlib/src/sqlite.rs

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,12 @@ mod _sqlite {
425425
name: PyStrRef,
426426
}
427427

428+
#[derive(FromArgs)]
429+
struct CursorArgs {
430+
#[pyarg(any, default)]
431+
factory: OptionalArg<PyObjectRef>,
432+
}
433+
428434
struct CallbackData {
429435
obj: NonNull<PyObject>,
430436
vm: *const VirtualMachine,
@@ -1023,22 +1029,29 @@ mod _sqlite {
10231029
#[pymethod]
10241030
fn cursor(
10251031
zelf: PyRef<Self>,
1026-
factory: OptionalArg<ArgCallable>,
1032+
args: CursorArgs,
10271033
vm: &VirtualMachine,
1028-
) -> PyResult<PyRef<Cursor>> {
1034+
) -> PyResult<PyObjectRef> {
10291035
zelf.db_lock(vm).map(drop)?;
10301036

1031-
let cursor = if let OptionalArg::Present(factory) = factory {
1032-
let cursor = factory.invoke((zelf.clone(),), vm)?;
1033-
let cursor = cursor.downcast::<Cursor>().map_err(|x| {
1034-
vm.new_type_error(format!("factory must return a cursor, not {}", x.class()))
1035-
})?;
1036-
let _ = unsafe { cursor.row_factory.swap(zelf.row_factory.to_owned()) };
1037-
cursor
1038-
} else {
1039-
let row_factory = zelf.row_factory.to_owned();
1040-
Cursor::new(zelf, row_factory, vm).into_ref(&vm.ctx)
1037+
let factory = match args.factory {
1038+
OptionalArg::Present(f) => f,
1039+
OptionalArg::Missing => Cursor::class(&vm.ctx).to_owned().into(),
10411040
};
1041+
1042+
let cursor = factory.call((zelf.clone(),), vm)?;
1043+
1044+
if !cursor.class().fast_issubclass(Cursor::class(&vm.ctx)) {
1045+
return Err(vm.new_type_error(format!(
1046+
"factory must return a cursor, not {}",
1047+
cursor.class()
1048+
)));
1049+
}
1050+
1051+
if let Some(cursor_ref) = cursor.downcast_ref::<Cursor>() {
1052+
let _ = unsafe { cursor_ref.row_factory.swap(zelf.row_factory.to_owned()) };
1053+
}
1054+
10421055
Ok(cursor)
10431056
}
10441057

0 commit comments

Comments
 (0)