|
1 | 1 | from unittest import mock |
| 2 | +from test import support |
| 3 | +from test.test_httpservers import NoLogRequestHandler |
2 | 4 | from unittest import TestCase |
3 | 5 | from wsgiref.util import setup_testing_defaults |
4 | 6 | from wsgiref.headers import Headers |
5 | | -from wsgiref.handlers import BaseHandler, BaseCGIHandler |
| 7 | +from wsgiref.handlers import BaseHandler, BaseCGIHandler, SimpleHandler |
6 | 8 | from wsgiref import util |
7 | 9 | from wsgiref.validate import validator |
8 | 10 | from wsgiref.simple_server import WSGIServer, WSGIRequestHandler |
9 | 11 | from wsgiref.simple_server import make_server |
| 12 | +from http.client import HTTPConnection |
10 | 13 | from io import StringIO, BytesIO, BufferedReader |
11 | 14 | from socketserver import BaseServer |
12 | 15 | from platform import python_implementation |
13 | 16 |
|
14 | 17 | import os |
15 | 18 | import re |
| 19 | +import signal |
16 | 20 | import sys |
17 | 21 | import unittest |
18 | 22 |
|
@@ -245,6 +249,56 @@ def app(e, s): |
245 | 249 | ], |
246 | 250 | out.splitlines()) |
247 | 251 |
|
| 252 | + def test_interrupted_write(self): |
| 253 | + # BaseHandler._write() and _flush() have to write all data, even if |
| 254 | + # it takes multiple send() calls. Test this by interrupting a send() |
| 255 | + # call with a Unix signal. |
| 256 | + threading = support.import_module("threading") |
| 257 | + pthread_kill = support.get_attribute(signal, "pthread_kill") |
| 258 | + |
| 259 | + def app(environ, start_response): |
| 260 | + start_response("200 OK", []) |
| 261 | + return [bytes(support.SOCK_MAX_SIZE)] |
| 262 | + |
| 263 | + class WsgiHandler(NoLogRequestHandler, WSGIRequestHandler): |
| 264 | + pass |
| 265 | + |
| 266 | + server = make_server(support.HOST, 0, app, handler_class=WsgiHandler) |
| 267 | + self.addCleanup(server.server_close) |
| 268 | + interrupted = threading.Event() |
| 269 | + |
| 270 | + def signal_handler(signum, frame): |
| 271 | + interrupted.set() |
| 272 | + |
| 273 | + original = signal.signal(signal.SIGUSR1, signal_handler) |
| 274 | + self.addCleanup(signal.signal, signal.SIGUSR1, original) |
| 275 | + received = None |
| 276 | + main_thread = threading.get_ident() |
| 277 | + |
| 278 | + def run_client(): |
| 279 | + http = HTTPConnection(*server.server_address) |
| 280 | + http.request("GET", "/") |
| 281 | + with http.getresponse() as response: |
| 282 | + response.read(100) |
| 283 | + # The main thread should now be blocking in a send() system |
| 284 | + # call. But in theory, it could get interrupted by other |
| 285 | + # signals, and then retried. So keep sending the signal in a |
| 286 | + # loop, in case an earlier signal happens to be delivered at |
| 287 | + # an inconvenient moment. |
| 288 | + while True: |
| 289 | + pthread_kill(main_thread, signal.SIGUSR1) |
| 290 | + if interrupted.wait(timeout=float(1)): |
| 291 | + break |
| 292 | + nonlocal received |
| 293 | + received = len(response.read()) |
| 294 | + http.close() |
| 295 | + |
| 296 | + background = threading.Thread(target=run_client) |
| 297 | + background.start() |
| 298 | + server.handle_request() |
| 299 | + background.join() |
| 300 | + self.assertEqual(received, support.SOCK_MAX_SIZE - 100) |
| 301 | + |
248 | 302 |
|
249 | 303 | class UtilityTests(TestCase): |
250 | 304 |
|
@@ -701,6 +755,31 @@ def close(self): |
701 | 755 | h.run(error_app) |
702 | 756 | self.assertEqual(side_effects['close_called'], True) |
703 | 757 |
|
| 758 | + def testPartialWrite(self): |
| 759 | + written = bytearray() |
| 760 | + |
| 761 | + class PartialWriter: |
| 762 | + def write(self, b): |
| 763 | + partial = b[:7] |
| 764 | + written.extend(partial) |
| 765 | + return len(partial) |
| 766 | + |
| 767 | + def flush(self): |
| 768 | + pass |
| 769 | + |
| 770 | + environ = {"SERVER_PROTOCOL": "HTTP/1.0"} |
| 771 | + h = SimpleHandler(BytesIO(), PartialWriter(), sys.stderr, environ) |
| 772 | + msg = "should not do partial writes" |
| 773 | + with self.assertWarnsRegex(DeprecationWarning, msg): |
| 774 | + h.run(hello_app) |
| 775 | + self.assertEqual(b"HTTP/1.0 200 OK\r\n" |
| 776 | + b"Content-Type: text/plain\r\n" |
| 777 | + b"Date: Mon, 05 Jun 2006 18:49:54 GMT\r\n" |
| 778 | + b"Content-Length: 13\r\n" |
| 779 | + b"\r\n" |
| 780 | + b"Hello, world!", |
| 781 | + written) |
| 782 | + |
704 | 783 |
|
705 | 784 | if __name__ == "__main__": |
706 | 785 | unittest.main() |
0 commit comments