11
22import threading
33import json
4- from typing import Mapping , Optional , Union , Callable
4+ from collections import defaultdict
5+ from typing import Callable , Mapping , Optional , Union
56from ssl import SSLContext
67
7- from werkzeug .wrappers import Request , Response
8+ from werkzeug .http import parse_authorization_header
89from werkzeug .serving import make_server
10+ from werkzeug .wrappers import Request , Response
911
1012URI_DEFAULT = ""
1113METHOD_ALL = "__ALL"
@@ -35,6 +37,49 @@ class HTTPServerError(Error):
3537 pass
3638
3739
40+ class NoMethodFoundForMatchingHeaderValueError (Error ):
41+ """
42+ Raised when a :py:class:`HeaderValueMatcher` has no registered method to match the header value.
43+ """
44+
45+ pass
46+
47+
48+ class HeaderValueMatcher :
49+ """
50+ Matcher object for the header value of incoming request.
51+
52+ :param matchers: mapping from header name to comparator function that accepts actual and expected header values
53+ and return whether they are equal as bool.
54+ """
55+ DEFAULT_MATCHERS = {}
56+
57+ def __init__ (self , matchers : Optional [Mapping [str , Callable [[str , str ], bool ]]] = None ):
58+ self .matchers = self .DEFAULT_MATCHERS if matchers is None else matchers
59+
60+ @staticmethod
61+ def authorization_header_value_matcher (actual : str , expected : str ) -> bool :
62+ return parse_authorization_header (actual ) == parse_authorization_header (expected )
63+
64+ @staticmethod
65+ def default_header_value_matcher (actual : str , expected : str ) -> bool :
66+ return actual == expected
67+
68+ def __call__ (self , header_name : str , actual : str , expected : str ) -> bool :
69+ try :
70+ matcher = self .matchers [header_name ]
71+ except KeyError :
72+ raise NoMethodFoundForMatchingHeaderValueError (
73+ "No method found for matching header value: {}" .format (header_name ))
74+ return matcher (actual , expected )
75+
76+
77+ HeaderValueMatcher .DEFAULT_MATCHERS = defaultdict (
78+ lambda : HeaderValueMatcher .default_header_value_matcher ,
79+ {'Authorization' : HeaderValueMatcher .authorization_header_value_matcher }
80+ )
81+
82+
3883class RequestMatcher :
3984 """
4085 Matcher object for the incoming request.
@@ -58,7 +103,8 @@ def __init__(
58103 data : Union [str , bytes , None ] = None ,
59104 data_encoding : str = "utf-8" ,
60105 headers : Optional [Mapping [str , str ]] = None ,
61- query_string : Optional [str ] = None ):
106+ query_string : Optional [str ] = None ,
107+ header_value_matcher : Optional [HeaderValueMatcher ] = None ):
62108
63109 self .uri = uri
64110 self .method = method
@@ -74,6 +120,8 @@ def __init__(
74120
75121 self .data = data
76122
123+ self .header_value_matcher = HeaderValueMatcher () if header_value_matcher is None else header_value_matcher
124+
77125 def __repr__ (self ):
78126 """
79127 Returns the string representation of the object, with the known parameters.
@@ -121,7 +169,7 @@ def difference(self, request: Request) -> list:
121169 request_headers = {}
122170 expected_headers = {}
123171 for key , value in self .headers .items ():
124- if request .headers .get (key ) != value :
172+ if not self . header_value_matcher ( key , request .headers .get (key ), value ) :
125173 request_headers [key ] = request .headers .get (key )
126174 expected_headers [key ] = value
127175
@@ -359,7 +407,8 @@ def expect_oneshot_request(
359407 headers : Optional [Mapping [str , str ]] = None ,
360408 query_string : Optional [str ] = None ,
361409 * ,
362- ordered = False ) -> RequestHandler :
410+ ordered = False ,
411+ header_value_matcher : Optional [HeaderValueMatcher ] = None ) -> RequestHandler :
363412 """
364413 Create and register a oneshot request handler.
365414
@@ -382,11 +431,20 @@ def expect_oneshot_request(
382431 :param headers: dictionary of the headers of the request to be matched
383432 :param query_string: the http query string starting with ``?``, such as ``?username=user``
384433 :param ordered: specifies whether to create an ordered handler or not. See above for details.
434+ :param header_value_matcher: :py:class:`HeaderValueMatcher` that matches values of headers.
385435
386436 :return: Created and register :py:class:`RequestHandler`.
387437 """
388438
389- matcher = self .create_matcher (uri , method = method , data = data , data_encoding = data_encoding , headers = headers , query_string = query_string )
439+ matcher = self .create_matcher (
440+ uri ,
441+ method = method ,
442+ data = data ,
443+ data_encoding = data_encoding ,
444+ headers = headers ,
445+ query_string = query_string ,
446+ header_value_matcher = header_value_matcher ,
447+ )
390448 request_handler = RequestHandler (matcher )
391449 if ordered :
392450 self .ordered_handlers .append (request_handler )
@@ -402,7 +460,8 @@ def expect_request(
402460 data : Union [str , bytes , None ] = None ,
403461 data_encoding : str = "utf-8" ,
404462 headers : Optional [Mapping [str , str ]] = None ,
405- query_string : Optional [str ] = None ) -> RequestHandler :
463+ query_string : Optional [str ] = None ,
464+ header_value_matcher : Optional [HeaderValueMatcher ] = None ) -> RequestHandler :
406465 """
407466 Create and register a permanent request handler.
408467
@@ -417,11 +476,20 @@ def expect_request(
417476 :param data_encoding: the encoding used for data parameter if data is a string.
418477 :param headers: dictionary of the headers of the request to be matched
419478 :param ordered: specifies whether to create an ordered handler or not. See above for details.
479+ :param header_value_matcher: :py:class:`HeaderValueMatcher` that matches values of headers.
420480
421481 :return: Created and register :py:class:`RequestHandler`.
422482 """
423483
424- matcher = self .create_matcher (uri , method = method , data = data , data_encoding = data_encoding , headers = headers , query_string = query_string )
484+ matcher = self .create_matcher (
485+ uri ,
486+ method = method ,
487+ data = data ,
488+ data_encoding = data_encoding ,
489+ headers = headers ,
490+ query_string = query_string ,
491+ header_value_matcher = header_value_matcher ,
492+ )
425493 request_handler = RequestHandler (matcher )
426494 self .handlers .append (request_handler )
427495 return request_handler
0 commit comments