Skip to content

Commit c6ac393

Browse files
Add {int,float}.{__radd__,__rsub__,__rmul__,__rtruediv__}
1 parent d76c86e commit c6ac393

4 files changed

Lines changed: 197 additions & 60 deletions

File tree

tests/snippets/floats.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,21 @@ def assert_raises(expr, exc_type):
8181

8282
assert_raises(lambda _: float('foo'), ValueError)
8383
assert_raises(lambda _: float(2**10000), OverflowError)
84+
85+
# check that magic methods are implemented for ints and floats
86+
87+
assert 1.0.__add__(1.0) == 2.0
88+
assert 1.0.__radd__(1.0) == 2.0
89+
assert 2.0.__sub__(1.0) == 1.0
90+
assert 2.0.__rmul__(1.0) == 2.0
91+
assert 1.0.__truediv__(2.0) == 0.5
92+
assert 1.0.__rtruediv__(2.0) == 2.0
93+
94+
assert 1.0.__add__(1) == 2.0
95+
assert 1.0.__radd__(1) == 2.0
96+
assert 2.0.__sub__(1) == 1.0
97+
assert 2.0.__rmul__(1) == 2.0
98+
assert 1.0.__truediv__(2) == 0.5
99+
assert 1.0.__rtruediv__(2) == 2.0
100+
assert 2.0.__mul__(1) == 2.0
101+
assert 2.0.__rsub__(1) == -1.0

tests/snippets/ints.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,34 @@
1515
assert 1 >= 1.0
1616
assert 1 <= 1.0
1717

18+
# magic methods should only be implemented for other ints
19+
20+
assert (1).__eq__(1) == True
21+
assert (1).__ne__(1) == False
22+
assert (1).__gt__(1) == False
23+
assert (1).__ge__(1) == True
24+
assert (1).__lt__(1) == False
25+
assert (1).__le__(1) == True
26+
assert (1).__add__(1) == 2
27+
assert (1).__radd__(1) == 2
28+
assert (2).__sub__(1) == 1
29+
assert (2).__rsub__(1) == -1
30+
assert (2).__mul__(1) == 2
31+
assert (2).__rmul__(1) == 2
32+
assert (2).__truediv__(1) == 2.0
33+
assert (2).__rtruediv__(1) == 0.5
34+
1835
assert (1).__eq__(1.0) == NotImplemented
1936
assert (1).__ne__(1.0) == NotImplemented
2037
assert (1).__gt__(1.0) == NotImplemented
2138
assert (1).__ge__(1.0) == NotImplemented
2239
assert (1).__lt__(1.0) == NotImplemented
2340
assert (1).__le__(1.0) == NotImplemented
41+
assert (1).__add__(1.0) == NotImplemented
42+
assert (2).__sub__(1.0) == NotImplemented
43+
assert (1).__radd__(1.0) == NotImplemented
44+
assert (2).__rsub__(1.0) == NotImplemented
45+
assert (2).__mul__(1.0) == NotImplemented
46+
assert (2).__rmul__(1.0) == NotImplemented
47+
assert (2).__truediv__(1.0) == NotImplemented
48+
assert (2).__rtruediv__(1.0) == NotImplemented

vm/src/obj/objfloat.rs

Lines changed: 85 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -198,21 +198,25 @@ fn float_add(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
198198
arg_check!(
199199
vm,
200200
args,
201-
required = [(i, Some(vm.ctx.float_type())), (i2, None)]
201+
required = [(zelf, Some(vm.ctx.float_type())), (other, None)]
202202
);
203203

204-
let v1 = get_value(i);
205-
if objtype::isinstance(i2, &vm.ctx.float_type()) {
206-
Ok(vm.ctx.new_float(v1 + get_value(i2)))
207-
} else if objtype::isinstance(i2, &vm.ctx.int_type()) {
204+
let v1 = get_value(zelf);
205+
if objtype::isinstance(other, &vm.ctx.float_type()) {
206+
Ok(vm.ctx.new_float(v1 + get_value(other)))
207+
} else if objtype::isinstance(other, &vm.ctx.int_type()) {
208208
Ok(vm
209209
.ctx
210-
.new_float(v1 + objint::get_value(i2).to_f64().unwrap()))
210+
.new_float(v1 + objint::get_value(other).to_f64().unwrap()))
211211
} else {
212212
Ok(vm.ctx.not_implemented())
213213
}
214214
}
215215

