Skip to content

Commit c1ae29e

Browse files
committed
Fancy args structs for codec functions
1 parent e21a447 commit c1ae29e

File tree

2 files changed

+76
-21
lines changed

2 files changed

+76
-21
lines changed

common/src/encodings.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
use std::ops::Range;
22

3+
pub type EncodeErrorResult<S, B, E> = Result<(EncodeReplace<S, B>, usize), E>;
4+
5+
pub type DecodeErrorResult<S, B, E> = Result<(S, Option<B>, usize), E>;
6+
37
pub trait ErrorHandler {
48
type Error;
59
type StrBuf: AsRef<str>;
@@ -8,13 +12,13 @@ pub trait ErrorHandler {
812
&self,
913
byte_range: Range<usize>,
1014
reason: &str,
11-
) -> Result<(EncodeReplace<Self::StrBuf, Self::BytesBuf>, usize), Self::Error>;
15+
) -> EncodeErrorResult<Self::StrBuf, Self::BytesBuf, Self::Error>;
1216
fn handle_decode_error(
1317
&self,
1418
data: &[u8],
1519
byte_range: Range<usize>,
1620
reason: &str,
17-
) -> Result<(Self::StrBuf, Option<Self::BytesBuf>, usize), Self::Error>;
21+
) -> DecodeErrorResult<Self::StrBuf, Self::BytesBuf, Self::Error>;
1822
fn error_oob_restart(&self, i: usize) -> Self::Error;
1923
}
2024
pub enum EncodeReplace<S, B> {
@@ -25,6 +29,9 @@ pub enum EncodeReplace<S, B> {
2529
pub mod utf8 {
2630
use super::*;
2731

32+
pub const ENCODING_NAME: &str = "utf-8";
33+
34+
#[inline]
2835
pub fn encode<E: ErrorHandler>(s: &str, _errors: &E) -> Result<Vec<u8>, E::Error> {
2936
Ok(s.as_bytes().to_vec())
3037
}

vm/src/stdlib/codecs.rs

Lines changed: 67 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ mod _codecs {
77
use crate::builtins::{PyBytesRef, PyStr, PyStrRef, PyTuple};
88
use crate::byteslike::PyBytesLike;
99
use crate::codecs;
10-
use crate::common::encodings::{self, utf8};
10+
use crate::common::encodings;
1111
use crate::exceptions::PyBaseExceptionRef;
12-
use crate::function::{FuncArgs, OptionalArg};
12+
use crate::function::FuncArgs;
1313
use crate::VirtualMachine;
1414
use crate::{IdProtocol, PyObjectRef, PyResult, TryFromObject};
1515

@@ -78,24 +78,25 @@ mod _codecs {
7878
struct ErrorsHandler<'a> {
7979
vm: &'a VirtualMachine,
8080
encoding: &'a str,
81-
errors: Option<&'a PyStrRef>,
81+
errors: Option<PyStrRef>,
8282
handler: once_cell::unsync::OnceCell<PyObjectRef>,
8383
}
8484
impl<'a> ErrorsHandler<'a> {
85-
fn new(encoding: &'a str, errors: Option<&'a PyStrRef>, vm: &'a VirtualMachine) -> Self {
85+
#[inline]
86+
fn new(encoding: &'a str, errors: Option<PyStrRef>, vm: &'a VirtualMachine) -> Self {
8687
ErrorsHandler {
8788
vm,
8889
encoding,
8990
errors,
9091
handler: Default::default(),
9192
}
9293
}
94+
#[inline]
9395
fn handler_func(&self) -> PyResult<&PyObjectRef> {
9496
let vm = self.vm;
9597
self.handler.get_or_try_init(|| {
96-
vm.state
97-
.codec_registry
98-
.lookup_error(self.errors.map_or("strict", |s| s.as_ref()), vm)
98+
let errors = self.errors.as_ref().map_or("strict", |s| s.as_str());
99+
vm.state.codec_registry.lookup_error(errors, vm)
99100
})
100101
}
101102
}
@@ -109,8 +110,8 @@ mod _codecs {
109110
_byte_range: Range<usize>,
110111
_reason: &str,
111112
) -> PyResult<(encodings::EncodeReplace<PyStrRef, PyBytesRef>, usize)> {
112-
// we don't use common::encodings to encode anything yet, so this can't
113-
// get called until we do
113+
// we don't use common::encodings to really encode anything yet (utf8 can't error
114+
// because PyStr is always utf8), so this can't get called until we do
114115
todo!()
115116
}
116117

@@ -174,20 +175,67 @@ mod _codecs {
174175
}
175176
}
176177

178+
type EncodeResult = PyResult<(Vec<u8>, usize)>;
179+
180+
#[derive(FromArgs)]
181+
struct EncodeArgs {
182+
#[pyarg(positional)]
183+
s: PyStrRef,
184+
#[pyarg(positional, optional)]
185+
errors: Option<PyStrRef>,
186+
}
187+
188+
impl EncodeArgs {
189+
#[inline]
190+
fn encode<'a, F>(self, name: &'a str, encode: F, vm: &'a VirtualMachine) -> EncodeResult
191+
where
192+
F: FnOnce(&str, &ErrorsHandler<'a>) -> PyResult<Vec<u8>>,
193+
{
194+
let errors = ErrorsHandler::new(name, self.errors, vm);
195+
let encoded = encode(self.s.as_str(), &errors)?;
196+
Ok((encoded, self.s.char_len()))
197+
}
198+
}
199+
200+
type DecodeResult = PyResult<(String, usize)>;
201+
202+
#[derive(FromArgs)]
203+
struct DecodeArgs {
204+
#[pyarg(positional)]
205+
data: PyBytesLike,
206+
#[pyarg(positional, optional)]
207+
errors: Option<PyStrRef>,
208+
#[pyarg(positional, default = "false")]
209+
final_decode: bool,
210+
}
211+
212+
impl DecodeArgs {
213+
#[inline]
214+
fn decode<'a, F>(self, name: &'a str, decode: F, vm: &'a VirtualMachine) -> DecodeResult
215+
where
216+
F: FnOnce(&[u8], &ErrorsHandler<'a>, bool) -> DecodeResult,
217+
{
218+
let data = self.data.borrow_buf();
219+
let errors = ErrorsHandler::new(name, self.errors, vm);
220+
decode(&data, &errors, self.final_decode)
221+
}
222+
}
223+
224+
macro_rules! do_codec {
225+
($module:ident :: $func:ident, $args: expr, $vm:expr) => {{
226+
use encodings::$module as codec;
227+
$args.$func(codec::ENCODING_NAME, codec::$func, $vm)
228+
}};
229+
}
230+
177231
#[pyfunction]
178-
fn utf_8_encode(s: PyStrRef, _errors: OptionalArg<PyStrRef>) -> (Vec<u8>, usize) {
179-
(s.as_str().as_bytes().to_vec(), s.char_len())
232+
fn utf_8_encode(args: EncodeArgs, vm: &VirtualMachine) -> EncodeResult {
233+
do_codec!(utf8::encode, args, vm)
180234
}
181235

182236
#[pyfunction]
183-
fn utf_8_decode(
184-
data: PyBytesLike,
185-
errors: OptionalArg<PyStrRef>,
186-
final_decode: OptionalArg<bool>,
187-
vm: &VirtualMachine,
188-
) -> PyResult<(String, usize)> {
189-
let errors = ErrorsHandler::new("utf-8", errors.as_ref().into_option(), vm);
190-
data.with_ref(|data| utf8::decode(data, &errors, final_decode.unwrap_or(false)))
237+
fn utf_8_decode(args: DecodeArgs, vm: &VirtualMachine) -> DecodeResult {
238+
do_codec!(utf8::decode, args, vm)
191239
}
192240

193241
// TODO: implement these codecs in Rust!

0 commit comments

Comments
 (0)