Skip to content
Prev Previous commit
Next Next commit
Rewrite using list as stack.
  • Loading branch information
barneygale committed May 30, 2024
commit 8d42ea89661599ceebf6ae31f8262faf316c3e61
79 changes: 36 additions & 43 deletions Lib/shutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,32 +635,27 @@ def onerror(err):
onexc(os.rmdir, path, err)

# Version using fd-based APIs to protect against races
def _rmtree_safe_fd(is_root, dir_fd, path, name, onexc):
# Note: To guard against symlink races, we use the standard
# lstat()/open()/fstat() trick.
def _rmtree_safe_fd(stack, onexc):
func, dir_fd, name, path = stack.pop()
try:
orig_st = os.lstat(name, dir_fd=dir_fd)
except OSError as err:
if is_root or not isinstance(err, FileNotFoundError):
onexc(os.lstat, path, err)
return
try:
fd = os.open(name, os.O_RDONLY | os.O_NONBLOCK, dir_fd=dir_fd)
except OSError as err:
if is_root or not isinstance(err, FileNotFoundError):
onexc(os.open, path, err)
return

try:
if not os.path.samestat(orig_st, os.fstat(fd)):
if func is os.close:
os.close(dir_fd)
elif func is os.rmdir:
os.rmdir(name, dir_fd=dir_fd)
else:
assert func is os.lstat
orig_st = os.lstat(name, dir_fd=dir_fd)
func = os.open # For error reporting.
fd = os.open(name, os.O_RDONLY | os.O_NONBLOCK, dir_fd=dir_fd)
try:
# symlinks to directories are forbidden, see bug #1669
raise OSError("Cannot call rmtree on a symbolic link")
except OSError as err:
onexc(os.path.islink, path, err)
return

try:
func = os.path.islink # For error reporting.
if not os.path.samestat(orig_st, os.fstat(fd)):
raise OSError("Cannot call rmtree on a symbolic link")
stack.append((os.rmdir, dir_fd, name, path))
finally:
stack.append((os.close, fd, name, path))

func = os.scandir # For error reporting.
with os.scandir(fd) as scandir_it:
entries = list(scandir_it)
for entry in entries:
Expand All @@ -670,32 +665,20 @@ def _rmtree_safe_fd(is_root, dir_fd, path, name, onexc):
except OSError:
is_dir = False
if is_dir:
_rmtree_safe_fd(False, fd, fullname, entry.name, onexc)
stack.append((os.lstat, fd, entry.name, fullname))
else:
try:
os.unlink(entry.name, dir_fd=fd)
except FileNotFoundError:
continue
except OSError as err:
err.filename = fullname
onexc(os.unlink, fullname, err)

if is_root or not isinstance(err, FileNotFoundError):
onexc(os.unlink, fullname, err)
except OSError as err:
if is_root or not isinstance(err, FileNotFoundError):
err.filename = path
onexc(os.scandir, path, err)

finally:
try:
os.close(fd)
except OSError as err:
onexc(os.close, path, err)

try:
os.rmdir(name, dir_fd=dir_fd)
except OSError as err:
if is_root or not isinstance(err, FileNotFoundError):
onexc(os.rmdir, path, err)
if func is os.close or name == path or not isinstance(err, FileNotFoundError):
err.filename = path
onexc(func, path, err)

_use_fd_functions = ({os.open, os.stat, os.unlink, os.rmdir} <=
os.supports_dir_fd and
Expand Down Expand Up @@ -748,7 +731,17 @@ def onexc(*args):
# While the unsafe rmtree works fine on bytes, the fd based does not.
if isinstance(path, bytes):
path = os.fsdecode(path)
_rmtree_safe_fd(True, dir_fd, path, path, onexc)
# Note: To guard against symlink races, we use the standard
# lstat()/open()/fstat() trick.
stack = [(os.lstat, dir_fd, path, path)]
try:
while stack:
_rmtree_safe_fd(stack, onexc)
finally:
while stack:
func, dir_fd, name, path = stack.pop()
if func is os.close:
os.close(dir_fd)
else:
if dir_fd is not None:
raise NotImplementedError("dir_fd unavailable on this platform")
Expand Down