@@ -19,6 +19,7 @@ use super::objtype;
1919#[ derive( Debug ) ]
2020pub struct PySuper {
2121 obj : PyObjectRef ,
22+ typ : PyObjectRef ,
2223}
2324
2425impl PyValue for PySuper {
@@ -68,8 +69,9 @@ fn super_getattribute(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
6869 ) ;
6970
7071 let inst = super_obj. payload :: < PySuper > ( ) . unwrap ( ) . obj . clone ( ) ;
72+ let typ = super_obj. payload :: < PySuper > ( ) . unwrap ( ) . typ . clone ( ) ;
7173
72- match inst . typ ( ) . payload :: < PyClass > ( ) {
74+ match typ. payload :: < PyClass > ( ) {
7375 Some ( PyClass { ref mro, .. } ) => {
7476 for class in mro {
7577 if let Ok ( item) = vm. get_attribute ( class. as_object ( ) . clone ( ) , name_str. clone ( ) ) {
@@ -99,6 +101,29 @@ fn super_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
99101 return Err ( vm. new_type_error ( format ! ( "{:?} is not a subtype of super" , cls) ) ) ;
100102 }
101103
104+ // Get the type:
105+ let py_type = if let Some ( ty) = py_type {
106+ ty. clone ( )
107+ } else {
108+ match vm. current_scope ( ) . get ( "__class__" ) {
109+ Some ( obj) => obj. clone ( ) ,
110+ _ => {
111+ return Err ( vm. new_type_error (
112+ "super must be called with 1 argument or from inside class method" . to_string ( ) ,
113+ ) ) ;
114+ }
115+ }
116+ } ;
117+
118+ // Check type argument:
119+ if !objtype:: isinstance ( & py_type, & vm. get_type ( ) ) {
120+ let type_name = objtype:: get_type_name ( & py_type. typ ( ) ) ;
121+ return Err ( vm. new_type_error ( format ! (
122+ "super() argument 1 must be type, not {}" ,
123+ type_name
124+ ) ) ) ;
125+ }
126+
102127 // Get the bound object:
103128 let py_obj = if let Some ( obj) = py_obj {
104129 obj. clone ( )
@@ -119,28 +144,18 @@ fn super_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
119144 }
120145 } ;
121146
122- // Get the type:
123- let py_type = if let Some ( ty) = py_type {
124- ty. clone ( )
125- } else {
126- py_obj. typ ( ) . clone ( )
127- } ;
128-
129- // Check type argument:
130- if !objtype:: isinstance ( & py_type, & vm. get_type ( ) ) {
131- let type_name = objtype:: get_type_name ( & py_type. typ ( ) ) ;
132- return Err ( vm. new_type_error ( format ! (
133- "super() argument 1 must be type, not {}" ,
134- type_name
135- ) ) ) ;
136- }
137-
138147 // Check obj type:
139148 if !( objtype:: isinstance ( & py_obj, & py_type) || objtype:: issubclass ( & py_obj, & py_type) ) {
140149 return Err ( vm. new_type_error (
141150 "super(type, obj): obj must be an instance or subtype of type" . to_string ( ) ,
142151 ) ) ;
143152 }
144153
145- Ok ( PyObject :: new ( PySuper { obj : py_obj } , cls. clone ( ) ) )
154+ Ok ( PyObject :: new (
155+ PySuper {
156+ obj : py_obj,
157+ typ : py_type,
158+ } ,
159+ cls. clone ( ) ,
160+ ) )
146161}
0 commit comments