Skip to content

Commit 4bb1fd1

Browse files
f0sselkylecarbsmafredri
authored
fix(coderd/externalauth): save refreshed token before validation (#24332) (backport to 2.29) (#24900)
Backport of #24332 to `release/2.29`. Moves the `UpdateExternalAuthLink` call to immediately after `TokenSource.Token()` succeeds (before validation). GitHub rotates refresh tokens on use, so if post-refresh validation fails (e.g. rate-limited 403), the new token was previously silently discarded, forcing manual re-authentication. Original PR: #24332 Merge commit: 2a1984f **Note:** This branch includes the cherry-pick of #22904 (optimistic locking) as a prerequisite since #24332's tests depend on it. The #22904 backport PR is #24901. Once that merges, the overlapping commit in this PR will be a no-op. Cherry-picks applied cleanly with no conflicts. > Generated by Coder Agents --------- Co-authored-by: Kyle Carberry <kyle@coder.com> Co-authored-by: Mathias Fredriksson <mafredri@gmail.com>
1 parent 63a9280 commit 4bb1fd1

2 files changed

Lines changed: 284 additions & 29 deletions

File tree

coderd/externalauth/externalauth.go

Lines changed: 41 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,37 @@ func (c *Config) RefreshToken(ctx context.Context, db database.Store, externalAu
238238
return externalAuthLink, xerrors.Errorf("generate token extra: %w", err)
239239
}
240240

241+
// Persist the refreshed token to the DB before validation. GitHub
242+
// rotates refresh tokens on every use, so the old refresh token is
243+
// already invalid on the IDP side. If we validated first and the
244+
// validation endpoint was unavailable (e.g. rate-limited 403), the
245+
// new token would be silently lost and the user would be forced to
246+
// re-authenticate manually.
247+
// Use a detached context for the DB write only. The IDP already
248+
// consumed the old refresh token, so if the caller's request
249+
// context is canceled mid-save, the new token would be lost.
250+
persistCtx, persistCancel := context.WithTimeout(context.WithoutCancel(ctx), 10*time.Second)
251+
defer persistCancel()
252+
253+
originalAccessToken := externalAuthLink.OAuthAccessToken
254+
if token.AccessToken != originalAccessToken {
255+
updatedAuthLink, err := db.UpdateExternalAuthLink(persistCtx, database.UpdateExternalAuthLinkParams{
256+
ProviderID: c.ID,
257+
UserID: externalAuthLink.UserID,
258+
UpdatedAt: dbtime.Now(),
259+
OAuthAccessToken: token.AccessToken,
260+
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required
261+
OAuthRefreshToken: token.RefreshToken,
262+
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
263+
OAuthExpiry: token.Expiry,
264+
OAuthExtra: extra,
265+
})
266+
if err != nil {
267+
return updatedAuthLink, xerrors.Errorf("persist refreshed token: %w", err)
268+
}
269+
externalAuthLink = updatedAuthLink
270+
}
271+
241272
r := retry.New(50*time.Millisecond, 200*time.Millisecond)
242273
// See the comment below why the retry and cancel is required.
243274
retryCtx, retryCtxCancel := context.WithTimeout(ctx, time.Second)
@@ -262,35 +293,18 @@ validate:
262293
return externalAuthLink, InvalidTokenError("token failed to validate")
263294
}
264295

