@@ -16,6 +16,7 @@ import (
1616 "github.com/spf13/cobra"
1717 "golang.org/x/xerrors"
1818 "tailscale.com/tailcfg"
19+ "tailscale.com/types/netlogtype"
1920
2021 "github.com/coder/coder/codersdk"
2122)
@@ -92,6 +93,7 @@ func vscodeSSH() *cobra.Command {
9293 if err != nil {
9394 return xerrors .Errorf ("find workspace: %w" , err )
9495 }
96+
9597 var agent codersdk.WorkspaceAgent
9698 var found bool
9799 for _ , resource := range workspace .LatestBuild .Resources {
@@ -117,61 +119,78 @@ func vscodeSSH() *cobra.Command {
117119 break
118120 }
119121 }
120- agentConn , err := client .DialWorkspaceAgent (ctx , agent .ID , & codersdk.DialWorkspaceAgentOptions {
121- EnableTrafficStats : true ,
122- })
122+
123+ agentConn , err := client .DialWorkspaceAgent (ctx , agent .ID , & codersdk.DialWorkspaceAgentOptions {})
123124 if err != nil {
124125 return xerrors .Errorf ("dial workspace agent: %w" , err )
125126 }
126127 defer agentConn .Close ()
128+
127129 agentConn .AwaitReachable (ctx )
128130 rawSSH , err := agentConn .SSH (ctx )
129131 if err != nil {
130132 return err
131133 }
132134 defer rawSSH .Close ()
135+
133136 // Copy SSH traffic over stdio.
134137 go func () {
135138 _ , _ = io .Copy (cmd .OutOrStdout (), rawSSH )
136139 }()
137140 go func () {
138141 _ , _ = io .Copy (rawSSH , cmd .InOrStdin ())
139142 }()
143+
140144 // The VS Code extension obtains the PID of the SSH process to
141145 // read the file below which contains network information to display.
142146 //
143147 // We get the parent PID because it's assumed `ssh` is calling this
144148 // command via the ProxyCommand SSH option.
145149 networkInfoFilePath := filepath .Join (networkInfoDir , fmt .Sprintf ("%d.json" , os .Getppid ()))
146- ticker := time . NewTicker ( networkInfoInterval )
147- defer ticker . Stop ( )
148- lastCollected := time .Now ()
149- for {
150- select {
151- case <- ctx . Done () :
152- return nil
153- case <- ticker . C :
150+
151+ statsErrChan := make ( chan error , 1 )
152+ cb := func ( start , end time.Time , virtual , _ map [netlogtype. Connection ]netlogtype. Counts ) {
153+ sendErr := func ( err error ) {
154+ select {
155+ case statsErrChan <- err :
156+ default :
157+ }
154158 }
155- stats , err := collectNetworkStats (ctx , agentConn , lastCollected )
159+
160+ stats , err := collectNetworkStats (ctx , agentConn , start , end , virtual )
156161 if err != nil {
157- return err
162+ sendErr (err )
163+ return
158164 }
165+
159166 rawStats , err := json .Marshal (stats )
160167 if err != nil {
161- return err
168+ sendErr (err )
169+ return
162170 }
163171 err = afero .WriteFile (fs , networkInfoFilePath , rawStats , 0600 )
164172 if err != nil {
165- return err
173+ sendErr (err )
174+ return
166175 }
167- lastCollected = time .Now ()
176+ }
177+
178+ now := time .Now ()
179+ cb (now , now .Add (time .Nanosecond ), map [netlogtype.Connection ]netlogtype.Counts {}, map [netlogtype.Connection ]netlogtype.Counts {})
180+ agentConn .SetConnStatsCallback (networkInfoInterval , 2048 , cb )
181+
182+ select {
183+ case <- ctx .Done ():
184+ return nil
185+ case err := <- statsErrChan :
186+ return err
168187 }
169188 },
170189 }
171190 cmd .Flags ().StringVarP (& networkInfoDir , "network-info-dir" , "" , "" , "Specifies a directory to write network information periodically." )
172191 cmd .Flags ().StringVarP (& sessionTokenFile , "session-token-file" , "" , "" , "Specifies a file that contains a session token." )
173192 cmd .Flags ().StringVarP (& urlFile , "url-file" , "" , "" , "Specifies a file that contains the Coder URL." )
174- cmd .Flags ().DurationVarP (& networkInfoInterval , "network-info-interval" , "" , 3 * time .Second , "Specifies the interval to update network information." )
193+ cmd .Flags ().DurationVarP (& networkInfoInterval , "network-info-interval" , "" , 5 * time .Second , "Specifies the interval to update network information." )
175194 return cmd
176195}
177196
@@ -184,7 +203,7 @@ type sshNetworkStats struct {
184203 DownloadBytesSec int64 `json:"download_bytes_sec"`
185204}
186205
187- func collectNetworkStats (ctx context.Context , agentConn * codersdk.WorkspaceAgentConn , lastCollected time.Time ) (* sshNetworkStats , error ) {
206+ func collectNetworkStats (ctx context.Context , agentConn * codersdk.WorkspaceAgentConn , start , end time.Time , counts map [netlogtype. Connection ]netlogtype. Counts ) (* sshNetworkStats , error ) {
188207 latency , p2p , err := agentConn .Ping (ctx )
189208 if err != nil {
190209 return nil , err
@@ -216,13 +235,13 @@ func collectNetworkStats(ctx context.Context, agentConn *codersdk.WorkspaceAgent
216235
217236 totalRx := uint64 (0 )
218237 totalTx := uint64 (0 )
219- for _ , stat := range agentConn . ExtractTrafficStats () {
238+ for _ , stat := range counts {
220239 totalRx += stat .RxBytes
221240 totalTx += stat .TxBytes
222241 }
223242 // Tracking the time since last request is required because
224243 // ExtractTrafficStats() resets its counters after each call.
225- dur := time . Since ( lastCollected )
244+ dur := end . Sub ( start )
226245 uploadSecs := float64 (totalTx ) / dur .Seconds ()
227246 downloadSecs := float64 (totalRx ) / dur .Seconds ()
228247
0 commit comments