Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Lib/test/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -2028,7 +2028,6 @@ def insert(self, index, value):
self.assertEqual(len(mss), len(mss2))
self.assertEqual(list(mss), list(mss2))

@unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: TypeError not raised
def test_illegal_patma_flags(self):
with self.assertRaises(TypeError):
class Both(Collection):
Expand Down
48 changes: 35 additions & 13 deletions crates/vm/src/builtins/type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -604,38 +604,60 @@ impl PyType {
attrs: &PyAttributes,
bases: &[PyRef<Self>],
ctx: &Context,
) {
) -> Result<(), String> {
const COLLECTION_FLAGS: PyTypeFlags = PyTypeFlags::from_bits_truncate(
PyTypeFlags::SEQUENCE.bits() | PyTypeFlags::MAPPING.bits(),
);

// Don't override if flags are already set
if slots.flags.intersects(COLLECTION_FLAGS) {
return;
}

// First check in our own attributes
// Always validate this class's own __abc_tpflags__ even when slot
// flags were already inherited, otherwise a child setting both
// Py_TPFLAGS_SEQUENCE and Py_TPFLAGS_MAPPING would slip through.
let abc_tpflags_name = ctx.intern_str("__abc_tpflags__");
if let Some(abc_tpflags_obj) = attrs.get(abc_tpflags_name)
&& let Some(int_obj) = abc_tpflags_obj.downcast_ref::<crate::builtins::int::PyInt>()
{
let flags_val = int_obj.as_bigint().to_i64().unwrap_or(0);
let abc_flags = PyTypeFlags::from_bits_truncate(flags_val as u64);
slots.flags |= abc_flags & COLLECTION_FLAGS;
return;
let masked = abc_flags & COLLECTION_FLAGS;
if masked == COLLECTION_FLAGS {
return Err(
"__abc_tpflags__ cannot be both Py_TPFLAGS_SEQUENCE and Py_TPFLAGS_MAPPING"
.to_owned(),
);
}
// Don't override flags already inherited from a base class.
if !slots.flags.intersects(COLLECTION_FLAGS) {
slots.flags |= masked;
}
return Ok(());
}

// No __abc_tpflags__ on this class — inheritance already happened
// in inherit_patma_flags, so nothing more to do if those bits are set.
if slots.flags.intersects(COLLECTION_FLAGS) {
return Ok(());
}

// Then check in base classes
// Then check in base classes (legacy path for cases that bypass
// inherit_patma_flags).
for base in bases {
if let Some(abc_tpflags_obj) = base.find_name_in_mro(abc_tpflags_name)
&& let Some(int_obj) = abc_tpflags_obj.downcast_ref::<crate::builtins::int::PyInt>()
{
let flags_val = int_obj.as_bigint().to_i64().unwrap_or(0);
let abc_flags = PyTypeFlags::from_bits_truncate(flags_val as u64);
slots.flags |= abc_flags & COLLECTION_FLAGS;
return;
let masked = abc_flags & COLLECTION_FLAGS;
if masked == COLLECTION_FLAGS {
return Err(
"__abc_tpflags__ cannot be both Py_TPFLAGS_SEQUENCE and Py_TPFLAGS_MAPPING"
.to_owned(),
);
}
slots.flags |= masked;
return Ok(());
}
}
Ok(())
}

#[allow(clippy::too_many_arguments)]
Expand Down Expand Up @@ -671,7 +693,7 @@ impl PyType {
Self::inherit_patma_flags(&mut slots, &bases);

// Check for __abc_tpflags__ from ABCMeta (for collections.abc.Sequence, Mapping, etc.)
Self::check_abc_tpflags(&mut slots, &attrs, &bases, ctx);
Self::check_abc_tpflags(&mut slots, &attrs, &bases, ctx)?;

if slots.basicsize == 0 {
slots.basicsize = base.slots.basicsize;
Expand Down
Loading