Skip to content

Commit ae00d33

Browse files
committed
Implemented set.remove
1 parent 4a08067 commit ae00d33

6 files changed

Lines changed: 110 additions & 10 deletions

File tree

py/map.c

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -147,20 +147,27 @@ static void mp_set_rehash(mp_set_t *set) {
147147
}
148148

149149
mp_obj_t mp_set_lookup(mp_set_t *set, mp_obj_t index, mp_map_lookup_kind_t lookup_kind) {
150-
int hash = mp_obj_hash(index);
150+
int hash;
151+
int pos;
151152
if (set->alloc == 0) {
152-
if (lookup_kind == MP_MAP_LOOKUP_ADD_IF_NOT_FOUND) {
153+
if (lookup_kind & MP_MAP_LOOKUP_ADD_IF_NOT_FOUND) {
153154
mp_set_rehash(set);
154155
} else {
155156
return NULL;
156157
}
157158
}
158-
int pos = hash % set->alloc;
159+
if (lookup_kind & MP_MAP_LOOKUP_FIRST) {
160+
hash = 0;
161+
pos = 0;
162+
} else {
163+
hash = mp_obj_hash(index);;
164+
pos = hash % set->alloc;
165+
}
159166
for (;;) {
160167
mp_obj_t elem = set->table[pos];
161168
if (elem == MP_OBJ_NULL) {
162169
// not in table
163-
if (lookup_kind == MP_MAP_LOOKUP_ADD_IF_NOT_FOUND) {
170+
if (lookup_kind & MP_MAP_LOOKUP_ADD_IF_NOT_FOUND) {
164171
if (set->used + 1 >= set->alloc) {
165172
// not enough room in table, rehash it
166173
mp_set_rehash(set);
@@ -171,15 +178,16 @@ mp_obj_t mp_set_lookup(mp_set_t *set, mp_obj_t index, mp_map_lookup_kind_t looku
171178
set->table[pos] = index;
172179
return index;
173180
}
181+
} else if (lookup_kind & MP_MAP_LOOKUP_FIRST) {
182+
pos++;
174183
} else {
175184
return MP_OBJ_NULL;
176185
}
177-
} else if (mp_obj_equal(elem, index)) {
186+
} else if (lookup_kind & MP_MAP_LOOKUP_FIRST || mp_obj_equal(elem, index)) {
178187
// found it
179-
if (lookup_kind == MP_MAP_LOOKUP_REMOVE_IF_FOUND) {
188+
if (lookup_kind & MP_MAP_LOOKUP_REMOVE_IF_FOUND) {
180189
set->used--;
181190
set->table[pos] = NULL;
182-
return elem;
183191
}
184192
return elem;
185193
} else {

py/map.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@ typedef struct _mp_set_t {
1919
} mp_set_t;
2020

2121
typedef enum _mp_map_lookup_kind_t {
22-
MP_MAP_LOOKUP,
23-
MP_MAP_LOOKUP_ADD_IF_NOT_FOUND,
24-
MP_MAP_LOOKUP_REMOVE_IF_FOUND,
22+
MP_MAP_LOOKUP, // 0
23+
MP_MAP_LOOKUP_ADD_IF_NOT_FOUND, // 1
24+
MP_MAP_LOOKUP_REMOVE_IF_FOUND, // 2
25+
MP_MAP_LOOKUP_FIRST = 4,
2526
} mp_map_lookup_kind_t;
2627

2728
int get_doubling_prime_greater_or_equal_to(int x);

py/objset.c

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,70 @@ static mp_obj_t set_isdisjoint(mp_obj_t self_in, mp_obj_t other) {
227227
}
228228
static MP_DEFINE_CONST_FUN_OBJ_2(set_isdisjoint_obj, set_isdisjoint);
229229

230+
static mp_obj_t set_issubset(mp_obj_t self_in, mp_obj_t other_in) {
231+
mp_obj_set_t *self;
232+
bool cleanup_self = false;
233+
if (MP_OBJ_IS_TYPE(self_in, &set_type)) {
234+
self = self_in;
235+
} else {
236+
self = set_make_new(NULL, 1, &self_in);
237+
cleanup_self = true;
238+
}
239+
240+
mp_obj_set_t *other;
241+
bool cleanup_other = false;
242+
if (MP_OBJ_IS_TYPE(other_in, &set_type)) {
243+
other = other_in;
244+
} else {
245+
other = set_make_new(NULL, 1, &other_in);
246+
cleanup_other = true;
247+
}
248+
mp_obj_t iter = set_getiter(self);
249+
mp_obj_t next;
250+
mp_obj_t out = mp_const_true;
251+
while ((next = set_it_iternext(iter)) != mp_const_stop_iteration) {
252+
if (!mp_set_lookup(&other->set, next, MP_MAP_LOOKUP)) {
253+
out = mp_const_false;
254+
break;
255+
}
256+
}
257+
if (cleanup_self) {
258+
set_clear(self);
259+
}
260+
if (cleanup_other) {
261+
set_clear(other);
262+
}
263+
return out;
264+
}
265+
static MP_DEFINE_CONST_FUN_OBJ_2(set_issubset_obj, set_issubset);
266+
267+
static mp_obj_t set_issuperset(mp_obj_t self_in, mp_obj_t other_in) {
268+
return set_issubset(other_in, self_in);
269+
}
270+
static MP_DEFINE_CONST_FUN_OBJ_2(set_issuperset_obj, set_issuperset);
271+
272+
static mp_obj_t set_pop(mp_obj_t self_in) {
273+
assert(MP_OBJ_IS_TYPE(self_in, &set_type));
274+
mp_obj_set_t *self = self_in;
275+
276+
if (self->set.used == 0) {
277+
nlr_jump(mp_obj_new_exception_msg(MP_QSTR_KeyError, "pop from an empty set"));
278+
}
279+
mp_obj_t obj = mp_set_lookup(&self->set, NULL,
280+
MP_MAP_LOOKUP_REMOVE_IF_FOUND | MP_MAP_LOOKUP_FIRST);
281+
return obj;
282+
}
283+
static MP_DEFINE_CONST_FUN_OBJ_1(set_pop_obj, set_pop);
284+
285+
static mp_obj_t set_remove(mp_obj_t self_in, mp_obj_t item) {
286+
assert(MP_OBJ_IS_TYPE(self_in, &set_type));
287+
mp_obj_set_t *self = self_in;
288+
if (mp_set_lookup(&self->set, item, MP_MAP_LOOKUP_REMOVE_IF_FOUND) == MP_OBJ_NULL) {
289+
nlr_jump(mp_obj_new_exception(MP_QSTR_KeyError));
290+
}
291+
return mp_const_none;
292+
}
293+
static MP_DEFINE_CONST_FUN_OBJ_2(set_remove_obj, set_remove);
230294

231295
/******************************************************************************/
232296
/* set constructors & public C API */
@@ -242,6 +306,10 @@ static const mp_method_t set_type_methods[] = {
242306
{ "intersection", &set_intersect_obj },
243307
{ "intersection_update", &set_intersect_update_obj },
244308
{ "isdisjoint", &set_isdisjoint_obj },
309+
{ "issubset", &set_issubset_obj },
310+
{ "issuperset", &set_issuperset_obj },
311+
{ "pop", &set_pop_obj },
312+
{ "remove", &set_remove_obj },
245313
{ NULL, NULL }, // end-of-list sentinel
246314
};
247315

tests/basics/tests/set_isfooset.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
sets = [set(), {1}, {1, 2, 3}, {3, 4, 5}, {5, 6, 7}]
2+
for i in sets:
3+
for j in sets:
4+
print(i.issubset(j))
5+
print(i.issuperset(j))

tests/basics/tests/set_pop.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
s = {1}
2+
print(s.pop())
3+
try:
4+
print(s.pop(), "!!!")
5+
except KeyError:
6+
pass
7+
else:
8+
print("Failed to raise KeyError")
9+

tests/basics/tests/set_remove.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
s = {1}
2+
print(s.remove(1))
3+
print(list(s))
4+
try:
5+
print(s.remove(1), "!!!")
6+
except KeyError:
7+
pass
8+
else:
9+
print("failed to raise KeyError")

0 commit comments

Comments
 (0)