Skip to content

Commit fca3e7a

Browse files
Merge pull request RustPython#165 from azmeuk/membership
Membership test fixes, and set implementation
2 parents 96d35ad + b18156a commit fca3e7a

5 files changed

Lines changed: 55 additions & 18 deletions

File tree

tests/snippets/membership.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,8 @@
1919
assert 3 not in (1, 2)
2020

2121
# test set
22-
# TODO: uncomment this when sets are implemented
23-
# assert 1 in set(1, 2)
24-
# assert 3 not in set(1, 2)
22+
assert 1 in set([1, 2])
23+
assert 3 not in set([1, 2])
2524

2625
# test dicts
2726
# TODO: test dicts when keys other than strings will be allowed

vm/src/obj/objiter.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use super::super::pyobject::{
77
TypeProtocol,
88
};
99
use super::super::vm::VirtualMachine;
10+
use super::objbool;
1011
use super::objstr;
1112
use super::objtype; // Required for arg_check! to use isinstance
1213

@@ -61,13 +62,16 @@ fn iter_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
6162
);
6263
loop {
6364
match vm.call_method(&iter, "__next__", vec![]) {
64-
Ok(element) => {
65-
if &element == needle {
66-
return Ok(vm.new_bool(true));
67-
} else {
68-
continue;
65+
Ok(element) => match vm.call_method(needle, "__eq__", vec![element.clone()]) {
66+
Ok(value) => {
67+
if objbool::get_value(&value) {
68+
return Ok(vm.new_bool(true));
69+
} else {
70+
continue;
71+
}
6972
}
70-
}
73+
Err(_) => return Err(vm.new_type_error("".to_string())),
74+
},
7175
Err(_) => return Ok(vm.new_bool(false)),
7276
}
7377
}

vm/src/obj/objlist.rs

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use super::super::pyobject::{
22
AttributeProtocol, PyContext, PyFuncArgs, PyObjectKind, PyObjectRef, PyResult, TypeProtocol,
33
};
44
use super::super::vm::VirtualMachine;
5+
use super::objbool;
56
use super::objiter;
67
use super::objsequence::{seq_equal, PySliceableSequence};
78
use super::objstr;
@@ -162,16 +163,21 @@ fn reverse(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
162163
}
163164
}
164165

165-
pub fn contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
166-
trace!("list.len called with: {:?}", args);
166+
fn list_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
167+
trace!("list.contains called with: {:?}", args);
167168
arg_check!(
168169
vm,
169170
args,
170-
required = [(list, Some(vm.ctx.list_type())), (x, None)]
171+
required = [(list, Some(vm.ctx.list_type())), (needle, None)]
171172
);
172173
for element in get_elements(list).iter() {
173-
if x == element {
174-
return Ok(vm.new_bool(true));
174+
match vm.call_method(needle, "__eq__", vec![element.clone()]) {
175+
Ok(value) => {
176+
if objbool::get_value(&value) {
177+
return Ok(vm.new_bool(true));
178+
}
179+
}
180+
Err(_) => return Err(vm.new_type_error("".to_string())),
175181
}
176182
}
177183

@@ -181,7 +187,7 @@ pub fn contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
181187
pub fn init(context: &PyContext) {
182188
let ref list_type = context.list_type;
183189
list_type.set_attr("__add__", context.new_rustfunc(list_add));
184-
list_type.set_attr("__contains__", context.new_rustfunc(contains));
190+
list_type.set_attr("__contains__", context.new_rustfunc(list_contains));
185191
list_type.set_attr("__eq__", context.new_rustfunc(list_eq));
186192
list_type.set_attr("__len__", context.new_rustfunc(list_len));
187193
list_type.set_attr("__new__", context.new_rustfunc(list_new));

vm/src/obj/objset.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use super::super::pyobject::{
77
PyResult, TypeProtocol,
88
};
99
use super::super::vm::VirtualMachine;
10+
use super::objbool;
1011
use super::objiter;
1112
use super::objstr;
1213
use super::objtype;
@@ -88,8 +89,29 @@ fn set_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
8889
Ok(vm.new_str(s))
8990
}
9091

92+
pub fn set_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
93+
arg_check!(
94+
vm,
95+
args,
96+
required = [(set, Some(vm.ctx.set_type())), (needle, None)]
97+
);
98+
for element in get_elements(set).iter() {
99+
match vm.call_method(needle, "__eq__", vec![element.1.clone()]) {
100+
Ok(value) => {
101+
if objbool::get_value(&value) {
102+
return Ok(vm.new_bool(true));
103+
}
104+
}
105+
Err(_) => return Err(vm.new_type_error("".to_string())),
106+
}
107+
}
108+
109+
Ok(vm.new_bool(false))
110+
}
111+
91112
pub fn init(context: &PyContext) {
92113
let ref set_type = context.set_type;
114+
set_type.set_attr("__contains__", context.new_rustfunc(set_contains));
93115
set_type.set_attr("__len__", context.new_rustfunc(set_len));
94116
set_type.set_attr("__new__", context.new_rustfunc(set_new));
95117
set_type.set_attr("__repr__", context.new_rustfunc(set_repr));

vm/src/obj/objtuple.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use super::super::pyobject::{
22
AttributeProtocol, PyContext, PyFuncArgs, PyObjectKind, PyObjectRef, PyResult, TypeProtocol,
33
};
44
use super::super::vm::VirtualMachine;
5+
use super::objbool;
56
use super::objsequence::seq_equal;
67
use super::objstr;
78
use super::objtype;
@@ -54,11 +55,16 @@ pub fn tuple_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
5455
arg_check!(
5556
vm,
5657
args,
57-
required = [(tuple, Some(vm.ctx.tuple_type())), (x, None)]
58+
required = [(tuple, Some(vm.ctx.tuple_type())), (needle, None)]
5859
);
5960
for element in get_elements(tuple).iter() {
60-
if x == element {
61-
return Ok(vm.new_bool(true));
61+
match vm.call_method(needle, "__eq__", vec![element.clone()]) {
62+
Ok(value) => {
63+
if objbool::get_value(&value) {
64+
return Ok(vm.new_bool(true));
65+
}
66+
}
67+
Err(_) => return Err(vm.new_type_error("".to_string())),
6268
}
6369
}
6470

0 commit comments

Comments
 (0)