Skip to content

Commit 8cdf19c

Browse files
committed
FileIO.write support ByteArray
1 parent dc7eb24 commit 8cdf19c

File tree

1 file changed

+21
-20
lines changed

1 file changed

+21
-20
lines changed

vm/src/stdlib/io.rs

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@ use crate::function::{OptionalArg, PyFuncArgs};
1616
use crate::import;
1717
use crate::obj::objbytearray::PyByteArray;
1818
use crate::obj::objbytes;
19+
use crate::obj::objbytes::PyBytes;
1920
use crate::obj::objint;
2021
use crate::obj::objstr;
2122
use crate::obj::objtype;
2223
use crate::obj::objtype::PyClassRef;
24+
use crate::pyobject::TypeProtocol;
2325
use crate::pyobject::{BufferProtocol, PyObjectRef, PyRef, PyResult, PyValue};
2426
use crate::vm::VirtualMachine;
2527

@@ -383,11 +385,7 @@ fn file_io_readinto(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
383385
}
384386

385387
fn file_io_write(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
386-
arg_check!(
387-
vm,
388-
args,
389-
required = [(file_io, None), (obj, Some(vm.ctx.bytes_type()))]
390-
);
388+
arg_check!(vm, args, required = [(file_io, None), (obj, None)]);
391389

392390
let file_no = vm.get_attribute(file_io.clone(), "fileno")?;
393391
let raw_fd = objint::get_value(&file_no).to_i64().unwrap();
@@ -397,22 +395,25 @@ fn file_io_write(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
397395
//to support windows - i.e. raw file_handles
398396
let mut handle = os::rust_file(raw_fd);
399397

400-
match obj.payload::<PyByteArray>() {
401-
Some(bytes) => {
402-
let value_mut = &mut bytes.inner.borrow_mut().elements;
403-
match handle.write(&value_mut[..]) {
404-
Ok(len) => {
405-
//reset raw fd on the FileIO object
406-
let updated = os::raw_file_number(handle);
407-
vm.set_attr(file_io, "fileno", vm.ctx.new_int(updated))?;
408-
409-
//return number of bytes written
410-
Ok(vm.ctx.new_int(len))
411-
}
412-
Err(_) => Err(vm.new_value_error("Error Writing Bytes to Handle".to_string())),
413-
}
398+
let bytes = match_class!(obj.clone(),
399+
i @ PyBytes => Ok(i.get_value().to_vec()),
400+
j @ PyByteArray => Ok(j.inner.borrow().elements.to_vec()),
401+
obj => Err(vm.new_type_error(format!(
402+
"a bytes-like object is required, not {}",
403+
obj.class()
404+
)))
405+
);
406+
407+
match handle.write(&bytes?) {
408+
Ok(len) => {
409+
//reset raw fd on the FileIO object
410+
let updated = os::raw_file_number(handle);
411+
vm.set_attr(file_io, "fileno", vm.ctx.new_int(updated))?;
412+
413+
//return number of bytes written
414+
Ok(vm.ctx.new_int(len))
414415
}
415-
None => Err(vm.new_value_error("Expected Bytes Object".to_string())),
416+
Err(_) => Err(vm.new_value_error("Error Writing Bytes to Handle".to_string())),
416417
}
417418
}
418419

0 commit comments

Comments
 (0)