Skip to content

Commit 36bace4

Browse files
committed
Fix set in-place operators with self argument
1 parent 4fbf617 commit 36bace4

File tree

2 files changed

+32
-6
lines changed

2 files changed

+32
-6
lines changed

crates/vm/src/builtins/set.rs

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -719,8 +719,10 @@ impl PySet {
719719
}
720720

721721
fn __iand__(zelf: PyRef<Self>, set: AnySet, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
722-
zelf.inner
723-
.intersection_update(core::iter::once(set.into_iterable(vm)?), vm)?;
722+
if !set.is_same_object(zelf.as_object()) {
723+
zelf.inner
724+
.intersection_update(core::iter::once(set.into_iterable(vm)?), vm)?;
725+
}
724726
Ok(zelf)
725727
}
726728

@@ -731,8 +733,12 @@ impl PySet {
731733
}
732734

733735
fn __isub__(zelf: PyRef<Self>, set: AnySet, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
734-
zelf.inner
735-
.difference_update(set.into_iterable_iter(vm)?, vm)?;
736+
if set.is_same_object(zelf.as_object()) {
737+
zelf.inner.clear();
738+
} else {
739+
zelf.inner
740+
.difference_update(set.into_iterable_iter(vm)?, vm)?;
741+
}
736742
Ok(zelf)
737743
}
738744

@@ -748,8 +754,12 @@ impl PySet {
748754
}
749755

750756
fn __ixor__(zelf: PyRef<Self>, set: AnySet, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
751-
zelf.inner
752-
.symmetric_difference_update(set.into_iterable_iter(vm)?, vm)?;
757+
if set.is_same_object(zelf.as_object()) {
758+
zelf.inner.clear();
759+
} else {
760+
zelf.inner
761+
.symmetric_difference_update(set.into_iterable_iter(vm)?, vm)?;
762+
}
753763
Ok(zelf)
754764
}
755765

@@ -1305,6 +1315,10 @@ impl AnySet {
13051315
obj.fast_isinstance(ctx.types.set_type) || obj.fast_isinstance(ctx.types.frozenset_type)
13061316
}
13071317

1318+
fn is_same_object(&self, other: &PyObject) -> bool {
1319+
self.object.is(other)
1320+
}
1321+
13081322
fn into_iterable(self, vm: &VirtualMachine) -> PyResult<ArgIterable> {
13091323
self.object.try_into_value(vm)
13101324
}

extra_tests/snippets/builtin_set.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,18 @@ class S(set):
200200
with assert_raises(TypeError):
201201
a &= [1, 2, 3]
202202

203+
a = set([1, 2, 3])
204+
a &= a
205+
assert a == set([1, 2, 3])
206+
207+
a = set([1, 2, 3])
208+
a -= a
209+
assert a == set()
210+
211+
a = set([1, 2, 3])
212+
a ^= a
213+
assert a == set()
214+
203215
a = set([1, 2, 3])
204216
a.difference_update([3, 4, 5])
205217
assert a == set([1, 2])

0 commit comments

Comments
 (0)