Skip to content

Commit e572371

Browse files
committed
Fix behavior of pow builtin
1 parent 2b0c7f8 commit e572371

File tree

4 files changed

+118
-63
lines changed

4 files changed

+118
-63
lines changed

vm/src/builtins/complex.rs

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,26 @@ fn inner_div(v1: Complex64, v2: Complex64, vm: &VirtualMachine) -> PyResult<Comp
6464
Ok(v1.fdiv(v2))
6565
}
6666

67+
fn inner_pow(v1: Complex64, v2: Complex64, vm: &VirtualMachine) -> PyResult<Complex64> {
68+
if v1.is_zero() {
69+
return if v2.im != 0.0 {
70+
let msg = format!("{} cannot be raised to a negative or complex power", v1);
71+
Err(vm.new_zero_division_error(msg))
72+
} else if v2.is_zero() {
73+
Ok(Complex64::new(1.0, 0.0))
74+
} else {
75+
Ok(Complex64::new(0.0, 0.0))
76+
}
77+
}
78+
79+
let ans = v1.powc(v2);
80+
if ans.is_infinite() && !(v1.is_infinite() || v2.is_infinite()) {
81+
Err(vm.new_overflow_error("complex exponentiation overflow".to_owned()))
82+
} else {
83+
Ok(ans)
84+
}
85+
}
86+
6787
#[pyimpl(flags(BASETYPE), with(Comparable, Hashable))]
6888
impl PyComplex {
6989
pub fn to_complex(&self) -> Complex64 {
@@ -215,9 +235,14 @@ impl PyComplex {
215235
fn pow(
216236
&self,
217237
other: PyObjectRef,
238+
mod_val: OptionalOption<PyObjectRef>,
218239
vm: &VirtualMachine,
219240
) -> PyResult<PyArithmaticValue<Complex64>> {
220-
self.op(other, |a, b| Ok(a.powc(b)), vm)
241+
if mod_val.flatten().is_some() {
242+
Err(vm.new_value_error("complex modulo not allowed".to_owned()))
243+
} else {
244+
self.op(other, |a, b| Ok(inner_pow(a, b, vm)?), vm)
245+
}
221246
}
222247

223248
#[pymethod(name = "__rpow__")]
@@ -226,7 +251,7 @@ impl PyComplex {
226251
other: PyObjectRef,
227252
vm: &VirtualMachine,
228253
) -> PyResult<PyArithmaticValue<Complex64>> {
229-
self.op(other, |a, b| Ok(b.powc(a)), vm)
254+
self.op(other, |a, b| Ok(inner_pow(b, a, vm)?), vm)
230255
}
231256

232257
#[pymethod(name = "__bool__")]
@@ -294,6 +319,7 @@ impl PyComplex {
294319
}
295320
};
296321

322+
297323
let final_real = if imag_was_complex {
298324
real.re - imag.im
299325
} else {

vm/src/builtins/make_module.rs

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,9 @@ mod decl {
2424
#[cfg(feature = "rustpython-compiler")]
2525
use crate::compile;
2626
use crate::exceptions::PyBaseExceptionRef;
27-
use crate::function::{single_or_tuple_any, Args, FuncArgs, KwArgs, OptionalArg};
27+
use crate::function::{single_or_tuple_any, Args, FuncArgs, KwArgs, OptionalArg, OptionalOption};
2828
use crate::iterator;
29-
use crate::pyobject::{
30-
BorrowValue, Either, IdProtocol, ItemProtocol, PyCallable, PyIterable, PyObjectRef,
31-
PyResult, PyValue, TryFromObject, TypeProtocol,
32-
};
29+
use crate::pyobject::{BorrowValue, Either, IdProtocol, ItemProtocol, PyCallable, PyIterable, PyObjectRef, PyResult, PyValue, TryFromObject, TypeProtocol, PyArithmaticValue};
3330
use crate::readline::{Readline, ReadlineResult};
3431
use crate::scope::Scope;
3532
use crate::sliceable;
@@ -38,6 +35,7 @@ mod decl {
3835
use crate::{py_io, sysmodule};
3936
use num_bigint::Sign;
4037
use num_traits::{Signed, ToPrimitive, Zero};
38+
use crate::builtins::PyInt;
4139

4240
#[pyfunction]
4341
fn abs(x: PyObjectRef, vm: &VirtualMachine) -> PyResult {
@@ -177,7 +175,7 @@ mod decl {
177175
#[pyfunction]
178176
fn divmod(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult {
179177
vm.call_or_reflection(&a, &b, "__divmod__", "__rdivmod__", |vm, a, b| {
180-
Err(vm.new_unsupported_operand_error(a, b, "divmod"))
178+
Err(vm.new_unsupported_binop_error(a, b, "divmod"))
181179
})
182180
}
183181

@@ -586,36 +584,46 @@ mod decl {
586584
fn pow(
587585
x: PyObjectRef,
588586
y: PyObjectRef,
589-
mod_value: OptionalArg<PyIntRef>,
587+
mod_value: OptionalOption<PyObjectRef>,
590588
vm: &VirtualMachine,
591589
) -> PyResult {
592-
match mod_value {
593-
OptionalArg::Missing => {
590+
match mod_value.flatten() {
591+
None => {
594592
vm.call_or_reflection(&x, &y, "__pow__", "__rpow__", |vm, x, y| {
595-
Err(vm.new_unsupported_operand_error(x, y, "pow"))
593+
Err(vm.new_unsupported_binop_error(x, y, "pow"))
596594
})
597595
}
598-
OptionalArg::Present(m) => {
599-
// Check if the 3rd argument is defined and perform modulus on the result
600-
if !(x.isinstance(&vm.ctx.types.int_type) && y.isinstance(&vm.ctx.types.int_type)) {
601-
return Err(vm.new_type_error(
602-
"pow() 3rd argument not allowed unless all arguments are integers"
603-
.to_owned(),
604-
));
596+
Some(z) => {
597+
let try_pow_value = |obj: &PyObjectRef, args: (PyObjectRef, PyObjectRef, PyObjectRef)| -> Option<PyResult> {
598+
if let Some(method) = obj.get_class_attr("__pow__") {
599+
let result = match vm.invoke(&method, args) {
600+
Ok(x) => x,
601+
Err(e) => return Some(Err(e)),
602+
};
603+
if let PyArithmaticValue::Implemented(x) = PyArithmaticValue::from_object(vm, result) {
604+
return Some(Ok(x))
605+
}
606+
}
607+
None
608+
};
609+
610+
if let Some(val) = try_pow_value(&x, (x.clone(), y.clone(), z.clone())) {
611+
return val
605612
}
606-
let y = int::get_value(&y);
607-
if y.sign() == Sign::Minus {
608-
return Err(vm.new_value_error(
609-
"pow() 2nd argument cannot be negative when 3rd argument specified"
610-
.to_owned(),
611-
));
613+
614+
if !x.class().is(&y.class()) {
615+
if let Some(val) = try_pow_value(&y, (x.clone(), y.clone(), z.clone())) {
616+
return val
617+
}
612618
}
613-
let m = m.borrow_value();
614-
if m.is_zero() {
615-
return Err(vm.new_value_error("pow() 3rd argument cannot be 0".to_owned()));
619+
620+
if !x.class().is(&z.class()) && !y.class().is(&z.class()) {
621+
if let Some(val) = try_pow_value(&z, (x.clone(), y.clone(), z.clone())) {
622+
return val
623+
}
616624
}
617-
let x = int::get_value(&x);
618-
Ok(vm.ctx.new_int(x.modpow(&y, &m)))
625+
626+
Err(vm.new_unsupported_ternop_error(&x, &y, &z, "pow"))
619627
}
620628
}
621629
}

0 commit comments

Comments
 (0)