Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
slot inheritance
  • Loading branch information
youknowone committed May 26, 2022
commit 9ee5ac768518f4c408ff118d5faca2074efebb79
18 changes: 3 additions & 15 deletions vm/src/builtins/type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ impl PyType {
.map(|x| x.iter_mro().cloned().collect())
.collect();
let mro = linearise_mro(mros)?;
slots.inherits(&mro);
slots.inherits(&mro.iter().map(|t| -> &PyType { &t }).collect::<Vec<_>>());
debug_assert!(slots.hash.load().is_some(), "{}", name);

if base.slots.flags.has_feature(PyTypeFlags::HAS_DICT) {
slots.flags |= PyTypeFlags::HAS_DICT
Expand Down Expand Up @@ -142,19 +143,6 @@ impl PyType {
std::iter::once(self).chain(self.mro.iter().map(|cls| -> &PyType { cls }))
}

pub(crate) fn mro_find_map<F, R>(&self, f: F) -> Option<R>
where
F: Fn(&Self) -> Option<R>,
{
// the hot path will be primitive types which usually hit the result from itself.
// try std::intrinsics::likely once it is stablized
if let Some(r) = f(self) {
Some(r)
} else {
self.mro.iter().find_map(|cls| f(cls))
}
}

// This is used for class initialisation where the vm is not yet available.
pub fn set_str_attr<V: Into<PyObjectRef>>(&self, attr_name: &str, value: V) {
self._set_str_attr(attr_name, value.into())
Expand Down Expand Up @@ -717,7 +705,7 @@ impl Callable for PyType {
return Ok(obj);
}

if let Some(init_method) = obj.class().mro_find_map(|cls| cls.slots.init.load()) {
if let Some(init_method) = obj.class().slots.init.load() {
init_method(obj.clone(), args, vm)?;
}
Ok(obj)
Expand Down
2 changes: 1 addition & 1 deletion vm/src/function/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ where
let iterfn;
{
let cls = obj.class();
iterfn = cls.mro_find_map(|x| x.slots.iter.load());
iterfn = cls.slots.iter.load();
if iterfn.is_none() && !cls.has_attr("__getitem__") {
return Err(vm.new_type_error(format!("'{}' object is not iterable", cls.name())));
}
Expand Down
13 changes: 8 additions & 5 deletions vm/src/object/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,7 @@ impl PyObject {
}

// CPython-compatible drop implementation
if let Some(slot_del) = self.class().mro_find_map(|cls| cls.slots.del.load()) {
if let Some(slot_del) = self.class().slots.del.load() {
call_slot_del(self, slot_del)?;
}
if let Some(wrl) = self.weak_ref_list() {
Expand Down Expand Up @@ -1081,22 +1081,25 @@ pub(crate) fn init_type_hierarchy() -> (PyTypeRef, PyTypeRef, PyTypeRef) {
static_assertions::assert_eq_size!(MaybeUninit<PyInner<PyType>>, PyInner<PyType>);
static_assertions::assert_eq_align!(MaybeUninit<PyInner<PyType>>, PyInner<PyType>);

let type_payload = PyType {
let object_payload = PyType {
base: None,
bases: vec![],
mro: vec![],
subclasses: PyRwLock::default(),
attributes: PyRwLock::new(Default::default()),
slots: PyType::make_slots(),
slots: object::PyBaseObject::make_slots(),
};
let object_payload = PyType {
let mut type_slots = PyType::make_slots();
type_slots.inherits(&[&object_payload]);
let type_payload = PyType {
base: None,
bases: vec![],
mro: vec![],
subclasses: PyRwLock::default(),
attributes: PyRwLock::new(Default::default()),
slots: object::PyBaseObject::make_slots(),
slots: type_slots,
};

let type_type_ptr = Box::into_raw(Box::new(partially_init!(
PyInner::<PyType> {
ref_count: RefCount::new(),
Expand Down
2 changes: 1 addition & 1 deletion vm/src/protocol/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ impl PyBuffer {
impl TryFromBorrowedObject for PyBuffer {
fn try_from_borrowed_object(vm: &VirtualMachine, obj: &PyObject) -> PyResult<Self> {
let cls = obj.class();
if let Some(f) = cls.mro_find_map(|cls| cls.slots.as_buffer) {
if let Some(f) = cls.slots.as_buffer {
return f(obj, vm);
}
Err(vm.new_type_error(format!(
Expand Down
10 changes: 5 additions & 5 deletions vm/src/protocol/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ where

impl PyIter<PyObjectRef> {
pub fn check(obj: &PyObject) -> bool {
obj.class()
.mro_find_map(|x| x.slots.iternext.load())
.is_some()
obj.class().slots.iternext.load().is_some()
}
}

Expand All @@ -34,7 +32,9 @@ where
self.0
.borrow()
.class()
.mro_find_map(|x| x.slots.iternext.load())
.slots
.iternext
.load()
.ok_or_else(|| {
vm.new_type_error(format!(
"'{}' object is not an iterator",
Expand Down Expand Up @@ -120,7 +120,7 @@ impl TryFromObject for PyIter<PyObjectRef> {
fn try_from_object(vm: &VirtualMachine, iter_target: PyObjectRef) -> PyResult<Self> {
let getiter = {
let cls = iter_target.class();
cls.mro_find_map(|x| x.slots.iter.load())
cls.slots.iter.load()
};
if let Some(getiter) = getiter {
let iter = getiter(iter_target, vm)?;
Expand Down
6 changes: 1 addition & 5 deletions vm/src/protocol/mapping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,7 @@ impl PyMapping<'_> {

pub fn methods(&self, vm: &VirtualMachine) -> &PyMappingMethods {
self.methods.get_or_init(|| {
if let Some(f) = self
.obj
.class()
.mro_find_map(|cls| cls.slots.as_mapping.load())
{
if let Some(f) = self.obj.class().slots.as_mapping.load() {
f(self.obj, vm)
} else {
PyMappingMethods::default()
Expand Down
38 changes: 14 additions & 24 deletions vm/src/protocol/object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,7 @@ impl PyObject {
#[inline]
fn _get_attr(&self, attr_name: PyStrRef, vm: &VirtualMachine) -> PyResult {
vm_trace!("object.__getattribute__: {:?} {:?}", obj, attr_name);
let getattro = self
.class()
.mro_find_map(|cls| cls.slots.getattro.load())
.unwrap();
let getattro = self.class().slots.getattro.load().unwrap();
getattro(self, attr_name.clone(), vm).map_err(|exc| {
vm.set_attribute_error_context(&exc, self.to_owned(), attr_name);
exc
Expand All @@ -108,18 +105,17 @@ impl PyObject {
) -> PyResult<()> {
let setattro = {
let cls = self.class();
cls.mro_find_map(|cls| cls.slots.setattro.load())
.ok_or_else(|| {
let assign = attr_value.is_some();
let has_getattr = cls.mro_find_map(|cls| cls.slots.getattro.load()).is_some();
vm.new_type_error(format!(
"'{}' object has {} attributes ({} {})",
cls.name(),
if has_getattr { "only read-only" } else { "no" },
if assign { "assign to" } else { "del" },
attr_name
))
})?
cls.slots.setattro.load().ok_or_else(|| {
let assign = attr_value.is_some();
let has_getattr = cls.slots.getattro.load().is_some();
vm.new_type_error(format!(
"'{}' object has {} attributes ({} {})",
cls.name(),
if has_getattr { "only read-only" } else { "no" },
if assign { "assign to" } else { "del" },
attr_name
))
})?
};
setattro(self, attr_name, attr_value, vm)
}
Expand Down Expand Up @@ -251,10 +247,7 @@ impl PyObject {
) -> PyResult<Either<PyObjectRef, bool>> {
let swapped = op.swapped();
let call_cmp = |obj: &PyObject, other: &PyObject, op| {
let cmp = obj
.class()
.mro_find_map(|cls| cls.slots.richcompare.load())
.unwrap();
let cmp = obj.class().slots.richcompare.load().unwrap();
let r = match cmp(obj, other, op, vm)? {
Either::A(obj) => PyArithmeticValue::from_object(vm, obj).map(Either::A),
Either::B(arithmetic) => arithmetic.map(Either::B),
Expand Down Expand Up @@ -493,10 +486,7 @@ impl PyObject {
}

pub fn hash(&self, vm: &VirtualMachine) -> PyResult<PyHash> {
let hash = self
.class()
.mro_find_map(|cls| cls.slots.hash.load())
.unwrap(); // hash always exist
let hash = self.class().slots.hash.load().unwrap(); // hash always exist
hash(self, vm)
}

Expand Down
42 changes: 26 additions & 16 deletions vm/src/types/slot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,19 @@ impl PyTypeSlots {
}
}

pub fn inherits(&mut self, mro: &[PyTypeRef]) {
pub fn inherits(&mut self, mro: &[&PyType]) {
macro_rules! inherit {
($name:ident) => {
if self.$name.is_none() {
for ty in mro {
if let Some(func) = ty.slots.$name {
self.$name = Some(func);
break;
}
}
}
};
($name:ident, "atomic") => {
if self.$name.load().is_none() {
for ty in mro {
if let Some(func) = ty.slots.$name.load() {
Expand All @@ -98,21 +108,21 @@ impl PyTypeSlots {
}
};
}
inherit!(as_sequence);
inherit!(as_mapping);
inherit!(hash);
inherit!(call);
inherit!(getattro);
inherit!(setattro);
// inherit!(as_buffer);
inherit!(richcompare);
inherit!(iter);
inherit!(iternext);
inherit!(descr_get);
inherit!(descr_set);
inherit!(init);
inherit!(new);
inherit!(del);
inherit!(as_sequence, "atomic");
inherit!(as_mapping, "atomic");
inherit!(hash, "atomic");
inherit!(call, "atomic");
inherit!(getattro, "atomic");
inherit!(setattro, "atomic");
inherit!(as_buffer);
inherit!(richcompare, "atomic");
inherit!(iter, "atomic");
inherit!(iternext, "atomic");
inherit!(descr_get, "atomic");
inherit!(descr_set, "atomic");
inherit!(init, "atomic");
inherit!(new, "atomic");
inherit!(del, "atomic");
}
}

Expand Down
2 changes: 1 addition & 1 deletion vm/src/vm/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub enum PyMethod {
impl PyMethod {
pub fn get(obj: PyObjectRef, name: PyStrRef, vm: &VirtualMachine) -> PyResult<Self> {
let cls = obj.class();
let getattro = cls.mro_find_map(|cls| cls.slots.getattro.load()).unwrap();
let getattro = cls.slots.getattro.load().unwrap();
if getattro as usize != PyBaseObject::getattro as usize {
drop(cls);
return obj.get_attr(name, vm).map(Self::Attribute);
Expand Down
4 changes: 1 addition & 3 deletions vm/src/vm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -579,9 +579,7 @@ impl VirtualMachine {
}

pub fn is_callable(&self, obj: &PyObject) -> bool {
obj.class()
.mro_find_map(|cls| cls.slots.call.load())
.is_some()
obj.class().slots.call.load().is_some()
}

#[inline]
Expand Down
2 changes: 1 addition & 1 deletion vm/src/vm/vm_object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ impl VirtualMachine {

fn _invoke(&self, callable: &PyObject, args: FuncArgs) -> PyResult {
vm_trace!("Invoke: {:?} {:?}", callable, args);
let slot_call = callable.class().mro_find_map(|cls| cls.slots.call.load());
let slot_call = callable.class().slots.call.load();
match slot_call {
Some(slot_call) => {
self.trace_event(TraceEvent::Call)?;
Expand Down