From 72c0270d44eff4a938a9b3ae5a8eed07c11b8209 Mon Sep 17 00:00:00 2001 From: ShaharNaveh <50263213+ShaharNaveh@users.noreply.github.com> Date: Mon, 11 May 2026 12:18:52 +0300 Subject: [PATCH 1/2] Code nits in `csv.rs` --- crates/stdlib/src/csv.rs | 280 +++++++++++++++++++++++---------------- 1 file changed, 166 insertions(+), 114 deletions(-) diff --git a/crates/stdlib/src/csv.rs b/crates/stdlib/src/csv.rs index 5c2d2662cc5..fe35130698f 100644 --- a/crates/stdlib/src/csv.rs +++ b/crates/stdlib/src/csv.rs @@ -22,16 +22,22 @@ mod _csv { #[pyattr] const QUOTE_MINIMAL: i32 = QuoteStyle::Minimal as i32; + #[pyattr] const QUOTE_ALL: i32 = QuoteStyle::All as i32; + #[pyattr] const QUOTE_NONNUMERIC: i32 = QuoteStyle::Nonnumeric as i32; + #[pyattr] const QUOTE_NONE: i32 = QuoteStyle::None as i32; + #[pyattr] const QUOTE_STRINGS: i32 = QuoteStyle::Strings as i32; + #[pyattr] const QUOTE_NOTNULL: i32 = QuoteStyle::Notnull as i32; + #[pyattr(name = "__version__")] const __VERSION__: &str = "1.0"; @@ -67,6 +73,7 @@ mod _csv { quoting: QuoteStyle, strict: bool, } + impl Constructor for PyDialect { type Args = PyObjectRef; @@ -74,24 +81,29 @@ mod _csv { Self::try_from_object(vm, ctx) } } + #[pyclass(with(Constructor))] impl PyDialect { #[pygetset] fn delimiter(&self, vm: &VirtualMachine) -> PyRef { vm.ctx.new_str(format!("{}", self.delimiter as char)) } + #[pygetset] fn quotechar(&self, vm: &VirtualMachine) -> Option> { Some(vm.ctx.new_str(format!("{}", self.quotechar? as char))) } + #[pygetset] const fn doublequote(&self) -> bool { self.doublequote } + #[pygetset] const fn skipinitialspace(&self) -> bool { self.skipinitialspace } + #[pygetset] fn lineterminator(&self, vm: &VirtualMachine) -> PyRef { match self.lineterminator { @@ -100,19 +112,23 @@ mod _csv { _ => unreachable!(), } } + #[pygetset] fn quoting(&self) -> isize { self.quoting.into() } + #[pygetset] fn escapechar(&self, vm: &VirtualMachine) -> Option> { Some(vm.ctx.new_str(format!("{}", self.escapechar? as char))) } + #[pygetset(name = "strict")] const fn get_strict(&self) -> bool { self.strict } } + /// Parses the delimiter from a Python object and returns its ASCII value. /// /// This function attempts to extract the 'delimiter' attribute from the given Python object and ensures that the attribute is a single-character string. If successful, it returns the ASCII value of the character. If the attribute is not a single-character string, an error is returned. @@ -146,11 +162,10 @@ mod _csv { })?) } attr => { - let msg = format!( + Err(vm.new_type_error(format!( r#""delimiter" must be a unicode character, not {}"#, - attr.class() - ); - Err(vm.new_type_error(msg)) + attr.class().name() + ))) } }) } @@ -171,12 +186,13 @@ mod _csv { vm, format!( r#""quotechar" must be a unicode character or None, not {}"#, - attr.class() + attr.class().name() ), )) } }) } + fn parse_escapechar_from_obj(vm: &VirtualMachine, obj: &PyObject) -> PyResult> { match_class!(match obj.get_attr("escapechar", vm)? { s @ PyStr => { @@ -191,14 +207,14 @@ mod _csv { Ok(None) } attr => { - let msg = format!( + Err(vm.new_type_error(format!( r#""escapechar" must be a unicode character or None, not {}"#, - attr.class() - ); - Err(vm.new_type_error(msg)) + attr.class().name() + ))) } }) } + fn prase_lineterminator_from_obj(vm: &VirtualMachine, obj: &PyObject) -> PyResult { match_class!(match obj.get_attr("lineterminator", vm)? { s @ PyStr => { @@ -216,25 +232,28 @@ mod _csv { attr => { Err(vm.new_type_error(format!( r#""lineterminator" must be a string, not {}"#, - attr.class() + attr.class().name() ))) } }) } + fn prase_quoting_from_obj(vm: &VirtualMachine, obj: &PyObject) -> PyResult { match_class!(match obj.get_attr("quoting", vm)? { i @ PyInt => { - Ok(i.try_to_primitive::(vm)?.try_into().map_err(|_| { - let msg = r#"bad "quoting" value"#; - vm.new_type_error(msg.to_owned()) - })?) + Ok(i.try_to_primitive::(vm)? + .try_into() + .map_err(|_| vm.new_type_error(r#"bad "quoting" value"#))?) } attr => { - let msg = format!(r#""quoting" must be string or None, not {}"#, attr.class()); - Err(vm.new_type_error(msg)) + Err(vm.new_type_error(format!( + r#""quoting" must be string or None, not {}"#, + attr.class().name() + ))) } }) } + impl TryFromObject for PyDialect { fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { let delimiter = parse_delimiter_from_obj(vm, &obj)?; @@ -244,6 +263,7 @@ mod _csv { let skipinitialspace = obj.get_attr("skipinitialspace", vm)?.try_to_bool(vm)?; let lineterminator = prase_lineterminator_from_obj(vm, &obj)?; let quoting = prase_quoting_from_obj(vm, &obj)?; + let strict = if let Ok(t) = obj.get_attr("strict", vm) { t.try_to_bool(vm).unwrap_or(false) } else { @@ -275,16 +295,20 @@ mod _csv { let name = name .downcast::() .map_err(|_| vm.new_type_error("argument 0 must be a string"))?; + let name: PyUtf8StrRef = name.try_into_utf8(vm)?; + let dialect = match dialect { OptionalArg::Present(d) => PyDialect::try_from_object(vm, d) .map_err(|_| vm.new_type_error("argument 1 must be a dialect object"))?, OptionalArg::Missing => opts.result(vm)?, }; + let dialect = opts.update_py_dialect(dialect); GLOBAL_HASHMAP .lock() .insert(name.as_str().to_owned(), dialect); + Ok(()) } @@ -297,14 +321,17 @@ mod _csv { let name = name.downcast::().map_err(|obj| { new_csv_error( vm, - format!("argument 0 must be a string, not '{}'", obj.class()), + format!("argument 0 must be a string, not '{}'", obj.class().name()), ) })?; + let name: PyUtf8StrRef = name.try_into_utf8(vm)?; let g = GLOBAL_HASHMAP.lock(); + if let Some(dialect) = g.get(name.as_str()) { return Ok(*dialect); } + Err(new_csv_error(vm, "unknown dialect")) } @@ -317,14 +344,17 @@ mod _csv { let name = name.downcast::().map_err(|obj| { new_csv_error( vm, - format!("argument 0 must be a string, not '{}'", obj.class()), + format!("argument 0 must be a string, not '{}'", obj.class().name()), ) })?; + let name: PyUtf8StrRef = name.try_into_utf8(vm)?; let mut g = GLOBAL_HASHMAP.lock(); + if let Some(_removed) = g.remove(name.as_str()) { return Ok(()); } + Err(new_csv_error(vm, "unknown dialect")) } @@ -398,7 +428,7 @@ mod _csv { Some(write_meth) => write_meth, None if file.is_callable() => file, None => { - return Err(vm.new_type_error("argument 1 must have a \"write\" method")); + return Err(vm.new_type_error(r#"argument 1 must have a "write" method"#)); } }; @@ -419,7 +449,7 @@ mod _csv { } #[repr(i32)] - #[derive(Debug, Clone, Copy)] + #[derive(Debug, Clone, Copy, Eq, PartialEq)] pub enum QuoteStyle { Minimal = 0, All = 1, @@ -428,6 +458,7 @@ mod _csv { Strings = 4, Notnull = 5, } + impl From for csv_core::QuoteStyle { fn from(val: QuoteStyle) -> Self { match val { @@ -440,6 +471,7 @@ mod _csv { } } } + impl TryFromObject for QuoteStyle { fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { let num = obj.try_int(vm)?.try_to_primitive::(vm)?; @@ -448,20 +480,23 @@ mod _csv { }) } } + impl TryFrom for QuoteStyle { type Error = (); + fn try_from(num: isize) -> Result { - match num { - 0 => Ok(Self::Minimal), - 1 => Ok(Self::All), - 2 => Ok(Self::Nonnumeric), - 3 => Ok(Self::None), - 4 => Ok(Self::Strings), - 5 => Ok(Self::Notnull), - _ => Err(()), - } + Ok(match num { + 0 => Self::Minimal, + 1 => Self::All, + 2 => Self::Nonnumeric, + 3 => Self::None, + 4 => Self::Strings, + 5 => Self::Notnull, + _ => return Err(()), + }) } } + impl From for isize { fn from(val: QuoteStyle) -> Self { match val { @@ -475,12 +510,15 @@ mod _csv { } } + #[derive(Default)] enum DialectItem { Str(String), Obj(PyDialect), + #[default] None, } + #[derive(Default)] struct FormatOptions { dialect: DialectItem, delimiter: Option, @@ -492,21 +530,7 @@ mod _csv { quoting: Option, strict: Option, } - impl Default for FormatOptions { - fn default() -> Self { - Self { - dialect: DialectItem::None, - delimiter: None, - quotechar: None, - escapechar: None, - doublequote: None, - skipinitialspace: None, - lineterminator: None, - quoting: None, - strict: None, - } - } - } + /// prase a dialect item from a Python argument and returns a `DialectItem` or an `ArgumentError`. /// /// This function takes a reference to the VirtualMachine and a PyObjectRef as input and attempts to parse a dialect item from the provided Python argument. It returns a `DialectItem` if successful, or an `ArgumentError` if unsuccessful. @@ -555,8 +579,7 @@ mod _csv { if let Ok(cur_dialect_item) = PyDialect::try_from_object(vm, obj) { Ok(DialectItem::Obj(cur_dialect_item)) } else { - let msg = "dialect".to_string(); - Err(ArgumentError::InvalidKeywordArgument(msg)) + Err(ArgumentError::InvalidKeywordArgument("dialect".to_string())) } } }) @@ -565,12 +588,12 @@ mod _csv { impl FromArgs for FormatOptions { fn from_args(vm: &VirtualMachine, args: &mut FuncArgs) -> Result { let mut res = Self::default(); - if let Some(dialect) = args.kwargs.swap_remove("dialect") { - res.dialect = prase_dialect_item_from_arg(vm, dialect)?; + res.dialect = if let Some(dialect) = args.kwargs.swap_remove("dialect") { + prase_dialect_item_from_arg(vm, dialect)? } else if let Some(dialect) = args.args.first() { - res.dialect = prase_dialect_item_from_arg(vm, dialect.clone())?; + prase_dialect_item_from_arg(vm, dialect.clone())? } else { - res.dialect = DialectItem::None; + DialectItem::None }; if let Some(delimiter) = args.kwargs.swap_remove("delimiter") { @@ -581,12 +604,12 @@ mod _csv { res.escapechar = match_class!(match escapechar { s @ PyStr => Some(s.as_bytes().iter().copied().exactly_one().map_err(|_| { - let msg = r#""escapechar" must be a 1-character string"#; - vm.new_type_error(msg.to_owned()) + vm.new_type_error(r#""escapechar" must be a 1-character string"#) })?), _ => None, }) }; + if let Some(lineterminator) = args.kwargs.swap_remove("lineterminator") { res.lineterminator = Some(csv_core::Terminator::Any( lineterminator @@ -594,23 +617,27 @@ mod _csv { .bytes() .exactly_one() .map_err(|_| { - let msg = r#""lineterminator" must be a 1-character string"#; - vm.new_type_error(msg.to_owned()) + vm.new_type_error(r#""lineterminator" must be a 1-character string"#) })?, )) }; + if let Some(doublequote) = args.kwargs.swap_remove("doublequote") { - res.doublequote = Some(doublequote.try_to_bool(vm).map_err(|_| { - let msg = r#""doublequote" must be a bool"#; - vm.new_type_error(msg.to_owned()) - })?) + res.doublequote = Some( + doublequote + .try_to_bool(vm) + .map_err(|_| vm.new_type_error(r#""doublequote" must be a bool"#))?, + ) }; + if let Some(skipinitialspace) = args.kwargs.swap_remove("skipinitialspace") { - res.skipinitialspace = Some(skipinitialspace.try_to_bool(vm).map_err(|_| { - let msg = r#""skipinitialspace" must be a bool"#; - vm.new_type_error(msg.to_owned()) - })?) + res.skipinitialspace = Some( + skipinitialspace + .try_to_bool(vm) + .map_err(|_| vm.new_type_error(r#""skipinitialspace" must be a bool"#))?, + ) }; + if let Some(quoting) = args.kwargs.swap_remove("quoting") { res.quoting = match_class!(match quoting { i @ PyInt => @@ -623,47 +650,47 @@ mod _csv { } }); }; + if let Some(quotechar) = args.kwargs.swap_remove("quotechar") { res.quotechar = match_class!(match quotechar { s @ PyStr => Some(Some(s.as_bytes().iter().copied().exactly_one().map_err( - |_| { - let msg = r#""quotechar" must be a 1-character string"#; - vm.new_type_error(msg.to_owned()) - } + |_| { vm.new_type_error(r#""quotechar" must be a 1-character string"#) } )?)), PyNone => { if let Some(QuoteStyle::All) = res.quoting { - let msg = "quotechar must be set if quoting enabled"; return Err(ArgumentError::Exception( - vm.new_type_error(msg.to_owned()), + vm.new_type_error("quotechar must be set if quoting enabled"), )); } Some(None) } _o => { - let msg = r#"quotechar"#; return Err( rustpython_vm::function::ArgumentError::InvalidKeywordArgument( - msg.to_string(), + "quotechar".to_string(), ), ); } }) }; + if let Some(strict) = args.kwargs.swap_remove("strict") { - res.strict = Some(strict.try_to_bool(vm).map_err(|_| { - let msg = r#""strict" must be a int enum"#; - vm.new_type_error(msg.to_owned()) - })?) + res.strict = Some( + strict + .try_to_bool(vm) + .map_err(|_| vm.new_type_error(r#""strict" must be a int enum"#))?, + ) }; if let Some(last_arg) = args.kwargs.pop() { - let msg = format!( - r#"'{}' is an invalid keyword argument for this function"#, - last_arg.0 + return Err( + rustpython_vm::function::ArgumentError::InvalidKeywordArgument(format!( + "'{}' is an invalid keyword argument for this function", + last_arg.0 + )), ); - return Err(rustpython_vm::function::ArgumentError::InvalidKeywordArgument(msg)); } + Ok(res) } } @@ -677,21 +704,21 @@ mod _csv { } }}; } + check_and_fill!(res, delimiter); // check_and_fill!(res, quotechar); check_and_fill!(res, delimiter); check_and_fill!(res, doublequote); check_and_fill!(res, skipinitialspace); + if let Some(t) = self.escapechar { res.escapechar = Some(t); }; + if let Some(t) = self.quotechar { - if let Some(u) = t { - res.quotechar = Some(u); - } else { - res.quotechar = None; - } + res.quotechar = t; }; + check_and_fill!(res, quoting); check_and_fill!(res, lineterminator); check_and_fill!(res, strict); @@ -707,8 +734,7 @@ mod _csv { } else { Err(new_csv_error(vm, format!("{name} is not registered."))) } - // TODO - // Maybe need to update the obj from HashMap + // TODO: Maybe need to update the obj from HashMap } DialectItem::Obj(o) => Ok(self.update_py_dialect(*o)), DialectItem::None => { @@ -718,14 +744,14 @@ mod _csv { } } } + fn get_skipinitialspace(&self) -> bool { let mut skipinitialspace = match &self.dialect { DialectItem::Str(name) => { let g = GLOBAL_HASHMAP.lock(); if let Some(dialect) = g.get(name) { dialect.skipinitialspace - // RustPython todo - // todo! Perfecting the remaining attributes. + // TODO: RUSTPYTHON; Perfecting the remaining attributes. } else { false } @@ -733,11 +759,14 @@ mod _csv { DialectItem::Obj(obj) => obj.skipinitialspace, _ => false, }; + if let Some(attr) = self.skipinitialspace { skipinitialspace = attr } + skipinitialspace } + fn get_delimiter(&self) -> u8 { let mut delimiter = match &self.dialect { DialectItem::Str(name) => { @@ -753,11 +782,14 @@ mod _csv { DialectItem::Obj(obj) => obj.delimiter, _ => b',', }; + if let Some(attr) = self.delimiter { delimiter = attr } + delimiter } + fn to_reader(&self) -> csv_core::Reader { let mut builder = csv_core::ReaderBuilder::new(); let mut reader = match &self.dialect { @@ -803,37 +835,34 @@ mod _csv { if let Some(t) = self.delimiter { reader = reader.delimiter(t); } + if let Some(t) = self.quotechar { - if let Some(u) = t { - reader = reader.quote(u); + reader = if let Some(u) = t { + reader.quote(u) } else { - reader = reader.quoting(false); + reader.quoting(false) } } else { - match self.quoting { - Some(QuoteStyle::None) => { - reader = reader.quoting(false); - } - // None => reader = reader.quoting(true), - _ => reader = reader.quoting(true), - } + reader = reader.quoting(self.quoting != Some(QuoteStyle::None)); } if let Some(t) = self.lineterminator { reader = reader.terminator(t); } + if let Some(t) = self.doublequote { reader = reader.double_quote(t); } + if self.escapechar.is_some() { reader = reader.escape(self.escapechar); } - reader = match self.lineterminator { - Some(u) => reader.terminator(u), - None => reader.terminator(Terminator::CRLF), - }; + + reader = reader.terminator(self.lineterminator.unwrap_or(Terminator::CRLF)); + reader.build() } + fn to_writer(&self) -> csv_core::Writer { let mut builder = csv_core::WriterBuilder::new(); let mut writer = match &self.dialect { @@ -844,13 +873,14 @@ mod _csv { .delimiter(dialect.delimiter) .double_quote(dialect.doublequote) .terminator(dialect.lineterminator); + if let Some(t) = dialect.quotechar { builder = builder.quote(t); } + builder - // RustPython todo - // todo! Perfecting the remaining attributes. + // TODO: RUSTPYTHON; Perfecting the remaining attributes. } else { &mut builder } @@ -860,16 +890,20 @@ mod _csv { .delimiter(obj.delimiter) .double_quote(obj.doublequote) .terminator(obj.lineterminator); + if let Some(t) = obj.quotechar { builder = builder.quote(t); } + builder } _ => &mut builder, }; + if let Some(t) = self.delimiter { writer = writer.delimiter(t); } + if let Some(t) = self.quotechar { if let Some(u) = t { writer = writer.quote(u); @@ -877,19 +911,21 @@ mod _csv { todo!() } } + if let Some(t) = self.doublequote { writer = writer.double_quote(t); } - writer = match self.lineterminator { - Some(u) => writer.terminator(u), - None => writer.terminator(Terminator::CRLF), - }; + + writer = writer.terminator(self.lineterminator.unwrap_or(Terminator::CRLF)); + if let Some(e) = self.escapechar { writer = writer.escape(e); } + if let Some(e) = self.quoting { writer = writer.quote_style(e.into()); } + writer.build() } } @@ -925,12 +961,15 @@ mod _csv { fn line_num(&self) -> u64 { self.state.lock().line_num } + #[pygetset] const fn dialect(&self, _vm: &VirtualMachine) -> PyDialect { self.dialect } } + impl SelfIter for Reader {} + impl IterNext for Reader { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { let string = raise_if_stop!(zelf.iter.next(vm)?); @@ -961,6 +1000,7 @@ mod _csv { let mut output_offset = 0; let mut output_ends_offset = 0; let field_limit = GLOBAL_FIELD_LIMIT.lock().to_owned(); + #[inline] fn trim_spaces(input: &[u8]) -> &[u8] { let trimmed_start = input.iter().position(|&x| x != b' ').unwrap_or(input.len()); @@ -971,6 +1011,7 @@ mod _csv { .unwrap_or(0); &input[trimmed_start..trimmed_end] } + let input = if *skipinitialspace { let t = input.split(|x| x == delimiter); t.map(|x| { @@ -981,6 +1022,7 @@ mod _csv { } else { String::from_utf8(input.to_vec()).unwrap() }; + loop { let (res, n_read, n_written, n_ends) = reader.read_record( &input.as_bytes()[input_offset..], @@ -1000,13 +1042,15 @@ mod _csv { } } } + let rest = &input.as_bytes()[input_offset..]; if !rest.iter().all(|&c| matches!(c, b'\r' | b'\n')) { return Err(new_csv_error( vm, - "new-line character seen in unquoted field - \ - do you need to open the file in universal-newline mode?" - .to_owned(), + concat!( + "new-line character seen in unquoted field", + " - do you need to open the file in universal-newline mode?" + ), )); } @@ -1016,14 +1060,15 @@ mod _csv { .map(|&end| { let range = prev_end..end; if range.len() > field_limit as usize { - return Err(new_csv_error(vm, "filed too long to read".to_string())); + return Err(new_csv_error(vm, "filed too long to read")); } + prev_end = end; let s = core::str::from_utf8(&buffer[range.clone()]) // not sure if this is possible - the input was all strings .map_err(|_e| vm.new_unicode_decode_error("csv not utf8"))?; - // Rustpython TODO! - // Incomplete implementation + + // TODO: RUSTPYTHON; Incomplete implementation if let QuoteStyle::Nonnumeric = zelf.dialect.quoting { if let Ok(t) = String::from_utf8(trim_spaces(&buffer[range]).to_vec()) .unwrap() @@ -1075,6 +1120,7 @@ mod _csv { const fn get_dialect(&self, _vm: &VirtualMachine) -> PyDialect { self.dialect } + #[pymethod] fn writerow(&self, row: PyObjectRef, vm: &VirtualMachine) -> PyResult { let mut state = self.state.lock(); @@ -1094,8 +1140,12 @@ mod _csv { } let row = ArgIterable::try_from_object(vm, row.clone()).map_err(|_e| { - new_csv_error(vm, format!("\'{}\' object is not iterable", row.class())) + new_csv_error( + vm, + format!("'{}' object is not iterable", row.class().name()), + ) })?; + let mut first_flag = true; for field in row.iter(vm)? { let field: PyObjectRef = field?; @@ -1109,6 +1159,7 @@ mod _csv { } }); let mut input_offset = 0; + if first_flag { first_flag = false; } else { @@ -1128,6 +1179,7 @@ mod _csv { loop { handle_res!(writer.terminator(&mut buffer[buffer_offset..])); } + let s = core::str::from_utf8(&buffer[..buffer_offset]) .map_err(|_| vm.new_unicode_decode_error("csv not utf8"))?; From 6f7c5aeaa3e6b56d14dbdb5832f272a684de8fb6 Mon Sep 17 00:00:00 2001 From: ShaharNaveh <50263213+ShaharNaveh@users.noreply.github.com> Date: Mon, 11 May 2026 12:46:50 +0300 Subject: [PATCH 2/2] clippy --- crates/stdlib/src/csv.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/crates/stdlib/src/csv.rs b/crates/stdlib/src/csv.rs index fe35130698f..43e056fb4c8 100644 --- a/crates/stdlib/src/csv.rs +++ b/crates/stdlib/src/csv.rs @@ -587,8 +587,7 @@ mod _csv { impl FromArgs for FormatOptions { fn from_args(vm: &VirtualMachine, args: &mut FuncArgs) -> Result { - let mut res = Self::default(); - res.dialect = if let Some(dialect) = args.kwargs.swap_remove("dialect") { + let dialect = if let Some(dialect) = args.kwargs.swap_remove("dialect") { prase_dialect_item_from_arg(vm, dialect)? } else if let Some(dialect) = args.args.first() { prase_dialect_item_from_arg(vm, dialect.clone())? @@ -596,6 +595,11 @@ mod _csv { DialectItem::None }; + let mut res = Self { + dialect, + ..Default::default() + }; + if let Some(delimiter) = args.kwargs.swap_remove("delimiter") { res.delimiter = Some(parse_delimiter_from_obj(vm, &delimiter)?); }