Skip to content

Commit 51ae261

Browse files
committed
Modify unary functions in number protocol
Signed-off-by: snowapril <sinjihng@gmail.com>
1 parent 1b66332 commit 51ae261

File tree

2 files changed

+58
-47
lines changed

2 files changed

+58
-47
lines changed

vm/src/protocol/number.rs

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use crate::{
1111
VirtualMachine,
1212
};
1313

14-
pub type PyNumberUnaryFunc<R = PyObjectRef> = fn(PyNumber, &VirtualMachine) -> PyResult<R>;
14+
pub type PyNumberUnaryFunc = fn(&PyObject, &VirtualMachine) -> PyResult;
1515
pub type PyNumberBinaryFunc = fn(&PyObject, &PyObject, &VirtualMachine) -> PyResult;
1616

1717
impl PyObject {
@@ -125,15 +125,15 @@ pub struct PyNumberMethods {
125125
pub negative: Option<PyNumberUnaryFunc>,
126126
pub positive: Option<PyNumberUnaryFunc>,
127127
pub absolute: Option<PyNumberUnaryFunc>,
128-
pub boolean: Option<PyNumberUnaryFunc<bool>>,
128+
pub boolean: Option<PyNumberUnaryFunc>,
129129
pub invert: Option<PyNumberUnaryFunc>,
130130
pub lshift: Option<PyNumberBinaryFunc>,
131131
pub rshift: Option<PyNumberBinaryFunc>,
132132
pub and: Option<PyNumberBinaryFunc>,
133133
pub xor: Option<PyNumberBinaryFunc>,
134134
pub or: Option<PyNumberBinaryFunc>,
135-
pub int: Option<PyNumberUnaryFunc<PyRef<PyInt>>>,
136-
pub float: Option<PyNumberUnaryFunc<PyRef<PyFloat>>>,
135+
pub int: Option<PyNumberUnaryFunc>,
136+
pub float: Option<PyNumberUnaryFunc>,
137137

138138
pub inplace_add: Option<PyNumberBinaryFunc>,
139139
pub inplace_subtract: Option<PyNumberBinaryFunc>,
@@ -151,7 +151,7 @@ pub struct PyNumberMethods {
151151
pub inplace_floor_divide: Option<PyNumberBinaryFunc>,
152152
pub inplace_true_divide: Option<PyNumberBinaryFunc>,
153153

154-
pub index: Option<PyNumberUnaryFunc<PyRef<PyInt>>>,
154+
pub index: Option<PyNumberUnaryFunc>,
155155

156156
pub matrix_multiply: Option<PyNumberBinaryFunc>,
157157
pub inplace_matrix_multiply: Option<PyNumberBinaryFunc>,
@@ -245,15 +245,15 @@ pub struct PyNumberSlots {
245245
pub negative: AtomicCell<Option<PyNumberUnaryFunc>>,
246246
pub positive: AtomicCell<Option<PyNumberUnaryFunc>>,
247247
pub absolute: AtomicCell<Option<PyNumberUnaryFunc>>,
248-
pub boolean: AtomicCell<Option<PyNumberUnaryFunc<bool>>>,
248+
pub boolean: AtomicCell<Option<PyNumberUnaryFunc>>,
249249
pub invert: AtomicCell<Option<PyNumberUnaryFunc>>,
250250
pub lshift: AtomicCell<Option<PyNumberBinaryFunc>>,
251251
pub rshift: AtomicCell<Option<PyNumberBinaryFunc>>,
252252
pub and: AtomicCell<Option<PyNumberBinaryFunc>>,
253253
pub xor: AtomicCell<Option<PyNumberBinaryFunc>>,
254254
pub or: AtomicCell<Option<PyNumberBinaryFunc>>,
255-
pub int: AtomicCell<Option<PyNumberUnaryFunc<PyRef<PyInt>>>>,
256-
pub float: AtomicCell<Option<PyNumberUnaryFunc<PyRef<PyFloat>>>>,
255+
pub int: AtomicCell<Option<PyNumberUnaryFunc>>,
256+
pub float: AtomicCell<Option<PyNumberUnaryFunc>>,
257257

258258
pub right_add: AtomicCell<Option<PyNumberBinaryFunc>>,
259259
pub right_subtract: AtomicCell<Option<PyNumberBinaryFunc>>,
@@ -285,7 +285,7 @@ pub struct PyNumberSlots {
285285
pub inplace_floor_divide: AtomicCell<Option<PyNumberBinaryFunc>>,
286286
pub inplace_true_divide: AtomicCell<Option<PyNumberBinaryFunc>>,
287287

288-
pub index: AtomicCell<Option<PyNumberUnaryFunc<PyRef<PyInt>>>>,
288+
pub index: AtomicCell<Option<PyNumberUnaryFunc>>,
289289

290290
pub matrix_multiply: AtomicCell<Option<PyNumberBinaryFunc>>,
291291
pub right_matrix_multiply: AtomicCell<Option<PyNumberBinaryFunc>>,
@@ -440,8 +440,16 @@ impl PyNumber<'_> {
440440
#[inline]
441441
pub fn int(self, vm: &VirtualMachine) -> Option<PyResult<PyIntRef>> {
442442
self.class().slots.as_number.int.load().map(|f| {
443-
let ret = f(self, vm)?;
444-
let value = if !ret.class().is(PyInt::class(&vm.ctx)) {
443+
let ret = f(self.obj(), vm)?;
444+
let value: PyRef<PyInt> = if !ret.class().is(PyInt::class(&vm.ctx)) {
445+
if !ret.class().fast_issubclass(vm.ctx.types.int_type) {
446+
return Err(vm.new_type_error(format!(
447+
"{}.__int__ returned non-int(type {})",
448+
self.class().name(),
449+
ret.class().name()
450+
)));
451+
}
452+
445453
warnings::warn(
446454
vm.ctx.exceptions.deprecation_warning,
447455
format!(
@@ -453,9 +461,11 @@ impl PyNumber<'_> {
453461
1,
454462
vm,
455463
)?;
456-
vm.ctx.new_bigint(ret.as_bigint())
464+
// TODO(snowapril) : modify to proper conversion method
465+
unsafe { ret.downcast_unchecked() }
457466
} else {
458-
ret
467+
// TODO(snowapril) : modify to proper conversion method
468+
unsafe { ret.downcast_unchecked() }
459469
};
460470
Ok(value)
461471
})
@@ -464,8 +474,16 @@ impl PyNumber<'_> {
464474
#[inline]
465475
pub fn index(self, vm: &VirtualMachine) -> Option<PyResult<PyIntRef>> {
466476
self.class().slots.as_number.index.load().map(|f| {
467-
let ret = f(self, vm)?;
468-
let value = if !ret.class().is(PyInt::class(&vm.ctx)) {
477+
let ret = f(self.obj(), vm)?;
478+
let value: PyRef<PyInt> = if !ret.class().is(PyInt::class(&vm.ctx)) {
479+
if !ret.class().fast_issubclass(vm.ctx.types.int_type) {
480+
return Err(vm.new_type_error(format!(
481+
"{}.__index__ returned non-int(type {})",
482+
self.class().name(),
483+
ret.class().name()
484+
)));
485+
}
486+
469487
warnings::warn(
470488
vm.ctx.exceptions.deprecation_warning,
471489
format!(
@@ -477,9 +495,11 @@ impl PyNumber<'_> {
477495
1,
478496
vm,
479497
)?;
480-
vm.ctx.new_bigint(ret.as_bigint())
498+
// TODO(snowapril) : modify to proper conversion method
499+
unsafe { ret.downcast_unchecked() }
481500
} else {
482-
ret
501+
// TODO(snowapril) : modify to proper conversion method
502+
unsafe { ret.downcast_unchecked() }
483503
};
484504
Ok(value)
485505
})
@@ -488,8 +508,16 @@ impl PyNumber<'_> {
488508
#[inline]
489509
pub fn float(self, vm: &VirtualMachine) -> Option<PyResult<PyRef<PyFloat>>> {
490510
self.class().slots.as_number.float.load().map(|f| {
491-
let ret = f(self, vm)?;
492-
let value = if !ret.class().is(PyFloat::class(&vm.ctx)) {
511+
let ret = f(self.obj(), vm)?;
512+
let value: PyRef<PyFloat> = if !ret.class().is(PyFloat::class(&vm.ctx)) {
513+
if !ret.class().fast_issubclass(vm.ctx.types.float_type) {
514+
return Err(vm.new_type_error(format!(
515+
"{}.__float__ returned non-float(type {})",
516+
self.class().name(),
517+
ret.class().name()
518+
)));
519+
}
520+
493521
warnings::warn(
494522
vm.ctx.exceptions.deprecation_warning,
495523
format!(
@@ -501,9 +529,11 @@ impl PyNumber<'_> {
501529
1,
502530
vm,
503531
)?;
504-
vm.ctx.new_float(ret.to_f64())
532+
// TODO(snowapril) : modify to proper conversion method
533+
unsafe { ret.downcast_unchecked() }
505534
} else {
506-
ret
535+
// TODO(snowapril) : modify to proper conversion method
536+
unsafe { ret.downcast_unchecked() }
507537
};
508538
Ok(value)
509539
})

vm/src/types/slot.rs

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -197,30 +197,11 @@ pub(crate) fn len_wrapper(obj: &PyObject, vm: &VirtualMachine) -> PyResult<usize
197197
Ok(len as usize)
198198
}
199199

200-
fn int_wrapper(num: PyNumber, vm: &VirtualMachine) -> PyResult<PyRef<PyInt>> {
201-
let ret = vm.call_special_method(num.deref(), identifier!(vm, __int__), ())?;
202-
ret.downcast::<PyInt>().map_err(|obj| {
203-
vm.new_type_error(format!("__int__ returned non-int (type {})", obj.class()))
204-
})
205-
}
206-
207-
fn index_wrapper(num: PyNumber, vm: &VirtualMachine) -> PyResult<PyRef<PyInt>> {
208-
let ret = vm.call_special_method(num.deref(), identifier!(vm, __index__), ())?;
209-
ret.downcast::<PyInt>().map_err(|obj| {
210-
vm.new_type_error(format!("__index__ returned non-int (type {})", obj.class()))
211-
})
212-
}
213-
214-
fn float_wrapper(num: PyNumber, vm: &VirtualMachine) -> PyResult<PyRef<PyFloat>> {
215-
let ret = vm.call_special_method(num.deref(), identifier!(vm, __float__), ())?;
216-
ret.downcast::<PyFloat>().map_err(|obj| {
217-
vm.new_type_error(format!(
218-
"__float__ returned non-float (type {})",
219-
obj.class()
220-
))
221-
})
200+
macro_rules! number_unary_op_wrapper {
201+
($name:ident) => {
202+
|a, vm| vm.call_special_method(a, identifier!(vm, $name), ())
203+
};
222204
}
223-
224205
macro_rules! number_binary_op_wrapper {
225206
($name:ident) => {
226207
|a, b, vm| vm.call_special_method(a, identifier!(vm, $name), (b.to_owned(),))
@@ -505,13 +486,13 @@ impl PyType {
505486
toggle_slot!(del, del_wrapper);
506487
}
507488
_ if name == identifier!(ctx, __int__) => {
508-
toggle_subslot!(as_number, int, int_wrapper);
489+
toggle_subslot!(as_number, int, number_unary_op_wrapper!(__int__));
509490
}
510491
_ if name == identifier!(ctx, __index__) => {
511-
toggle_subslot!(as_number, index, index_wrapper);
492+
toggle_subslot!(as_number, index, number_unary_op_wrapper!(__index__));
512493
}
513494
_ if name == identifier!(ctx, __float__) => {
514-
toggle_subslot!(as_number, float, float_wrapper);
495+
toggle_subslot!(as_number, float, number_unary_op_wrapper!(__float__));
515496
}
516497
_ if name == identifier!(ctx, __add__) => {
517498
toggle_subslot!(as_number, add, number_binary_op_wrapper!(__add__));

0 commit comments

Comments
 (0)