11import asyncio
22import logging
3+ import os
34from concurrent .futures import Executor , ProcessPoolExecutor
45from datetime import datetime , timezone
56from functools import cache , partial
5556
5657# Response headers
5758BLACK_VERSION_HEADER = "X-Black-Version"
59+ DEFAULT_MAX_BODY_SIZE = 5 * 1024 * 1024
60+ DEFAULT_WORKERS = os .cpu_count () or 1
5861
5962
6063class HeaderError (Exception ):
@@ -76,10 +79,28 @@ class InvalidVariantHeader(Exception):
7679@click .option (
7780 "--bind-port" , type = int , help = "Port to listen on" , default = 45484 , show_default = True
7881)
82+ @click .option (
83+ "--cors-allow-origin" ,
84+ "cors_allow_origins" ,
85+ multiple = True ,
86+ help = "Origin allowed to access blackd over CORS. Can be passed multiple times." ,
87+ )
88+ @click .option (
89+ "--max-body-size" ,
90+ type = click .IntRange (min = 1 ),
91+ default = DEFAULT_MAX_BODY_SIZE ,
92+ show_default = True ,
93+ help = "Maximum request body size in bytes." ,
94+ )
7995@click .version_option (version = black .__version__ )
80- def main (bind_host : str , bind_port : int ) -> None :
96+ def main (
97+ bind_host : str ,
98+ bind_port : int ,
99+ cors_allow_origins : tuple [str , ...],
100+ max_body_size : int ,
101+ ) -> None :
81102 logging .basicConfig (level = logging .INFO )
82- app = make_app ()
103+ app = make_app (cors_allow_origins = cors_allow_origins , max_body_size = max_body_size )
83104 ver = black .__version__
84105 black .out (f"blackd version { ver } listening on { bind_host } port { bind_port } " )
85106 loop = maybe_use_uvloop ()
@@ -99,18 +120,42 @@ def main(bind_host: str, bind_port: int) -> None:
99120
100121@cache
101122def executor () -> Executor :
102- return ProcessPoolExecutor ()
123+ return ProcessPoolExecutor (max_workers = DEFAULT_WORKERS )
103124
104125
105- def make_app () -> web .Application :
126+ def make_app (
127+ * ,
128+ cors_allow_origins : tuple [str , ...] = (),
129+ max_body_size : int = DEFAULT_MAX_BODY_SIZE ,
130+ ) -> web .Application :
106131 app = web .Application (
107- middlewares = [cors (allow_headers = (* BLACK_HEADERS , "Content-Type" ))]
132+ client_max_size = max_body_size ,
133+ middlewares = [
134+ cors (
135+ allow_headers = (* BLACK_HEADERS , "Content-Type" ),
136+ allow_origins = frozenset (cors_allow_origins ),
137+ expose_headers = (BLACK_VERSION_HEADER ,),
138+ )
139+ ],
108140 )
109- app .add_routes ([web .post ("/" , partial (handle , executor = executor ()))])
141+ app .add_routes ([
142+ web .post (
143+ "/" ,
144+ partial (
145+ handle ,
146+ executor = executor (),
147+ executor_semaphore = asyncio .BoundedSemaphore (DEFAULT_WORKERS ),
148+ ),
149+ )
150+ ])
110151 return app
111152
112153
113- async def handle (request : web .Request , executor : Executor ) -> web .Response :
154+ async def handle (
155+ request : web .Request ,
156+ executor : Executor ,
157+ executor_semaphore : asyncio .BoundedSemaphore ,
158+ ) -> web .Response :
114159 headers = {BLACK_VERSION_HEADER : __version__ }
115160 try :
116161 if request .headers .get (PROTOCOL_VERSION_HEADER , "1" ) != "1" :
@@ -125,7 +170,7 @@ async def handle(request: web.Request, executor: Executor) -> web.Response:
125170 mode = parse_mode (request .headers )
126171 except HeaderError as e :
127172 return web .Response (status = 400 , text = e .args [0 ])
128- req_bytes = await request .content . read ()
173+ req_bytes = await request .read ()
129174 charset = request .charset if request .charset is not None else "utf8"
130175 req_str = req_bytes .decode (charset )
131176 then = datetime .now (timezone .utc )
@@ -136,27 +181,21 @@ async def handle(request: web.Request, executor: Executor) -> web.Response:
136181 header = req_str [:first_newline_position ]
137182 req_str = req_str [first_newline_position :]
138183
139- loop = asyncio .get_event_loop ()
140- formatted_str = await loop .run_in_executor (
141- executor , partial (black .format_file_contents , req_str , fast = fast , mode = mode )
184+ only_diff = bool (request .headers .get (DIFF_HEADER , False ))
185+ formatted_str = await format_code (
186+ req_str = req_str ,
187+ fast = fast ,
188+ mode = mode ,
189+ then = then ,
190+ only_diff = only_diff ,
191+ executor = executor ,
192+ executor_semaphore = executor_semaphore ,
142193 )
143194
144195 # Put the source first line back
145196 req_str = header + req_str
146197 formatted_str = header + formatted_str
147198
148- # Only output the diff in the HTTP response
149- only_diff = bool (request .headers .get (DIFF_HEADER , False ))
150- if only_diff :
151- now = datetime .now (timezone .utc )
152- src_name = f"In\t { then } "
153- dst_name = f"Out\t { now } "
154- loop = asyncio .get_event_loop ()
155- formatted_str = await loop .run_in_executor (
156- executor ,
157- partial (black .diff , req_str , formatted_str , src_name , dst_name ),
158- )
159-
160199 return web .Response (
161200 content_type = request .content_type ,
162201 charset = charset ,
@@ -167,11 +206,41 @@ async def handle(request: web.Request, executor: Executor) -> web.Response:
167206 return web .Response (status = 204 , headers = headers )
168207 except black .InvalidInput as e :
169208 return web .Response (status = 400 , headers = headers , text = str (e ))
209+ except web .HTTPException :
210+ raise
170211 except Exception as e :
171212 logging .exception ("Exception during handling a request" )
172213 return web .Response (status = 500 , headers = headers , text = str (e ))
173214
174215
216+ async def format_code (
217+ * ,
218+ req_str : str ,
219+ fast : bool ,
220+ mode : black .FileMode ,
221+ then : datetime ,
222+ only_diff : bool ,
223+ executor : Executor ,
224+ executor_semaphore : asyncio .BoundedSemaphore ,
225+ ) -> str :
226+ async with executor_semaphore :
227+ loop = asyncio .get_event_loop ()
228+ formatted_str = await loop .run_in_executor (
229+ executor , partial (black .format_file_contents , req_str , fast = fast , mode = mode )
230+ )
231+
232+ if not only_diff :
233+ return formatted_str
234+
235+ now = datetime .now (timezone .utc )
236+ src_name = f"In\t { then } "
237+ dst_name = f"Out\t { now } "
238+ return await loop .run_in_executor (
239+ executor ,
240+ partial (black .diff , req_str , formatted_str , src_name , dst_name ),
241+ )
242+
243+
175244def parse_mode (headers : MultiMapping [str ]) -> black .Mode :
176245 try :
177246 line_length = int (headers .get (LINE_LENGTH_HEADER , black .DEFAULT_LINE_LENGTH ))
0 commit comments