88 "net/http/httptest"
99 "net/url"
1010 "strings"
11+ "sync/atomic"
1112 "testing"
1213 "time"
1314
@@ -27,6 +28,7 @@ import (
2728 "github.com/coder/coder/v2/coderd/database/dbauthz"
2829 "github.com/coder/coder/v2/coderd/database/dbmock"
2930 "github.com/coder/coder/v2/coderd/database/dbtestutil"
31+ "github.com/coder/coder/v2/coderd/database/dbtime"
3032 "github.com/coder/coder/v2/coderd/externalauth"
3133 "github.com/coder/coder/v2/coderd/promoauth"
3234 "github.com/coder/coder/v2/codersdk"
@@ -120,6 +122,11 @@ func TestRefreshToken(t *testing.T) {
120122 t .Run ("ValidateServerError" , func (t * testing.T ) {
121123 t .Parallel ()
122124
125+ ctrl := gomock .NewController (t )
126+ mDB := dbmock .NewMockStore (ctrl )
127+ mDB .EXPECT ().UpdateExternalAuthLink (gomock .Any (), gomock .Any ()).
128+ Return (database.ExternalAuthLink {}, nil ).AnyTimes ()
129+
123130 const staticError = "static error"
124131 validated := false
125132 fake , config , link := setupOauth2Test (t , testConfig {
@@ -136,7 +143,7 @@ func TestRefreshToken(t *testing.T) {
136143 ctx := oidc .ClientContext (context .Background (), fake .HTTPClient (nil ))
137144 link .OAuthExpiry = expired
138145
139- _ , err := config .RefreshToken (ctx , nil , link )
146+ _ , err := config .RefreshToken (ctx , mDB , link )
140147 require .ErrorContains (t , err , staticError )
141148 // Unsure if this should be the correct behavior. It's an invalid token because
142149 // 'ValidateToken()' failed with a runtime error. This was the previous behavior,
@@ -223,6 +230,11 @@ func TestRefreshToken(t *testing.T) {
223230 t .Run ("ValidateFailure" , func (t * testing.T ) {
224231 t .Parallel ()
225232
233+ ctrl := gomock .NewController (t )
234+ mDB := dbmock .NewMockStore (ctrl )
235+ mDB .EXPECT ().UpdateExternalAuthLink (gomock .Any (), gomock .Any ()).
236+ Return (database.ExternalAuthLink {}, nil ).AnyTimes ()
237+
226238 const staticError = "static error"
227239 validated := false
228240 fake , config , link := setupOauth2Test (t , testConfig {
@@ -239,7 +251,7 @@ func TestRefreshToken(t *testing.T) {
239251 ctx := oidc .ClientContext (context .Background (), fake .HTTPClient (nil ))
240252 link .OAuthExpiry = expired
241253
242- _ , err := config .RefreshToken (ctx , nil , link )
254+ _ , err := config .RefreshToken (ctx , mDB , link )
243255 require .ErrorContains (t , err , "token failed to validate" )
244256 require .True (t , externalauth .IsInvalidTokenError (err ))
245257 require .True (t , validated , "token should have been attempted to be validated" )
@@ -380,6 +392,235 @@ func TestRefreshToken(t *testing.T) {
380392 require .True (t , ok )
381393 require .Equal (t , updated .OAuthAccessToken , mapping ["access_token" ])
382394 })
395+
396+ // SaveBeforeValidate tests that a successfully refreshed token is
397+ // persisted to the DB even when post-refresh validation fails. This
398+ // prevents the data-loss scenario where GitHub rotates the refresh
399+ // token on use but the new token is silently discarded because a
400+ // rate-limited validation endpoint returns 403.
401+ t .Run ("SaveBeforeValidate" , func (t * testing.T ) {
402+ t .Parallel ()
403+
404+ db , _ := dbtestutil .NewDB (t )
405+
406+ // simulateRateLimit controls whether the validate endpoint
407+ // returns 403 (true) or 200 (false).
408+ var simulateRateLimit atomic.Bool
409+ simulateRateLimit .Store (true )
410+
411+ var refreshCalls atomic.Int64
412+ fake , config , link := setupOauth2Test (t , testConfig {
413+ FakeIDPOpts : []oidctest.FakeIDPOpt {
414+ oidctest .WithRefresh (func (_ string ) error {
415+ refreshCalls .Add (1 )
416+ return nil
417+ }),
418+ oidctest .WithDynamicUserInfo (func (_ string ) (jwt.MapClaims , error ) {
419+ if simulateRateLimit .Load () {
420+ return jwt.MapClaims {}, oidctest .StatusError (http .StatusForbidden , xerrors .New ("rate limit exceeded" ))
421+ }
422+ return jwt.MapClaims {}, nil
423+ }),
424+ },
425+ ExternalAuthOpt : func (cfg * externalauth.Config ) {
426+ cfg .Type = codersdk .EnhancedExternalAuthProviderGitHub .String ()
427+ },
428+ DB : db ,
429+ })
430+
431+ ctx := oidc .ClientContext (context .Background (), fake .HTTPClient (nil ))
432+
433+ oldAccessToken := link .OAuthAccessToken
434+ oldRefreshToken := link .OAuthRefreshToken
435+
436+ // Expire the token to force a refresh.
437+ link .OAuthExpiry = expired
438+
439+ // First call: refresh succeeds, validation fails (403).
440+ _ , err := config .RefreshToken (ctx , db , link )
441+ require .Error (t , err , "expected error because validation returned 403" )
442+ require .True (t , externalauth .IsInvalidTokenError (err ))
443+ require .Equal (t , int64 (1 ), refreshCalls .Load (), "IDP refresh should have been called exactly once" )
444+
445+ // Critical assertion: the DB must contain the NEW tokens from the
446+ // successful refresh, not the old (now-stale) ones.
447+ dbLink , err := db .GetExternalAuthLink (context .Background (), database.GetExternalAuthLinkParams {
448+ ProviderID : link .ProviderID ,
449+ UserID : link .UserID ,
450+ })
451+ require .NoError (t , err )
452+ require .NotEqual (t , oldAccessToken , dbLink .OAuthAccessToken ,
453+ "DB should have the new access token from the successful refresh" )
454+ require .NotEqual (t , oldRefreshToken , dbLink .OAuthRefreshToken ,
455+ "DB should have the new refresh token (old one was rotated by the IDP)" )
456+
457+ // Second call: uses the saved token from DB, no re-refresh.
458+ // The saved token has a future expiry, so TokenSource should return
459+ // it without contacting the IDP. Validation should succeed now.
460+ simulateRateLimit .Store (false )
461+ updated , err := config .RefreshToken (ctx , db , dbLink )
462+ require .NoError (t , err , "second call should succeed because rate limit lifted" )
463+ require .Equal (t , int64 (1 ), refreshCalls .Load (),
464+ "IDP refresh should NOT have been called again; the saved token is not expired" )
465+ require .Equal (t , dbLink .OAuthAccessToken , updated .OAuthAccessToken ,
466+ "returned token should match what was saved in the DB" )
467+ })
468+
469+ // SaveBeforeValidate_ContextCanceled verifies the early DB save
470+ // uses a detached context. The parent context is canceled inside
471+ // the refresh hook (after TokenSource.Token() but before the DB
472+ // write), and the test asserts the new token is still persisted.
473+ t .Run ("SaveBeforeValidate_ContextCanceled" , func (t * testing.T ) {
474+ t .Parallel ()
475+
476+ db , _ := dbtestutil .NewDB (t )
477+
478+ var refreshCalls atomic.Int64
479+ cancelOnRefresh , cancel := context .WithCancel (context .Background ())
480+ defer cancel ()
481+
482+ fake , config , link := setupOauth2Test (t , testConfig {
483+ FakeIDPOpts : []oidctest.FakeIDPOpt {
484+ oidctest .WithRefresh (func (_ string ) error {
485+ refreshCalls .Add (1 )
486+ // Cancel the parent context after refresh succeeds
487+ // but before the DB save and validation.
488+ cancel ()
489+ return nil
490+ }),
491+ oidctest .WithDynamicUserInfo (func (_ string ) (jwt.MapClaims , error ) {
492+ return jwt.MapClaims {}, nil
493+ }),
494+ },
495+ ExternalAuthOpt : func (cfg * externalauth.Config ) {
496+ cfg .Type = codersdk .EnhancedExternalAuthProviderGitHub .String ()
497+ },
498+ DB : db ,
499+ })
500+
501+ ctx := oidc .ClientContext (cancelOnRefresh , fake .HTTPClient (nil ))
502+
503+ oldAccessToken := link .OAuthAccessToken
504+ oldRefreshToken := link .OAuthRefreshToken
505+ link .OAuthExpiry = expired
506+
507+ _ , err := config .RefreshToken (ctx , db , link )
508+ require .NoError (t , err )
509+ require .Equal (t , int64 (1 ), refreshCalls .Load ())
510+
511+ dbLink , err := db .GetExternalAuthLink (context .Background (), database.GetExternalAuthLinkParams {
512+ ProviderID : link .ProviderID ,
513+ UserID : link .UserID ,
514+ })
515+ require .NoError (t , err )
516+ require .NotEqual (t , oldAccessToken , dbLink .OAuthAccessToken ,
517+ "DB should have the new access token despite context cancellation" )
518+ require .NotEqual (t , oldRefreshToken , dbLink .OAuthRefreshToken ,
519+ "DB should have the new refresh token despite context cancellation" )
520+ })
521+
522+ // SaveBeforeValidate_DBError tests that when the early DB save
523+ // fails after a successful IDP refresh, the error is surfaced
524+ // as a non-InvalidTokenError. This is a degraded state (token
525+ // issued by IDP but not persisted), and callers should see a
526+ // real error, not a "please re-authenticate" prompt.
527+ t .Run ("SaveBeforeValidate_DBError" , func (t * testing.T ) {
528+ t .Parallel ()
529+
530+ ctrl := gomock .NewController (t )
531+ mDB := dbmock .NewMockStore (ctrl )
532+
533+ fake , config , link := setupOauth2Test (t , testConfig {
534+ FakeIDPOpts : []oidctest.FakeIDPOpt {
535+ oidctest .WithRefresh (func (_ string ) error {
536+ return nil
537+ }),
538+ },
539+ ExternalAuthOpt : func (cfg * externalauth.Config ) {
540+ cfg .Type = codersdk .EnhancedExternalAuthProviderGitHub .String ()
541+ },
542+ })
543+
544+ ctx := oidc .ClientContext (context .Background (), fake .HTTPClient (nil ))
545+ link .OAuthExpiry = expired
546+
547+ mDB .EXPECT ().
548+ UpdateExternalAuthLink (gomock .Any (), gomock .Any ()).
549+ Return (database.ExternalAuthLink {}, xerrors .New ("db connection lost" ))
550+
551+ _ , err := config .RefreshToken (ctx , mDB , link )
552+ require .Error (t , err )
553+ require .Contains (t , err .Error (), "persist refreshed token" )
554+ require .False (t , externalauth .IsInvalidTokenError (err ),
555+ "DB errors should not be treated as invalid token" )
556+ })
557+
558+ // OptimisticLockPreventsStaleOverwrite verifies that the
559+ // UpdateExternalAuthLinkRefreshToken WHERE clause prevents a
560+ // stale caller from overwriting a valid refresh token saved
561+ // by a concurrent winner.
562+ t .Run ("OptimisticLockPreventsStaleOverwrite" , func (t * testing.T ) {
563+ t .Parallel ()
564+
565+ db , _ := dbtestutil .NewDB (t )
566+
567+ fake , config , link := setupOauth2Test (t , testConfig {
568+ FakeIDPOpts : []oidctest.FakeIDPOpt {
569+ oidctest .WithRefresh (func (_ string ) error {
570+ return nil
571+ }),
572+ oidctest .WithDynamicUserInfo (func (_ string ) (jwt.MapClaims , error ) {
573+ return jwt.MapClaims {}, nil
574+ }),
575+ },
576+ ExternalAuthOpt : func (cfg * externalauth.Config ) {
577+ cfg .Type = codersdk .EnhancedExternalAuthProviderGitHub .String ()
578+ },
579+ DB : db ,
580+ })
581+
582+ ctx := oidc .ClientContext (context .Background (), fake .HTTPClient (nil ))
583+
584+ // Snapshot the original tokens before any refresh.
585+ oldRefreshToken := link .OAuthRefreshToken
586+
587+ // Expire the token to force a refresh.
588+ link .OAuthExpiry = expired
589+
590+ // Caller A: refresh and save successfully.
591+ updated , err := config .RefreshToken (ctx , db , link )
592+ require .NoError (t , err )
593+ require .NotEqual (t , oldRefreshToken , updated .OAuthRefreshToken ,
594+ "caller A should have a new refresh token" )
595+
596+ // Caller B had a stale read of the original link. It tries to
597+ // destroy the refresh token using the OLD refresh token in the
598+ // optimistic lock. Because caller A already wrote a different
599+ // refresh token, this WHERE clause matches nothing.
600+ err = db .UpdateExternalAuthLinkRefreshToken (ctx , database.UpdateExternalAuthLinkRefreshTokenParams {
601+ OauthRefreshFailureReason : "simulated failure from stale caller B" ,
602+ OAuthRefreshToken : "" ,
603+ OAuthRefreshTokenKeyID : "" ,
604+ UpdatedAt : dbtime .Now (),
605+ ProviderID : link .ProviderID ,
606+ UserID : link .UserID ,
607+ OldOauthRefreshToken : oldRefreshToken ,
608+ })
609+ require .NoError (t , err , "optimistic lock write should not error, it is a no-op" )
610+
611+ // Verify DB still has caller A's valid token.
612+ dbLink , err := db .GetExternalAuthLink (context .Background (), database.GetExternalAuthLinkParams {
613+ ProviderID : link .ProviderID ,
614+ UserID : link .UserID ,
615+ })
616+ require .NoError (t , err )
617+ require .Equal (t , updated .OAuthAccessToken , dbLink .OAuthAccessToken ,
618+ "caller A's access token should still be in DB" )
619+ require .Equal (t , updated .OAuthRefreshToken , dbLink .OAuthRefreshToken ,
620+ "caller A's refresh token should still be in DB" )
621+ require .Empty (t , dbLink .OauthRefreshFailureReason ,
622+ "caller B's failure reason should not have been written" )
623+ })
383624}
384625
385626func TestRevokeToken (t * testing.T ) {
0 commit comments