|
9 | 9 | from copy import copy |
10 | 10 | from typing import Callable, Mapping, Optional, Union |
11 | 11 | from ssl import SSLContext |
| 12 | +import urllib.parse |
| 13 | +import abc |
12 | 14 |
|
13 | 15 | from werkzeug.http import parse_authorization_header |
14 | 16 | from werkzeug.serving import make_server |
15 | 17 | from werkzeug.wrappers import Request, Response |
| 18 | +import werkzeug.urls |
| 19 | +from werkzeug.datastructures import MultiDict |
16 | 20 |
|
17 | 21 | URI_DEFAULT = "" |
18 | 22 | METHOD_ALL = "__ALL" |
@@ -123,6 +127,83 @@ def __call__(self, header_name: str, actual: str, expected: str) -> bool: |
123 | 127 | ) |
124 | 128 |
|
125 | 129 |
|
| 130 | +class QueryMatcher(abc.ABC): |
| 131 | + def match(self, request_query_string: bytes) -> bool: |
| 132 | + values = self.get_comparing_values(request_query_string) |
| 133 | + return values[0] == values[1] |
| 134 | + |
| 135 | + @abc.abstractmethod |
| 136 | + def get_comparing_values(self, request_query_string: bytes) -> tuple: |
| 137 | + pass |
| 138 | + |
| 139 | +class StringQueryMatcher(QueryMatcher): |
| 140 | + def __init__(self, query_string: Union[bytes, str]): |
| 141 | + if query_string is not None and not isinstance(query_string, (str, bytes)): |
| 142 | + raise TypeError("query_string must be a string, or a bytes-like object") |
| 143 | + |
| 144 | + self.query_string = query_string |
| 145 | + |
| 146 | + def get_comparing_values(self, request_query_string: bytes) -> tuple: |
| 147 | + if self.query_string is not None: |
| 148 | + if isinstance(self.query_string, str): |
| 149 | + query_string = self.query_string.encode() |
| 150 | + elif isinstance(self.query_string, bytes): |
| 151 | + query_string = self.query_string |
| 152 | + else: |
| 153 | + raise TypeError("query_string must be a string, or a bytes-like object") |
| 154 | + |
| 155 | + return (request_query_string, query_string) |
| 156 | + |
| 157 | + |
| 158 | +class MappingQueryMatcher(QueryMatcher): |
| 159 | + def __init__(self, query_dict: [Mapping, MultiDict]): |
| 160 | + self.query_dict = query_dict |
| 161 | + |
| 162 | + def get_comparing_values(self, request_query_string: bytes) -> tuple: |
| 163 | + query = werkzeug.urls.url_decode(request_query_string) |
| 164 | + if isinstance(self.query_dict, MultiDict): |
| 165 | + return (query, self.query_dict) |
| 166 | + else: |
| 167 | + return (query.to_dict(), dict(self.query_dict)) |
| 168 | + |
| 169 | + |
| 170 | +class BooleanQueryMatcher(QueryMatcher): |
| 171 | + def __init__(self, result: bool): |
| 172 | + self.result = result |
| 173 | + |
| 174 | + def get_comparing_values(self, request_query_string): |
| 175 | + if self.result: |
| 176 | + return (True, True) |
| 177 | + else: |
| 178 | + return (True, False) |
| 179 | + |
| 180 | + |
| 181 | +def _get_dict_type(d: Mapping, default=bytes): |
| 182 | + try: |
| 183 | + first_key = next(iter(d.keys())) |
| 184 | + key_type = type(first_key) |
| 185 | + except StopIteration: |
| 186 | + key_type = default |
| 187 | + |
| 188 | + return key_type |
| 189 | + |
| 190 | + |
| 191 | +def _create_query_matcher(query_string: Union[None, QueryMatcher, str, bytes, Mapping]) -> QueryMatcher: |
| 192 | + if isinstance(query_string, QueryMatcher): |
| 193 | + return query_string |
| 194 | + |
| 195 | + if query_string is None: |
| 196 | + return BooleanQueryMatcher(True) |
| 197 | + |
| 198 | + if isinstance(query_string, (str, bytes)): |
| 199 | + return StringQueryMatcher(query_string) |
| 200 | + |
| 201 | + if isinstance(query_string, Mapping): |
| 202 | + return MappingQueryMatcher(query_string) |
| 203 | + |
| 204 | + raise TypeError("Unable to cast this type to QueryMatcher: {!r}".format(type(query_string))) |
| 205 | + |
| 206 | + |
126 | 207 | class RequestMatcher: |
127 | 208 | """ |
128 | 209 | Matcher object for the incoming request. |
@@ -151,12 +232,10 @@ def __init__( |
151 | 232 | query_string: Union[None, bytes, str] = None, |
152 | 233 | header_value_matcher: Optional[HeaderValueMatcher] = None): |
153 | 234 |
|
154 | | - if query_string is not None and not isinstance(query_string, (str, bytes)): |
155 | | - raise TypeError("query_string must be a string, or a bytes-like object") |
156 | | - |
157 | 235 | self.uri = uri |
158 | 236 | self.method = method |
159 | 237 | self.query_string = query_string |
| 238 | + self.query_matcher = _create_query_matcher(self.query_string) |
160 | 239 |
|
161 | 240 | if headers is None: |
162 | 241 | self.headers = {} |
@@ -211,15 +290,7 @@ def difference(self, request: Request) -> list: |
211 | 290 | if self.method != METHOD_ALL and self.method != request.method: |
212 | 291 | retval.append(("method", request.method, self.method)) |
213 | 292 |
|
214 | | - if self.query_string is not None: |
215 | | - if isinstance(self.query_string, str): |
216 | | - query_string = self.query_string.encode() |
217 | | - elif isinstance(self.query_string, bytes): |
218 | | - query_string = self.query_string |
219 | | - else: |
220 | | - raise TypeError("query_string must be a string, or a bytes-like object") |
221 | | - |
222 | | - if self.query_string is not None and query_string != request.query_string: |
| 293 | + if not self.query_matcher.match(request.query_string): |
223 | 294 | retval.append(("query_string", request.query_string, self.query_string)) |
224 | 295 |
|
225 | 296 | request_headers = {} |
|
0 commit comments