From e66b5078a887a622e48ced73dfaa8391bd1bdea1 Mon Sep 17 00:00:00 2001 From: Adam Kelly Date: Wed, 20 Mar 2019 14:34:41 +0000 Subject: [PATCH] Introduce TryIntoRef to make vm.get_attribute more usable. --- vm/src/builtins.rs | 7 +++---- vm/src/exceptions.rs | 8 +++----- vm/src/frame.rs | 6 +++--- vm/src/obj/objstr.rs | 18 ++++++++++++++++-- vm/src/pyobject.rs | 21 +++++++++++++++++++++ vm/src/stdlib/dis.rs | 3 +-- vm/src/vm.rs | 12 ++++++++---- 7 files changed, 55 insertions(+), 20 deletions(-) diff --git a/vm/src/builtins.rs b/vm/src/builtins.rs index eb06446f42f..55aa2a75ecb 100644 --- a/vm/src/builtins.rs +++ b/vm/src/builtins.rs @@ -310,7 +310,7 @@ fn builtin_getattr( default: OptionalArg, vm: &mut VirtualMachine, ) -> PyResult { - let ret = vm.get_attribute(obj.clone(), attr.into_object()); + let ret = vm.get_attribute(obj.clone(), attr); if let OptionalArg::Present(default) = default { ret.or_else(|ex| catch_attr_exception(ex, default, vm)) } else { @@ -323,7 +323,7 @@ fn builtin_globals(vm: &mut VirtualMachine, _args: PyFuncArgs) -> PyResult { } fn builtin_hasattr(obj: PyObjectRef, attr: PyStringRef, vm: &mut VirtualMachine) -> PyResult { - if let Err(ex) = vm.get_attribute(obj.clone(), attr.into_object()) { + if let Err(ex) = vm.get_attribute(obj.clone(), attr) { catch_attr_exception(ex, false, vm) } else { Ok(true) @@ -832,8 +832,7 @@ pub fn builtin_build_class_(vm: &mut VirtualMachine, mut args: PyFuncArgs) -> Py let bases = vm.context().new_tuple(bases); // Prepare uses full __getattribute__ resolution chain. - let prepare_name = vm.new_str("__prepare__".to_string()); - let prepare = vm.get_attribute(metaclass.clone(), prepare_name)?; + let prepare = vm.get_attribute(metaclass.clone(), "__prepare__")?; let namespace = vm.invoke(prepare, vec![name_arg.clone(), bases.clone()])?; let cells = vm.new_dict(); diff --git a/vm/src/exceptions.rs b/vm/src/exceptions.rs index a4cd9144c71..68f4105c5eb 100644 --- a/vm/src/exceptions.rs +++ b/vm/src/exceptions.rs @@ -1,9 +1,7 @@ use crate::function::PyFuncArgs; use crate::obj::objsequence; use crate::obj::objtype; -use crate::pyobject::{ - create_type, AttributeProtocol, PyContext, PyObjectRef, PyResult, TypeProtocol, -}; +use crate::pyobject::{create_type, PyContext, PyObjectRef, PyResult, TypeProtocol}; use crate::vm::VirtualMachine; fn exception_init(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { @@ -21,7 +19,7 @@ fn exception_init(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { // Print exception including traceback: pub fn print_exception(vm: &mut VirtualMachine, exc: &PyObjectRef) { - if let Some(tb) = exc.get_attr("__traceback__") { + if let Ok(tb) = vm.get_attribute(exc.clone(), "__traceback__") { println!("Traceback (most recent call last):"); if objtype::isinstance(&tb, &vm.ctx.list_type()) { let mut elements = objsequence::get_elements(&tb).to_vec(); @@ -70,7 +68,7 @@ fn exception_str(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { required = [(exc, Some(vm.ctx.exceptions.exception_type.clone()))] ); let type_name = objtype::get_type_name(&exc.typ()); - let msg = if let Some(m) = exc.get_attr("msg") { + let msg = if let Ok(m) = vm.get_attribute(exc.clone(), "msg") { match vm.to_pystr(&m) { Ok(msg) => msg, _ => "".to_string(), diff --git a/vm/src/frame.rs b/vm/src/frame.rs index 6f8653f8812..1588699456c 100644 --- a/vm/src/frame.rs +++ b/vm/src/frame.rs @@ -245,8 +245,9 @@ impl Frame { &exception, &vm.ctx.exceptions.base_exception_type )); - let traceback_name = vm.new_str("__traceback__".to_string()); - let traceback = vm.get_attribute(exception.clone(), traceback_name).unwrap(); + let traceback = vm + .get_attribute(exception.clone(), "__traceback__") + .unwrap(); trace!("Adding to traceback: {:?} {:?}", traceback, lineno); let pos = vm.ctx.new_tuple(vec![ vm.ctx.new_str(filename.clone()), @@ -1137,7 +1138,6 @@ impl Frame { fn load_attr(&self, vm: &mut VirtualMachine, attr_name: &str) -> FrameResult { let parent = self.pop_value(); - let attr_name = vm.new_str(attr_name.to_string()); let obj = vm.get_attribute(parent, attr_name)?; self.push_value(obj); Ok(None) diff --git a/vm/src/obj/objstr.rs b/vm/src/obj/objstr.rs index 8569280ef97..d1ffb065bda 100644 --- a/vm/src/obj/objstr.rs +++ b/vm/src/obj/objstr.rs @@ -10,7 +10,7 @@ use crate::format::{FormatParseError, FormatPart, FormatString}; use crate::function::{OptionalArg, PyFuncArgs}; use crate::pyobject::{ IdProtocol, IntoPyObject, PyContext, PyIterable, PyObjectRef, PyRef, PyResult, PyValue, - TryFromObject, TypeProtocol, + TryFromObject, TryIntoRef, TypeProtocol, }; use crate::vm::VirtualMachine; @@ -24,6 +24,7 @@ pub struct PyString { // TODO: shouldn't be public pub value: String, } +pub type PyStringRef = PyRef; impl From for PyString { fn from(t: T) -> PyString { @@ -33,7 +34,20 @@ impl From for PyString { } } -pub type PyStringRef = PyRef; +impl TryIntoRef for String { + fn try_into_ref(self, vm: &mut VirtualMachine) -> PyResult> { + Ok(PyString { value: self }.into_ref(vm)) + } +} + +impl TryIntoRef for &str { + fn try_into_ref(self, vm: &mut VirtualMachine) -> PyResult> { + Ok(PyString { + value: self.to_string(), + } + .into_ref(vm)) + } +} impl PyStringRef { fn add(self, rhs: PyObjectRef, vm: &mut VirtualMachine) -> PyResult { diff --git a/vm/src/pyobject.rs b/vm/src/pyobject.rs index cd1ca77e664..8ee49b57f0f 100644 --- a/vm/src/pyobject.rs +++ b/vm/src/pyobject.rs @@ -1044,6 +1044,27 @@ impl TryFromObject for Option { } } +/// Allows coercion of a types into PyRefs, so that we can write functions that can take +/// refs, pyobject refs or basic types. +pub trait TryIntoRef { + fn try_into_ref(self, vm: &mut VirtualMachine) -> PyResult>; +} + +impl TryIntoRef for PyRef { + fn try_into_ref(self, _vm: &mut VirtualMachine) -> PyResult> { + Ok(self) + } +} + +impl TryIntoRef for PyObjectRef +where + T: PyValue, +{ + fn try_into_ref(self, vm: &mut VirtualMachine) -> PyResult> { + TryFromObject::try_from_object(vm, self) + } +} + /// Implemented by any type that can be created from a Python object. /// /// Any type that implements `TryFromObject` is automatically `FromArgs`, and diff --git a/vm/src/stdlib/dis.rs b/vm/src/stdlib/dis.rs index ccc8930ba07..f583cc5a8df 100644 --- a/vm/src/stdlib/dis.rs +++ b/vm/src/stdlib/dis.rs @@ -7,8 +7,7 @@ fn dis_dis(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(obj, None)]); // Method or function: - let code_name = vm.new_str("__code__".to_string()); - if let Ok(co) = vm.get_attribute(obj.clone(), code_name) { + if let Ok(co) = vm.get_attribute(obj.clone(), "__code__") { return dis_disassemble(vm, PyFuncArgs::new(vec![co], vec![])); } diff --git a/vm/src/vm.rs b/vm/src/vm.rs index f78d657ccf9..7f9b474d0af 100644 --- a/vm/src/vm.rs +++ b/vm/src/vm.rs @@ -24,12 +24,12 @@ use crate::obj::objgenerator; use crate::obj::objiter; use crate::obj::objlist::PyList; use crate::obj::objsequence; -use crate::obj::objstr::PyStringRef; +use crate::obj::objstr::{PyString, PyStringRef}; use crate::obj::objtuple::PyTuple; use crate::obj::objtype; use crate::pyobject::{ AttributeProtocol, DictProtocol, IdProtocol, PyContext, PyObjectRef, PyResult, TryFromObject, - TypeProtocol, + TryIntoRef, TypeProtocol, }; use crate::stdlib; use crate::sysmodule; @@ -549,9 +549,13 @@ impl VirtualMachine { } // get_attribute should be used for full attribute access (usually from user code). - pub fn get_attribute(&mut self, obj: PyObjectRef, attr_name: PyObjectRef) -> PyResult { + pub fn get_attribute(&mut self, obj: PyObjectRef, attr_name: T) -> PyResult + where + T: TryIntoRef, + { + let attr_name = attr_name.try_into_ref(self)?; trace!("vm.__getattribute__: {:?} {:?}", obj, attr_name); - self.call_method(&obj, "__getattribute__", vec![attr_name]) + self.call_method(&obj, "__getattribute__", vec![attr_name.into_object()]) } pub fn set_attr(