Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 12 additions & 19 deletions Lib/test/pickletester.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,29 +73,25 @@ def tell(self):
raise io.UnsupportedOperation


# We can't very well test the extension registry without putting known stuff
# in it, but we have to be careful to restore its original state. Code
# should do this:
#
# e = ExtensionSaver(extension_code)
# try:
# fiddle w/ the extension registry's stuff for extension_code
# finally:
# e.restore()
# We can't test the extension registry well without putting known stuff
# in it, but we have to be careful to restore its original state.

class ExtensionSaver:
# Remember current registration for code (if any), and remove it (if
# there is one).
def __init__(self, code):
self.code = code
if code in copyreg._inverted_registry:
self.pair = copyreg._inverted_registry[code]
copyreg.remove_extension(self.pair[0], self.pair[1], code)
else:
self.pair = None
self.pair = None

def __enter__(self):
if self.code not in copyreg._inverted_registry:
return

self.pair = copyreg._inverted_registry[self.code]
copyreg.remove_extension(self.pair[0], self.pair[1], self.code)

# Restore previous registration for code.
def restore(self):
def __exit__(self, exc_type, exc_value, traceback):
code = self.code
curpair = copyreg._inverted_registry.get(code)
if curpair is not None:
Expand Down Expand Up @@ -1944,8 +1940,7 @@ def test_newobj_not_class(self):
# (EXT[124]) under proto 2, and not in proto 1.

def produce_global_ext(self, extcode, opcode):
e = ExtensionSaver(extcode)
try:
with ExtensionSaver(extcode):
copyreg.add_extension(__name__, "MyList", extcode)
x = MyList([1, 2, 3])
x.foo = 42
Expand All @@ -1968,8 +1963,6 @@ def produce_global_ext(self, extcode, opcode):

y = self.loads(s2)
self.assert_is_copy(x, y)
finally:
e.restore()

def test_global_ext1(self):
self.produce_global_ext(0x00000001, pickle.EXT1) # smallest EXT1 code
Expand Down