Skip to content

Commit a91470d

Browse files
authored
Merge pull request #1171 from mpajkowski/iter_str
Implement str iterator support
2 parents ea71428 + 1f59532 commit a91470d

File tree

3 files changed

+140
-0
lines changed

3 files changed

+140
-0
lines changed

tests/snippets/strings.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,33 @@ def try_mutate_str():
263263
assert "\u00BE" == "¾"
264264
assert "\u9487" == "钇"
265265
assert "\U0001F609" == "😉"
266+
267+
# test str iter
268+
iterable_str = "123456789"
269+
str_iter = iter(iterable_str)
270+
271+
assert next(str_iter) == "1"
272+
assert next(str_iter) == "2"
273+
assert next(str_iter) == "3"
274+
assert next(str_iter) == "4"
275+
assert next(str_iter) == "5"
276+
assert next(str_iter) == "6"
277+
assert next(str_iter) == "7"
278+
assert next(str_iter) == "8"
279+
assert next(str_iter) == "9"
280+
assert next(str_iter, None) == None
281+
assert_raises(StopIteration, lambda: next(str_iter))
282+
283+
str_iter_reversed = reversed(iterable_str)
284+
285+
assert next(str_iter_reversed) == "9"
286+
assert next(str_iter_reversed) == "8"
287+
assert next(str_iter_reversed) == "7"
288+
assert next(str_iter_reversed) == "6"
289+
assert next(str_iter_reversed) == "5"
290+
assert next(str_iter_reversed) == "4"
291+
assert next(str_iter_reversed) == "3"
292+
assert next(str_iter_reversed) == "2"
293+
assert next(str_iter_reversed) == "1"
294+
assert next(str_iter_reversed, None) == None
295+
assert_raises(StopIteration, lambda: next(str_iter_reversed))

vm/src/obj/objstr.rs

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
extern crate unicode_categories;
22
extern crate unicode_xid;
33

4+
use std::cell::Cell;
45
use std::char;
56
use std::fmt;
67
use std::ops::Range;
@@ -28,6 +29,7 @@ use crate::vm::VirtualMachine;
2829
use super::objbytes::PyBytes;
2930
use super::objdict::PyDict;
3031
use super::objint::{self, PyInt};
32+
use super::objiter;
3133
use super::objnone::PyNone;
3234
use super::objsequence::PySliceableSequence;
3335
use super::objslice::PySlice;
@@ -90,6 +92,79 @@ impl TryIntoRef<PyString> for &str {
9092
}
9193
}
9294

