Skip to content
This repository was archived by the owner on Jul 16, 2022. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions atomicwrites/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import contextlib
import io
import os
import sys
import tempfile
Expand Down Expand Up @@ -118,7 +119,7 @@ class AtomicWriter(object):
subclass.
'''

def __init__(self, path, mode='w', overwrite=False):
def __init__(self, path, mode='w', overwrite=False, **open_kwargs):
if 'a' in mode:
raise ValueError(
'Appending to an existing file is not supported, because that '
Expand All @@ -134,6 +135,7 @@ def __init__(self, path, mode='w', overwrite=False):
self._path = path
self._mode = mode
self._overwrite = overwrite
self._open_kwargs = open_kwargs

def open(self):
'''
Expand All @@ -146,7 +148,7 @@ def _open(self, get_fileobject):
f = None # make sure f exists even if get_fileobject() fails
try:
success = False
with get_fileobject() as f:
with get_fileobject(**self._open_kwargs) as f:
yield f
self.sync(f)
self.commit(f)
Expand All @@ -162,8 +164,14 @@ def get_fileobject(self, dir=None, **kwargs):
'''Return the temporary file to use.'''
if dir is None:
dir = os.path.normpath(os.path.dirname(self._path))
return tempfile.NamedTemporaryFile(mode=self._mode, dir=dir,
delete=False, **kwargs)
descriptor, name = tempfile.mkstemp(dir=dir)
# io.open() will take either the descriptor or the name, but we need
# the name later for commit()/replace_atomic() and couldn't find a way
# to get the filename from the descriptor.
os.close(descriptor)
kwargs['mode'] = self._mode
kwargs['file'] = name
return io.open(**kwargs)

def sync(self, f):
'''responsible for clearing as many file caches as possible before
Expand Down
18 changes: 9 additions & 9 deletions tests/test_atomicwrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ def test_atomic_write(tmpdir):
fname = tmpdir.join('ha')
for i in range(2):
with atomic_write(str(fname), overwrite=True) as f:
f.write('hoho')
f.write(u'hoho')

with pytest.raises(OSError) as excinfo:
with atomic_write(str(fname), overwrite=False) as f:
f.write('haha')
f.write(u'haha')

assert excinfo.value.errno == errno.EEXIST

Expand All @@ -34,7 +34,7 @@ def test_teardown(tmpdir):
def test_replace_simultaneously_created_file(tmpdir):
fname = tmpdir.join('ha')
with atomic_write(str(fname), overwrite=True) as f:
f.write('hoho')
f.write(u'hoho')
fname.write('harhar')
assert fname.read() == 'harhar'
assert fname.read() == 'hoho'
Expand All @@ -45,7 +45,7 @@ def test_dont_remove_simultaneously_created_file(tmpdir):
fname = tmpdir.join('ha')
with pytest.raises(OSError) as excinfo:
with atomic_write(str(fname), overwrite=False) as f:
f.write('hoho')
f.write(u'hoho')
fname.write('harhar')
assert fname.read() == 'harhar'

Expand All @@ -60,10 +60,10 @@ def test_open_reraise(tmpdir):
fname = tmpdir.join('ha')
with pytest.raises(AssertionError):
with atomic_write(str(fname), overwrite=False) as f:
# Mess with f, so rollback will trigger an OSError. We're testing
# Mess with f, so commit will trigger a ValueError. We're testing
# that the initial AssertionError triggered below is propagated up
# the stack, not the second exception triggered during rollback.
f.name = "asdf"
# the stack, not the second exception triggered during commit.
f.close()
# Now trigger our own exception.
assert False, "Intentional failure for testing purposes"

Expand All @@ -75,11 +75,11 @@ def test_atomic_write_in_pwd(tmpdir):
fname = 'ha'
for i in range(2):
with atomic_write(str(fname), overwrite=True) as f:
f.write('hoho')
f.write(u'hoho')

with pytest.raises(OSError) as excinfo:
with atomic_write(str(fname), overwrite=False) as f:
f.write('haha')
f.write(u'haha')

assert excinfo.value.errno == errno.EEXIST

Expand Down