Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
gh-91810: ElementTree: Use text file's encoding by default in XML dec…
…laration

ElementTree method write() and function tostring() now use the text file's
encoding ("UTF-8" if not available) instead of locale encoding in XML
declaration when encoding="unicode" is specified.
  • Loading branch information
serhiy-storchaka committed Apr 25, 2022
commit ca0f8e9cf5a9de5fddaaa1e6b273b90838f5670c
113 changes: 87 additions & 26 deletions Lib/test/test_xml_etree.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import html
import io
import itertools
import locale
import operator
import os
import pickle
Expand Down Expand Up @@ -975,15 +974,13 @@ def test_tostring_xml_declaration(self):

def test_tostring_xml_declaration_unicode_encoding(self):
elem = ET.XML('<body><tag/></body>')
preferredencoding = locale.getpreferredencoding()
self.assertEqual(
f"<?xml version='1.0' encoding='{preferredencoding}'?>\n<body><tag /></body>",
ET.tostring(elem, encoding='unicode', xml_declaration=True)
ET.tostring(elem, encoding='unicode', xml_declaration=True),
"<?xml version='1.0' encoding='utf-8'?>\n<body><tag /></body>"
)

def test_tostring_xml_declaration_cases(self):
elem = ET.XML('<body><tag>ø</tag></body>')
preferredencoding = locale.getpreferredencoding()
TESTCASES = [
# (expected_retval, encoding, xml_declaration)
# ... xml_declaration = None
Expand All @@ -1010,7 +1007,7 @@ def test_tostring_xml_declaration_cases(self):
b"<body><tag>&#248;</tag></body>", 'US-ASCII', True),
(b"<?xml version='1.0' encoding='ISO-8859-1'?>\n"
b"<body><tag>\xf8</tag></body>", 'ISO-8859-1', True),
(f"<?xml version='1.0' encoding='{preferredencoding}'?>\n"
("<?xml version='1.0' encoding='utf-8'?>\n"
"<body><tag>ø</tag></body>", 'unicode', True),

]
Expand Down Expand Up @@ -1048,11 +1045,10 @@ def test_tostringlist_xml_declaration(self):
b"<?xml version='1.0' encoding='us-ascii'?>\n<body><tag /></body>"
)

preferredencoding = locale.getpreferredencoding()
stringlist = ET.tostringlist(elem, encoding='unicode', xml_declaration=True)
self.assertEqual(
''.join(stringlist),
f"<?xml version='1.0' encoding='{preferredencoding}'?>\n<body><tag /></body>"
"<?xml version='1.0' encoding='utf-8'?>\n<body><tag /></body>"
)
self.assertRegex(stringlist[0], r"^<\?xml version='1.0' encoding='.+'?>")
self.assertEqual(['<body', '>', '<tag', ' />', '</body>'], stringlist[1:])
Expand Down Expand Up @@ -3712,49 +3708,114 @@ def test_encoding(self):
"<tag key=\"åöö&lt;&gt;\" />" % enc).encode(enc))

def test_write_to_filename(self):
self.addCleanup(os_helper.unlink, TESTFN)
tree = ET.ElementTree(ET.XML('''<site />'''))
tree = ET.ElementTree(ET.XML('''<site>\xf8</site>'''))
tree.write(TESTFN)
with open(TESTFN, 'rb') as f:
self.assertEqual(f.read(), b'''<site />''')
self.assertEqual(f.read(), b'''<site>&#248;</site>''')

def test_write_to_filename_with_encoding(self):
self.addCleanup(os_helper.unlink, TESTFN)
tree = ET.ElementTree(ET.XML('''<site>\xf8</site>'''))
tree.write(TESTFN, encoding='utf-8')
with open(TESTFN, 'rb') as f:
self.assertEqual(f.read(), b'''<site>\xc3\xb8</site>''')

tree.write(TESTFN, encoding='ISO-8859-1')
with open(TESTFN, 'rb') as f:
self.assertEqual(f.read(),
b'''<?xml version='1.0' encoding='ISO-8859-1'?>\n'''
b'''<site>\xf8</site>''')

def test_write_to_filename_as_unicode(self):
self.addCleanup(os_helper.unlink, TESTFN)
with open(TESTFN, 'w') as f:
encoding = f.encoding
os_helper.unlink(TESTFN)

