|
3 | 3 | HTTP_LINE = re.compile('([^ ]+) +(.+?) +(HTTP/[^ ]+)$') |
4 | 4 | packstr = lambda s, n=1: len(s).to_bytes(n, 'big') + s |
5 | 5 |
|
6 | | -def netloc_split(loc, default_port): |
| 6 | +def netloc_split(loc, default_host=None, default_port=None): |
7 | 7 | ipv6 = re.fullmatch('\[([0-9a-fA-F:]*)\](?::(\d+)?)?', loc) |
8 | 8 | if ipv6: |
9 | 9 | host_name, port = ipv6.groups() |
| 10 | + elif ':' in loc: |
| 11 | + host_name, port = loc.rsplit(':', 1) |
10 | 12 | else: |
11 | | - host_name, _, port = loc.partition(':') |
12 | | - return host_name, int(port) if port else default_port |
| 13 | + host_name, port = loc, None |
| 14 | + return host_name or default_host, int(port) if port else default_port |
13 | 15 |
|
14 | 16 | async def socks_address_stream(reader, n): |
15 | 17 | if n in (1, 17): |
@@ -311,12 +313,11 @@ async def accept(self, reader, user, writer, users, authtable, httpget=None, **k |
311 | 313 | raise Exception('Unauthorized HTTP') |
312 | 314 | authtable.set_authed(user) |
313 | 315 | if method == 'CONNECT': |
314 | | - host_name, port = path.rsplit(':', 1) |
315 | | - port = int(port) |
| 316 | + host_name, port = netloc_split(path) |
316 | 317 | return user, host_name, port, f'{ver} 200 OK\r\nConnection: close\r\n\r\n'.encode() |
317 | 318 | else: |
318 | 319 | url = urllib.parse.urlparse(path) |
319 | | - host_name, port = netloc_split(url.netloc or headers.get("Host"), 80) |
| 320 | + host_name, port = netloc_split(url.netloc or headers.get("Host"), default_port=80) |
320 | 321 | newpath = url._replace(netloc='', scheme='').geturl() |
321 | 322 | return user, host_name, port, b'', f'{method} {newpath} {ver}\r\n{lines}\r\n\r\n'.encode() |
322 | 323 | async def connect(self, reader_remote, writer_remote, rauth, host_name, port, myhost, **kw): |
@@ -421,11 +422,8 @@ class Tunnel(Transparent): |
421 | 422 | def query_remote(self, sock): |
422 | 423 | if not self.param: |
423 | 424 | return 'tunnel', 0 |
424 | | - host, _, port = self.param.partition(':') |
425 | 425 | dst = sock.getsockname() |
426 | | - host = host or dst[0] |
427 | | - port = int(port) if port else dst[1] |
428 | | - return host, port |
| 426 | + return netloc_split(self.param, dst[0], dst[1]) |
429 | 427 | async def connect(self, reader_remote, writer_remote, rauth, host_name, port, **kw): |
430 | 428 | pass |
431 | 429 | def udp_connect(self, rauth, host_name, port, data, **kw): |
@@ -502,10 +500,8 @@ async def accept(self, reader, user, writer, users, authtable, sock, **kw): |
502 | 500 | self.patch_ws_stream(reader, writer, False) |
503 | 501 | if not self.param: |
504 | 502 | return 'tunnel', 0 |
505 | | - host, _, port = self.param.partition(':') |
506 | 503 | dst = sock.getsockname() |
507 | | - host = host or dst[0] |
508 | | - port = int(port) if port else dst[1] |
| 504 | + host, port = netloc_split(self.param, dst[0], dst[1]) |
509 | 505 | return user, host, port |
510 | 506 | async def connect(self, reader_remote, writer_remote, rauth, host_name, port, myhost, **kw): |
511 | 507 | seckey = base64.b64encode(os.urandom(16)).decode() |
|
0 commit comments