Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions Lib/test/test_sqlite3/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,6 @@ def test_null_character(self):
self.assertRaisesRegex(sqlite.ProgrammingError, "null char",
cur.execute, query)

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_surrogates(self):
con = sqlite.connect(":memory:")
self.assertRaises(UnicodeEncodeError, con, "select '\ud8ff'")
Expand Down
9 changes: 5 additions & 4 deletions stdlib/src/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -844,7 +844,7 @@ mod _sqlite {
type Args = (PyStrRef,);

fn call(zelf: &Py<Self>, args: Self::Args, vm: &VirtualMachine) -> PyResult {
if let Some(stmt) = Statement::new(zelf, &args.0, vm)? {
if let Some(stmt) = Statement::new(zelf, args.0, vm)? {
Ok(stmt.into_ref(&vm.ctx).into())
} else {
Ok(vm.ctx.none())
Expand Down Expand Up @@ -1480,7 +1480,7 @@ mod _sqlite {
stmt.lock().reset();
}

let Some(stmt) = Statement::new(&zelf.connection, &sql, vm)? else {
let Some(stmt) = Statement::new(&zelf.connection, sql, vm)? else {
drop(inner);
return Ok(zelf);
};
Expand Down Expand Up @@ -1552,7 +1552,7 @@ mod _sqlite {
stmt.lock().reset();
}

let Some(stmt) = Statement::new(&zelf.connection, &sql, vm)? else {
let Some(stmt) = Statement::new(&zelf.connection, sql, vm)? else {
drop(inner);
return Ok(zelf);
};
Expand Down Expand Up @@ -2291,9 +2291,10 @@ mod _sqlite {
impl Statement {
fn new(
connection: &Connection,
sql: &PyStr,
sql: PyStrRef,
vm: &VirtualMachine,
) -> PyResult<Option<Self>> {
let sql = sql.try_into_utf8(vm)?;
let sql_cstr = sql.to_cstring(vm)?;
let sql_len = sql.byte_len() + 1;

Expand Down
48 changes: 42 additions & 6 deletions vm/src/builtins/str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ use rustpython_common::{
str::DeduceStrKind,
wtf8::{CodePoint, Wtf8, Wtf8Buf, Wtf8Chunk},
};
use std::sync::LazyLock;
use std::{borrow::Cow, char, fmt, ops::Range};
use std::{mem, sync::LazyLock};
use unic_ucd_bidi::BidiClass;
use unic_ucd_category::GeneralCategory;
use unic_ucd_ident::{is_xid_continue, is_xid_start};
Expand Down Expand Up @@ -80,6 +80,29 @@ impl fmt::Debug for PyStr {
}
}

#[repr(transparent)]
#[derive(Debug)]
pub struct PyUtf8Str(PyStr);

impl std::ops::Deref for PyUtf8Str {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
impl std::ops::Deref for PyUtf8Str {
// TODO: Remove this Deref which may hide missing optimized methods of PyUtf8Str
impl std::ops::Deref for PyUtf8Str {

type Target = PyStr;
fn deref(&self) -> &Self::Target {
&self.0
}
}

impl PyUtf8Str {
/// Returns the underlying string slice. This is safe because the
/// type invariant guarantees UTF-8 validity.
pub fn as_str(&self) -> &str {
debug_assert!(
self.0.is_utf8(),
"PyUtf8Str invariant violated: inner string is not valid UTF-8"
);
unsafe { self.0.to_str().unwrap_unchecked() }
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The safety is trivial for users, but not for devs

Suggested change
/// Returns the underlying string slice. This is safe because the
/// type invariant guarantees UTF-8 validity.
pub fn as_str(&self) -> &str {
debug_assert!(
self.0.is_utf8(),
"PyUtf8Str invariant violated: inner string is not valid UTF-8"
);
unsafe { self.0.to_str().unwrap_unchecked() }
}
/// Returns the underlying string slice.
pub fn as_str(&self) -> &str {
debug_assert!(
self.0.is_utf8(),
"PyUtf8Str invariant violated: inner string is not valid UTF-8"
);
// Safety: This is safe because the type invariant guarantees UTF-8 validity.
unsafe { self.0.to_str().unwrap_unchecked() }
}

}

impl AsRef<str> for PyStr {
#[track_caller] // <- can remove this once it doesn't panic
fn as_ref(&self) -> &str {
Expand Down Expand Up @@ -433,21 +456,29 @@ impl PyStr {
self.data.as_str()
}

pub fn try_to_str(&self, vm: &VirtualMachine) -> PyResult<&str> {
self.to_str().ok_or_else(|| {
fn ensure_valid_utf8(&self, vm: &VirtualMachine) -> PyResult<()> {
if self.is_utf8() {
Ok(())
} else {
let start = self
.as_wtf8()
.code_points()
.position(|c| c.to_char().is_none())
.unwrap();
vm.new_unicode_encode_error_real(
Err(vm.new_unicode_encode_error_real(
identifier!(vm, utf_8).to_owned(),
vm.ctx.new_str(self.data.clone()),
start,
start + 1,
vm.ctx.new_str("surrogates not allowed"),
)
})
))
}
}

pub fn try_to_str(&self, vm: &VirtualMachine) -> PyResult<&str> {
self.ensure_valid_utf8(vm)?;
// SAFETY: ensure_valid_utf8 passed, so unwrap is safe.
Ok(unsafe { self.to_str().unwrap_unchecked() })
}

pub fn to_string_lossy(&self) -> Cow<'_, str> {
Expand Down Expand Up @@ -1486,6 +1517,11 @@ impl PyStrRef {
s.push_wtf8(other);
*self = PyStr::from(s).into_ref(&vm.ctx);
}

pub fn try_into_utf8(self, vm: &VirtualMachine) -> PyResult<PyRef<PyUtf8Str>> {
self.ensure_valid_utf8(vm)?;
Ok(unsafe { mem::transmute::<PyRef<PyStr>, PyRef<PyUtf8Str>>(self) })
}
}

impl Representable for PyStr {
Expand Down
Loading