265-
if token.AccessToken != externalAuthLink.OAuthAccessToken {
266-
updatedAuthLink, err := db.UpdateExternalAuthLink(ctx, database.UpdateExternalAuthLinkParams{
267-
ProviderID: c.ID,
268-
UserID: externalAuthLink.UserID,
269-
UpdatedAt: dbtime.Now(),
270-
OAuthAccessToken: token.AccessToken,
271-
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required
272-
OAuthRefreshToken: token.RefreshToken,
273-
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
274-
OAuthExpiry: token.Expiry,
275-
OAuthExtra: extra,
296+
// Update the associated user's github.com user ID if the token
297+
// is for github.com and validation returned user info.
298+
if token.AccessToken != originalAccessToken && IsGithubDotComURL(c.AuthCodeURL("")) && user != nil {
299+
err = db.UpdateUserGithubComUserID(ctx, database.UpdateUserGithubComUserIDParams{
300+
ID: externalAuthLink.UserID,
301+
GithubComUserID: sql.NullInt64{
302+
Int64: user.ID,
303+
Valid: true,
304+
},
276305
})
277306
if err != nil {
278-
return updatedAuthLink, xerrors.Errorf("update external auth link: %w", err)
279-
}
280-
externalAuthLink = updatedAuthLink
281-
282-
// Update the associated users github.com username if the token is for github.com.
283-
if IsGithubDotComURL(c.AuthCodeURL("")) && user != nil {
284-
err = db.UpdateUserGithubComUserID(ctx, database.UpdateUserGithubComUserIDParams{
285-
ID: externalAuthLink.UserID,
286-
GithubComUserID: sql.NullInt64{
287-
Int64: user.ID,
288-
Valid: true,
289-
},
290-
})
291-
if err != nil {
292-
return externalAuthLink, xerrors.Errorf("update user github com user id: %w", err)
293-
}
307+
return externalAuthLink, xerrors.Errorf("update user github com user id: %w", err)
294308
}
295309
}
296310

coderd/externalauth/externalauth_test.go

Lines changed: 243 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"net/http/httptest"
99
"net/url"
1010
"strings"
11+
"sync/atomic"
1112
"testing"
1213
"time"
1314

@@ -27,6 +28,7 @@ import (
2728
"github.com/coder/coder/v2/coderd/database/dbauthz"
2829
"github.com/coder/coder/v2/coderd/database/dbmock"
2930
"github.com/coder/coder/v2/coderd/database/dbtestutil"
31+
"github.com/coder/coder/v2/coderd/database/dbtime"
3032
"github.com/coder/coder/v2/coderd/externalauth"
3133
"github.com/coder/coder/v2/coderd/promoauth"
3234
"github.com/coder/coder/v2/codersdk"
@@ -120,6 +122,11 @@ func TestRefreshToken(t *testing.T) {
120122
t.Run("ValidateServerError", func(t *testing.T) {
121123
t.Parallel()
122124

125+
ctrl := gomock.NewController(t)
126+
mDB := dbmock.NewMockStore(ctrl)
127+
mDB.EXPECT().UpdateExternalAuthLink(gomock.Any(), gomock.Any()).
128+
Return(database.ExternalAuthLink{}, nil).AnyTimes()
129+
123130
const staticError = "static error"
124131
validated := false
125132
fake, config, link := setupOauth2Test(t, testConfig{
@@ -136,7 +143,7 @@ func TestRefreshToken(t *testing.T) {
136143
ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil))
137144
link.OAuthExpiry = expired
138145

139-
_, err := config.RefreshToken(ctx, nil, link)
146+
_, err := config.RefreshToken(ctx, mDB, link)
140147
require.ErrorContains(t, err, staticError)
141148
// Unsure if this should be the correct behavior. It's an invalid token because
142149
// 'ValidateToken()' failed with a runtime error. This was the previous behavior,
@@ -223,6 +230,11 @@ func TestRefreshToken(t *testing.T) {
223230
t.Run("ValidateFailure", func(t *testing.T) {
224231
t.Parallel()
225232

233+
ctrl := gomock.NewController(t)
234+
mDB := dbmock.NewMockStore(ctrl)
235+
mDB.EXPECT().UpdateExternalAuthLink(gomock.Any(), gomock.Any()).
236+
Return(database.ExternalAuthLink{}, nil).AnyTimes()
237+
226238
const staticError = "static error"
227239
validated := false
228240
fake, config, link := setupOauth2Test(t, testConfig{
@@ -239,7 +251,7 @@ func TestRefreshToken(t *testing.T) {
239251
ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil))
240252
link.OAuthExpiry = expired
241253

242-
_, err := config.RefreshToken(ctx, nil, link)
254+
_, err := config.RefreshToken(ctx, mDB, link)
243255
require.ErrorContains(t, err, "token failed to validate")
244256
require.True(t, externalauth.IsInvalidTokenError(err))
245257
require.True(t, validated, "token should have been attempted to be validated")
@@ -380,6 +392,235 @@ func TestRefreshToken(t *testing.T) {
380392
require.True(t, ok)
381393
require.Equal(t, updated.OAuthAccessToken, mapping["access_token"])
382394
})
395+
396+
// SaveBeforeValidate tests that a successfully refreshed token is
397+
// persisted to the DB even when post-refresh validation fails. This
398+
// prevents the data-loss scenario where GitHub rotates the refresh
399+
// token on use but the new token is silently discarded because a
400+
// rate-limited validation endpoint returns 403.
401+
t.Run("SaveBeforeValidate", func(t *testing.T) {
402+
t.Parallel()
403+
404+
db, _ := dbtestutil.NewDB(t)
405+
406+
// simulateRateLimit controls whether the validate endpoint
407+
// returns 403 (true) or 200 (false).
408+
var simulateRateLimit atomic.Bool
409+
simulateRateLimit.Store(true)
410+
411+
var refreshCalls atomic.Int64
412+
fake, config, link := setupOauth2Test(t, testConfig{
413+
FakeIDPOpts: []oidctest.FakeIDPOpt{
414+
oidctest.WithRefresh(func(_ string) error {
415+
refreshCalls.Add(1)
416+
return nil
417+
}),
418+
oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) {
419+
if simulateRateLimit.Load() {
420+
return jwt.MapClaims{}, oidctest.StatusError(http.StatusForbidden, xerrors.New("rate limit exceeded"))
421+
}
422+
return jwt.MapClaims{}, nil
423+
}),
424+
},
425+
ExternalAuthOpt: func(cfg *externalauth.Config) {
426+
cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String()
427+
},
428+
DB: db,
429+
})
430+
431+
ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil))
432+
433+
oldAccessToken := link.OAuthAccessToken
434+
oldRefreshToken := link.OAuthRefreshToken
435+
436+
// Expire the token to force a refresh.
437+
link.OAuthExpiry = expired
438+
439+
// First call: refresh succeeds, validation fails (403).
440+
_, err := config.RefreshToken(ctx, db, link)
441+
require.Error(t, err, "expected error because validation returned 403")
442+
require.True(t, externalauth.IsInvalidTokenError(err))
443+
require.Equal(t, int64(1), refreshCalls.Load(), "IDP refresh should have been called exactly once")
444+
445+
// Critical assertion: the DB must contain the NEW tokens from the
446+
// successful refresh, not the old (now-stale) ones.
447+
dbLink, err := db.GetExternalAuthLink(context.Background(), database.GetExternalAuthLinkParams{
448+
ProviderID: link.ProviderID,
449+
UserID: link.UserID,
450+
})
451+
require.NoError(t, err)
452+
require.NotEqual(t, oldAccessToken, dbLink.OAuthAccessToken,
453+
"DB should have the new access token from the successful refresh")
454+
require.NotEqual(t, oldRefreshToken, dbLink.OAuthRefreshToken,
455+
"DB should have the new refresh token (old one was rotated by the IDP)")
456+
457+
// Second call: uses the saved token from DB, no re-refresh.
458+
// The saved token has a future expiry, so TokenSource should return
459+
// it without contacting the IDP. Validation should succeed now.
460+
simulateRateLimit.Store(false)
461+
updated, err := config.RefreshToken(ctx, db, dbLink)
462+
require.NoError(t, err, "second call should succeed because rate limit lifted")
463+
require.Equal(t, int64(1), refreshCalls.Load(),
464+
"IDP refresh should NOT have been called again; the saved token is not expired")
465+
require.Equal(t, dbLink.OAuthAccessToken, updated.OAuthAccessToken,
466+
"returned token should match what was saved in the DB")
467+
})
468+
469+
// SaveBeforeValidate_ContextCanceled verifies the early DB save
470+
// uses a detached context. The parent context is canceled inside
471+
// the refresh hook (after TokenSource.Token() but before the DB
472+
// write), and the test asserts the new token is still persisted.
473+
t.Run("SaveBeforeValidate_ContextCanceled", func(t *testing.T) {
474+
t.Parallel()
475+
476+
db, _ := dbtestutil.NewDB(t)
477+
478+
var refreshCalls atomic.Int64
479+
cancelOnRefresh, cancel := context.WithCancel(context.Background())
480+
defer cancel()
481+
482+
fake, config, link := setupOauth2Test(t, testConfig{
483+
FakeIDPOpts: []oidctest.FakeIDPOpt{
484+
oidctest.WithRefresh(func(_ string) error {
485+
refreshCalls.Add(1)
486+
// Cancel the parent context after refresh succeeds
487+
// but before the DB save and validation.
488+
cancel()
489+
return nil
490+
}),
491+
oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) {
492+
return jwt.MapClaims{}, nil
493+
}),
494+
},
495+
ExternalAuthOpt: func(cfg *externalauth.Config) {
496+
cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String()
497+
},
498+
DB: db,
499+
})
500+
501+
ctx := oidc.ClientContext(cancelOnRefresh, fake.HTTPClient(nil))
502+
503+
oldAccessToken := link.OAuthAccessToken
504+
oldRefreshToken := link.OAuthRefreshToken
505+
link.OAuthExpiry = expired
506+
507+
_, err := config.RefreshToken(ctx, db, link)
508+
require.NoError(t, err)
509+
require.Equal(t, int64(1), refreshCalls.Load())
510+
511+
dbLink, err := db.GetExternalAuthLink(context.Background(), database.GetExternalAuthLinkParams{
512+
ProviderID: link.ProviderID,
513+
UserID: link.UserID,
514+
})
515+
require.NoError(t, err)
516+
require.NotEqual(t, oldAccessToken, dbLink.OAuthAccessToken,
517+
"DB should have the new access token despite context cancellation")
518+
require.NotEqual(t, oldRefreshToken, dbLink.OAuthRefreshToken,
519+
"DB should have the new refresh token despite context cancellation")
520+
})
521+
522+
// SaveBeforeValidate_DBError tests that when the early DB save
523+
// fails after a successful IDP refresh, the error is surfaced
524+
// as a non-InvalidTokenError. This is a degraded state (token
525+
// issued by IDP but not persisted), and callers should see a
526+
// real error, not a "please re-authenticate" prompt.
527+
t.Run("SaveBeforeValidate_DBError", func(t *testing.T) {
528+
t.Parallel()
529+
530+
ctrl := gomock.NewController(t)
531+
mDB := dbmock.NewMockStore(ctrl)
532+
533+
fake, config, link := setupOauth2Test(t, testConfig{
534+
FakeIDPOpts: []oidctest.FakeIDPOpt{
535+
oidctest.WithRefresh(func(_ string) error {
536+
return nil
537+
}),
538+
},
539+
ExternalAuthOpt: func(cfg *externalauth.Config) {
540+
cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String()
541+
},
542+
})
543+
544+
ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil))
545+
link.OAuthExpiry = expired
546+
547+
mDB.EXPECT().
548+
UpdateExternalAuthLink(gomock.Any(), gomock.Any()).
549+
Return(database.ExternalAuthLink{}, xerrors.New("db connection lost"))
550+
551+
_, err := config.RefreshToken(ctx, mDB, link)
552+
require.Error(t, err)
553+
require.Contains(t, err.Error(), "persist refreshed token")
554+
require.False(t, externalauth.IsInvalidTokenError(err),
555+
"DB errors should not be treated as invalid token")
556+
})
557+
558+
// OptimisticLockPreventsStaleOverwrite verifies that the
559+
// UpdateExternalAuthLinkRefreshToken WHERE clause prevents a
560+
// stale caller from overwriting a valid refresh token saved
561+
// by a concurrent winner.
562+
t.Run("OptimisticLockPreventsStaleOverwrite", func(t *testing.T) {
563+
t.Parallel()
564+
565+
db, _ := dbtestutil.NewDB(t)
566+
567+
fake, config, link := setupOauth2Test(t, testConfig{
568+
FakeIDPOpts: []oidctest.FakeIDPOpt{
569+
oidctest.WithRefresh(func(_ string) error {
570+
return nil
571+
}),
572+
oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) {
573+
return jwt.MapClaims{}, nil
574+
}),
575+
},
576+
ExternalAuthOpt: func(cfg *externalauth.Config) {
577+
cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String()
578+
},
579+
DB: db,
580+
})
581+
582+
ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil))
583+
584+
// Snapshot the original tokens before any refresh.
585+
oldRefreshToken := link.OAuthRefreshToken
586+
587+
// Expire the token to force a refresh.
588+
link.OAuthExpiry = expired
589+
590+
// Caller A: refresh and save successfully.
591+
updated, err := config.RefreshToken(ctx, db, link)
592+
require.NoError(t, err)
593+
require.NotEqual(t, oldRefreshToken, updated.OAuthRefreshToken,
594+
"caller A should have a new refresh token")
595+
596+
// Caller B had a stale read of the original link. It tries to
597+
// destroy the refresh token using the OLD refresh token in the
598+
// optimistic lock. Because caller A already wrote a different
599+
// refresh token, this WHERE clause matches nothing.
600+
err = db.UpdateExternalAuthLinkRefreshToken(ctx, database.UpdateExternalAuthLinkRefreshTokenParams{
601+
OauthRefreshFailureReason: "simulated failure from stale caller B",
602+
OAuthRefreshToken: "",
603+
OAuthRefreshTokenKeyID: "",
604+
UpdatedAt: dbtime.Now(),
605+
ProviderID: link.ProviderID,
606+
UserID: link.UserID,
607+
OldOauthRefreshToken: oldRefreshToken,
608+
})
609+
require.NoError(t, err, "optimistic lock write should not error, it is a no-op")
610+
611+
// Verify DB still has caller A's valid token.
612+
dbLink, err := db.GetExternalAuthLink(context.Background(), database.GetExternalAuthLinkParams{
613+
ProviderID: link.ProviderID,
614+
UserID: link.UserID,
615+
})
616+
require.NoError(t, err)
617+
require.Equal(t, updated.OAuthAccessToken, dbLink.OAuthAccessToken,
618+
"caller A's access token should still be in DB")
619+
require.Equal(t, updated.OAuthRefreshToken, dbLink.OAuthRefreshToken,
620+
"caller A's refresh token should still be in DB")
621+
require.Empty(t, dbLink.OauthRefreshFailureReason,
622+
"caller B's failure reason should not have been written")
623+
})
383624
}
384625

385626
func TestRevokeToken(t *testing.T) {

0 commit comments

Comments
 (0)