Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Align copy module behaviour with pickle module
  • Loading branch information
eendebakpt committed Sep 16, 2023
commit 1535893ec9a0f17c52cc86ff77f4875166b46306
30 changes: 16 additions & 14 deletions Lib/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class Error(Exception):
pass
error = Error # backward compatibility

_NoValue = object()

__all__ = ["Error", "copy", "deepcopy"]

def copy(x):
Expand All @@ -74,20 +76,20 @@ def copy(x):
# treat it as a regular class:
return _copy_immutable(x)

copier = getattr(cls, "__copy__", None)
if copier is not None:
copier = getattr(cls, "__copy__", _NoValue)
if copier is not _NoValue:
return copier(x)

reductor = dispatch_table.get(cls)
if reductor is not None:
reductor = dispatch_table.get(cls, _NoValue)
if reductor is not _NoValue:
rv = reductor(x)
else:
reductor = getattr(x, "__reduce_ex__", None)
if reductor is not None:
reductor = getattr(x, "__reduce_ex__", _NoValue)
if reductor is not _NoValue:
rv = reductor(4)
else:
reductor = getattr(x, "__reduce__", None)
if reductor:
reductor = getattr(x, "__reduce__", _NoValue)
if reductor is not _NoValue:
rv = reductor()
else:
raise Error("un(shallow)copyable object of type %s" % cls)
Expand Down Expand Up @@ -138,20 +140,20 @@ def deepcopy(x, memo=None, _nil=[]):
if issubclass(cls, type):
y = _deepcopy_atomic(x, memo)
else:
copier = getattr(x, "__deepcopy__", None)
if copier is not None:
copier = getattr(x, "__deepcopy__", _NoValue)
if copier is not _NoValue:
y = copier(memo)
else:
reductor = dispatch_table.get(cls)
if reductor:
rv = reductor(x)
else:
reductor = getattr(x, "__reduce_ex__", None)
if reductor is not None:
reductor = getattr(x, "__reduce_ex__", _NoValue)
if reductor is not _NoValue:
rv = reductor(4)
else:
reductor = getattr(x, "__reduce__", None)
if reductor:
reductor = getattr(x, "__reduce__", _NoValue)
if reductor is not _NoValue:
rv = reductor()
else:
raise Error(
Expand Down
20 changes: 20 additions & 0 deletions Lib/test/test_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,26 @@ def __reduce__(self):
self.assertIs(y, x)
self.assertEqual(c, [1])

def test_copy_invalid_reduction_methods(self):
class C(object):
__copy__ = None
x = C()
with self.assertRaises(TypeError):
copy.copy(x)

class C(object):
__reduce_ex__ = None
x = C()
with self.assertRaises(TypeError):
copy.copy(x)

class C(object):
__reduce_ex__ = copy._NoValue
__reduce__ = None
x = C()
with self.assertRaises(TypeError):
copy.copy(x)

def test_copy_reduce(self):
class C(object):
def __reduce__(self):
Expand Down