Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
downcastable_from
  • Loading branch information
youknowone committed Jul 30, 2025
commit 053cfeecce89148cded2106d9eb9fe9f71699139
264 changes: 215 additions & 49 deletions vm/src/builtins/str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::{
format::{format, format_map},
function::{ArgIterable, ArgSize, FuncArgs, OptionalArg, OptionalOption, PyComparisonValue},
intern::PyInterned,
object::{Traverse, TraverseFn},
object::{MaybeTraverse, Traverse, TraverseFn},
protocol::{PyIterReturn, PyMappingMethods, PyNumberMethods, PySequenceMethods},
sequence::SequenceExt,
sliceable::{SequenceIndex, SliceableSequenceOp},
Expand Down Expand Up @@ -64,6 +64,9 @@ impl<'a> TryFromBorrowedObject<'a> for &'a Wtf8 {
}
}

pub type PyStrRef = PyRef<PyStr>;
pub type PyUtf8StrRef = PyRef<PyUtf8Str>;

#[pyclass(module = false, name = "str")]
pub struct PyStr {
data: StrData,
Expand All @@ -80,30 +83,6 @@ impl fmt::Debug for PyStr {
}
}

#[repr(transparent)]
#[derive(Debug)]
pub struct PyUtf8Str(PyStr);

// TODO: Remove this Deref which may hide missing optimized methods of PyUtf8Str
impl std::ops::Deref for PyUtf8Str {
type Target = PyStr;
fn deref(&self) -> &Self::Target {
&self.0
}
}

impl PyUtf8Str {
/// Returns the underlying string slice.
pub fn as_str(&self) -> &str {
debug_assert!(
self.0.is_utf8(),
"PyUtf8Str invariant violated: inner string is not valid UTF-8"
);
// Safety: This is safe because the type invariant guarantees UTF-8 validity.
unsafe { self.0.to_str().unwrap_unchecked() }
}
}

