diff --git a/atomicwrites/__init__.py b/atomicwrites/__init__.py index a182c07..3a83761 100644 --- a/atomicwrites/__init__.py +++ b/atomicwrites/__init__.py @@ -1,4 +1,5 @@ import contextlib +import io import os import sys import tempfile @@ -118,7 +119,7 @@ class AtomicWriter(object): subclass. ''' - def __init__(self, path, mode='w', overwrite=False): + def __init__(self, path, mode='w', overwrite=False, **open_kwargs): if 'a' in mode: raise ValueError( 'Appending to an existing file is not supported, because that ' @@ -134,6 +135,7 @@ def __init__(self, path, mode='w', overwrite=False): self._path = path self._mode = mode self._overwrite = overwrite + self._open_kwargs = open_kwargs def open(self): ''' @@ -146,7 +148,7 @@ def _open(self, get_fileobject): f = None # make sure f exists even if get_fileobject() fails try: success = False - with get_fileobject() as f: + with get_fileobject(**self._open_kwargs) as f: yield f self.sync(f) self.commit(f) @@ -162,8 +164,14 @@ def get_fileobject(self, dir=None, **kwargs): '''Return the temporary file to use.''' if dir is None: dir = os.path.normpath(os.path.dirname(self._path)) - return tempfile.NamedTemporaryFile(mode=self._mode, dir=dir, - delete=False, **kwargs) + descriptor, name = tempfile.mkstemp(dir=dir) + # io.open() will take either the descriptor or the name, but we need + # the name later for commit()/replace_atomic() and couldn't find a way + # to get the filename from the descriptor. + os.close(descriptor) + kwargs['mode'] = self._mode + kwargs['file'] = name + return io.open(**kwargs) def sync(self, f): '''responsible for clearing as many file caches as possible before diff --git a/tests/test_atomicwrites.py b/tests/test_atomicwrites.py index 9577199..3bdcd5e 100644 --- a/tests/test_atomicwrites.py +++ b/tests/test_atomicwrites.py @@ -10,11 +10,11 @@ def test_atomic_write(tmpdir): fname = tmpdir.join('ha') for i in range(2): with atomic_write(str(fname), overwrite=True) as f: - f.write('hoho') + f.write(u'hoho') with pytest.raises(OSError) as excinfo: with atomic_write(str(fname), overwrite=False) as f: - f.write('haha') + f.write(u'haha') assert excinfo.value.errno == errno.EEXIST @@ -34,7 +34,7 @@ def test_teardown(tmpdir): def test_replace_simultaneously_created_file(tmpdir): fname = tmpdir.join('ha') with atomic_write(str(fname), overwrite=True) as f: - f.write('hoho') + f.write(u'hoho') fname.write('harhar') assert fname.read() == 'harhar' assert fname.read() == 'hoho' @@ -45,7 +45,7 @@ def test_dont_remove_simultaneously_created_file(tmpdir): fname = tmpdir.join('ha') with pytest.raises(OSError) as excinfo: with atomic_write(str(fname), overwrite=False) as f: - f.write('hoho') + f.write(u'hoho') fname.write('harhar') assert fname.read() == 'harhar' @@ -60,10 +60,10 @@ def test_open_reraise(tmpdir): fname = tmpdir.join('ha') with pytest.raises(AssertionError): with atomic_write(str(fname), overwrite=False) as f: - # Mess with f, so rollback will trigger an OSError. We're testing + # Mess with f, so commit will trigger a ValueError. We're testing # that the initial AssertionError triggered below is propagated up - # the stack, not the second exception triggered during rollback. - f.name = "asdf" + # the stack, not the second exception triggered during commit. + f.close() # Now trigger our own exception. assert False, "Intentional failure for testing purposes" @@ -75,11 +75,11 @@ def test_atomic_write_in_pwd(tmpdir): fname = 'ha' for i in range(2): with atomic_write(str(fname), overwrite=True) as f: - f.write('hoho') + f.write(u'hoho') with pytest.raises(OSError) as excinfo: with atomic_write(str(fname), overwrite=False) as f: - f.write('haha') + f.write(u'haha') assert excinfo.value.errno == errno.EEXIST