22 * I/O core tools.
33 */
44use std:: cell:: RefCell ;
5- use std:: fs:: File ;
65use std:: io:: prelude:: * ;
7- use std:: io:: BufReader ;
86use std:: io:: Cursor ;
97use std:: io:: SeekFrom ;
108
@@ -18,9 +16,12 @@ use crate::function::{OptionalArg, PyFuncArgs};
1816use crate :: import;
1917use crate :: obj:: objbytearray:: PyByteArray ;
2018use crate :: obj:: objbytes;
19+ use crate :: obj:: objbytes:: PyBytes ;
2120use crate :: obj:: objint;
2221use crate :: obj:: objstr;
22+ use crate :: obj:: objtype;
2323use crate :: obj:: objtype:: PyClassRef ;
24+ use crate :: pyobject:: TypeProtocol ;
2425use crate :: pyobject:: { BufferProtocol , PyObjectRef , PyRef , PyResult , PyValue } ;
2526use crate :: vm:: VirtualMachine ;
2627
@@ -284,16 +285,19 @@ fn buffered_reader_read(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
284285}
285286
286287fn compute_c_flag ( mode : & str ) -> u32 {
287- let flags = match mode {
288- "w" => os:: FileCreationFlags :: O_WRONLY | os:: FileCreationFlags :: O_CREAT ,
289- "x" => {
290- os:: FileCreationFlags :: O_WRONLY
291- | os:: FileCreationFlags :: O_CREAT
292- | os:: FileCreationFlags :: O_EXCL
293- }
294- "a" => os:: FileCreationFlags :: O_APPEND ,
295- "+" => os:: FileCreationFlags :: O_RDWR ,
296- _ => os:: FileCreationFlags :: O_RDONLY ,
288+ let flags = match mode. chars ( ) . next ( ) {
289+ Some ( mode) => match mode {
290+ 'w' => os:: FileCreationFlags :: O_WRONLY | os:: FileCreationFlags :: O_CREAT ,
291+ 'x' => {
292+ os:: FileCreationFlags :: O_WRONLY
293+ | os:: FileCreationFlags :: O_CREAT
294+ | os:: FileCreationFlags :: O_EXCL
295+ }
296+ 'a' => os:: FileCreationFlags :: O_APPEND ,
297+ '+' => os:: FileCreationFlags :: O_RDWR ,
298+ _ => os:: FileCreationFlags :: O_RDONLY ,
299+ } ,
300+ None => os:: FileCreationFlags :: O_RDONLY ,
297301 } ;
298302 flags. bits ( )
299303}
@@ -302,47 +306,43 @@ fn file_io_init(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
302306 arg_check ! (
303307 vm,
304308 args,
305- required = [ ( file_io, None ) , ( name, Some ( vm . ctx . str_type ( ) ) ) ] ,
309+ required = [ ( file_io, None ) , ( name, None ) ] ,
306310 optional = [ ( mode, Some ( vm. ctx. str_type( ) ) ) ]
307311 ) ;
308312
309- let rust_mode = mode. map_or ( "r" . to_string ( ) , objstr:: get_value) ;
310-
311- match compute_c_flag ( & rust_mode) . to_bigint ( ) {
312- Some ( os_mode) => {
313- let args = vec ! [ name. clone( ) , vm. ctx. new_int( os_mode) ] ;
314- let file_no = os:: os_open ( vm, PyFuncArgs :: new ( args, vec ! [ ] ) ) ?;
315-
316- vm. set_attr ( file_io, "name" , name. clone ( ) ) ?;
317- vm. set_attr ( file_io, "fileno" , file_no) ?;
318- vm. set_attr ( file_io, "closefd" , vm. new_bool ( false ) ) ?;
319- vm. set_attr ( file_io, "closed" , vm. new_bool ( false ) ) ?;
313+ let file_no = if objtype:: isinstance ( & name, & vm. ctx . str_type ( ) ) {
314+ let rust_mode = mode. map_or ( "r" . to_string ( ) , objstr:: get_value) ;
315+ let args = vec ! [
316+ name. clone( ) ,
317+ vm. ctx
318+ . new_int( compute_c_flag( & rust_mode) . to_bigint( ) . unwrap( ) ) ,
319+ ] ;
320+ os:: os_open ( vm, PyFuncArgs :: new ( args, vec ! [ ] ) ) ?
321+ } else if objtype:: isinstance ( & name, & vm. ctx . int_type ( ) ) {
322+ name. clone ( )
323+ } else {
324+ return Err ( vm. new_type_error ( "name parameter must be string or int" . to_string ( ) ) ) ;
325+ } ;
320326
321- Ok ( vm. get_none ( ) )
322- }
323- None => Err ( vm. new_type_error ( format ! ( "invalid mode {}" , rust_mode) ) ) ,
324- }
327+ vm. set_attr ( file_io, "name" , name. clone ( ) ) ?;
328+ vm. set_attr ( file_io, "fileno" , file_no) ?;
329+ vm. set_attr ( file_io, "closefd" , vm. new_bool ( false ) ) ?;
330+ vm. set_attr ( file_io, "closed" , vm. new_bool ( false ) ) ?;
331+ Ok ( vm. get_none ( ) )
325332}
326333
327334fn file_io_read ( vm : & VirtualMachine , args : PyFuncArgs ) -> PyResult {
328335 arg_check ! ( vm, args, required = [ ( file_io, None ) ] ) ;
329- let py_name = vm. get_attribute ( file_io. clone ( ) , "name" ) ?;
330- let f = match File :: open ( objstr:: get_value ( & py_name) ) {
331- Ok ( v) => Ok ( v) ,
332- Err ( _) => Err ( vm. new_type_error ( "Error opening file" . to_string ( ) ) ) ,
333- } ;
334336
335- let buffer = match f {
336- Ok ( v ) => Ok ( BufReader :: new ( v ) ) ,
337- Err ( _ ) => Err ( vm . new_type_error ( "Error reading from file" . to_string ( ) ) ) ,
338- } ;
337+ let file_no = vm . get_attribute ( file_io . clone ( ) , "fileno" ) ? ;
338+ let raw_fd = objint :: get_value ( & file_no ) . to_i64 ( ) . unwrap ( ) ;
339+
340+ let mut handle = os :: rust_file ( raw_fd ) ;
339341
340342 let mut bytes = vec ! [ ] ;
341- if let Ok ( mut buff) = buffer {
342- match buff. read_to_end ( & mut bytes) {
343- Ok ( _) => { }
344- Err ( _) => return Err ( vm. new_value_error ( "Error reading from Buffer" . to_string ( ) ) ) ,
345- }
343+ match handle. read_to_end ( & mut bytes) {
344+ Ok ( _) => { }
345+ Err ( _) => return Err ( vm. new_value_error ( "Error reading from Buffer" . to_string ( ) ) ) ,
346346 }
347347
348348 Ok ( vm. ctx . new_bytes ( bytes) )
@@ -385,11 +385,7 @@ fn file_io_readinto(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
385385}
386386
387387fn file_io_write ( vm : & VirtualMachine , args : PyFuncArgs ) -> PyResult {
388- arg_check ! (
389- vm,
390- args,
391- required = [ ( file_io, None ) , ( obj, Some ( vm. ctx. bytes_type( ) ) ) ]
392- ) ;
388+ arg_check ! ( vm, args, required = [ ( file_io, None ) , ( obj, None ) ] ) ;
393389
394390 let file_no = vm. get_attribute ( file_io. clone ( ) , "fileno" ) ?;
395391 let raw_fd = objint:: get_value ( & file_no) . to_i64 ( ) . unwrap ( ) ;
@@ -399,22 +395,25 @@ fn file_io_write(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
399395 //to support windows - i.e. raw file_handles
400396 let mut handle = os:: rust_file ( raw_fd) ;
401397
402- match obj. payload :: < PyByteArray > ( ) {
403- Some ( bytes) => {
404- let value_mut = & mut bytes. inner . borrow_mut ( ) . elements ;
405- match handle. write ( & value_mut[ ..] ) {
406- Ok ( len) => {
407- //reset raw fd on the FileIO object
408- let updated = os:: raw_file_number ( handle) ;
409- vm. set_attr ( file_io, "fileno" , vm. ctx . new_int ( updated) ) ?;
410-
411- //return number of bytes written
412- Ok ( vm. ctx . new_int ( len) )
413- }
414- Err ( _) => Err ( vm. new_value_error ( "Error Writing Bytes to Handle" . to_string ( ) ) ) ,
415- }
398+ let bytes = match_class ! ( obj. clone( ) ,
399+ i @ PyBytes => Ok ( i. get_value( ) . to_vec( ) ) ,
400+ j @ PyByteArray => Ok ( j. inner. borrow( ) . elements. to_vec( ) ) ,
401+ obj => Err ( vm. new_type_error( format!(
402+ "a bytes-like object is required, not {}" ,
403+ obj. class( )
404+ ) ) )
405+ ) ;
406+
407+ match handle. write ( & bytes?) {
408+ Ok ( len) => {
409+ //reset raw fd on the FileIO object
410+ let updated = os:: raw_file_number ( handle) ;
411+ vm. set_attr ( file_io, "fileno" , vm. ctx . new_int ( updated) ) ?;
412+
413+ //return number of bytes written
414+ Ok ( vm. ctx . new_int ( len) )
416415 }
417- None => Err ( vm. new_value_error ( "Expected Bytes Object " . to_string ( ) ) ) ,
416+ Err ( _ ) => Err ( vm. new_value_error ( "Error Writing Bytes to Handle " . to_string ( ) ) ) ,
418417 }
419418}
420419
0 commit comments