Skip to content

Commit 979e125

Browse files
committed
Fix set/frozenset comparison
1 parent 31c8872 commit 979e125

2 files changed

Lines changed: 73 additions & 45 deletions

File tree

tests/snippets/set.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,26 @@
2727
assert not set([1,2]) < set([1,2])
2828
assert not set([1,3]) < set([1,2])
2929

30+
assert (set() == []) is False
31+
assert set().__eq__([]) == NotImplemented
32+
assert_raises(TypeError, lambda: set() < [], "'<' not supported between instances of 'set' and 'list'")
33+
assert_raises(TypeError, lambda: set() <= [], "'<=' not supported between instances of 'set' and 'list'")
34+
assert_raises(TypeError, lambda: set() > [], "'>' not supported between instances of 'set' and 'list'")
35+
assert_raises(TypeError, lambda: set() >= [], "'>=' not supported between instances of 'set' and 'list'")
36+
assert set().issuperset([])
37+
assert set().issubset([])
38+
assert not set().issuperset([1, 2, 3])
39+
assert set().issubset([1, 2])
40+
41+
assert (set() == 3) is False
42+
assert set().__eq__(3) == NotImplemented
43+
assert_raises(TypeError, lambda: set() < 3, "'int' object is not iterable")
44+
assert_raises(TypeError, lambda: set() <= 3, "'int' object is not iterable")
45+
assert_raises(TypeError, lambda: set() > 3, "'int' object is not iterable")
46+
assert_raises(TypeError, lambda: set() >= 3, "'int' object is not iterable")
47+
assert_raises(TypeError, lambda: set().issuperset(3), "'int' object is not iterable")
48+
assert_raises(TypeError, lambda: set().issubset(3), "'int' object is not iterable")
49+
3050
class Hashable(object):
3151
def __init__(self, obj):
3252
self.obj = obj

vm/src/obj/objset.rs

Lines changed: 53 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -73,16 +73,22 @@ struct PySetInner {
7373
}
7474

