Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions Lib/test/test_sqlite3/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,8 +573,6 @@ def test_connection_reinit(self):
self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows))
self.assertEqual([r[0] for r in rows], ["2", "3"])

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_connection_bad_reinit(self):
cx = sqlite.connect(":memory:")
with cx:
Expand Down
88 changes: 64 additions & 24 deletions crates/stdlib/src/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -833,10 +833,11 @@ mod _sqlite {
#[derive(PyPayload)]
struct Connection {
db: PyMutex<Option<Sqlite>>,
detect_types: c_int,
initialized: PyAtomic<bool>,
detect_types: PyAtomic<c_int>,
isolation_level: PyAtomicRef<Option<PyStr>>,
check_same_thread: bool,
thread_ident: ThreadId,
check_same_thread: PyAtomic<bool>,
thread_ident: PyMutex<ThreadId>, // TODO: Use atomic
row_factory: PyAtomicRef<Option<PyObject>>,
text_factory: PyAtomicRef<PyObject>,
}
Expand Down Expand Up @@ -865,12 +866,15 @@ mod _sqlite {
None
};

let initialized = db.is_some();

let conn = Self {
db: PyMutex::new(db),
detect_types: args.detect_types,
initialized: Radium::new(initialized),
detect_types: Radium::new(args.detect_types),
isolation_level: PyAtomicRef::from(args.isolation_level),
check_same_thread: args.check_same_thread,
thread_ident: std::thread::current().id(),
check_same_thread: Radium::new(args.check_same_thread),
thread_ident: PyMutex::new(std::thread::current().id()),
row_factory: PyAtomicRef::from(None),
text_factory: PyAtomicRef::from(text_factory),
};
Expand Down Expand Up @@ -899,20 +903,51 @@ mod _sqlite {
type Args = ConnectArgs;

fn init(zelf: PyRef<Self>, args: Self::Args, vm: &VirtualMachine) -> PyResult<()> {
let mut guard = zelf.db.lock();
if guard.is_some() {
// Already initialized
return Ok(());
let was_initialized = Radium::swap(&zelf.initialized, false, Ordering::Relaxed);

// Reset factories to their defaults, matching CPython's behavior.
zelf.reset_factories(vm);

if was_initialized {
zelf.drop_db();
}

// Attempt to open the new database before mutating other state so failures leave
// the connection uninitialized (and subsequent operations raise ProgrammingError).
let db = Self::initialize_db(&args, vm)?;

let ConnectArgs {
detect_types,
isolation_level,
check_same_thread,
..
} = args;

zelf.detect_types.store(detect_types, Ordering::Relaxed);
zelf.check_same_thread
.store(check_same_thread, Ordering::Relaxed);
*zelf.thread_ident.lock() = std::thread::current().id();
let _ = unsafe { zelf.isolation_level.swap(isolation_level) };

let mut guard = zelf.db.lock();
*guard = Some(db);
Radium::store(&zelf.initialized, true, Ordering::Relaxed);
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
Ok(())
}
}

#[pyclass(with(Constructor, Callable, Initializer), flags(BASETYPE))]
impl Connection {
fn drop_db(&self) {
self.db.lock().take();
}

fn reset_factories(&self, vm: &VirtualMachine) {
let default_text_factory = PyStr::class(&vm.ctx).to_owned().into_object();
let _ = unsafe { self.row_factory.swap(None) };
let _ = unsafe { self.text_factory.swap(default_text_factory) };
}

fn initialize_db(args: &ConnectArgs, vm: &VirtualMachine) -> PyResult<Sqlite> {
let path = args.database.to_cstring(vm)?;
let db = Sqlite::from(SqliteRaw::open(path.as_ptr(), args.uri, vm)?);
Expand Down Expand Up @@ -1003,7 +1038,7 @@ mod _sqlite {
#[pymethod]
fn close(&self, vm: &VirtualMachine) -> PyResult<()> {
self.check_thread(vm)?;
self.db.lock().take();
self.drop_db();
Ok(())
}

Expand Down Expand Up @@ -1446,15 +1481,17 @@ mod _sqlite {
}

fn check_thread(&self, vm: &VirtualMachine) -> PyResult<()> {
if self.check_same_thread && (std::thread::current().id() != self.thread_ident) {
Err(new_programming_error(
vm,
"SQLite objects created in a thread can only be used in that same thread."
.to_owned(),
))
} else {
Ok(())
if self.check_same_thread.load(Ordering::Relaxed) {
let creator_id = *self.thread_ident.lock();
if std::thread::current().id() != creator_id {
return Err(new_programming_error(
vm,
"SQLite objects created in a thread can only be used in that same thread."
.to_owned(),
));
}
}
Ok(())
}

#[pygetset]
Expand Down Expand Up @@ -1628,7 +1665,8 @@ mod _sqlite {

inner.row_cast_map = zelf.build_row_cast_map(&st, vm)?;

inner.description = st.columns_description(zelf.connection.detect_types, vm)?;
let detect_types = zelf.connection.detect_types.load(Ordering::Relaxed);
inner.description = st.columns_description(detect_types, vm)?;

if ret == SQLITE_ROW {
drop(st);
Expand Down Expand Up @@ -1676,7 +1714,8 @@ mod _sqlite {
));
}

inner.description = st.columns_description(zelf.connection.detect_types, vm)?;
let detect_types = zelf.connection.detect_types.load(Ordering::Relaxed);
inner.description = st.columns_description(detect_types, vm)?;

inner.rowcount = if stmt.is_dml { 0 } else { -1 };

Expand Down Expand Up @@ -1841,15 +1880,16 @@ mod _sqlite {
st: &SqliteStatementRaw,
vm: &VirtualMachine,
) -> PyResult<Vec<Option<PyObjectRef>>> {
if self.connection.detect_types == 0 {
let detect_types = self.connection.detect_types.load(Ordering::Relaxed);
if detect_types == 0 {
return Ok(vec![]);
}

let mut cast_map = vec![];
let num_cols = st.column_count();

for i in 0..num_cols {
if self.connection.detect_types & PARSE_COLNAMES != 0 {
if detect_types & PARSE_COLNAMES != 0 {
let col_name = st.column_name(i);
let col_name = ptr_to_str(col_name, vm)?;
let col_name = col_name
Expand All @@ -1864,7 +1904,7 @@ mod _sqlite {
continue;
}
}
if self.connection.detect_types & PARSE_DECLTYPES != 0 {
if detect_types & PARSE_DECLTYPES != 0 {
let decltype = st.column_decltype(i);
let decltype = ptr_to_str(decltype, vm)?;
if let Some(decltype) = decltype.split_terminator(&[' ', '(']).next() {
Expand Down
Loading