95+
#[pyclass]
96+
#[derive(Debug)]
97+
pub struct PyStringIterator {
98+
pub string: PyStringRef,
99+
position: Cell<usize>,
100+
}
101+
102+
impl PyValue for PyStringIterator {
103+
fn class(vm: &VirtualMachine) -> PyClassRef {
104+
vm.ctx.striterator_type()
105+
}
106+
}
107+
108+
#[pyimpl]
109+
impl PyStringIterator {
110+
#[pymethod(name = "__next__")]
111+
fn next(&self, vm: &VirtualMachine) -> PyResult {
112+
let pos = self.position.get();
113+
114+
if pos < self.string.value.chars().count() {
115+
self.position.set(self.position.get() + 1);
116+
117+
#[allow(clippy::range_plus_one)]
118+
let value = self.string.value.do_slice(pos..pos + 1);
119+
120+
value.into_pyobject(vm)
121+
} else {
122+
Err(objiter::new_stop_iteration(vm))
123+
}
124+
}
125+
126+
#[pymethod(name = "__iter__")]
127+
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
128+
zelf
129+
}
130+
}
131+
132+
#[pyclass]
133+
#[derive(Debug)]
134+
pub struct PyStringReverseIterator {
135+
pub position: Cell<usize>,
136+
pub string: PyStringRef,
137+
}
138+
139+
impl PyValue for PyStringReverseIterator {
140+
fn class(vm: &VirtualMachine) -> PyClassRef {
141+
vm.ctx.strreverseiterator_type()
142+
}
143+
}
144+
145+
#[pyimpl]
146+
impl PyStringReverseIterator {
147+
#[pymethod(name = "__next__")]
148+
fn next(&self, vm: &VirtualMachine) -> PyResult {
149+
if self.position.get() > 0 {
150+
let position: usize = self.position.get() - 1;
151+
152+
#[allow(clippy::range_plus_one)]
153+
let value = self.string.value.do_slice(position..position + 1);
154+
155+
self.position.set(position);
156+
value.into_pyobject(vm)
157+
} else {
158+
Err(objiter::new_stop_iteration(vm))
159+
}
160+
}
161+
162+
#[pymethod(name = "__iter__")]
163+
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
164+
zelf
165+
}
166+
}
167+
93168
#[pyimpl]
94169
impl PyString {
95170
// TODO: should with following format
@@ -1025,6 +1100,24 @@ impl PyString {
10251100
let encoded = PyBytes::from_string(&self.value, &encoding, vm)?;
10261101
Ok(encoded.into_pyobject(vm)?)
10271102
}
1103+
1104+
#[pymethod(name = "__iter__")]
1105+
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyStringIterator {
1106+
PyStringIterator {
1107+
position: Cell::new(0),
1108+
string: zelf,
1109+
}
1110+
}
1111+
1112+
#[pymethod(name = "__reversed__")]
1113+
fn reversed(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyStringReverseIterator {
1114+
let begin = zelf.value.chars().count();
1115+
1116+
PyStringReverseIterator {
1117+
position: Cell::new(begin),
1118+
string: zelf,
1119+
}
1120+
}
10281121
}
10291122

10301123
impl PyValue for PyString {
@@ -1053,6 +1146,9 @@ impl IntoPyObject for &String {
10531146

10541147
pub fn init(ctx: &PyContext) {
10551148
PyString::extend_class(ctx, &ctx.str_type);
1149+
1150+
PyStringIterator::extend_class(ctx, &ctx.striterator_type);
1151+
PyStringReverseIterator::extend_class(ctx, &ctx.strreverseiterator_type);
10561152
}
10571153

10581154
pub fn get_value(obj: &PyObjectRef) -> String {

vm/src/pyobject.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ pub struct PyContext {
135135
pub list_type: PyClassRef,
136136
pub listiterator_type: PyClassRef,
137137
pub listreverseiterator_type: PyClassRef,
138+
pub striterator_type: PyClassRef,
139+
pub strreverseiterator_type: PyClassRef,
138140
pub dictkeyiterator_type: PyClassRef,
139141
pub dictvalueiterator_type: PyClassRef,
140142
pub dictitemiterator_type: PyClassRef,
@@ -274,6 +276,8 @@ impl PyContext {
274276
let listiterator_type = create_type("list_iterator", &type_type, &object_type);
275277
let listreverseiterator_type =
276278
create_type("list_reverseiterator", &type_type, &object_type);
279+
let striterator_type = create_type("str_iterator", &type_type, &object_type);
280+
let strreverseiterator_type = create_type("str_reverseiterator", &type_type, &object_type);
277281
let dictkeys_type = create_type("dict_keys", &type_type, &object_type);
278282
let dictvalues_type = create_type("dict_values", &type_type, &object_type);
279283
let dictitems_type = create_type("dict_items", &type_type, &object_type);
@@ -341,6 +345,8 @@ impl PyContext {
341345
list_type,
342346
listiterator_type,
343347
listreverseiterator_type,
348+
striterator_type,
349+
strreverseiterator_type,
344350
dictkeys_type,
345351
dictvalues_type,
346352
dictitems_type,
@@ -476,6 +482,14 @@ impl PyContext {
476482
self.listreverseiterator_type.clone()
477483
}
478484

485+
pub fn striterator_type(&self) -> PyClassRef {
486+
self.striterator_type.clone()
487+
}
488+
489+
pub fn strreverseiterator_type(&self) -> PyClassRef {
490+
self.strreverseiterator_type.clone()
491+
}
492+
479493
pub fn module_type(&self) -> PyClassRef {
480494
self.module_type.clone()
481495
}

0 commit comments

Comments
 (0)