diff --git a/pproxy/server.py b/pproxy/server.py index 1209767..888d2a0 100644 --- a/pproxy/server.py +++ b/pproxy/server.py @@ -296,12 +296,15 @@ async def prepare_connection(self, reader_remote, writer_remote, host, port): whost, wport = self.jump.destination(host, port) await self.rproto.connect(reader_remote=reader_remote, writer_remote=writer_remote, rauth=self.auth, host_name=whost, port=wport, writer_cipher_r=writer_cipher_r, myhost=self.host_name, sock=writer_remote.get_extra_info('socket')) return await self.jump.prepare_connection(reader_remote, writer_remote, host, port) - def start_server(self, args, stream_handler=stream_handler): + def start_server(self, args, sock=None, stream_handler=stream_handler): handler = functools.partial(stream_handler, **vars(self), **args) if self.unix: return asyncio.start_unix_server(handler, path=self.bind) else: - return asyncio.start_server(handler, host=self.host_name, port=self.port, reuse_port=args.get('ruport')) + if sock is None: + return asyncio.start_server(handler, host=self.host_name, port=self.port, reuse_port=args.get('ruport')) + else: + return asyncio.start_server(handler, reuse_port=args.get('ruport'), sock=sock) class ProxyH2(ProxySimple): def __init__(self, sslserver, sslclient, **kw): diff --git a/tests/api_server.py b/tests/api_server.py index 23a7cd0..9d84152 100644 --- a/tests/api_server.py +++ b/tests/api_server.py @@ -1,13 +1,19 @@ +import socket import asyncio +sock = socket.socket() +sock.bind(('', 0)) + import pproxy -server = pproxy.Server('ss://0.0.0.0:1234') -remote = pproxy.Connection('ss://1.2.3.4:5678') +server = pproxy.Server('ss://chacha20:123@localhost:10') +# remote = pproxy.Connection('ss://1.2.3.4:5678') +remote = pproxy.DIRECT args = dict( rserver = [remote], verbose = print ) loop = asyncio.get_event_loop() -handler = loop.run_until_complete(server.start_server(args)) +print(sock.getsockname()) +handler = loop.run_until_complete(server.start_server(args, sock)) try: loop.run_forever() except KeyboardInterrupt: