Skip to content

Commit 69658bd

Browse files
committed
Update binascii with base64, add with_ref method to PyBytesLike
1 parent 73f56af commit 69658bd

File tree

7 files changed

+137
-58
lines changed

7 files changed

+137
-58
lines changed

Cargo.lock

Lines changed: 7 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vm/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ wtf8 = "0.0.3"
6767
arr_macro = "0.1.2"
6868
csv = "1.1.1"
6969
paste = "0.1"
70+
base64 = "0.11"
7071

7172
flame = { version = "0.2", optional = true }
7273
flamer = { version = "0.3", optional = true }

vm/src/obj/objbyteinner.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1418,16 +1418,13 @@ fn split_slice_reverse<'a>(slice: &'a [u8], sep: &[u8], maxsplit: i32) -> Vec<&'
14181418
pub enum PyBytesLike {
14191419
Bytes(PyBytesRef),
14201420
Bytearray(PyByteArrayRef),
1421-
Vec(Vec<u8>),
14221421
}
14231422

14241423
impl TryFromObject for PyBytesLike {
14251424
fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult<Self> {
14261425
match_class!(match obj {
14271426
b @ PyBytes => Ok(PyBytesLike::Bytes(b)),
14281427
b @ PyByteArray => Ok(PyBytesLike::Bytearray(b)),
1429-
m @ PyMemoryView => Ok(PyBytesLike::Vec(m.get_obj_value().unwrap())),
1430-
l @ PyList => Ok(PyBytesLike::Vec(l.get_byte_inner(vm)?.elements)),
14311428
obj => Err(vm.new_type_error(format!(
14321429
"a bytes-like object is required, not {}",
14331430
obj.class()
@@ -1441,7 +1438,14 @@ impl PyBytesLike {
14411438
match self {
14421439
PyBytesLike::Bytes(b) => b.get_value().into(),
14431440
PyBytesLike::Bytearray(b) => b.inner.borrow().elements.clone().into(),
1444-
PyBytesLike::Vec(b) => b.as_slice().into(),
1441+
}
1442+
}
1443+
1444+
#[inline]
1445+
pub fn with_ref<R>(&self, f: impl FnOnce(&[u8]) -> R) -> R {
1446+
match self {
1447+
PyBytesLike::Bytes(b) => f(b.get_value()),
1448+
PyBytesLike::Bytearray(b) => f(&b.inner.borrow().elements),
14451449
}
14461450
}
14471451
}

vm/src/stdlib/binascii.rs

Lines changed: 106 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,52 @@
11
use crate::function::OptionalArg;
2-
use crate::obj::objbytes::PyBytesRef;
3-
use crate::pyobject::{PyObjectRef, PyResult};
2+
use crate::obj::objbytearray::{PyByteArray, PyByteArrayRef};
3+
use crate::obj::objbyteinner::PyBytesLike;
4+
use crate::obj::objbytes::{PyBytes, PyBytesRef};
5+
use crate::obj::objstr::{PyString, PyStringRef};
6+
use crate::pyobject::{PyObjectRef, PyResult, TryFromObject, TypeProtocol};
47
use crate::vm::VirtualMachine;
8+
59
use crc::{crc32, Hasher32};
10+
use itertools::Itertools;
11+
12+
enum SerializedData {
13+
Bytes(PyBytesRef),
14+
Buffer(PyByteArrayRef),
15+
Ascii(PyStringRef),
16+
}
17+
18+
impl TryFromObject for SerializedData {
19+
fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult<Self> {
20+
match_class!(match obj {
21+
b @ PyBytes => Ok(SerializedData::Bytes(b)),
22+
b @ PyByteArray => Ok(SerializedData::Buffer(b)),
23+
a @ PyString => {
24+
if a.as_str().is_ascii() {
25+
Ok(SerializedData::Ascii(a))
26+
} else {
27+
Err(vm.new_value_error(
28+
"string argument should contain only ASCII characters".to_string(),
29+
))
30+
}
31+
}
32+
obj => Err(vm.new_type_error(format!(
33+
"argument should be bytes, buffer or ASCII string, not '{}'",
34+
obj.class().name,
35+
))),
36+
})
37+
}
38+
}
39+
40+
impl SerializedData {
41+
#[inline]
42+
pub fn with_ref<R>(&self, f: impl FnOnce(&[u8]) -> R) -> R {
43+
match self {
44+
SerializedData::Bytes(b) => f(b.get_value()),
45+
SerializedData::Buffer(b) => f(&b.inner.borrow().elements),
46+
SerializedData::Ascii(a) => f(a.as_str().as_bytes()),
47+
}
48+
}
49+
}
650

751
fn hex_nibble(n: u8) -> u8 {
852
match n {
@@ -12,15 +56,15 @@ fn hex_nibble(n: u8) -> u8 {
1256
}
1357
}
1458

15-
fn binascii_hexlify(data: PyBytesRef, vm: &VirtualMachine) -> PyResult {
16-
let bytes = data.get_value();
17-
let mut hex = Vec::<u8>::with_capacity(bytes.len() * 2);
18-
for b in bytes.iter() {
19-
hex.push(hex_nibble(b >> 4));
20-
hex.push(hex_nibble(b & 0xf));
21-
}
22-
23-
Ok(vm.ctx.new_bytes(hex))
59+
fn binascii_hexlify(data: PyBytesLike, _vm: &VirtualMachine) -> Vec<u8> {
60+
data.with_ref(|bytes| {
61+
let mut hex = Vec::<u8>::with_capacity(bytes.len() * 2);
62+
for b in bytes.iter() {
63+
hex.push(hex_nibble(b >> 4));
64+
hex.push(hex_nibble(b & 0xf));
65+
}
66+
hex
67+
})
2468
}
2569

2670
fn unhex_nibble(c: u8) -> Option<u8> {
@@ -32,37 +76,66 @@ fn unhex_nibble(c: u8) -> Option<u8> {
3276
}
3377
}
3478

35-
fn binascii_unhexlify(hexstr: PyBytesRef, vm: &VirtualMachine) -> PyResult {
36-
// TODO: allow 'str' hexstrings as well
37-
let hex_bytes = hexstr.get_value();
38-
if hex_bytes.len() % 2 != 0 {
39-
return Err(vm.new_value_error("Odd-length string".to_string()));
40-
}
79+
fn binascii_unhexlify(data: SerializedData, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
80+
data.with_ref(|hex_bytes| {
81+
if hex_bytes.len() % 2 != 0 {
82+
return Err(vm.new_value_error("Odd-length string".to_string()));
83+
}
4184

42-
let mut unhex = Vec::<u8>::with_capacity(hex_bytes.len() / 2);
43-
for i in (0..hex_bytes.len()).step_by(2) {
44-
let n1 = unhex_nibble(hex_bytes[i]);
45-
let n2 = unhex_nibble(hex_bytes[i + 1]);
46-
if let (Some(n1), Some(n2)) = (n1, n2) {
47-
unhex.push(n1 << 4 | n2);
48-
} else {
49-
return Err(vm.new_value_error("Non-hexadecimal digit found".to_string()));
85+
let mut unhex = Vec::<u8>::with_capacity(hex_bytes.len() / 2);
86+
for (n1, n2) in hex_bytes.iter().tuples() {
87+
if let (Some(n1), Some(n2)) = (unhex_nibble(*n1), unhex_nibble(*n2)) {
88+
unhex.push(n1 << 4 | n2);
89+
} else {
90+
return Err(vm.new_value_error("Non-hexadecimal digit found".to_string()));
91+
}
5092
}
51-
}
5293

53-
Ok(vm.ctx.new_bytes(unhex))
94+
Ok(unhex)
95+
})
5496
}
5597

56-
fn binascii_crc32(data: PyBytesRef, value: OptionalArg<u32>, vm: &VirtualMachine) -> PyResult {
57-
let bytes = data.get_value();
58-
let crc = value.unwrap_or(0u32);
98+
fn binascii_crc32(data: SerializedData, value: OptionalArg<u32>, vm: &VirtualMachine) -> PyResult {
99+
let crc = value.unwrap_or(0);
59100

60101
let mut digest = crc32::Digest::new_with_initial(crc32::IEEE, crc);
61-
digest.write(&bytes);
102+
data.with_ref(|bytes| digest.write(&bytes));
62103

63104
Ok(vm.ctx.new_int(digest.sum32()))
64105
}
65106

107+
#[derive(FromArgs)]
108+
struct NewlineArg {
109+
#[pyarg(keyword_only, default = "true")]
110+
newline: bool,
111+
}
112+
113+
/// trim a newline from the end of the bytestring, if it exists
114+
fn trim_newline(b: &[u8]) -> &[u8] {
115+
if b.ends_with(b"\n") {
116+
&b[..b.len() - 1]
117+
} else {
118+
b
119+
}
120+
}
121+
122+
fn binascii_a2b_base64(s: SerializedData, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
123+
s.with_ref(|b| base64::decode(trim_newline(b)))
124+
.map_err(|err| vm.new_value_error(format!("error decoding base64: {}", err)))
125+
}
126+
127+
fn binascii_b2a_base64(
128+
data: PyBytesLike,
129+
NewlineArg { newline }: NewlineArg,
130+
_vm: &VirtualMachine,
131+
) -> Vec<u8> {
132+
let mut encoded = data.with_ref(base64::encode).into_bytes();
133+
if newline {
134+
encoded.push(b'\n');
135+
}
136+
encoded
137+
}
138+
66139
pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
67140
let ctx = &vm.ctx;
68141

@@ -72,5 +145,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
72145
"unhexlify" => ctx.new_rustfunc(binascii_unhexlify),
73146
"a2b_hex" => ctx.new_rustfunc(binascii_unhexlify),
74147
"crc32" => ctx.new_rustfunc(binascii_crc32),
148+
"a2b_base64" => ctx.new_rustfunc(binascii_a2b_base64),
149+
"b2a_base64" => ctx.new_rustfunc(binascii_b2a_base64),
75150
})
76151
}

vm/src/stdlib/io.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,8 @@ impl PyBytesIORef {
244244
}
245245

246246
fn write(self, data: PyBytesLike, vm: &VirtualMachine) -> PyResult<u64> {
247-
match self.buffer(vm)?.write(data.to_cow().as_ref()) {
247+
let mut buffer = self.buffer(vm)?;
248+
match data.with_ref(|b| buffer.write(b)) {
248249
Some(value) => Ok(value),
249250
None => Err(vm.new_type_error("Error Writing Bytes".to_string())),
250251
}
@@ -358,11 +359,10 @@ fn io_base_readline(
358359
let read = vm.get_attribute(instance, "read")?;
359360
while size < 0 || res.len() < size as usize {
360361
let read_res = PyBytesLike::try_from_object(vm, vm.invoke(&read, vec![vm.new_int(1)])?)?;
361-
let b = read_res.to_cow();
362-
if b.is_empty() {
362+
if read_res.with_ref(|b| b.is_empty()) {
363363
break;
364364
}
365-
res.extend_from_slice(b.as_ref());
365+
read_res.with_ref(|b| res.extend_from_slice(b));
366366
if res.ends_with(b"\n") {
367367
break;
368368
}
@@ -643,8 +643,8 @@ mod fileio {
643643
) -> PyResult<usize> {
644644
let mut handle = fio_get_fileno(&instance, vm)?;
645645

646-
let len = handle
647-
.write(obj.to_cow().as_ref())
646+
let len = obj
647+
.with_ref(|b| handle.write(b))
648648
.map_err(|e| os::convert_io_error(vm, e))?;
649649

650650
fio_set_fileno(&instance, handle, vm)?;

vm/src/stdlib/multiprocessing.rs

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,9 @@ fn multiprocessing_send(
4141
buf: PyBytesLike,
4242
vm: &VirtualMachine,
4343
) -> PyResult<libc::c_int> {
44-
let buf = buf.to_cow();
45-
let ret = unsafe {
46-
winsock2::send(
47-
socket as SOCKET,
48-
buf.as_ptr() as *const _,
49-
buf.len() as i32,
50-
0,
51-
)
52-
};
44+
let ret = buf.with_ref(|b| unsafe {
45+
winsock2::send(socket as SOCKET, b.as_ptr() as *const _, b.len() as i32, 0)
46+
});
5347
if ret < 0 {
5448
Err(super::os::convert_io_error(
5549
vm,

vm/src/stdlib/socket.rs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -204,25 +204,23 @@ impl PySocket {
204204

205205
#[pymethod]
206206
fn send(&self, bytes: PyBytesLike, vm: &VirtualMachine) -> PyResult<usize> {
207-
// TODO: use PyBytesLike.with_ref() instead of to_cow()
208-
self.sock()
209-
.send(bytes.to_cow().as_ref())
207+
bytes
208+
.with_ref(|b| self.sock().send(b))
210209
.map_err(|err| convert_sock_error(vm, err))
211210
}
212211

213212
#[pymethod]
214213
fn sendall(&self, bytes: PyBytesLike, vm: &VirtualMachine) -> PyResult<()> {
215-
self.sock
216-
.borrow_mut()
217-
.write_all(bytes.to_cow().as_ref())
214+
bytes
215+
.with_ref(|b| self.sock.borrow_mut().write_all(b))
218216
.map_err(|err| convert_sock_error(vm, err))
219217
}
220218

221219
#[pymethod]
222220
fn sendto(&self, bytes: PyBytesLike, address: Address, vm: &VirtualMachine) -> PyResult<()> {
223221
let addr = get_addr(vm, address)?;
224-
self.sock()
225-
.send_to(bytes.to_cow().as_ref(), &addr)
222+
bytes
223+
.with_ref(|b| self.sock().send_to(b, &addr))
226224
.map_err(|err| convert_sock_error(vm, err))?;
227225
Ok(())
228226
}

0 commit comments

Comments
 (0)