tree = ET.ElementTree(ET.XML('''<site>\xf8</site>'''))
tree.write(TESTFN, encoding='unicode')
with open(TESTFN, 'rb') as f:
data = f.read()
expected = "<site>\xf8</site>".encode(encoding, 'xmlcharrefreplace')
self.assertIn(
"<site>\xf8</site>".encode(encoding, 'xmlcharrefreplace'),
data)
if encoding.lower() in ('utf-8', 'ascii'):
self.assertEqual(data, expected)
else:
self.assertIn(b"<?xml version='1.0' encoding=", data)
self.assertIn(expected, data)

def test_write_to_text_file(self):
self.addCleanup(os_helper.unlink, TESTFN)
tree = ET.ElementTree(ET.XML('''<site />'''))
tree = ET.ElementTree(ET.XML('''<site>\xf8</site>'''))
with open(TESTFN, 'w', encoding='utf-8') as f:
tree.write(f, encoding='unicode')
self.assertFalse(f.closed)
with open(TESTFN, 'rb') as f:
self.assertEqual(f.read(), b'''<site />''')
self.assertEqual(f.read(), b'''<site>\xc3\xb8</site>''')

with open(TESTFN, 'w', encoding='ascii', errors='xmlcharrefreplace') as f:
tree.write(f, encoding='unicode')
self.assertFalse(f.closed)
with open(TESTFN, 'rb') as f:
self.assertEqual(f.read(),
b'''<?xml version='1.0' encoding='ascii'?>\n'''
b'''<site>&#248;</site>''')

with open(TESTFN, 'w', encoding='ISO-8859-1') as f:
tree.write(f, encoding='unicode')
self.assertFalse(f.closed)
with open(TESTFN, 'rb') as f:
self.assertEqual(f.read(),
b'''<?xml version='1.0' encoding='ISO-8859-1'?>\n'''
b'''<site>\xf8</site>''')

def test_write_to_binary_file(self):
self.addCleanup(os_helper.unlink, TESTFN)
tree = ET.ElementTree(ET.XML('''<site />'''))
tree = ET.ElementTree(ET.XML('''<site>\xf8</site>'''))
with open(TESTFN, 'wb') as f:
tree.write(f)
self.assertFalse(f.closed)
with open(TESTFN, 'rb') as f:
self.assertEqual(f.read(), b'''<site />''')
self.assertEqual(f.read(), b'''<site>&#248;</site>''')

def test_write_to_binary_file_with_encoding(self):
self.addCleanup(os_helper.unlink, TESTFN)
tree = ET.ElementTree(ET.XML('''<site>\xf8</site>'''))
with open(TESTFN, 'wb') as f:
tree.write(f, encoding='utf-8')
self.assertFalse(f.closed)
with open(TESTFN, 'rb') as f:
self.assertEqual(f.read(), b'''<site>\xc3\xb8</site>''')

with open(TESTFN, 'wb') as f:
tree.write(f, encoding='ISO-8859-1')
self.assertFalse(f.closed)
with open(TESTFN, 'rb') as f:
self.assertEqual(f.read(),
b'''<?xml version='1.0' encoding='ISO-8859-1'?>\n'''
b'''<site>\xf8</site>''')

def test_write_to_binary_file_with_bom(self):
self.addCleanup(os_helper.unlink, TESTFN)
tree = ET.ElementTree(ET.XML('''<site />'''))
tree = ET.ElementTree(ET.XML('''<site>\xf8</site>'''))
# test BOM writing to buffered file
with open(TESTFN, 'wb') as f:
tree.write(f, encoding='utf-16')
self.assertFalse(f.closed)
with open(TESTFN, 'rb') as f:
self.assertEqual(f.read(),
'''<?xml version='1.0' encoding='utf-16'?>\n'''
'''<site />'''.encode("utf-16"))
'''<site>\xf8</site>'''.encode("utf-16"))
# test BOM writing to non-buffered file
with open(TESTFN, 'wb', buffering=0) as f:
tree.write(f, encoding='utf-16')
self.assertFalse(f.closed)
with open(TESTFN, 'rb') as f:
self.assertEqual(f.read(),
'''<?xml version='1.0' encoding='utf-16'?>\n'''
'''<site />'''.encode("utf-16"))
'''<site>\xf8</site>'''.encode("utf-16"))

def test_read_from_stringio(self):
tree = ET.ElementTree()
Expand All @@ -3763,10 +3824,10 @@ def test_read_from_stringio(self):
self.assertEqual(tree.getroot().tag, 'site')

def test_write_to_stringio(self):
tree = ET.ElementTree(ET.XML('''<site />'''))
tree = ET.ElementTree(ET.XML('''<site>\xf8</site>'''))
stream = io.StringIO()
tree.write(stream, encoding='unicode')
self.assertEqual(stream.getvalue(), '''<site />''')
self.assertEqual(stream.getvalue(), '''<site>\xf8</site>''')

