Skip to content

Commit 69973fd

Browse files
Harden blackd browser-facing request handling (#5039)
1 parent 4937fe6 commit 69973fd

5 files changed

Lines changed: 244 additions & 41 deletions

File tree

CHANGES.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@
4848

4949
<!-- Changes to blackd -->
5050

51+
- Disable browser-originated requests by default, add configurable origin allowlisting
52+
and request body limits, and bound executor submissions to improve backpressure
53+
(#5039)
54+
5155
### Integrations
5256

5357
<!-- For example, Docker, GitHub Actions, pre-commit, editors -->

docs/usage_and_configuration/black_as_a_server.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,17 @@ if __name__ == "__main__":
5151
asyncio.run(main())
5252
```
5353

54+
Cross-origin browser requests are rejected by default. If you need to access `blackd`
55+
from a browser-based client, pass one or more `--cors-allow-origin` options to allow
56+
specific origins.
57+
5458
## Protocol
5559

5660
`blackd` only accepts `POST` requests at the `/` path. The body of the request should
5761
contain the python source code to be formatted, encoded according to the `charset` field
5862
in the `Content-Type` request header. If no `charset` is specified, `blackd` assumes
59-
`UTF-8`.
63+
`UTF-8`. Request bodies are limited to 5 MiB by default; use `--max-body-size` to change
64+
that limit.
6065

6166
There are a few HTTP headers that control how the source code is formatted. These
6267
correspond to command line flags for _Black_. There is one exception to this:

src/blackd/__init__.py

Lines changed: 92 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import logging
3+
import os
34
from concurrent.futures import Executor, ProcessPoolExecutor
45
from datetime import datetime, timezone
56
from functools import cache, partial
@@ -55,6 +56,8 @@
5556

5657
# Response headers
5758
BLACK_VERSION_HEADER = "X-Black-Version"
59+
DEFAULT_MAX_BODY_SIZE = 5 * 1024 * 1024
60+
DEFAULT_WORKERS = os.cpu_count() or 1
5861

5962

6063
class HeaderError(Exception):
@@ -76,10 +79,28 @@ class InvalidVariantHeader(Exception):
7679
@click.option(
7780
"--bind-port", type=int, help="Port to listen on", default=45484, show_default=True
7881
)
82+
@click.option(
83+
"--cors-allow-origin",
84+
"cors_allow_origins",
85+
multiple=True,
86+
help="Origin allowed to access blackd over CORS. Can be passed multiple times.",
87+
)
88+
@click.option(
89+
"--max-body-size",
90+
type=click.IntRange(min=1),
91+
default=DEFAULT_MAX_BODY_SIZE,
92+
show_default=True,
93+
help="Maximum request body size in bytes.",
94+
)
7995
@click.version_option(version=black.__version__)
80-
def main(bind_host: str, bind_port: int) -> None:
96+
def main(
97+
bind_host: str,
98+
bind_port: int,
99+
cors_allow_origins: tuple[str, ...],
100+
max_body_size: int,
101+
) -> None:
81102
logging.basicConfig(level=logging.INFO)
82-
app = make_app()
103+
app = make_app(cors_allow_origins=cors_allow_origins, max_body_size=max_body_size)
83104
ver = black.__version__
84105
black.out(f"blackd version {ver} listening on {bind_host} port {bind_port}")
85106
loop = maybe_use_uvloop()
@@ -99,18 +120,42 @@ def main(bind_host: str, bind_port: int) -> None:
99120

100121
@cache
101122
def executor() -> Executor:
102-
return ProcessPoolExecutor()
123+
return ProcessPoolExecutor(max_workers=DEFAULT_WORKERS)
103124

104125

105-
def make_app() -> web.Application:
126+
def make_app(
127+
*,
128+
cors_allow_origins: tuple[str, ...] = (),
129+
max_body_size: int = DEFAULT_MAX_BODY_SIZE,
130+
) -> web.Application:
106131
app = web.Application(
107-
middlewares=[cors(allow_headers=(*BLACK_HEADERS, "Content-Type"))]
132+
client_max_size=max_body_size,
133+
middlewares=[
134+
cors(
135+
allow_headers=(*BLACK_HEADERS, "Content-Type"),
136+
allow_origins=frozenset(cors_allow_origins),
137+
expose_headers=(BLACK_VERSION_HEADER,),
138+
)
139+
],
108140
)
109-
app.add_routes([web.post("/", partial(handle, executor=executor()))])
141+
app.add_routes([
142+
web.post(
143+
"/",
144+
partial(
145+
handle,
146+
executor=executor(),
147+
executor_semaphore=asyncio.BoundedSemaphore(DEFAULT_WORKERS),
148+
),
149+
)
150+
])
110151
return app
111152

112153

113-
async def handle(request: web.Request, executor: Executor) -> web.Response:
154+
async def handle(
155+
request: web.Request,
156+
executor: Executor,
157+
executor_semaphore: asyncio.BoundedSemaphore,
158+
) -> web.Response:
114159
headers = {BLACK_VERSION_HEADER: __version__}
115160
try:
116161
if request.headers.get(PROTOCOL_VERSION_HEADER, "1") != "1":
@@ -125,7 +170,7 @@ async def handle(request: web.Request, executor: Executor) -> web.Response:
125170
mode = parse_mode(request.headers)
126171
except HeaderError as e:
127172
return web.Response(status=400, text=e.args[0])
128-
req_bytes = await request.content.read()
173+
req_bytes = await request.read()
129174
charset = request.charset if request.charset is not None else "utf8"
130175
req_str = req_bytes.decode(charset)
131176
then = datetime.now(timezone.utc)
@@ -136,27 +181,21 @@ async def handle(request: web.Request, executor: Executor) -> web.Response:
136181
header = req_str[:first_newline_position]
137182
req_str = req_str[first_newline_position:]
138183

139-
loop = asyncio.get_event_loop()
140-
formatted_str = await loop.run_in_executor(
141-
executor, partial(black.format_file_contents, req_str, fast=fast, mode=mode)
184+
only_diff = bool(request.headers.get(DIFF_HEADER, False))
185+
formatted_str = await format_code(
186+
req_str=req_str,
187+
fast=fast,
188+
mode=mode,
189+
then=then,
190+
only_diff=only_diff,
191+
executor=executor,
192+
executor_semaphore=executor_semaphore,
142193
)
143194

144195
# Put the source first line back
145196
req_str = header + req_str
146197
formatted_str = header + formatted_str
147198

148-
# Only output the diff in the HTTP response
149-
only_diff = bool(request.headers.get(DIFF_HEADER, False))
150-
if only_diff:
151-
now = datetime.now(timezone.utc)
152-
src_name = f"In\t{then}"
153-
dst_name = f"Out\t{now}"
154-
loop = asyncio.get_event_loop()
155-
formatted_str = await loop.run_in_executor(
156-
executor,
157-
partial(black.diff, req_str, formatted_str, src_name, dst_name),
158-
)
159-
160199
return web.Response(
161200
content_type=request.content_type,
162201
charset=charset,
@@ -167,11 +206,41 @@ async def handle(request: web.Request, executor: Executor) -> web.Response:
167206
return web.Response(status=204, headers=headers)
168207
except black.InvalidInput as e:
169208
return web.Response(status=400, headers=headers, text=str(e))
209+
except web.HTTPException:
210+
raise
170211
except Exception as e:
171212
logging.exception("Exception during handling a request")
172213
return web.Response(status=500, headers=headers, text=str(e))
173214

174215

216+
async def format_code(
217+
*,
218+
req_str: str,
219+
fast: bool,
220+
mode: black.FileMode,
221+
then: datetime,
222+
only_diff: bool,
223+
executor: Executor,
224+
executor_semaphore: asyncio.BoundedSemaphore,
225+
) -> str:
226+
async with executor_semaphore:
227+
loop = asyncio.get_event_loop()
228+
formatted_str = await loop.run_in_executor(
229+
executor, partial(black.format_file_contents, req_str, fast=fast, mode=mode)
230+
)
231+
232+
if not only_diff:
233+
return formatted_str
234+
235+
now = datetime.now(timezone.utc)
236+
src_name = f"In\t{then}"
237+
dst_name = f"Out\t{now}"
238+
return await loop.run_in_executor(
239+
executor,
240+
partial(black.diff, req_str, formatted_str, src_name, dst_name),
241+
)
242+
243+
175244
def parse_mode(headers: MultiMapping[str]) -> black.Mode:
176245
try:
177246
line_length = int(headers.get(LINE_LENGTH_HEADER, black.DEFAULT_LINE_LENGTH))

src/blackd/middlewares.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from collections.abc import Awaitable, Callable, Iterable
1+
from collections.abc import Awaitable, Callable, Collection, Iterable
22

3+
from aiohttp import web
34
from aiohttp.typedefs import Middleware
45
from aiohttp.web_middlewares import middleware
56
from aiohttp.web_request import Request
@@ -8,22 +9,31 @@
89
Handler = Callable[[Request], Awaitable[StreamResponse]]
910

1011

11-
def cors(allow_headers: Iterable[str]) -> Middleware:
12+
def cors(
13+
*,
14+
allow_headers: Iterable[str],
15+
allow_origins: Collection[str],
16+
expose_headers: Iterable[str],
17+
) -> Middleware:
1218
@middleware
1319
async def impl(request: Request, handler: Handler) -> StreamResponse:
20+
origin = request.headers.get("Origin")
21+
if not origin:
22+
return await handler(request)
23+
24+
if origin not in allow_origins:
25+
return web.Response(status=403, text="CORS origin is not allowed")
26+
1427
is_options = request.method == "OPTIONS"
1528
is_preflight = is_options and "Access-Control-Request-Method" in request.headers
1629
if is_preflight:
1730
resp = StreamResponse()
1831
else:
1932
resp = await handler(request)
2033

21-
origin = request.headers.get("Origin")
22-
if not origin:
23-
return resp
24-
25-
resp.headers["Access-Control-Allow-Origin"] = "*"
26-
resp.headers["Access-Control-Expose-Headers"] = "*"
34+
resp.headers["Access-Control-Allow-Origin"] = origin
35+
if expose_headers:
36+
resp.headers["Access-Control-Expose-Headers"] = ", ".join(expose_headers)
2737
if is_options:
2838
resp.headers["Access-Control-Allow-Headers"] = ", ".join(allow_headers)
2939
resp.headers["Access-Control-Allow-Methods"] = ", ".join(

0 commit comments

Comments
 (0)