@@ -65,6 +65,7 @@ type Options struct {
6565 WebRTCDialer WebRTCDialer
6666 FetchMetadata FetchMetadata
6767
68+ StatsReporter StatsReporter
6869 ReconnectingPTYTimeout time.Duration
6970 EnvironmentVariables map [string ]string
7071 Logger slog.Logger
@@ -100,6 +101,8 @@ func New(options Options) io.Closer {
100101 envVars : options .EnvironmentVariables ,
101102 coordinatorDialer : options .CoordinatorDialer ,
102103 fetchMetadata : options .FetchMetadata ,
104+ stats : & Stats {},
105+ statsReporter : options .StatsReporter ,
103106 }
104107 server .init (ctx )
105108 return server
@@ -125,6 +128,8 @@ type agent struct {
125128
126129 network * tailnet.Conn
127130 coordinatorDialer CoordinatorDialer
131+ stats * Stats
132+ statsReporter StatsReporter
128133}
129134
130135func (a * agent ) run (ctx context.Context ) {
@@ -194,6 +199,13 @@ func (a *agent) runTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) {
194199 a .logger .Critical (ctx , "create tailnet" , slog .Error (err ))
195200 return
196201 }
202+ a .network .SetForwardTCPCallback (func (conn net.Conn , listenerExists bool ) net.Conn {
203+ if listenerExists {
204+ // If a listener already exists, we would double-wrap the conn.
205+ return conn
206+ }
207+ return a .stats .wrapConn (conn )
208+ })
197209 go a .runCoordinator (ctx )
198210
199211 sshListener , err := a .network .Listen ("tcp" , ":" + strconv .Itoa (tailnetSSHPort ))
@@ -207,7 +219,7 @@ func (a *agent) runTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) {
207219 if err != nil {
208220 return
209221 }
210- go a .sshServer .HandleConn (conn )
222+ a .sshServer .HandleConn (a . stats . wrapConn ( conn ) )
211223 }
212224 }()
213225 reconnectingPTYListener , err := a .network .Listen ("tcp" , ":" + strconv .Itoa (tailnetReconnectingPTYPort ))
@@ -219,8 +231,10 @@ func (a *agent) runTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) {
219231 for {
220232 conn , err := reconnectingPTYListener .Accept ()
221233 if err != nil {
234+ a .logger .Debug (ctx , "accept pty failed" , slog .Error (err ))
222235 return
223236 }
237+ conn = a .stats .wrapConn (conn )
224238 // This cannot use a JSON decoder, since that can
225239 // buffer additional data that is required for the PTY.
226240 rawLen := make ([]byte , 2 )
@@ -364,17 +378,17 @@ func (a *agent) runStartupScript(ctx context.Context, script string) error {
364378 return nil
365379}
366380
367- func (a * agent ) handlePeerConn (ctx context.Context , conn * peer.Conn ) {
381+ func (a * agent ) handlePeerConn (ctx context.Context , peerConn * peer.Conn ) {
368382 go func () {
369383 select {
370384 case <- a .closed :
371- case <- conn .Closed ():
385+ case <- peerConn .Closed ():
372386 }
373- _ = conn .Close ()
387+ _ = peerConn .Close ()
374388 a .connCloseWait .Done ()
375389 }()
376390 for {
377- channel , err := conn .Accept (ctx )
391+ channel , err := peerConn .Accept (ctx )
378392 if err != nil {
379393 if errors .Is (err , peer .ErrClosed ) || a .isClosed () {
380394 return
@@ -383,9 +397,11 @@ func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) {
383397 return
384398 }
385399
400+ conn := channel .NetConn ()
401+
386402 switch channel .Protocol () {
387403 case ProtocolSSH :
388- go a .sshServer .HandleConn (channel . NetConn ( ))
404+ go a .sshServer .HandleConn (a . stats . wrapConn ( conn ))
389405 case ProtocolReconnectingPTY :
390406 rawID := channel .Label ()
391407 // The ID format is referenced in conn.go.
@@ -418,9 +434,9 @@ func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) {
418434 Height : uint16 (height ),
419435 Width : uint16 (width ),
420436 Command : idParts [3 ],
421- }, channel . NetConn ( ))
437+ }, a . stats . wrapConn ( conn ))
422438 case ProtocolDial :
423- go a .handleDial (ctx , channel .Label (), channel . NetConn ( ))
439+ go a .handleDial (ctx , channel .Label (), a . stats . wrapConn ( conn ))
424440 default :
425441 a .logger .Warn (ctx , "unhandled protocol from channel" ,
426442 slog .F ("protocol" , channel .Protocol ()),
@@ -514,6 +530,21 @@ func (a *agent) init(ctx context.Context) {
514530 }
515531
516532 go a .run (ctx )
533+ if a .statsReporter != nil {
534+ cl , err := a .statsReporter (ctx , a .logger , func () * Stats {
535+ return a .stats .Copy ()
536+ })
537+ if err != nil {
538+ a .logger .Error (ctx , "report stats" , slog .Error (err ))
539+ return
540+ }
541+ a .connCloseWait .Add (1 )
542+ go func () {
543+ defer a .connCloseWait .Done ()
544+ <- a .closed
545+ cl .Close ()
546+ }()
547+ }
517548}
518549
519550// createCommand processes raw command input with OpenSSH-like behavior.
0 commit comments