Skip to content

Commit 7a72db2

Browse files
committed
Fix set in-place operators with self argument
1 parent 9e43966 commit 7a72db2

File tree

2 files changed

+32
-6
lines changed

2 files changed

+32
-6
lines changed

crates/vm/src/builtins/set.rs

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -730,8 +730,10 @@ impl PySet {
730730

731731
#[pymethod]
732732
fn __iand__(zelf: PyRef<Self>, set: AnySet, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
733-
zelf.inner
734-
.intersection_update(std::iter::once(set.into_iterable(vm)?), vm)?;
733+
if !set.is_same_object(zelf.as_object()) {
734+
zelf.inner
735+
.intersection_update(std::iter::once(set.into_iterable(vm)?), vm)?;
736+
}
735737
Ok(zelf)
736738
}
737739

@@ -743,8 +745,12 @@ impl PySet {
743745

744746
#[pymethod]
745747
fn __isub__(zelf: PyRef<Self>, set: AnySet, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
746-
zelf.inner
747-
.difference_update(set.into_iterable_iter(vm)?, vm)?;
748+
if set.is_same_object(zelf.as_object()) {
749+
zelf.inner.clear();
750+
} else {
751+
zelf.inner
752+
.difference_update(set.into_iterable_iter(vm)?, vm)?;
753+
}
748754
Ok(zelf)
749755
}
750756

@@ -761,8 +767,12 @@ impl PySet {
761767

762768
#[pymethod]
763769
fn __ixor__(zelf: PyRef<Self>, set: AnySet, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
764-
zelf.inner
765-
.symmetric_difference_update(set.into_iterable_iter(vm)?, vm)?;
770+
if set.is_same_object(zelf.as_object()) {
771+
zelf.inner.clear();
772+
} else {
773+
zelf.inner
774+
.symmetric_difference_update(set.into_iterable_iter(vm)?, vm)?;
775+
}
766776
Ok(zelf)
767777
}
768778

@@ -1251,6 +1261,10 @@ struct AnySet {
12511261
}
12521262

12531263
impl AnySet {
1264+
fn is_same_object(&self, other: &PyObject) -> bool {
1265+
self.object.is(other)
1266+
}
1267+
12541268
fn into_iterable(self, vm: &VirtualMachine) -> PyResult<ArgIterable> {
12551269
self.object.try_into_value(vm)
12561270
}

extra_tests/snippets/builtin_set.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,18 @@ class S(set):
200200
with assert_raises(TypeError):
201201
a &= [1, 2, 3]
202202

203+
a = set([1, 2, 3])
204+
a &= a
205+
assert a == set([1, 2, 3])
206+
207+
a = set([1, 2, 3])
208+
a -= a
209+
assert a == set()
210+
211+
a = set([1, 2, 3])
212+
a ^= a
213+
assert a == set()
214+
203215
a = set([1, 2, 3])
204216
a.difference_update([3, 4, 5])
205217
assert a == set([1, 2])

0 commit comments

Comments
 (0)