Skip to content

Commit 4a07745

Browse files
fix: avoid memory leak when decoding invalid nested arrays (#671)
1 parent 378edc6 commit 4a07745

3 files changed

Lines changed: 44 additions & 0 deletions

File tree

msgpack/_unpacker.pyx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,7 @@ cdef class Unpacker:
322322
self.buf = NULL
323323

324324
def __dealloc__(self):
325+
unpack_clear(&self.ctx)
325326
PyMem_Free(self.buf)
326327
self.buf = NULL
327328

msgpack/unpack_template.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@ static inline PyObject* unpack_data(unpack_context* ctx)
7272

7373
static inline void unpack_clear(unpack_context *ctx)
7474
{
75+
unsigned int i;
76+
for (i = 1; i < ctx->top; i++) {
77+
Py_CLEAR(ctx->stack[i].obj);
78+
/* map_key holds a live reference only while waiting for the value */
79+
if (ctx->stack[i].ct == CT_MAP_VALUE) {
80+
Py_CLEAR(ctx->stack[i].map_key);
81+
}
82+
}
7583
Py_CLEAR(ctx->stack[0].obj);
7684
}
7785

test/test_except.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#!/usr/bin/env python
22

33
import datetime
4+
import gc
5+
import tracemalloc
46

57
from pytest import raises
68

@@ -80,6 +82,39 @@ def test_invalidvalue():
8082
unpackb(b"\x91" * 3000) # nested fixarray(len=1)
8183

8284

85+
def test_no_memory_leak_on_nested_invalid_tag() -> None:
86+
"""Regression test: unpacking nested arrays containing an invalid tag must not leak objects."""
87+
88+
kwargs: dict = {
89+
"raw": False,
90+
"strict_map_key": False,
91+
"max_array_len": 1 << 20,
92+
"max_map_len": 1 << 20,
93+
}
94+
n = 1000
95+
96+
for depth in range(1, 15):
97+
data = bytes([0x91] * depth + [0xC1])
98+
99+
gc.collect()
100+
tracemalloc.start()
101+
s1 = tracemalloc.take_snapshot()
102+
103+
for _ in range(n):
104+
try:
105+
unpackb(data, **kwargs)
106+
except Exception:
107+
pass
108+
109+
gc.collect()
110+
s2 = tracemalloc.take_snapshot()
111+
tracemalloc.stop()
112+
113+
leaked = sum(s.count_diff for s in s2.compare_to(s1, "lineno") if s.count_diff > 0)
114+
per_call = leaked / n
115+
assert per_call < 1.0, f"depth={depth}: {per_call:.2f} leaked objects/call (expected < 1)"
116+
117+
83118
def test_strict_map_key():
84119
valid = {"unicode": 1, b"bytes": 2}
85120
packed = packb(valid, use_bin_type=True)

0 commit comments

Comments
 (0)