diff --git a/crates/stdlib/src/csv.rs b/crates/stdlib/src/csv.rs index f728002768d..bd80cf7eccf 100644 --- a/crates/stdlib/src/csv.rs +++ b/crates/stdlib/src/csv.rs @@ -466,8 +466,7 @@ mod _csv { QuoteStyle::All => Self::Always, QuoteStyle::Nonnumeric => Self::NonNumeric, QuoteStyle::None => Self::Never, - QuoteStyle::Strings => todo!(), - QuoteStyle::Notnull => todo!(), + QuoteStyle::Strings | QuoteStyle::Notnull => Self::Necessary, } } } @@ -661,7 +660,10 @@ mod _csv { |_| { vm.new_type_error(r#""quotechar" must be a 1-character string"#) } )?)), PyNone => { - if let Some(QuoteStyle::All) = res.quoting { + if res + .quoting + .is_some_and(|quoting| quoting != QuoteStyle::None) + { return Err(ArgumentError::Exception( vm.new_type_error("quotechar must be set if quoting enabled"), )); @@ -1114,6 +1116,55 @@ mod _csv { } } + fn write_quoted_field( + output: &mut Vec, + data: &[u8], + dialect: PyDialect, + vm: &VirtualMachine, + ) -> PyResult<()> { + let quotechar = dialect + .quotechar + .ok_or_else(|| vm.new_type_error("quotechar must be set if quoting enabled"))?; + output.push(quotechar); + for &byte in data { + if byte == quotechar { + if dialect.doublequote { + output.push(quotechar); + output.push(quotechar); + } else if let Some(escapechar) = dialect.escapechar { + output.push(escapechar); + output.push(byte); + } else { + return Err(new_csv_error(vm, "need to escape, but no escapechar set")); + } + } else { + if dialect.escapechar == Some(byte) { + output.push(byte); + } + output.push(byte); + } + } + output.push(quotechar); + Ok(()) + } + + fn field_needs_quotes(data: &[u8], dialect: PyDialect) -> bool { + data.iter().any(|&byte| { + byte == dialect.delimiter + || dialect.quotechar == Some(byte) + || matches!(byte, b'\r' | b'\n') + || matches!(dialect.lineterminator, Terminator::Any(t) if byte == t) + }) + } + + fn write_lineterminator(output: &mut Vec, terminator: Terminator) { + match terminator { + Terminator::CRLF => output.extend_from_slice(b"\r\n"), + Terminator::Any(byte) => output.push(byte), + _ => unreachable!(), + } + } + #[pyclass(flags(DISALLOW_INSTANTIATION))] impl Writer { #[pygetset(name = "dialect")] @@ -1121,8 +1172,65 @@ mod _csv { self.dialect } + fn writerow_quoted_strings(&self, row: PyObjectRef, vm: &VirtualMachine) -> PyResult { + let _state = self.state.lock(); + let row: ArgIterable = ArgIterable::try_from_object(vm, row.clone()).map_err(|_e| { + new_csv_error( + vm, + format!("'{}' object is not iterable", row.class().name()), + ) + })?; + let fields = row.iter(vm)?.collect::>>()?; + let single_field = fields.len() == 1; + let mut output = Vec::new(); + + for (index, field) in fields.into_iter().enumerate() { + if index > 0 { + output.push(self.dialect.delimiter); + } + + let stringified; + let (data, is_str, is_none): (&[u8], bool, bool) = match_class!(match field { + ref s @ PyStr => (s.as_bytes(), true, false), + crate::builtins::PyNone => (b"", false, true), + ref obj => { + stringified = obj.str(vm)?; + (stringified.as_bytes(), false, false) + } + }); + + let should_quote = match self.dialect.quoting { + QuoteStyle::Strings => is_str || field_needs_quotes(data, self.dialect), + QuoteStyle::Notnull => !is_none, + _ => unreachable!(), + }; + if should_quote { + write_quoted_field(&mut output, data, self.dialect, vm)?; + } else if single_field && data.is_empty() { + return Err(new_csv_error( + vm, + "single empty field record must be quoted", + )); + } else { + output.extend_from_slice(data); + } + } + + write_lineterminator(&mut output, self.dialect.lineterminator); + let s = core::str::from_utf8(&output) + .map_err(|_| vm.new_unicode_decode_error("csv not utf8"))?; + self.write.call((s,), vm) + } + #[pymethod] fn writerow(&self, row: PyObjectRef, vm: &VirtualMachine) -> PyResult { + if matches!( + self.dialect.quoting, + QuoteStyle::Strings | QuoteStyle::Notnull + ) { + return self.writerow_quoted_strings(row, vm); + } + let mut state = self.state.lock(); let WriteState { buffer, writer } = &mut *state; diff --git a/extra_tests/snippets/stdlib_csv.py b/extra_tests/snippets/stdlib_csv.py index eb3461e9082..aa7b41223b6 100644 --- a/extra_tests/snippets/stdlib_csv.py +++ b/extra_tests/snippets/stdlib_csv.py @@ -1,4 +1,5 @@ import csv +import io from testutils import assert_raises @@ -47,3 +48,27 @@ def test_delim(): test_delim() + + +def test_quote_strings_and_notnull_writer(): + string_buf = io.StringIO() + csv.writer(string_buf, quoting=csv.QUOTE_STRINGS).writerow(["x", 1, None, ""]) + assert string_buf.getvalue() == '"x",1,,""\r\n' + + notnull_buf = io.StringIO() + csv.writer(notnull_buf, quoting=csv.QUOTE_NOTNULL).writerow(["x", 1, None, ""]) + assert notnull_buf.getvalue() == '"x","1",,""\r\n' + + for quoting in (csv.QUOTE_STRINGS, csv.QUOTE_NOTNULL): + buf = io.StringIO() + csv.writer(buf, quoting=quoting).writerow([None, None]) + assert buf.getvalue() == ",\r\n" + + with assert_raises(csv.Error): + csv.writer(io.StringIO(), quoting=quoting).writerow([None]) + + with assert_raises(TypeError): + csv.writer(io.StringIO(), quoting=quoting, quotechar=None) + + +test_quote_strings_and_notnull_writer()