diff --git a/doc/api.rst b/doc/api.rst index 110fdcf3..17c28576 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -57,6 +57,12 @@ HeaderValueMatcher .. autoclass:: HeaderValueMatcher :members: +MappingQueryMatcher +~~~~~~~~~~~~~~~~~~ + + .. autoclass:: MappingQueryMatcher + :members: + URIPattern ~~~~~~~~~~ diff --git a/doc/howto.rst b/doc/howto.rst index aac9857e..b19ceebf 100644 --- a/doc/howto.rst +++ b/doc/howto.rst @@ -50,6 +50,37 @@ Behind the scenes an additional step is done by the library: it parses up the query_string into the dict and then compares it with the dict provided. +Query parameter quoting +~~~~~~~~~~~~~~~~~~~~~~~ + +*pytest-httpserver* does the quoting of query parameters in the following way. + +Specifying the query parameter as a string expects a quoted version of the query +parameters: + +.. code-block:: python + + httpserver.expect_request("/test", query_string="foo=bar%20baz") + +While if you specify a dict or a multidict to it, it must not be URL encoded: + +.. code-block:: python + + httpserver.expect_request("/test", query_string={"foo": "bar baz"}) + + +The reasoning behind this is the following: + +* if it is specified as a string or bytes, *pytest-httpserver* compares this + value as-is with the incoming query parameter in the http protocol. For + strings, an implicit ``encode()`` method is called on it to convert to bytes. + If you want 100% correctness you can specify a bytes object. + +* if it is specified as a dict or other higher-level object, this is treated as + a higher-level specification so *pytest-httpserver* does the encoding for + convenience. + + URI matching ------------ diff --git a/pytest_httpserver/__init__.py b/pytest_httpserver/__init__.py index f84d01ac..512480ef 100644 --- a/pytest_httpserver/__init__.py +++ b/pytest_httpserver/__init__.py @@ -17,6 +17,7 @@ "METHOD_ALL", "BlockingHTTPServer", "BlockingRequestHandler", + "MappingQueryMatcher", ] from .blocking_httpserver import BlockingHTTPServer @@ -27,6 +28,7 @@ from .httpserver import HeaderValueMatcher from .httpserver import HTTPServer from .httpserver import HTTPServerError +from .httpserver import MappingQueryMatcher from .httpserver import NoHandlerError from .httpserver import RequestHandler from .httpserver import RequestMatcher diff --git a/pytest_httpserver/httpserver.py b/pytest_httpserver/httpserver.py index d29a6c78..207d62dc 100644 --- a/pytest_httpserver/httpserver.py +++ b/pytest_httpserver/httpserver.py @@ -188,7 +188,8 @@ def __init__(self, query_string: bytes | str): :param query_string: the query string will be compared to this string or bytes. If string is specified, it will be encoded by the encode() method. The query must not start with '?' but will be exactly (byte-by-byte) equal - the actual query string of the incoming request. + the actual query string of the incoming request. This means that it + should be URL-quoted as well. """ if not isinstance(query_string, (str, bytes)): raise TypeError("query_string must be a string, or a bytes-like object") @@ -217,7 +218,7 @@ def __init__(self, query_dict: Mapping | MultiDict): key-value mapping where both key and value should be string. If there are multiple values specified for the same key in the request, the first element will be used. If you want to match multiple values, use a MultiDict object from werkzeug, which - represents multiple values for one key. + represents multiple values for one key. Keys and values must be URL-unquoted. """ self.query_dict = query_dict diff --git a/tests/test_querystring.py b/tests/test_querystring.py index ff3bd997..607ae9c4 100644 --- a/tests/test_querystring.py +++ b/tests/test_querystring.py @@ -1,6 +1,15 @@ +from __future__ import annotations + +import typing +import urllib + import requests -from pytest_httpserver import HTTPServer +if typing.TYPE_CHECKING: + from pytest_httpserver import HTTPServer + +from pytest_httpserver.httpserver import MappingQueryMatcher +from pytest_httpserver.httpserver import QueryMatcher def test_querystring_str(httpserver: HTTPServer): @@ -32,3 +41,50 @@ def test_querystring_dict(httpserver: HTTPServer): httpserver.check_assertions() assert response.text == "example_response" assert response.status_code == 200 + + +class MyQueryStringMatcher(QueryMatcher): + def __init__(self, expected_string: str): + parsed = urllib.parse.parse_qsl(expected_string) # Parse query string into key-value pairs + self._encoded_query_string = urllib.parse.urlencode(parsed, quote_via=urllib.parse.quote) + + def get_comparing_values(self, request_query_string: bytes) -> tuple: + return (self._encoded_query_string, request_query_string.decode("utf-8")) + + +def test_query_string_with_spaces_string_fails(httpserver: HTTPServer): + httpserver.expect_request("/test", query_string=MyQueryStringMatcher("foo=bar baz")).respond_with_data("OK") + + url = httpserver.url_for("/test") + "?foo=bar baz" + requests.get(url) + httpserver.check_assertions() + + +def test_query_string_is_encoded_string_passes(httpserver: HTTPServer): + httpserver.expect_request("/test", query_string="foo=bar%20baz").respond_with_data("OK") + + url = httpserver.url_for("/test") + "?foo=bar baz" + requests.get(url) + httpserver.check_assertions() + + +def test_query_string_with_spaces_dict_passes(httpserver: HTTPServer): + httpserver.expect_request("/test", query_string={"foo": "bar baz"}).respond_with_data("OK") + + url = httpserver.url_for("/test") + "?foo=bar baz" + requests.get(url) + httpserver.check_assertions() + + +class QuotedDictMatcher(MappingQueryMatcher): + def __init__(self, query_dict: dict[str, str]): + unquoted_dict = {k: urllib.parse.unquote(v) for k, v in query_dict.items()} + super().__init__(unquoted_dict) + + +def test_query_string_is_encoded_dict_fails(httpserver: HTTPServer): + httpserver.expect_request("/test", query_string=QuotedDictMatcher({"foo": "bar%20baz"})).respond_with_data("OK") + + url = httpserver.url_for("/test") + "?foo=bar baz" + requests.get(url) + httpserver.check_assertions()