def test_read_from_bytesio(self):
tree = ET.ElementTree()
Expand All @@ -3775,10 +3836,10 @@ def test_read_from_bytesio(self):
self.assertEqual(tree.getroot().tag, 'site')

def test_write_to_bytesio(self):
tree = ET.ElementTree(ET.XML('''<site />'''))
tree = ET.ElementTree(ET.XML('''<site>\xf8</site>'''))
raw = io.BytesIO()
tree.write(raw)
self.assertEqual(raw.getvalue(), b'''<site />''')
self.assertEqual(raw.getvalue(), b'''<site>&#248;</site>''')

class dummy:
pass
Expand All @@ -3792,12 +3853,12 @@ def test_read_from_user_text_reader(self):
self.assertEqual(tree.getroot().tag, 'site')

def test_write_to_user_text_writer(self):
tree = ET.ElementTree(ET.XML('''<site />'''))
tree = ET.ElementTree(ET.XML('''<site>\xf8</site>'''))
stream = io.StringIO()
writer = self.dummy()
writer.write = stream.write
tree.write(writer, encoding='unicode')
self.assertEqual(stream.getvalue(), '''<site />''')
self.assertEqual(stream.getvalue(), '''<site>\xf8</site>''')

def test_read_from_user_binary_reader(self):
raw = io.BytesIO(b'''<?xml version="1.0"?><site></site>''')
Expand All @@ -3809,12 +3870,12 @@ def test_read_from_user_binary_reader(self):
tree = ET.ElementTree()

def test_write_to_user_binary_writer(self):
tree = ET.ElementTree(ET.XML('''<site />'''))
tree = ET.ElementTree(ET.XML('''<site>\xf8</site>'''))
raw = io.BytesIO()
writer = self.dummy()
writer.write = raw.write
tree.write(writer)
self.assertEqual(raw.getvalue(), b'''<site />''')
self.assertEqual(raw.getvalue(), b'''<site>&#248;</site>''')

def test_write_to_user_binary_writer_with_bom(self):
tree = ET.ElementTree(ET.XML('''<site />'''))
Expand Down
23 changes: 9 additions & 14 deletions Lib/xml/etree/ElementTree.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,16 +728,10 @@ def write(self, file_or_filename,
encoding = "utf-8"
else:
encoding = "us-ascii"
enc_lower = encoding.lower()
with _get_writer(file_or_filename, enc_lower) as write:
with _get_writer(file_or_filename, encoding) as (write, declared_encoding):
if method == "xml" and (xml_declaration or
(xml_declaration is None and
enc_lower not in ("utf-8", "us-ascii", "unicode"))):
declared_encoding = encoding
if enc_lower == "unicode":
# Retrieve the default encoding for the xml declaration
import locale
declared_encoding = locale.getpreferredencoding()
declared_encoding.lower() not in ("utf-8", "us-ascii"))):
write("<?xml version='1.0' encoding='%s'?>\n" % (
declared_encoding,))
if method == "text":
Expand All @@ -762,19 +756,20 @@ def _get_writer(file_or_filename, encoding):
write = file_or_filename.write
except AttributeError:
# file_or_filename is a file name
if encoding == "unicode":
file = open(file_or_filename, "w")
if encoding.lower() == "unicode":
file = open(file_or_filename, "w",
errors="xmlcharrefreplace")
else:
file = open(file_or_filename, "w", encoding=encoding,
errors="xmlcharrefreplace")
with file:
yield file.write
yield file.write, file.encoding
else:
# file_or_filename is a file-like object
# encoding determines if it is a text or binary writer
if encoding == "unicode":
if encoding.lower() == "unicode":
# use a text writer as is
yield write
yield write, getattr(file_or_filename, "encoding", None) or "utf-8"
else:
# wrap a binary writer with TextIOWrapper
with contextlib.ExitStack() as stack:
Expand Down Expand Up @@ -805,7 +800,7 @@ def _get_writer(file_or_filename, encoding):
# Keep the original file open when the TextIOWrapper is
# destroyed
stack.callback(file.detach)
yield file.write
yield file.write, encoding

def _namespaces(elem, default_namespace=None):
# identify namespaces used in this tree
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
:class:`~xml.etree.ElementTree.ElementTree` method
:meth:`~xml.etree.ElementTree.ElementTree.write` and function
:func:`~xml.etree.ElementTree.tostring` now use the text file's encoding
("UTF-8" if not available) instead of locale encoding in XML declaration
when ``encoding="unicode"`` is specified.