7575
impl PySetInner {
76-
fn new(iterable: OptionalArg<PyIterable>, vm: &VirtualMachine) -> PyResult<PySetInner> {
76+
fn new(iterable: PyIterable, vm: &VirtualMachine) -> PyResult<PySetInner> {
7777
let mut set = PySetInner::default();
78-
if let OptionalArg::Present(iterable) = iterable {
79-
for item in iterable.iter(vm)? {
80-
set.add(&item?, vm)?;
81-
}
78+
for item in iterable.iter(vm)? {
79+
set.add(&item?, vm)?;
8280
}
8381
Ok(set)
8482
}
8583

84+
fn from_arg(iterable: OptionalArg<PyIterable>, vm: &VirtualMachine) -> PyResult<PySetInner> {
85+
if let OptionalArg::Present(iterable) = iterable {
86+
Self::new(iterable, vm)
87+
} else {
88+
Ok(PySetInner::default())
89+
}
90+
}
91+
8692
fn len(&self) -> usize {
8793
self.content.len()
8894
}
@@ -97,33 +103,21 @@ impl PySetInner {
97103
self.content.contains(vm, needle)
98104
}
99105

106+
#[inline]
100107
fn _compare_inner(
101108
&self,
102109
other: &PySetInner,
103110
size_func: &Fn(usize, usize) -> bool,
104111
swap: bool,
105112
vm: &VirtualMachine,
106113
) -> PyResult {
107-
let get_zelf = |swap: bool| -> &PySetInner {
108-
if swap {
109-
other
110-
} else {
111-
self
112-
}
113-
};
114-
let get_other = |swap: bool| -> &PySetInner {
115-
if swap {
116-
self
117-
} else {
118-
other
119-
}
120-
};
114+
let (zelf, other) = if swap { (other, self) } else { (self, other) };
121115

122-
if size_func(get_zelf(swap).len(), get_other(swap).len()) {
116+
if size_func(zelf.len(), other.len()) {
123117
return Ok(vm.new_bool(false));
124118
}
125-
for key in get_other(swap).content.keys() {
126-
if !get_zelf(swap).contains(&key, vm)? {
119+
for key in other.content.keys() {
120+
if !zelf.contains(&key, vm)? {
127121
return Ok(vm.new_bool(false));
128122
}
129123
}
@@ -213,6 +207,20 @@ impl PySetInner {
213207
Ok(new_inner)
214208
}
215209

210+
fn issuperset(&self, other: PyIterable, vm: &VirtualMachine) -> PyResult<bool> {
211+
for item in other.iter(vm)? {
212+
if !self.contains(&item?, vm)? {
213+
return Ok(false);
214+
}
215+
}
216+
Ok(true)
217+
}
218+
219+
fn issubset(&self, other: PyIterable, vm: &VirtualMachine) -> PyResult {
220+
let other_set = PySetInner::new(other, vm)?;
221+
self.le(&other_set, vm)
222+
}
223+
216224
fn isdisjoint(&self, other: PyIterable, vm: &VirtualMachine) -> PyResult<bool> {
217225
for item in other.iter(vm)? {
218226
if self.contains(&item?, vm)? {
@@ -308,7 +316,7 @@ macro_rules! try_set_inner {
308316
match_class!($other,
309317
set @ PySet => $op(&*set.inner.borrow()),
310318
frozen @ PyFrozenSet => $op(&frozen.inner),
311-
other => Err($vm.new_type_error(format!("{} is not a subtype of set or frozenset", other.class()))),
319+
_ => Ok($vm.ctx.not_implemented()),
312320
);
313321
};
314322
}
@@ -322,7 +330,7 @@ impl PySet {
322330
vm: &VirtualMachine,
323331
) -> PyResult<PySetRef> {
324332
PySet {
325-
inner: RefCell::new(PySetInner::new(iterable, vm)?),
333+
inner: RefCell::new(PySetInner::from_arg(iterable, vm)?),
326334
}
327335
.into_ref_with_type(vm, cls)
328336
}
@@ -413,6 +421,16 @@ impl PySet {
413421
))
414422
}
415423

424+
#[pymethod]
425+
fn issubset(&self, other: PyIterable, vm: &VirtualMachine) -> PyResult {
426+
self.inner.borrow().issubset(other, vm)
427+
}
428+
429+
#[pymethod]
430+
fn issuperset(&self, other: PyIterable, vm: &VirtualMachine) -> PyResult<bool> {
431+
self.inner.borrow().issuperset(other, vm)
432+
}
433+
416434
#[pymethod]
417435
fn isdisjoint(&self, other: PyIterable, vm: &VirtualMachine) -> PyResult<bool> {
418436
self.inner.borrow().isdisjoint(other, vm)
@@ -539,16 +557,6 @@ impl PySet {
539557
Ok(zelf.as_object().clone())
540558
}
541559

542-
#[pymethod]
543-
fn issubset(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
544-
self.le(other, vm)
545-
}
546-
547-
#[pymethod]
548-
fn issuperset(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
549-
self.ge(other, vm)
550-
}
551-
552560
#[pymethod(name = "__hash__")]
553561
fn hash(&self, vm: &VirtualMachine) -> PyResult<()> {
554562
Err(vm.new_type_error("unhashable type".to_string()))
@@ -564,7 +572,7 @@ impl PyFrozenSet {
564572
vm: &VirtualMachine,
565573
) -> PyResult<PyFrozenSetRef> {
566574
PyFrozenSet {
567-
inner: PySetInner::new(iterable, vm)?,
575+
inner: PySetInner::from_arg(iterable, vm)?,
568576
}
569577
.into_ref_with_type(vm, cls)
570578
}
@@ -655,6 +663,16 @@ impl PyFrozenSet {
655663
))
656664
}
657665

666+
#[pymethod]
667+
fn issubset(&self, other: PyIterable, vm: &VirtualMachine) -> PyResult {
668+
self.inner.issubset(other, vm)
669+
}
670+
671+
#[pymethod]
672+
fn issuperset(&self, other: PyIterable, vm: &VirtualMachine) -> PyResult<bool> {
673+
self.inner.issuperset(other, vm)
674+
}
675+
658676
#[pymethod]
659677
fn isdisjoint(&self, other: PyIterable, vm: &VirtualMachine) -> PyResult<bool> {
660678
self.inner.isdisjoint(other, vm)
@@ -697,16 +715,6 @@ impl PyFrozenSet {
697715
};
698716
Ok(vm.new_str(s))
699717
}
700-
701-
#[pymethod]
702-
fn issubset(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
703-
self.le(other, vm)
704-
}
705-
706-
#[pymethod]
707-
fn issuperset(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
708-
self.ge(other, vm)
709-
}
710718
}
711719

712720
struct SetIterable {

0 commit comments

Comments
 (0)