diff --git a/pytest_httpserver/__init__.py b/pytest_httpserver/__init__.py index c691a3ec..9d3692f2 100644 --- a/pytest_httpserver/__init__.py +++ b/pytest_httpserver/__init__.py @@ -5,7 +5,7 @@ # flake8: noqa -from .httpserver import HTTPServer +from .httpserver import HTTPServer, HTTPProxy from .httpserver import HTTPServerError, Error, NoHandlerError from .httpserver import WaitingSettings, HeaderValueMatcher, RequestHandler from .httpserver import URIPattern, URI_DEFAULT, METHOD_ALL diff --git a/pytest_httpserver/httpserver.py b/pytest_httpserver/httpserver.py index 14c57a31..3ef4ea7b 100644 --- a/pytest_httpserver/httpserver.py +++ b/pytest_httpserver/httpserver.py @@ -19,6 +19,8 @@ import werkzeug.urls from werkzeug.datastructures import MultiDict +from wsgiprox.wsgiprox import WSGIProxMiddleware + URI_DEFAULT = "" METHOD_ALL = "__ALL" @@ -1118,3 +1120,62 @@ def __exit__(self, *args, **kwargs): """ if self.is_running(): self.stop() + + +class HTTPProxy(HTTPServer): + """ + Proxy instance which manages handlers to serve pre-defined requests. + + :param host: the host or IP where the proxy will listen + :param port: the TCP port where the proxy will listen + :param default_waiting_settings: the waiting settings object to use as default settings for :py:meth:`wait` context + manager + + .. py:attribute:: log + + Attribute containing the list of two-element tuples. Each tuple contains + :py:class:`Request` and :py:class:`Response` object which represents the + incoming request and the outgoing response which happened during the lifetime + of the server. + """ + + DEFAULT_LISTEN_HOST = "localhost" + DEFAULT_LISTEN_PORT = 0 # Use ephemeral port + DEFAULT_PREFIX = "/proxy/" + DEFAULT_PROXY_HOST = "wsgiprox" + + def __init__( + self, + host=DEFAULT_LISTEN_HOST, + port=DEFAULT_LISTEN_PORT, + prefix=DEFAULT_PREFIX, + proxy_host=DEFAULT_PROXY_HOST, + proxy_options=None, + default_waiting_settings: Optional[WaitingSettings] = None): + """ + Initializes the instance. + + """ + + super().__init__(host, port, default_waiting_settings) + self.prefix = prefix + self.proxy_host = proxy_host + self.proxy_options = proxy_options + + @property + def ca_cert(self): + return self.proxy_options["ca_file_cache"] + + def get_proxy_url(self): + return self.url_for("") + + def start(self): + proxy = WSGIProxMiddleware( + self.application, + self.prefix, + self.proxy_host, + proxy_options=self.proxy_options) + self.server = make_server(self.host, self.port, proxy) + self.port = self.server.port # Update port (needed if `port` was set to 0) + self.server_thread = threading.Thread(target=self.thread_target) + self.server_thread.start() diff --git a/pytest_httpserver/pytest_plugin.py b/pytest_httpserver/pytest_plugin.py index 99195968..58f3b2a0 100644 --- a/pytest_httpserver/pytest_plugin.py +++ b/pytest_httpserver/pytest_plugin.py @@ -3,11 +3,12 @@ import os import pytest -from .httpserver import HTTPServer +from .httpserver import HTTPServer, HTTPProxy class Plugin: SERVER = None + PROXY = None class PluginHTTPServer(HTTPServer): @@ -20,6 +21,27 @@ def stop(self): Plugin.SERVER = None +class PluginHTTPProxy(HTTPProxy): + def start(self): + super().start() + Plugin.PROXY = self + + def stop(self): + super().stop() + Plugin.PROXY = None + + +@pytest.fixture(scope="session") +def plugin_httpserver_class(): + yield PluginHTTPServer + + +@pytest.fixture(scope="session") +def plugin_proxy_class(): + print("plugin_proxy_class") + yield PluginHTTPProxy + + def get_httpserver_listen_address(): listen_host = os.environ.get("PYTEST_HTTPSERVER_HOST") listen_port = os.environ.get("PYTEST_HTTPSERVER_PORT") @@ -35,7 +57,7 @@ def httpserver_listen_address(): @pytest.fixture -def httpserver(httpserver_listen_address): +def httpserver(httpserver_listen_address, plugin_httpserver_class): if Plugin.SERVER: Plugin.SERVER.clear() yield Plugin.SERVER @@ -47,13 +69,34 @@ def httpserver(httpserver_listen_address): if not port: port = HTTPServer.DEFAULT_LISTEN_PORT - server = PluginHTTPServer(host=host, port=port) + server = plugin_httpserver_class(host=host, port=port) + server.start() + yield server + + +@pytest.fixture +def httpproxy(httpserver_listen_address, tmp_path, plugin_proxy_class): + if Plugin.PROXY: + Plugin.PROXY.clear() + yield Plugin.PROXY + return + + host, port = httpserver_listen_address + if not host: + host = HTTPProxy.DEFAULT_LISTEN_HOST + if not port: + port = HTTPProxy.DEFAULT_LISTEN_PORT + + ca_dir = tmp_path.joinpath("httpproxy_ca") + ca_dir.mkdir(exist_ok=True) + server = plugin_proxy_class(host=host, port=port, proxy_options={"ca_file_cache": str(ca_dir.joinpath("wsgiprox-ca.pem"))}) server.start() yield server def pytest_sessionfinish(session, exitstatus): # pylint: disable=unused-argument - if Plugin.SERVER is not None: - Plugin.SERVER.clear() - if Plugin.SERVER.is_running(): - Plugin.SERVER.stop() + for instance in (Plugin.SERVER, Plugin.PROXY): + if instance is not None: + instance.clear() + if instance.is_running(): + instance.stop() diff --git a/setup.py b/setup.py index c480111c..1a783604 100755 --- a/setup.py +++ b/setup.py @@ -19,6 +19,7 @@ python_requires=">=3.4", install_requires=[ "typing;python_version<'3.5'", + "wsgiprox", "werkzeug" ], extras_require={ diff --git a/tests/test_proxy.py b/tests/test_proxy.py new file mode 100644 index 00000000..e8db5f3c --- /dev/null +++ b/tests/test_proxy.py @@ -0,0 +1,24 @@ + +import requests +from pytest_httpserver import HTTPProxy + +def test_proxy_http(httpproxy: HTTPProxy): + httpproxy.expect_request("/proxy/http://example.com/path/file.html").respond_with_data("Hello world!") + + with requests.Session() as session: + session.proxies = {"http": httpproxy.get_proxy_url()} + resp = session.get("http://example.com/path/file.html", ) + assert resp.status_code == 200 + assert resp.text == "Hello world!" + + +def test_proxy_https(httpproxy: HTTPProxy): + httpproxy.expect_request("/proxy/https://example.com/path/file.html").respond_with_data("Hello world!") + + with requests.Session() as session: + session.verify = httpproxy.ca_cert + session.proxies = {"https": httpproxy.get_proxy_url()} + + resp = session.get("https://example.com/path/file.html") + assert resp.status_code == 200 + assert resp.text == "Hello world!"