Skip to content

Commit 03a2aad

Browse files
Merge pull request #410 from OddCoincidence/complex
Add complex.{__abs__, __eq__, __neg__}
2 parents c2db23d + 94db145 commit 03a2aad

File tree

4 files changed

+81
-0
lines changed

4 files changed

+81
-0
lines changed

tests/snippets/builtin_complex.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# __abs__
2+
3+
assert abs(complex(3, 4)) == 5
4+
assert abs(complex(3, -4)) == 5
5+
assert abs(complex(1.5, 2.5)) == 2.9154759474226504
6+
7+
# __eq__
8+
9+
assert complex(1, -1) == complex(1, -1)
10+
assert complex(1, 0) == 1
11+
assert not complex(1, 1) == 1
12+
assert complex(1, 0) == 1.0
13+
assert not complex(1, 1) == 1.0
14+
assert not complex(1, 0) == 1.5
15+
assert bool(complex(1, 0))
16+
assert not complex(1, 2) == complex(1, 1)
17+
# Currently broken - see issue #419
18+
# assert complex(1, 2) != 'foo'
19+
assert complex(1, 2).__eq__('foo') == NotImplemented
20+
21+
# __neg__
22+
23+
assert -complex(1, -1) == complex(-1, 1)
24+
assert -complex(0, 0) == complex(0, 0)

vm/src/builtins.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,9 @@ pub fn make_module(ctx: &PyContext) -> PyObjectRef {
718718
ctx.set_attr(&py_mod, "type", ctx.type_type());
719719
ctx.set_attr(&py_mod, "zip", ctx.zip_type());
720720

721+
// Constants
722+
ctx.set_attr(&py_mod, "NotImplemented", ctx.not_implemented.clone());
723+
721724
// Exceptions:
722725
ctx.set_attr(
723726
&py_mod,

vm/src/obj/objcomplex.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@ use super::super::pyobject::{
33
};
44
use super::super::vm::VirtualMachine;
55
use super::objfloat;
6+
use super::objint;
67
use super::objtype;
78
use num_complex::Complex64;
9+
use num_traits::ToPrimitive;
810

911
pub fn init(context: &PyContext) {
1012
let complex_type = &context.complex_type;
@@ -13,7 +15,10 @@ pub fn init(context: &PyContext) {
1315
"Create a complex number from a real part and an optional imaginary part.\n\n\
1416
This is equivalent to (real + imag*1j) where imag defaults to 0.";
1517

18+
context.set_attr(&complex_type, "__abs__", context.new_rustfunc(complex_abs));
1619
context.set_attr(&complex_type, "__add__", context.new_rustfunc(complex_add));
20+
context.set_attr(&complex_type, "__eq__", context.new_rustfunc(complex_eq));
21+
context.set_attr(&complex_type, "__neg__", context.new_rustfunc(complex_neg));
1722
context.set_attr(&complex_type, "__new__", context.new_rustfunc(complex_new));
1823
context.set_attr(
1924
&complex_type,
@@ -70,6 +75,13 @@ fn complex_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
7075
))
7176
}
7277

78+
fn complex_abs(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
79+
arg_check!(vm, args, required = [(zelf, Some(vm.ctx.complex_type()))]);
80+
81+
let Complex64 { re, im } = get_value(zelf);
82+
Ok(vm.ctx.new_float(re.hypot(im)))
83+
}
84+
7385
fn complex_add(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
7486
arg_check!(
7587
vm,
@@ -92,6 +104,36 @@ fn complex_conjugate(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
92104
Ok(vm.ctx.new_complex(v1.conj()))
93105
}
94106

107+
fn complex_eq(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
108+
arg_check!(
109+
vm,
110+
args,
111+
required = [(zelf, Some(vm.ctx.complex_type())), (other, None)]
112+
);
113+
114+
let z = get_value(zelf);
115+
116+
let result = if objtype::isinstance(other, &vm.ctx.complex_type()) {
117+
z == get_value(other)
118+
} else if objtype::isinstance(other, &vm.ctx.int_type()) {
119+
match objint::get_value(other).to_f64() {
120+
Some(f) => z.im == 0.0f64 && z.re == f,
121+
None => false,
122+
}
123+
} else if objtype::isinstance(other, &vm.ctx.float_type()) {
124+
z.im == 0.0 && z.re == objfloat::get_value(other)
125+
} else {
126+
return Ok(vm.ctx.not_implemented());
127+
};
128+
129+
Ok(vm.ctx.new_bool(result))
130+
}
131+
132+
fn complex_neg(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
133+
arg_check!(vm, args, required = [(zelf, Some(vm.ctx.complex_type()))]);
134+
Ok(vm.ctx.new_complex(-get_value(zelf)))
135+
}
136+
95137
fn complex_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
96138
arg_check!(vm, args, required = [(obj, Some(vm.ctx.complex_type()))]);
97139
let v = get_value(obj);

vm/src/pyobject.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ pub struct PyContext {
131131
pub map_type: PyObjectRef,
132132
pub memoryview_type: PyObjectRef,
133133
pub none: PyObjectRef,
134+
pub not_implemented: PyObjectRef,
134135
pub tuple_type: PyObjectRef,
135136
pub set_type: PyObjectRef,
136137
pub staticmethod_type: PyObjectRef,
@@ -226,6 +227,11 @@ impl PyContext {
226227
create_type("NoneType", &type_type, &object_type, &dict_type),
227228
);
228229

230+
let not_implemented = PyObject::new(
231+
PyObjectPayload::NotImplemented,
232+
create_type("NotImplementedType", &type_type, &object_type, &dict_type),
233+
);
234+
229235
let true_value = PyObject::new(
230236
PyObjectPayload::Integer { value: One::one() },
231237
bool_type.clone(),
@@ -261,6 +267,7 @@ impl PyContext {
261267
zip_type,
262268
dict_type,
263269
none,
270+
not_implemented,
264271
str_type,
265272
range_type,
266273
slice_type,
@@ -432,6 +439,9 @@ impl PyContext {
432439
pub fn none(&self) -> PyObjectRef {
433440
self.none.clone()
434441
}
442+
pub fn not_implemented(&self) -> PyObjectRef {
443+
self.not_implemented.clone()
444+
}
435445
pub fn object(&self) -> PyObjectRef {
436446
self.object.clone()
437447
}
@@ -965,6 +975,7 @@ pub enum PyObjectPayload {
965975
dict: PyObjectRef,
966976
},
967977
None,
978+
NotImplemented,
968979
Class {
969980
name: String,
970981
dict: RefCell<PyAttributes>,
@@ -1011,6 +1022,7 @@ impl fmt::Debug for PyObjectPayload {
10111022
PyObjectPayload::Module { .. } => write!(f, "module"),
10121023
PyObjectPayload::Scope { .. } => write!(f, "scope"),
10131024
PyObjectPayload::None => write!(f, "None"),
1025+
PyObjectPayload::NotImplemented => write!(f, "NotImplemented"),
10141026
PyObjectPayload::Class { ref name, .. } => write!(f, "class {:?}", name),
10151027
PyObjectPayload::Instance { .. } => write!(f, "instance"),
10161028
PyObjectPayload::RustFunction { .. } => write!(f, "rust function"),

0 commit comments

Comments
 (0)