66 "errors"
77 "fmt"
88 "io"
9- "net"
109 "net/url"
1110 "os"
1211 "os/exec"
@@ -27,7 +26,6 @@ import (
2726 "cdr.dev/slog"
2827 "cdr.dev/slog/sloggers/sloghuman"
2928
30- "github.com/coder/coder/agent/agentssh"
3129 "github.com/coder/coder/cli/clibase"
3230 "github.com/coder/coder/cli/cliui"
3331 "github.com/coder/coder/coderd/autobuild/notify"
@@ -53,6 +51,7 @@ func (r *RootCmd) ssh() *clibase.Cmd {
5351 waitEnum string
5452 noWait bool
5553 logDirPath string
54+ remoteForward string
5655 )
5756 client := new (codersdk.Client )
5857 cmd := & clibase.Cmd {
@@ -122,6 +121,16 @@ func (r *RootCmd) ssh() *clibase.Cmd {
122121 client .SetLogger (logger )
123122 }
124123
124+ if remoteForward != "" {
125+ isValid := validateRemoteForward (remoteForward )
126+ if ! isValid {
127+ return xerrors .Errorf (`invalid format of remote-forward, expected: remote_port:local_address:local_port` )
128+ }
129+ if isValid && stdio {
130+ return xerrors .Errorf (`remote-forward can't be enabled in the stdio mode` )
131+ }
132+ }
133+
125134 workspace , workspaceAgent , err := getWorkspaceAndAgent (ctx , inv , client , codersdk .Me , inv .Args [0 ])
126135 if err != nil {
127136 return err
@@ -198,6 +207,7 @@ func (r *RootCmd) ssh() *clibase.Cmd {
198207 }
199208 defer conn .Close ()
200209 conn .AwaitReachable (ctx )
210+
201211 stopPolling := tryPollWorkspaceAutostop (ctx , client , workspace )
202212 defer stopPolling ()
203213
@@ -300,6 +310,19 @@ func (r *RootCmd) ssh() *clibase.Cmd {
300310 defer closer .Close ()
301311 }
302312
313+ if remoteForward != "" {
314+ localAddr , remoteAddr , err := parseRemoteForward (remoteForward )
315+ if err != nil {
316+ return err
317+ }
318+
319+ closer , err := sshRemoteForward (ctx , inv .Stderr , sshClient , localAddr , remoteAddr )
320+ if err != nil {
321+ return xerrors .Errorf ("ssh remote forward: %w" , err )
322+ }
323+ defer closer .Close ()
324+ }
325+
303326 stdoutFile , validOut := inv .Stdout .(* os.File )
304327 stdinFile , validIn := inv .Stdin .(* os.File )
305328 if validOut && validIn && isatty .IsTerminal (stdoutFile .Fd ()) {
@@ -424,6 +447,13 @@ func (r *RootCmd) ssh() *clibase.Cmd {
424447 FlagShorthand : "l" ,
425448 Value : clibase .StringOf (& logDirPath ),
426449 },
450+ {
451+ Flag : "remote-forward" ,
452+ Description : "Enable remote port forwarding (remote_port:local_address:local_port)." ,
453+ Env : "CODER_SSH_REMOTE_FORWARD" ,
454+ FlagShorthand : "R" ,
455+ Value : clibase .StringOf (& remoteForward ),
456+ },
427457 }
428458 return cmd
429459}
@@ -568,8 +598,15 @@ func getWorkspaceAndAgent(ctx context.Context, inv *clibase.Invocation, client *
568598// of the CLI running simultaneously.
569599func tryPollWorkspaceAutostop (ctx context.Context , client * codersdk.Client , workspace codersdk.Workspace ) (stop func ()) {
570600 lock := flock .New (filepath .Join (os .TempDir (), "coder-autostop-notify-" + workspace .ID .String ()))
571- condition := notifyCondition (ctx , client , workspace .ID , lock )
572- return notify .Notify (condition , workspacePollInterval , autostopNotifyCountdown ... )
601+ conditionCtx , cancelCondition := context .WithCancel (ctx )
602+ condition := notifyCondition (conditionCtx , client , workspace .ID , lock )
603+ stopFunc := notify .Notify (condition , workspacePollInterval , autostopNotifyCountdown ... )
604+ return func () {
605+ // With many "ssh" processes running, `lock.TryLockContext` can be hanging until the context canceled.
606+ // Without this cancellation, a CLI process with failed remote-forward could be hanging indefinitely.
607+ cancelCondition ()
608+ stopFunc ()
609+ }
573610}
574611
575612// Notify the user if the workspace is due to shutdown.
@@ -752,56 +789,3 @@ func remoteGPGAgentSocket(sshClient *gossh.Client) (string, error) {
752789
753790 return string (bytes .TrimSpace (remoteSocket )), nil
754791}
755-
756- // cookieAddr is a special net.Addr accepted by sshForward() which includes a
757- // cookie which is written to the connection before forwarding.
758- type cookieAddr struct {
759- net.Addr
760- cookie []byte
761- }
762-
763- // sshForwardRemote starts forwarding connections from a remote listener to a
764- // local address via SSH in a goroutine.
765- //
766- // Accepts a `cookieAddr` as the local address.
767- func sshForwardRemote (ctx context.Context , stderr io.Writer , sshClient * gossh.Client , localAddr , remoteAddr net.Addr ) (io.Closer , error ) {
768- listener , err := sshClient .Listen (remoteAddr .Network (), remoteAddr .String ())
769- if err != nil {
770- return nil , xerrors .Errorf ("listen on remote SSH address %s: %w" , remoteAddr .String (), err )
771- }
772-
773- go func () {
774- for {
775- remoteConn , err := listener .Accept ()
776- if err != nil {
777- if ctx .Err () == nil {
778- _ , _ = fmt .Fprintf (stderr , "Accept SSH listener connection: %+v\n " , err )
779- }
780- return
781- }
782-
783- go func () {
784- defer remoteConn .Close ()
785-
786- localConn , err := net .Dial (localAddr .Network (), localAddr .String ())
787- if err != nil {
788- _ , _ = fmt .Fprintf (stderr , "Dial local address %s: %+v\n " , localAddr .String (), err )
789- return
790- }
791- defer localConn .Close ()
792-
793- if c , ok := localAddr .(cookieAddr ); ok {
794- _ , err = localConn .Write (c .cookie )
795- if err != nil {
796- _ , _ = fmt .Fprintf (stderr , "Write cookie to local connection: %+v\n " , err )
797- return
798- }
799- }
800-
801- agentssh .Bicopy (ctx , localConn , remoteConn )
802- }()
803- }
804- }()
805-
806- return listener , nil
807- }
0 commit comments