216+
fn float_radd(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
217+
float_add(vm, args)
218+
}
219+
216220
fn float_divmod(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
217221
arg_check!(
218222
vm,
@@ -259,15 +263,33 @@ fn float_sub(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
259263
arg_check!(
260264
vm,
261265
args,
262-
required = [(i, Some(vm.ctx.float_type())), (i2, None)]
266+
required = [(zelf, Some(vm.ctx.float_type())), (other, None)]
263267
);
264-
let v1 = get_value(i);
265-
if objtype::isinstance(i2, &vm.ctx.float_type()) {
266-
Ok(vm.ctx.new_float(v1 - get_value(i2)))
267-
} else if objtype::isinstance(i2, &vm.ctx.int_type()) {
268+
let v1 = get_value(zelf);
269+
if objtype::isinstance(other, &vm.ctx.float_type()) {
270+
Ok(vm.ctx.new_float(v1 - get_value(other)))
271+
} else if objtype::isinstance(other, &vm.ctx.int_type()) {
268272
Ok(vm
269273
.ctx
270-
.new_float(v1 - objint::get_value(i2).to_f64().unwrap()))
274+
.new_float(v1 - objint::get_value(other).to_f64().unwrap()))
275+
} else {
276+
Ok(vm.ctx.not_implemented())
277+
}
278+
}
279+
280+
fn float_rsub(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
281+
arg_check!(
282+
vm,
283+
args,
284+
required = [(zelf, Some(vm.ctx.float_type())), (other, None)]
285+
);
286+
let v1 = get_value(zelf);
287+
if objtype::isinstance(other, &vm.ctx.float_type()) {
288+
Ok(vm.ctx.new_float(get_value(other) - v1))
289+
} else if objtype::isinstance(other, &vm.ctx.int_type()) {
290+
Ok(vm
291+
.ctx
292+
.new_float(objint::get_value(other).to_f64().unwrap() - v1))
271293
} else {
272294
Ok(vm.ctx.not_implemented())
273295
}
@@ -328,18 +350,18 @@ fn float_truediv(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
328350
arg_check!(
329351
vm,
330352
args,
331-
required = [(i, Some(vm.ctx.float_type())), (i2, None)]
353+
required = [(zelf, Some(vm.ctx.float_type())), (other, None)]
332354
);
333355

334-
let v1 = get_value(i);
335-
let v2 = if objtype::isinstance(i2, &vm.ctx.float_type) {
336-
get_value(i2)
337-
} else if objtype::isinstance(i2, &vm.ctx.int_type) {
338-
objint::get_value(i2)
356+
let v1 = get_value(zelf);
357+
let v2 = if objtype::isinstance(other, &vm.ctx.float_type) {
358+
get_value(other)
359+
} else if objtype::isinstance(other, &vm.ctx.int_type) {
360+
objint::get_value(other)
339361
.to_f64()
340362
.ok_or_else(|| vm.new_overflow_error("int too large to convert to float".to_string()))?
341363
} else {
342-
return Err(vm.new_type_error(format!("Cannot divide {} and {}", i.borrow(), i2.borrow())));
364+
return Ok(vm.ctx.not_implemented());
343365
};
344366

345367
if v2 != 0.0 {
@@ -349,28 +371,53 @@ fn float_truediv(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
349371
}
350372
}
351373

374+
fn float_rtruediv(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
375+
arg_check!(
376+
vm,
377+
args,
378+
required = [(zelf, Some(vm.ctx.float_type())), (other, None)]
379+
);
380+
381+
let v1 = get_value(zelf);
382+
let v2 = if objtype::isinstance(other, &vm.ctx.float_type) {
383+
get_value(other)
384+
} else if objtype::isinstance(other, &vm.ctx.int_type) {
385+
objint::get_value(other)
386+
.to_f64()
387+
.ok_or_else(|| vm.new_overflow_error("int too large to convert to float".to_string()))?
388+
} else {
389+
return Ok(vm.ctx.not_implemented());
390+
};
391+
392+
if v1 != 0.0 {
393+
Ok(vm.ctx.new_float(v2 / v1))
394+
} else {
395+
Err(vm.new_zero_division_error("float division by zero".to_string()))
396+
}
397+
}
398+
352399
fn float_mul(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
353400
arg_check!(
354401
vm,
355402
args,
356-
required = [(i, Some(vm.ctx.float_type())), (i2, None)]
403+
required = [(zelf, Some(vm.ctx.float_type())), (other, None)]
357404
);
358-
let v1 = get_value(i);
359-
if objtype::isinstance(i2, &vm.ctx.float_type) {
360-
Ok(vm.ctx.new_float(v1 * get_value(i2)))
361-
} else if objtype::isinstance(i2, &vm.ctx.int_type) {
405+
let v1 = get_value(zelf);
406+
if objtype::isinstance(other, &vm.ctx.float_type) {
407+
Ok(vm.ctx.new_float(v1 * get_value(other)))
408+
} else if objtype::isinstance(other, &vm.ctx.int_type) {
362409
Ok(vm
363410
.ctx
364-
.new_float(v1 * objint::get_value(i2).to_f64().unwrap()))
411+
.new_float(v1 * objint::get_value(other).to_f64().unwrap()))
365412
} else {
366-
Err(vm.new_type_error(format!(
367-
"Cannot multiply {} and {}",
368-
i.borrow(),
369-
i2.borrow()
370-
)))
413+
Ok(vm.ctx.not_implemented())
371414
}
372415
}
373416

417+
fn float_rmul(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
418+
float_mul(vm, args)
419+
}
420+
374421
pub fn init(context: &PyContext) {
375422
let float_type = &context.float_type;
376423

@@ -383,6 +430,7 @@ pub fn init(context: &PyContext) {
383430
context.set_attr(&float_type, "__ge__", context.new_rustfunc(float_ge));
384431
context.set_attr(&float_type, "__abs__", context.new_rustfunc(float_abs));
385432
context.set_attr(&float_type, "__add__", context.new_rustfunc(float_add));
433+
context.set_attr(&float_type, "__radd__", context.new_rustfunc(float_radd));
386434
context.set_attr(
387435
&float_type,
388436
"__divmod__",
@@ -398,6 +446,7 @@ pub fn init(context: &PyContext) {
398446
context.set_attr(&float_type, "__neg__", context.new_rustfunc(float_neg));
399447
context.set_attr(&float_type, "__pow__", context.new_rustfunc(float_pow));
400448
context.set_attr(&float_type, "__sub__", context.new_rustfunc(float_sub));
449+
context.set_attr(&float_type, "__rsub__", context.new_rustfunc(float_rsub));
401450
context.set_attr(&float_type, "__repr__", context.new_rustfunc(float_repr));
402451
context.set_attr(
403452
&float_type,
@@ -409,5 +458,11 @@ pub fn init(context: &PyContext) {
409458
"__truediv__",
410459
context.new_rustfunc(float_truediv),
411460
);
461+
context.set_attr(
462+
&float_type,
463+
"__rtruediv__",
464+
context.new_rustfunc(float_rtruediv),
465+
);
412466
context.set_attr(&float_type, "__mul__", context.new_rustfunc(float_mul));
467+
context.set_attr(&float_type, "__rmul__", context.new_rustfunc(float_rmul));
413468
}

0 commit comments

Comments
 (0)