@@ -25,6 +25,14 @@ import (
2525 "github.com/coder/serpent"
2626)
2727
28+ var (
29+ // noAddr is the zero-value of netip.Addr, and is not a valid address. We use it to identify
30+ // when the local address is not specified in port-forward flags.
31+ noAddr netip.Addr
32+ ipv6Loopback = netip .MustParseAddr ("::1" )
33+ ipv4Loopback = netip .MustParseAddr ("127.0.0.1" )
34+ )
35+
2836func (r * RootCmd ) portForward () * serpent.Command {
2937 var (
3038 tcpForwards []string // <port>:<port>
@@ -122,7 +130,7 @@ func (r *RootCmd) portForward() *serpent.Command {
122130 // Start all listeners.
123131 var (
124132 wg = new (sync.WaitGroup )
125- listeners = make ([]net.Listener , len (specs ))
133+ listeners = make ([]net.Listener , 0 , len (specs )* 2 )
126134 closeAllListeners = func () {
127135 logger .Debug (ctx , "closing all listeners" )
128136 for _ , l := range listeners {
@@ -135,13 +143,25 @@ func (r *RootCmd) portForward() *serpent.Command {
135143 )
136144 defer closeAllListeners ()
137145
138- for i , spec := range specs {
146+ for _ , spec := range specs {
147+ if spec .listenHost == noAddr {
148+ // first, opportunistically try to listen on IPv6
149+ spec6 := spec
150+ spec6 .listenHost = ipv6Loopback
151+ l6 , err6 := listenAndPortForward (ctx , inv , conn , wg , spec6 , logger )
152+ if err6 != nil {
153+ logger .Info (ctx , "failed to opportunistically listen on IPv6" , slog .F ("spec" , spec ), slog .Error (err6 ))
154+ } else {
155+ listeners = append (listeners , l6 )
156+ }
157+ spec .listenHost = ipv4Loopback
158+ }
139159 l , err := listenAndPortForward (ctx , inv , conn , wg , spec , logger )
140160 if err != nil {
141161 logger .Error (ctx , "failed to listen" , slog .F ("spec" , spec ), slog .Error (err ))
142162 return err
143163 }
144- listeners [ i ] = l
164+ listeners = append ( listeners , l )
145165 }
146166
147167 stopUpdating := client .UpdateWorkspaceUsageContext (ctx , workspace .ID )
@@ -206,12 +226,19 @@ func listenAndPortForward(
206226 spec portForwardSpec ,
207227 logger slog.Logger ,
208228) (net.Listener , error ) {
209- logger = logger .With (slog .F ("network" , spec .listenNetwork ), slog .F ("address" , spec .listenAddress ))
210- _ , _ = fmt .Fprintf (inv .Stderr , "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n " , spec .listenNetwork , spec .listenAddress , spec .dialNetwork , spec .dialAddress )
229+ logger = logger .With (
230+ slog .F ("network" , spec .network ),
231+ slog .F ("listen_host" , spec .listenHost ),
232+ slog .F ("listen_port" , spec .listenPort ),
233+ )
234+ listenAddress := netip .AddrPortFrom (spec .listenHost , spec .listenPort )
235+ dialAddress := fmt .Sprintf ("127.0.0.1:%d" , spec .dialPort )
236+ _ , _ = fmt .Fprintf (inv .Stderr , "Forwarding '%s://%s' locally to '%s://%s' in the workspace\n " ,
237+ spec .network , listenAddress , spec .network , dialAddress )
211238
212- l , err := inv .Net .Listen (spec .listenNetwork , spec . listenAddress )
239+ l , err := inv .Net .Listen (spec .network , listenAddress . String () )
213240 if err != nil {
214- return nil , xerrors .Errorf ("listen '%v ://%v ': %w" , spec .listenNetwork , spec . listenAddress , err )
241+ return nil , xerrors .Errorf ("listen '%s ://%s ': %w" , spec .network , listenAddress . String () , err )
215242 }
216243 logger .Debug (ctx , "listening" )
217244
@@ -226,24 +253,31 @@ func listenAndPortForward(
226253 logger .Debug (ctx , "listener closed" )
227254 return
228255 }
229- _ , _ = fmt .Fprintf (inv .Stderr , "Error accepting connection from '%v://%v': %v\n " , spec .listenNetwork , spec .listenAddress , err )
256+ _ , _ = fmt .Fprintf (inv .Stderr ,
257+ "Error accepting connection from '%s://%s': %v\n " ,
258+ spec .network , listenAddress .String (), err )
230259 _ , _ = fmt .Fprintln (inv .Stderr , "Killing listener" )
231260 return
232261 }
233- logger .Debug (ctx , "accepted connection" , slog .F ("remote_addr" , netConn .RemoteAddr ()))
262+ logger .Debug (ctx , "accepted connection" ,
263+ slog .F ("remote_addr" , netConn .RemoteAddr ()))
234264
235265 go func (netConn net.Conn ) {
236266 defer netConn .Close ()
237- remoteConn , err := conn .DialContext (ctx , spec .dialNetwork , spec . dialAddress )
267+ remoteConn , err := conn .DialContext (ctx , spec .network , dialAddress )
238268 if err != nil {
239- _ , _ = fmt .Fprintf (inv .Stderr , "Failed to dial '%v://%v' in workspace: %s\n " , spec .dialNetwork , spec .dialAddress , err )
269+ _ , _ = fmt .Fprintf (inv .Stderr ,
270+ "Failed to dial '%s://%s' in workspace: %s\n " ,
271+ spec .network , dialAddress , err )
240272 return
241273 }
242274 defer remoteConn .Close ()
243- logger .Debug (ctx , "dialed remote" , slog .F ("remote_addr" , netConn .RemoteAddr ()))
275+ logger .Debug (ctx ,
276+ "dialed remote" , slog .F ("remote_addr" , netConn .RemoteAddr ()))
244277
245278 agentssh .Bicopy (ctx , netConn , remoteConn )
246- logger .Debug (ctx , "connection closing" , slog .F ("remote_addr" , netConn .RemoteAddr ()))
279+ logger .Debug (ctx ,
280+ "connection closing" , slog .F ("remote_addr" , netConn .RemoteAddr ()))
247281 }(netConn )
248282 }
249283 }(spec )
@@ -252,58 +286,48 @@ func listenAndPortForward(
252286}
253287
254288type portForwardSpec struct {
255- listenNetwork string // tcp, udp
256- listenAddress string // <ip>:<port> or path
257-
258- dialNetwork string // tcp, udp
259- dialAddress string // <ip>:<port> or path
289+ network string // tcp, udp
290+ listenHost netip.Addr
291+ listenPort , dialPort uint16
260292}
261293
262294func parsePortForwards (tcpSpecs , udpSpecs []string ) ([]portForwardSpec , error ) {
263295 specs := []portForwardSpec {}
264296
265297 for _ , specEntry := range tcpSpecs {
266298 for _ , spec := range strings .Split (specEntry , "," ) {
267- ports , err := parseSrcDestPorts (strings .TrimSpace (spec ))
299+ pfSpecs , err := parseSrcDestPorts (strings .TrimSpace (spec ))
268300 if err != nil {
269301 return nil , xerrors .Errorf ("failed to parse TCP port-forward specification %q: %w" , spec , err )
270302 }
271303
272- for _ , port := range ports {
273- specs = append (specs , portForwardSpec {
274- listenNetwork : "tcp" ,
275- listenAddress : port .local .String (),
276- dialNetwork : "tcp" ,
277- dialAddress : port .remote .String (),
278- })
304+ for _ , pfSpec := range pfSpecs {
305+ pfSpec .network = "tcp"
306+ specs = append (specs , pfSpec )
279307 }
280308 }
281309 }
282310
283311 for _ , specEntry := range udpSpecs {
284312 for _ , spec := range strings .Split (specEntry , "," ) {
285- ports , err := parseSrcDestPorts (strings .TrimSpace (spec ))
313+ pfSpecs , err := parseSrcDestPorts (strings .TrimSpace (spec ))
286314 if err != nil {
287315 return nil , xerrors .Errorf ("failed to parse UDP port-forward specification %q: %w" , spec , err )
288316 }
289317
290- for _ , port := range ports {
291- specs = append (specs , portForwardSpec {
292- listenNetwork : "udp" ,
293- listenAddress : port .local .String (),
294- dialNetwork : "udp" ,
295- dialAddress : port .remote .String (),
296- })
318+ for _ , pfSpec := range pfSpecs {
319+ pfSpec .network = "udp"
320+ specs = append (specs , pfSpec )
297321 }
298322 }
299323 }
300324
301325 // Check for duplicate entries.
302326 locals := map [string ]struct {}{}
303327 for _ , spec := range specs {
304- localStr := fmt .Sprintf ("%v:%v " , spec .listenNetwork , spec .listenAddress )
328+ localStr := fmt .Sprintf ("%s:%s:%d " , spec .network , spec .listenHost , spec . listenPort )
305329 if _ , ok := locals [localStr ]; ok {
306- return nil , xerrors .Errorf ("local %v %v is specified twice" , spec .listenNetwork , spec .listenAddress )
330+ return nil , xerrors .Errorf ("local %s host:%s port:%d is specified twice" , spec .network , spec .listenHost , spec . listenPort )
307331 }
308332 locals [localStr ] = struct {}{}
309333 }
@@ -323,10 +347,6 @@ func parsePort(in string) (uint16, error) {
323347 return uint16 (port ), nil
324348}
325349
326- type parsedSrcDestPort struct {
327- local , remote netip.AddrPort
328- }
329-
330350// specRegexp matches port specs. It handles all the following formats:
331351//
332352// 8000
@@ -347,21 +367,19 @@ type parsedSrcDestPort struct {
347367// 9: end or remote port range
348368var specRegexp = regexp .MustCompile (`^((\[[0-9a-fA-F:]+]|\d+\.\d+\.\d+\.\d+):)?(\d+)(-(\d+))?(:(\d+)(-(\d+))?)?$` )
349369
350- func parseSrcDestPorts (in string ) ([]parsedSrcDestPort , error ) {
351- var (
352- err error
353- localAddr = netip .AddrFrom4 ([4 ]byte {127 , 0 , 0 , 1 })
354- remoteAddr = netip .AddrFrom4 ([4 ]byte {127 , 0 , 0 , 1 })
355- )
370+ func parseSrcDestPorts (in string ) ([]portForwardSpec , error ) {
356371 groups := specRegexp .FindStringSubmatch (in )
357372 if len (groups ) == 0 {
358373 return nil , xerrors .Errorf ("invalid port specification %q" , in )
359374 }
375+
376+ var localAddr netip.Addr
360377 if groups [2 ] != "" {
361- localAddr , err = netip .ParseAddr (strings .Trim (groups [2 ], "[]" ))
378+ parsedAddr , err : = netip .ParseAddr (strings .Trim (groups [2 ], "[]" ))
362379 if err != nil {
363380 return nil , xerrors .Errorf ("invalid IP address %q" , groups [2 ])
364381 }
382+ localAddr = parsedAddr
365383 }
366384
367385 local , err := parsePortRange (groups [3 ], groups [5 ])
@@ -378,11 +396,12 @@ func parseSrcDestPorts(in string) ([]parsedSrcDestPort, error) {
378396 if len (local ) != len (remote ) {
379397 return nil , xerrors .Errorf ("port ranges must be the same length, got %d ports forwarded to %d ports" , len (local ), len (remote ))
380398 }
381- var out []parsedSrcDestPort
399+ var out []portForwardSpec
382400 for i := range local {
383- out = append (out , parsedSrcDestPort {
384- local : netip .AddrPortFrom (localAddr , local [i ]),
385- remote : netip .AddrPortFrom (remoteAddr , remote [i ]),
401+ out = append (out , portForwardSpec {
402+ listenHost : localAddr ,
403+ listenPort : local [i ],
404+ dialPort : remote [i ],
386405 })
387406 }
388407 return out , nil
0 commit comments