Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion coderd/database/dbauthz/dbauthz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1538,7 +1538,7 @@ func (s *MethodTestSuite) TestUser() {
}))
s.Run("UpdateExternalAuthLinkRefreshToken", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
link := testutil.Fake(s.T(), faker, database.ExternalAuthLink{})
arg := database.UpdateExternalAuthLinkRefreshTokenParams{OAuthRefreshToken: "", OAuthRefreshTokenKeyID: "", ProviderID: link.ProviderID, UserID: link.UserID, UpdatedAt: link.UpdatedAt}
arg := database.UpdateExternalAuthLinkRefreshTokenParams{OAuthRefreshToken: "", OAuthRefreshTokenKeyID: "", ProviderID: link.ProviderID, UserID: link.UserID, UpdatedAt: link.UpdatedAt, OldOauthRefreshToken: link.OAuthRefreshToken}
dbm.EXPECT().GetExternalAuthLink(gomock.Any(), database.GetExternalAuthLinkParams{ProviderID: link.ProviderID, UserID: link.UserID}).Return(link, nil).AnyTimes()
dbm.EXPECT().UpdateExternalAuthLinkRefreshToken(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).Asserts(link, policy.ActionUpdatePersonal)
Expand Down
4 changes: 4 additions & 0 deletions coderd/database/querier.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 9 additions & 1 deletion coderd/database/queries.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions coderd/database/queries/externalauth.sql
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ UPDATE external_auth_links SET
WHERE provider_id = $1 AND user_id = $2 RETURNING *;

-- name: UpdateExternalAuthLinkRefreshToken :exec
-- Optimistic lock: only update the row if the refresh token in the database
-- still matches the one we read before attempting the refresh. This prevents
-- a concurrent caller that lost a token-refresh race from overwriting a valid
-- token stored by the winner.
UPDATE
external_auth_links
SET
Expand All @@ -60,6 +64,8 @@ WHERE
provider_id = @provider_id
AND
user_id = @user_id
AND
oauth_refresh_token = @old_oauth_refresh_token
AND
-- Required for sqlc to generate a parameter for the oauth_refresh_token_key_id
@oauth_refresh_token_key_id :: text = @oauth_refresh_token_key_id :: text;
5 changes: 3 additions & 2 deletions coderd/externalauth/externalauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,6 @@ func IsInvalidTokenError(err error) bool {
}

// RefreshToken automatically refreshes the token if expired and permitted.
// If an error is returned, the token is either invalid, or an error occurred.
// Use 'IsInvalidTokenError(err)' to determine the difference.
func (c *Config) RefreshToken(ctx context.Context, db database.Store, externalAuthLink database.ExternalAuthLink) (database.ExternalAuthLink, error) {
// If the token is expired and refresh is disabled, we prompt
// the user to authenticate again.
Expand Down Expand Up @@ -195,6 +193,9 @@ func (c *Config) RefreshToken(ctx context.Context, db database.Store, externalAu
UpdatedAt: dbtime.Now(),
ProviderID: externalAuthLink.ProviderID,
UserID: externalAuthLink.UserID,
// Optimistic lock: only clear the token if it hasn't been
// updated by a concurrent caller that won the refresh race.
OldOauthRefreshToken: externalAuthLink.OAuthRefreshToken,
})
if dbExecErr != nil {
// This error should be rare.
Expand Down
3 changes: 2 additions & 1 deletion coderd/externalauth/externalauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ func TestRefreshToken(t *testing.T) {

// Zero time used
link.OAuthExpiry = time.Time{}

_, err := config.RefreshToken(ctx, nil, link)
require.NoError(t, err)
require.True(t, validated, "token should have been validated")
Expand All @@ -107,6 +108,7 @@ func TestRefreshToken(t *testing.T) {
},
},
}

_, err := config.RefreshToken(context.Background(), nil, database.ExternalAuthLink{
OAuthExpiry: expired,
})
Expand Down Expand Up @@ -344,7 +346,6 @@ func TestRefreshToken(t *testing.T) {
require.NoError(t, err)
require.Equal(t, updated.OAuthAccessToken, dbLink.OAuthAccessToken, "token is updated in the DB")
})

t.Run("WithExtra", func(t *testing.T) {
t.Parallel()

Expand Down
33 changes: 33 additions & 0 deletions enterprise/dbcrypt/dbcrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,39 @@ func (db *dbCrypt) UpdateExternalAuthLink(ctx context.Context, params database.U
}

func (db *dbCrypt) UpdateExternalAuthLinkRefreshToken(ctx context.Context, params database.UpdateExternalAuthLinkRefreshTokenParams) error {
// The SQL query uses an optimistic lock:
// WHERE oauth_refresh_token = @old_oauth_refresh_token
// The caller supplies the plaintext old token (since dbcrypt
// decrypts on read), but the DB stores the encrypted value.
// Because AES-GCM is non-deterministic, we cannot simply
// re-encrypt the old token — the ciphertext would differ.
// Instead, read the current row from the inner (raw) store
// and use the actual encrypted value for the WHERE clause.
if params.OldOauthRefreshToken != "" && db.ciphers != nil && db.primaryCipherDigest != "" {
raw, err := db.Store.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{
ProviderID: params.ProviderID,
UserID: params.UserID,
})
if err != nil {
return err
}
// Decrypt the stored token so we can compare with the
// caller-supplied plaintext.
decrypted := raw.OAuthRefreshToken
if err := db.decryptField(&decrypted, raw.OAuthRefreshTokenKeyID); err != nil {
return err
}
if decrypted != params.OldOauthRefreshToken {
// The token has changed since the caller read it;
// the optimistic lock should fail (no rows updated).
// Return nil to match the :exec semantics of the SQL
// query, which silently updates zero rows.
return nil
}
// Use the raw encrypted value so the WHERE clause matches.
params.OldOauthRefreshToken = raw.OAuthRefreshToken
}

// We would normally use a sql.NullString here, but sqlc does not want to make
// a params struct with a nullable string.
var digest sql.NullString
Expand Down
1 change: 1 addition & 0 deletions enterprise/dbcrypt/dbcrypt_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ func TestUserLinks(t *testing.T) {
err := crypt.UpdateExternalAuthLinkRefreshToken(ctx, database.UpdateExternalAuthLinkRefreshTokenParams{
OAuthRefreshToken: "",
OAuthRefreshTokenKeyID: link.OAuthRefreshTokenKeyID.String,
OldOauthRefreshToken: link.OAuthRefreshToken,
UpdatedAt: dbtime.Now(),
ProviderID: link.ProviderID,
UserID: link.UserID,
Expand Down
Loading