diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 8627862..bf6d374 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -29,7 +29,7 @@ jobs: run: uv sync --frozen - name: Run the benchmarks - uses: CodSpeedHQ/action@db35df748deb45fdef0960669f57d627c1956c30 # v4.13.1 + uses: CodSpeedHQ/action@3194d9a39c4d46684cb44bf7207fc56626aad8fd # v4.15.1 with: mode: instrumentation run: uv run pytest tests/test_benchmarks.py --codspeed diff --git a/.github/workflows/cifuzz.yml b/.github/workflows/cifuzz.yml index 1c1f26a..89ce530 100644 --- a/.github/workflows/cifuzz.yml +++ b/.github/workflows/cifuzz.yml @@ -47,7 +47,7 @@ jobs: - name: Upload Sarif if: always() && steps.build.outcome == 'success' - uses: github/codeql-action/upload-sarif@95e58e9a2cdfd71adc6e0353d5c52f41a045d225 # v4.35.2 + uses: github/codeql-action/upload-sarif@7211b7c8077ea37d8641b6271f6a365a22a5fbfa # v4.36.0 with: sarif_file: cifuzz-sarif/results.sarif checkout_path: cifuzz-sarif diff --git a/.github/workflows/zizmor.yml b/.github/workflows/zizmor.yml index 6c18a20..8147373 100644 --- a/.github/workflows/zizmor.yml +++ b/.github/workflows/zizmor.yml @@ -22,4 +22,4 @@ jobs: persist-credentials: false - name: Run zizmor 🌈 - uses: zizmorcore/zizmor-action@b1d7e1fb5de872772f31590499237e7cce841e8e # v0.5.3 + uses: zizmorcore/zizmor-action@5f14fd08f7cf1cb1609c1e344975f152c7ee938d # v0.5.6 diff --git a/CHANGELOG.md b/CHANGELOG.md index f8bda92..8d470c2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,12 @@ ## Unreleased +## 0.0.31 (2026-06-04) + +* Speed up multipart header parsing and callback dispatch [#295](https://github.com/Kludex/python-multipart/pull/295). +* Bound header field name size before validating [#296](https://github.com/Kludex/python-multipart/pull/296). +* Validate `Content-Length` is non-negative in `parse_form` [#297](https://github.com/Kludex/python-multipart/pull/297). + ## 0.0.30 (2026-05-31) * Parse `application/x-www-form-urlencoded` bodies per the WHATWG URL standard, treating only `&` as a field separator [#290](https://github.com/Kludex/python-multipart/pull/290). diff --git a/pyproject.toml b/pyproject.toml index 7d1236a..368188a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,8 +36,8 @@ dependencies = [] dev = [ "atomicwrites==1.4.1", "attrs==26.1.0", - "coverage==7.13.5", - "more-itertools==11.0.2", + "coverage>=7.14.0", + "more-itertools>=11.1.0", "pbr==7.0.3", "pluggy==1.6.0", "py==1.11.0", @@ -47,7 +47,7 @@ dev = [ "PyYAML==6.0.3", "invoke==3.0.3", "pytest-timeout==2.4.0", - "ruff==0.15.11", + "ruff>=0.15.14", "mypy", "types-PyYAML", "atheris==2.3.0; python_version <= '3.11'", diff --git a/python_multipart/__init__.py b/python_multipart/__init__.py index f55f976..931c87d 100644 --- a/python_multipart/__init__.py +++ b/python_multipart/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.30" +__version__ = "0.0.31" from .multipart import ( BaseParser, diff --git a/python_multipart/multipart.py b/python_multipart/multipart.py index 1e63a2b..6769d9f 100644 --- a/python_multipart/multipart.py +++ b/python_multipart/multipart.py @@ -133,11 +133,12 @@ class MultipartState(IntEnum): # Mask for ASCII characters that can be http tokens. # Per RFC7230 - 3.2.6, this is all alpha-numeric characters # and these: !#$%&'*+-.^_`|~ -TOKEN_CHARS_SET = frozenset( +TOKEN_CHARS = ( b"ABCDEFGHIJKLMNOPQRSTUVWXYZ" b"abcdefghijklmnopqrstuvwxyz" b"0123456789" b"!#$%&'*+-.^_`|~") +TOKEN_CHARS_SET = frozenset(TOKEN_CHARS) # fmt: on DEFAULT_MAX_HEADER_COUNT = 8 @@ -647,8 +648,7 @@ def callback( end: An integer that is passed to the data callback. start: An integer that is passed to the data callback. """ - on_name = "on_" + name - func = self.callbacks.get(on_name) + func = self.callbacks.get("on_" + name) if func is None: return func = cast("Callable[..., Any]", func) @@ -657,11 +657,8 @@ def callback( # Don't do anything if we have start == end. if start is not None and start == end: return - - self.logger.debug("Calling %s with data[%d:%d]", on_name, start, end) func(data, start, end) else: - self.logger.debug("Calling %s with no data", on_name) func() def set_callback(self, name: CallbackName, new_func: Callable[..., Any] | None) -> None: @@ -1078,6 +1075,7 @@ def write(self, data: bytes) -> int: def _internal_write(self, data: bytes, length: int) -> int: # Get values from locals. boundary = self.boundary + boundary_length = len(boundary) # Get our state, flags and index. These are persisted between calls to # this function. @@ -1128,7 +1126,7 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No # We need to use self.flags (and not flags) because we care about # the state when we entered the loop. lookbehind_len = -marked_index - if lookbehind_len <= len(boundary): + if lookbehind_len <= boundary_length: self.callback(name, boundary, 0, lookbehind_len) elif self.flags & FLAG_PART_BOUNDARY: lookback = boundary + b"\r\n" @@ -1173,7 +1171,7 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No elif state == MultipartState.START_BOUNDARY: # Check to ensure that the last 2 characters in our boundary # are CRLF. - if index == len(boundary) - 2: + if index == boundary_length - 2: if c == HYPHEN: # Potential empty message. state = MultipartState.END_BOUNDARY @@ -1185,7 +1183,7 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No index += 1 - elif index == len(boundary) - 2 + 1: + elif index == boundary_length - 1: if c != LF: msg = "Did not find LF at end of boundary (%d)" % (i,) self.logger.warning(msg) @@ -1247,31 +1245,41 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No i += 1 continue - # Increment our index in the header. - index += 1 + # The field name runs until the colon; jump straight to it and + # validate the whole span at once instead of byte by byte. + colon = data.find(b":", i, length) + end = colon if colon != -1 else length - # If we've reached a colon, we're done with this header. - if c == COLON: - advance_header_size() + # Enforce the size limit before slicing and validating, so an oversized header + # name fails fast instead of copying and scanning a potentially huge span. + advance_header_size(end - i if colon == -1 else end - i + 1) + + field = data[i:end] + if field.translate(None, TOKEN_CHARS): + bad = next(b for b in field if b not in TOKEN_CHARS_SET) + bad_i = i + field.index(bad) + msg = "Found invalid character %r in header at %d" % (bad, bad_i) + self.logger.warning(msg) + raise MultipartParseError(msg, offset=bad_i) + + index += end - i + if colon == -1: + # Field name continues into the next chunk. + i = length + else: # A 0-length header is an error. - if index == 1: + if index == 0: msg = "Found 0-length header at %d" % (i,) self.logger.warning(msg) raise MultipartParseError(msg, offset=i) # Call our callback with the header field. + i = colon data_callback("header_field", i) # Move to parsing the header value. state = MultipartState.HEADER_VALUE_START - elif c not in TOKEN_CHARS_SET: - msg = "Found invalid character %r in header at %d" % (c, i) - self.logger.warning(msg) - raise MultipartParseError(msg, offset=i) - else: - advance_header_size() - elif state == MultipartState.HEADER_VALUE_START: # Skip leading spaces. if c == SPACE: @@ -1287,15 +1295,19 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No i -= 1 elif state == MultipartState.HEADER_VALUE: - # If we've got a CR, we're nearly done our headers. Otherwise, - # we do nothing and just move past this character. - if c == CR: + # The value runs until the terminating CR; jump straight to it + # instead of inspecting every byte. + cr = data.find(b"\r", i, length) + end = cr if cr != -1 else length + advance_header_size(end - i) + if cr != -1: + i = cr data_callback("header_value", i) self.callback("header_end") current_header_size = 0 state = MultipartState.HEADER_VALUE_ALMOST_DONE else: - advance_header_size() + i = length elif state == MultipartState.HEADER_VALUE_ALMOST_DONE: # The last character should be a LF. If not, it's an error. @@ -1338,17 +1350,13 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No # find part of a boundary, but it doesn't match fully. prev_index = index - # Set up variables. - boundary_length = len(boundary) - data_length = length - # If our index is 0, we're starting a new part, so start our # search. if index == 0: # The most common case is likely to be that the whole # boundary is present in the buffer. # Calling `find` is much faster than iterating here. - i0 = data.find(boundary, i, data_length) + i0 = data.find(boundary, i, length) if i0 >= 0: # We matched the whole boundary string. index = boundary_length - 1 @@ -1360,9 +1368,9 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No # Since the length to be searched is limited to the # boundary length, scan the tail for boundary[0] via # bytes.find (C-level) to keep cost off the Python loop. - i = max(i, data_length - boundary_length) - j = data.find(boundary[:1], i, data_length - 1) - i = j if j >= 0 else data_length - 1 + i = max(i, length - boundary_length) + j = data.find(boundary[:1], i, length - 1) + i = j if j >= 0 else length - 1 c = data[i] @@ -1456,7 +1464,7 @@ def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> No i -= 1 elif state == MultipartState.END_BOUNDARY: - if index == len(boundary) - 2 + 1: + if index == boundary_length - 1: if c != HYPHEN: msg = "Did not find - at end of boundary (%d)" % (i,) self.logger.warning(msg) @@ -1892,6 +1900,8 @@ def parse_form( content_length: int | float | bytes | None = headers.get("Content-Length") if content_length is not None: content_length = int(content_length) + if content_length < 0: + raise ValueError("Content-Length must be non-negative") else: content_length = float("inf") bytes_read = 0 diff --git a/tests/test_multipart.py b/tests/test_multipart.py index aa371ad..ae2bb01 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -1650,6 +1650,15 @@ def test_parse_form_invalid_chunk_size(self) -> None: chunk_size=0, ) + def test_parse_form_negative_content_length(self) -> None: + with self.assertRaisesRegex(ValueError, "Content-Length must be non-negative"): + parse_form( + {"Content-Type": b"application/octet-stream", "Content-Length": b"-1"}, + BytesIO(b"123456789012345"), + lambda _: None, + lambda _: None, + ) + def suite() -> unittest.TestSuite: suite = unittest.TestSuite()