impl AsRef<str> for PyStr {
#[track_caller] // <- can remove this once it doesn't panic
fn as_ref(&self) -> &str {
Expand Down Expand Up @@ -241,8 +220,6 @@ impl Default for PyStr {
}
}

pub type PyStrRef = PyRef<PyStr>;

impl fmt::Display for PyStr {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
Expand Down Expand Up @@ -374,7 +351,7 @@ impl Constructor for PyStr {
type Args = StrArgs;

fn py_new(cls: PyTypeRef, args: Self::Args, vm: &VirtualMachine) -> PyResult {
let string: PyStrRef = match args.object {
let string: PyRef<PyStr> = match args.object {
OptionalArg::Present(input) => {
if let OptionalArg::Present(enc) = args.encoding {
vm.state.codec_registry.decode_text(
Expand Down Expand Up @@ -458,7 +435,7 @@ impl PyStr {
self.data.as_str()
}

pub fn ensure_valid_utf8(&self, vm: &VirtualMachine) -> PyResult<()> {
fn ensure_valid_utf8(&self, vm: &VirtualMachine) -> PyResult<()> {
if self.is_utf8() {
Ok(())
} else {
Expand Down Expand Up @@ -531,6 +508,22 @@ impl PyStr {
.mul(vm, value)
.map(|x| Self::from(unsafe { Wtf8Buf::from_bytes_unchecked(x) }).into_ref(&vm.ctx))
}

pub fn try_as_utf8<'a>(&'a self, vm: &VirtualMachine) -> PyResult<&'a PyUtf8Str> {
// Check if the string contains surrogates
self.ensure_valid_utf8(vm)?;
// If no surrogates, we can safely cast to PyStr
Ok(unsafe { &*(self as *const _ as *const PyUtf8Str) })
}
}

impl Py<PyStr> {
pub fn try_as_utf8<'a>(&'a self, vm: &VirtualMachine) -> PyResult<&'a Py<PyUtf8Str>> {
// Check if the string contains surrogates
self.ensure_valid_utf8(vm)?;
// If no surrogates, we can safely cast to PyStr
Ok(unsafe { &*(self as *const _ as *const Py<PyUtf8Str>) })
}
}

#[pyclass(
Expand Down Expand Up @@ -980,7 +973,11 @@ impl PyStr {
}

#[pymethod(name = "__format__")]
fn __format__(zelf: PyRef<Self>, spec: PyStrRef, vm: &VirtualMachine) -> PyResult<PyStrRef> {
fn __format__(
zelf: PyRef<PyStr>,
spec: PyStrRef,
vm: &VirtualMachine,
) -> PyResult<PyRef<PyStr>> {
let spec = spec.as_str();
if spec.is_empty() {
return if zelf.class().is(vm.ctx.types.str_type) {
Expand All @@ -989,7 +986,7 @@ impl PyStr {
zelf.as_object().str(vm)
};
}

let zelf = zelf.try_into_utf8(vm)?;
let s = FormatSpec::parse(spec)
.and_then(|format_spec| {
format_spec.format_string(&CharLenStr(zelf.as_str(), zelf.char_len()))
Expand Down Expand Up @@ -1351,8 +1348,12 @@ impl PyStr {
}

#[pymethod]
fn expandtabs(&self, args: anystr::ExpandTabsArgs) -> String {
rustpython_common::str::expandtabs(self.as_str(), args.tabsize())
fn expandtabs(&self, args: anystr::ExpandTabsArgs, vm: &VirtualMachine) -> PyResult<String> {
// TODO: support WTF-8
Ok(rustpython_common::str::expandtabs(
self.try_as_utf8(vm)?.as_str(),
args.tabsize(),
))
}

#[pymethod]
Expand Down Expand Up @@ -1480,20 +1481,6 @@ impl PyStr {
}
}

struct CharLenStr<'a>(&'a str, usize);
impl std::ops::Deref for CharLenStr<'_> {
type Target = str;

fn deref(&self) -> &Self::Target {
self.0
}
}
impl crate::common::format::CharLen for CharLenStr<'_> {
fn char_len(&self) -> usize {
self.1
}
}

#[pyclass]
impl PyRef<PyStr> {
#[pymethod]
Expand All @@ -1504,7 +1491,7 @@ impl PyRef<PyStr> {
}
}

impl PyStrRef {
impl PyRef<PyStr> {
pub fn is_empty(&self) -> bool {
(**self).is_empty()
}
Expand All @@ -1526,6 +1513,20 @@ impl PyStrRef {
}
}

struct CharLenStr<'a>(&'a str, usize);
impl std::ops::Deref for CharLenStr<'_> {
type Target = str;

fn deref(&self) -> &Self::Target {
self.0
}
}
impl crate::common::format::CharLen for CharLenStr<'_> {
fn char_len(&self) -> usize {
self.1
}
}

impl Representable for PyStr {
#[inline]
fn repr_str(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<String> {
Expand Down Expand Up @@ -1941,6 +1942,170 @@ impl AnyStrWrapper<AsciiStr> for PyStrRef {
}
}

#[repr(transparent)]
#[derive(Debug)]
pub struct PyUtf8Str(PyStr);

impl fmt::Display for PyUtf8Str {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}

impl MaybeTraverse for PyUtf8Str {
fn try_traverse(&self, traverse_fn: &mut TraverseFn<'_>) {
self.0.try_traverse(traverse_fn);
}
}

impl PyPayload for PyUtf8Str {
#[inline]
fn class(ctx: &Context) -> &'static Py<PyType> {
ctx.types.str_type
}

fn payload_type_id() -> std::any::TypeId {
std::any::TypeId::of::<PyStr>()
}

fn downcastable_from(obj: &PyObject) -> bool {
obj.typeid() == Self::payload_type_id() && {
// SAFETY: we know the object is a PyStr in this context
let wtf8 = unsafe { obj.downcast_unchecked_ref::<PyStr>() };
wtf8.is_utf8()
}
}

fn try_downcast_from(obj: &PyObject, vm: &VirtualMachine) -> PyResult<()> {
let str = obj.try_downcast_ref::<PyStr>(vm)?;
str.ensure_valid_utf8(vm)
}
}

impl<'a> From<&'a AsciiStr> for PyUtf8Str {
fn from(s: &'a AsciiStr) -> Self {
s.to_owned().into()
}
}

impl From<AsciiString> for PyUtf8Str {
fn from(s: AsciiString) -> Self {
s.into_boxed_ascii_str().into()
}
}

impl From<Box<AsciiStr>> for PyUtf8Str {
fn from(s: Box<AsciiStr>) -> Self {
let data = StrData::from(s);
unsafe { Self::from_str_data_unchecked(data) }
}
}

impl From<AsciiChar> for PyUtf8Str {
fn from(ch: AsciiChar) -> Self {
AsciiString::from(ch).into()
}
}

impl<'a> From<&'a str> for PyUtf8Str {
fn from(s: &'a str) -> Self {
s.to_owned().into()
}
}

impl From<String> for PyUtf8Str {
fn from(s: String) -> Self {
s.into_boxed_str().into()
}
}

impl From<char> for PyUtf8Str {
fn from(ch: char) -> Self {
let data = StrData::from(ch);
unsafe { Self::from_str_data_unchecked(data) }
}
}

impl<'a> From<std::borrow::Cow<'a, str>> for PyUtf8Str {
fn from(s: std::borrow::Cow<'a, str>) -> Self {
s.into_owned().into()
}
}

impl From<Box<str>> for PyUtf8Str {
#[inline]
fn from(value: Box<str>) -> Self {
let data = StrData::from(value);
unsafe { Self::from_str_data_unchecked(data) }
}
}

impl AsRef<Wtf8> for PyUtf8Str {
#[inline]
fn as_ref(&self) -> &Wtf8 {
self.0.as_wtf8()
}
}

impl AsRef<str> for PyUtf8Str {
#[inline]
fn as_ref(&self) -> &str {
self.0.as_str()
}
}

impl PyUtf8Str {
// Create a new `PyUtf8Str` from `StrData` without validation.
// This function must be only used in this module to create conversions.
// # Safety: must be called with a valid UTF-8 string data.
unsafe fn from_str_data_unchecked(data: StrData) -> Self {
Self(PyStr::from(data))
}

/// Returns the underlying string slice.
pub fn as_str(&self) -> &str {
debug_assert!(
self.0.is_utf8(),
"PyUtf8Str invariant violated: inner string is not valid UTF-8"
);
// Safety: This is safe because the type invariant guarantees UTF-8 validity.
unsafe { self.0.to_str().unwrap_unchecked() }
}

#[inline]
pub fn byte_len(&self) -> usize {
self.0.byte_len()
}

#[inline]
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}

#[inline]
pub fn char_len(&self) -> usize {
self.0.char_len()
}
}

impl Py<PyUtf8Str> {
/// Upcast to PyStr.
pub fn as_pystr(&self) -> &Py<PyStr> {
unsafe {
// Safety: PyUtf8Str is a wrapper around PyStr, so this cast is safe.
&*(self as *const Self as *const Py<PyStr>)
}
}
}

impl PartialEq for PyUtf8Str {
fn eq(&self, other: &Self) -> bool {
self.as_str() == other.as_str()
}
}
impl Eq for PyUtf8Str {}

impl AnyStrContainer<str> for String {
fn new() -> Self {
Self::new()
Expand Down Expand Up @@ -2302,7 +2467,8 @@ impl std::fmt::Display for PyStrInterned {
impl AsRef<str> for PyStrInterned {
#[inline(always)]
fn as_ref(&self) -> &str {
self.as_str()
self.to_str()
.expect("Interned PyStr should always be valid UTF-8")
}
Comment on lines +2464 to 2466
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Verify UTF-8 validity is enforced at string interning time

The expect assumes interned strings are always valid UTF-8. This invariant should be enforced when strings are interned.


🏁 Script executed:

#!/bin/bash
# Search for string interning implementation to verify UTF-8 validation
ast-grep --pattern 'impl $_ {
  $$$
  fn intern($$$) {
    $$$
  }
  $$$
}'

# Also search for intern_str or similar methods
rg -A 10 "intern_str|new_interned" --type rust

Length of output: 35264


🏁 Script executed:

#!/bin/bash
# Locate string_pool usage in Context
rg -n "string_pool" vm/src/vm/context.rs

# Show beginning of context.rs to inspect field definitions
sed -n '1,200p' vm/src/vm/context.rs

# Search for StringPool type and its implementation
rg -n "struct StringPool" -n vm/src/vm
rg -n "impl StringPool" -n vm/src/vm

Length of output: 5090


🏁 Script executed:

#!/bin/bash
# Locate the intern module file
rg --files-with-matches "InternableString" -t rust

# Inspect its contents to find StringPool implementation
FILE=$(rg --files-with-matches "InternableString" -t rust | head -n1)
echo "=== Showing top of $FILE ==="
sed -n '1,200p' "$FILE"

# Search for StringPool definition and intern method
rg -n "pub struct StringPool" -n "$FILE"
rg -n "impl StringPool" -n "$FILE"

Length of output: 5331


🏁 Script executed:

#!/bin/bash
# Find InternableString and MaybeInternedString definitions in intern.rs
rg -n "trait InternableString" -n vm/src/intern.rs
rg -n "trait MaybeInternedString" -n vm/src/intern.rs

# Show the definitions around those lines
sed -n '250,350p' vm/src/intern.rs

Length of output: 2393


Enforce UTF-8 validity when interning strings
The StringPool::intern implementation accepts Wtf8Buf (which may contain unpaired surrogates) without any check, so the later

self.to_str().expect("Interned PyStr should always be valid UTF-8")

in vm/src/builtins/str.rs can still panic. We need to validate once at interning time to uphold the invariant.

• In vm/src/intern.rs, inside impl StringPool { pub unsafe fn intern<…> }, right after you get back the PyRefExact<PyStr> (e.g. before calling miss), add:

     let str_ref = s.into_pyref_exact(typ);
+    // Validate that the underlying WTF-8 data is valid UTF-8 before inserting
+    let w = str_ref.as_wtf8();
+    debug_assert!(
+        w.to_str().is_some(),
+        "Tried to intern invalid UTF-8 string: {:?}",
+        w
+    );

This guarantees that any panic in to_str().expect(…) truly reflects a broken invariant, not a late check.

🤖 Prompt for AI Agents
In vm/src/intern.rs inside the unsafe fn intern method, after obtaining the
PyRefExact<PyStr> and before calling miss, add a validation step to check that
the interned string is valid UTF-8. This ensures the invariant that interned
PyStr is always valid UTF-8 is enforced at interning time, preventing the later
to_str().expect(...) call in vm/src/builtins/str.rs from panicking unexpectedly.
Implement this by attempting to convert the interned string to &str and handling
any invalid UTF-8 by panicking or returning an error immediately.

}

Expand Down
10 changes: 5 additions & 5 deletions vm/src/convert/try_from.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,12 @@ where
#[inline]
fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult<Self> {
let class = T::class(&vm.ctx);
let result = if obj.fast_isinstance(class) {
obj.downcast()
if obj.fast_isinstance(class) {
T::try_downcast_from(&obj, vm)?;
Ok(unsafe { obj.downcast_unchecked() })
} else {
Err(obj)
};
result.map_err(|obj| vm.new_downcast_type_error(class, &obj))
Err(vm.new_downcast_type_error(class, &obj))
}
}
}

Expand Down
Loading