Skip to content

Commit 823b44e

Browse files
authored
fix: make Tags pickle-safe (#1170)
* fix: make Tags pickle-safe Signed-off-by: Henry Schreiner <henryfs@princeton.edu> Assisted-by: Copilot:claude-hauku-4.5 * fix: address copilot review Assisted-by: Copilot:Kimi-K2.6 Signed-off-by: Henry Schreiner <henryfs@princeton.edu> --------- Signed-off-by: Henry Schreiner <henryfs@princeton.edu>
1 parent 4bed32d commit 823b44e

2 files changed

Lines changed: 135 additions & 7 deletions

File tree

src/packaging/tags.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from importlib.machinery import EXTENSION_SUFFIXES
1616
from typing import (
1717
TYPE_CHECKING,
18-
Any,
1918
Iterable,
2019
Iterator,
2120
Sequence,
@@ -158,12 +157,37 @@ def __str__(self) -> str:
158157
def __repr__(self) -> str:
159158
return f"<{self} @ {id(self)}>"
160159

161-
def __setstate__(self, state: tuple[None, dict[str, Any]]) -> None:
162-
# The cached _hash is wrong when unpickling.
163-
_, slots = state
164-
for k, v in slots.items():
165-
setattr(self, k, v)
166-
self._hash = hash((self._interpreter, self._abi, self._platform))
160+
def __getstate__(self) -> tuple[str, str, str]:
161+
# Return state as a 3-item tuple: (interpreter, abi, platform).
162+
# Cache member _hash is excluded and will be recomputed.
163+
return (self._interpreter, self._abi, self._platform)
164+
165+
def __setstate__(self, state: object) -> None:
166+
if isinstance(state, tuple):
167+
if len(state) == 3 and all(isinstance(s, str) for s in state):
168+
# New format (26.2+): (interpreter, abi, platform)
169+
self._interpreter, self._abi, self._platform = state
170+
self._hash = hash((self._interpreter, self._abi, self._platform))
171+
return
172+
if len(state) == 2 and isinstance(state[1], dict):
173+
# Old format (packaging <= 26.1, __slots__): (None, {slot: value}).
174+
_, slots = state
175+
try:
176+
interpreter = slots["_interpreter"]
177+
abi = slots["_abi"]
178+
platform = slots["_platform"]
179+
except KeyError:
180+
raise TypeError(f"Cannot restore Tag from {state!r}") from None
181+
if not all(
182+
isinstance(value, str) for value in (interpreter, abi, platform)
183+
):
184+
raise TypeError(f"Cannot restore Tag from {state!r}")
185+
self._interpreter = interpreter.lower()
186+
self._abi = abi.lower()
187+
self._platform = platform.lower()
188+
self._hash = hash((self._interpreter, self._abi, self._platform))
189+
return
190+
raise TypeError(f"Cannot restore Tag from {state!r}")
167191

168192

169193
def parse_tag(tag: str, *, validate_order: bool = False) -> frozenset[Tag]:

tests/test_tags.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1867,6 +1867,110 @@ def test_pickle() -> None:
18671867
assert pickle.loads(pickle.dumps(tag)) == tag
18681868

18691869

1870+
@pytest.mark.parametrize(
1871+
("interpreter", "abi", "platform"),
1872+
[
1873+
("py3", "none", "any"),
1874+
("cp39", "cp39", "linux_x86_64"),
1875+
("cp312", "cp312", "win_amd64"),
1876+
("pp310", "pypy310_pp73", "manylinux_2_17_x86_64"),
1877+
],
1878+
)
1879+
def test_pickle_tag_roundtrip(interpreter: str, abi: str, platform: str) -> None:
1880+
# Make sure equality, str(), and hash() work between a pickle/unpickle round trip.
1881+
t = tags.Tag(interpreter, abi, platform)
1882+
loaded = pickle.loads(pickle.dumps(t))
1883+
assert loaded == t
1884+
assert str(loaded) == str(t)
1885+
assert hash(loaded) == hash(t)
1886+
1887+
1888+
def test_pickle_tag_setstate_rejects_invalid_state() -> None:
1889+
# Cover the TypeError branches in __setstate__ for invalid input.
1890+
t = tags.Tag.__new__(tags.Tag)
1891+
with pytest.raises(TypeError, match="Cannot restore Tag"):
1892+
t.__setstate__(12345)
1893+
with pytest.raises(TypeError, match="Cannot restore Tag"):
1894+
t.__setstate__((1, 2, 3)) # Wrong types, not all strings
1895+
with pytest.raises(TypeError, match="Cannot restore Tag"):
1896+
t.__setstate__((None, {"_interpreter": "cp39", "_abi": "cp39"}))
1897+
with pytest.raises(TypeError, match="Cannot restore Tag"):
1898+
t.__setstate__(
1899+
(None, {"_interpreter": 123, "_abi": "cp39", "_platform": "linux_x86_64"})
1900+
)
1901+
with pytest.raises(TypeError, match="Cannot restore Tag"):
1902+
t.__setstate__((1, 2)) # len==2 but second element not a dict
1903+
with pytest.raises(TypeError, match="Cannot restore Tag"):
1904+
t.__setstate__((1, 2, 3, 4)) # tuple length not 2 or 3
1905+
1906+
1907+
# Pickle bytes generated with packaging==26.1, Python 3.13.1, pickle protocol 2.
1908+
# Format: __slots__ (no __getstate__), state is (None, {slot: value}). The
1909+
# _hash slot contains a pre-computed integer that must be discarded on load.
1910+
_PACKAGING_26_1_PICKLE_TAG_CP39 = (
1911+
b"\x80\x02cpackaging.tags\nTag\nq\x00)\x81q\x01N}q\x02(X\x04\x00\x00"
1912+
b"\x00_abiq\x03X\x04\x00\x00\x00cp39q\x04X\x05\x00\x00\x00_hashq\x05"
1913+
b"\x8a\x08)\xb1\xe8\x9d\x90\xf8tFX\x0c\x00\x00\x00_interpreterq\x06X"
1914+
b"\x04\x00\x00\x00cp39q\x07X\t\x00\x00\x00_platformq\x08X\x0c\x00\x00"
1915+
b"\x00linux_x86_64q\tu\x86q\nb."
1916+
)
1917+
1918+
1919+
# Pickle bytes generated with packaging==26.0, Python 3.13.1, pickle protocol 2.
1920+
# Format: __slots__ (no __getstate__), state is (None, {slot: value}).
1921+
_PACKAGING_26_0_PICKLE_TAG_CP39 = (
1922+
b"\x80\x02cpackaging.tags\nTag\nq\x00)\x81q\x01N}q\x02(X\x04\x00\x00"
1923+
b"\x00_abiq\x03X\x04\x00\x00\x00cp39q\x04X\x05\x00\x00\x00_hashq\x05"
1924+
b"\x8a\x08\xc1\xdb\xa0\xe5]7z\x87X\x0c\x00\x00\x00_interpreterq\x06X"
1925+
b"\x04\x00\x00\x00cp39q\x07X\t\x00\x00\x00_platformq\x08X\x0c\x00\x00"
1926+
b"\x00linux_x86_64q\tu\x86q\nb."
1927+
)
1928+
1929+
1930+
# Pickle bytes generated with packaging==25.0, Python 3.13.1, pickle protocol 2.
1931+
# Format: plain __dict__ (no __slots__).
1932+
_PACKAGING_25_0_PICKLE_TAG_CP39 = (
1933+
b"\x80\x02cpackaging.tags\nTag\nq\x00)\x81q\x01N}q\x02(X\x04\x00\x00\x00"
1934+
b"_abiq\x03X\x04\x00\x00\x00cp39q\x04X\x05\x00\x00\x00_hashq\x05\x8a\x08"
1935+
b"\xea\xa5X\x92\xa5\xc9\x11\x0cX\x0c\x00\x00\x00_interpreterq\x06X\x04"
1936+
b"\x00\x00\x00cp39q\x07X\t\x00\x00\x00_platformq\x08X\x0c\x00\x00\x00"
1937+
b"linux_x86_64q\tu\x86q\nb."
1938+
)
1939+
1940+
1941+
def test_pickle_tag_old_format_loads() -> None:
1942+
# Verify that Tag pickles created with packaging <= 26.1 (__slots__,
1943+
# no __getstate__) can be loaded and produce correct Tag objects.
1944+
t = pickle.loads(_PACKAGING_26_1_PICKLE_TAG_CP39)
1945+
assert isinstance(t, tags.Tag)
1946+
assert str(t) == "cp39-cp39-linux_x86_64"
1947+
assert t == tags.Tag("cp39", "cp39", "linux_x86_64")
1948+
assert t.interpreter == "cp39"
1949+
assert t.abi == "cp39"
1950+
assert t.platform == "linux_x86_64"
1951+
assert t._hash == hash(("cp39", "cp39", "linux_x86_64"))
1952+
1953+
1954+
def test_pickle_tag_26_0_format_loads() -> None:
1955+
# Verify that Tag pickles created with packaging 26.0 (__slots__,
1956+
# no __getstate__) can be loaded and produce correct Tag objects.
1957+
t = pickle.loads(_PACKAGING_26_0_PICKLE_TAG_CP39)
1958+
assert isinstance(t, tags.Tag)
1959+
assert str(t) == "cp39-cp39-linux_x86_64"
1960+
assert t == tags.Tag("cp39", "cp39", "linux_x86_64")
1961+
assert t._hash == hash(("cp39", "cp39", "linux_x86_64"))
1962+
1963+
1964+
def test_pickle_tag_25_0_format_loads() -> None:
1965+
# Verify that Tag pickles created with packaging 25.0 (plain __dict__)
1966+
# can be loaded and produce correct Tag objects.
1967+
t = pickle.loads(_PACKAGING_25_0_PICKLE_TAG_CP39)
1968+
assert isinstance(t, tags.Tag)
1969+
assert str(t) == "cp39-cp39-linux_x86_64"
1970+
assert t == tags.Tag("cp39", "cp39", "linux_x86_64")
1971+
assert t._hash == hash(("cp39", "cp39", "linux_x86_64"))
1972+
1973+
18701974
@pytest.mark.parametrize(
18711975
("supported", "things", "expected"),
18721976
[

0 commit comments

Comments
 (0)