Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 111 additions & 3 deletions crates/stdlib/src/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -466,8 +466,7 @@ mod _csv {
QuoteStyle::All => Self::Always,
QuoteStyle::Nonnumeric => Self::NonNumeric,
QuoteStyle::None => Self::Never,
QuoteStyle::Strings => todo!(),
QuoteStyle::Notnull => todo!(),
QuoteStyle::Strings | QuoteStyle::Notnull => Self::Necessary,
}
}
}
Expand Down Expand Up @@ -661,7 +660,10 @@ mod _csv {
|_| { vm.new_type_error(r#""quotechar" must be a 1-character string"#) }
)?)),
PyNone => {
if let Some(QuoteStyle::All) = res.quoting {
if res
.quoting
.is_some_and(|quoting| quoting != QuoteStyle::None)
{
return Err(ArgumentError::Exception(
vm.new_type_error("quotechar must be set if quoting enabled"),
));
Expand Down Expand Up @@ -1114,15 +1116,121 @@ mod _csv {
}
}

fn write_quoted_field(
output: &mut Vec<u8>,
data: &[u8],
dialect: PyDialect,
vm: &VirtualMachine,
) -> PyResult<()> {
let quotechar = dialect
.quotechar
.ok_or_else(|| vm.new_type_error("quotechar must be set if quoting enabled"))?;
output.push(quotechar);
for &byte in data {
if byte == quotechar {
if dialect.doublequote {
output.push(quotechar);
output.push(quotechar);
} else if let Some(escapechar) = dialect.escapechar {
output.push(escapechar);
output.push(byte);
} else {
return Err(new_csv_error(vm, "need to escape, but no escapechar set"));
}
} else {
if dialect.escapechar == Some(byte) {
output.push(byte);
}
output.push(byte);
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}
}
output.push(quotechar);
Ok(())
}

fn field_needs_quotes(data: &[u8], dialect: PyDialect) -> bool {
data.iter().any(|&byte| {
byte == dialect.delimiter
|| dialect.quotechar == Some(byte)
|| matches!(byte, b'\r' | b'\n')
|| matches!(dialect.lineterminator, Terminator::Any(t) if byte == t)
})
}

fn write_lineterminator(output: &mut Vec<u8>, terminator: Terminator) {
match terminator {
Terminator::CRLF => output.extend_from_slice(b"\r\n"),
Terminator::Any(byte) => output.push(byte),
_ => unreachable!(),
}
}

#[pyclass(flags(DISALLOW_INSTANTIATION))]
impl Writer {
#[pygetset(name = "dialect")]
const fn get_dialect(&self, _vm: &VirtualMachine) -> PyDialect {
self.dialect
}

fn writerow_quoted_strings(&self, row: PyObjectRef, vm: &VirtualMachine) -> PyResult {
let _state = self.state.lock();
let row: ArgIterable = ArgIterable::try_from_object(vm, row.clone()).map_err(|_e| {
new_csv_error(
vm,
format!("'{}' object is not iterable", row.class().name()),
)
})?;
let fields = row.iter(vm)?.collect::<PyResult<Vec<_>>>()?;
let single_field = fields.len() == 1;
let mut output = Vec::new();

for (index, field) in fields.into_iter().enumerate() {
if index > 0 {
output.push(self.dialect.delimiter);
}

let stringified;
let (data, is_str, is_none): (&[u8], bool, bool) = match_class!(match field {
ref s @ PyStr => (s.as_bytes(), true, false),
crate::builtins::PyNone => (b"", false, true),
ref obj => {
stringified = obj.str(vm)?;
(stringified.as_bytes(), false, false)
}
});

let should_quote = match self.dialect.quoting {
QuoteStyle::Strings => is_str || field_needs_quotes(data, self.dialect),
QuoteStyle::Notnull => !is_none,
_ => unreachable!(),
};
if should_quote {
write_quoted_field(&mut output, data, self.dialect, vm)?;
} else if single_field && data.is_empty() {
return Err(new_csv_error(
vm,
"single empty field record must be quoted",
));
} else {
output.extend_from_slice(data);
}
}

write_lineterminator(&mut output, self.dialect.lineterminator);
let s = core::str::from_utf8(&output)
.map_err(|_| vm.new_unicode_decode_error("csv not utf8"))?;
self.write.call((s,), vm)
}

#[pymethod]
fn writerow(&self, row: PyObjectRef, vm: &VirtualMachine) -> PyResult {
if matches!(
self.dialect.quoting,
QuoteStyle::Strings | QuoteStyle::Notnull
) {
return self.writerow_quoted_strings(row, vm);
}

let mut state = self.state.lock();
let WriteState { buffer, writer } = &mut *state;

Expand Down
25 changes: 25 additions & 0 deletions extra_tests/snippets/stdlib_csv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import csv
import io

from testutils import assert_raises

Expand Down Expand Up @@ -47,3 +48,27 @@ def test_delim():


test_delim()


def test_quote_strings_and_notnull_writer():
string_buf = io.StringIO()
csv.writer(string_buf, quoting=csv.QUOTE_STRINGS).writerow(["x", 1, None, ""])
assert string_buf.getvalue() == '"x",1,,""\r\n'

notnull_buf = io.StringIO()
csv.writer(notnull_buf, quoting=csv.QUOTE_NOTNULL).writerow(["x", 1, None, ""])
assert notnull_buf.getvalue() == '"x","1",,""\r\n'

for quoting in (csv.QUOTE_STRINGS, csv.QUOTE_NOTNULL):
buf = io.StringIO()
csv.writer(buf, quoting=quoting).writerow([None, None])
assert buf.getvalue() == ",\r\n"

with assert_raises(csv.Error):
csv.writer(io.StringIO(), quoting=quoting).writerow([None])

with assert_raises(TypeError):
csv.writer(io.StringIO(), quoting=quoting, quotechar=None)


test_quote_strings_and_notnull_writer()
Loading