diff --git a/apps/sim/app/api/auth/oauth/utils.test.ts b/apps/sim/app/api/auth/oauth/utils.test.ts index 7f67d37673a..78fdecf2975 100644 --- a/apps/sim/app/api/auth/oauth/utils.test.ts +++ b/apps/sim/app/api/auth/oauth/utils.test.ts @@ -4,6 +4,7 @@ * @vitest-environment node */ +import { redisConfigMock, redisConfigMockFns } from '@sim/testing' import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' vi.mock('@/lib/oauth/oauth', () => ({ @@ -11,7 +12,10 @@ vi.mock('@/lib/oauth/oauth', () => ({ OAUTH_PROVIDERS: {}, })) +vi.mock('@/lib/core/config/redis', () => redisConfigMock) + import { db } from '@sim/db' +import { __resetCoalesceLocallyForTests } from '@/lib/concurrency/singleflight' import { refreshOAuthToken } from '@/lib/oauth' import { getCredential, @@ -49,6 +53,10 @@ function mockUpdateChain() { describe('OAuth Utils', () => { beforeEach(() => { vi.clearAllMocks() + __resetCoalesceLocallyForTests() + redisConfigMockFns.mockGetRedisClient.mockReturnValue(null) + redisConfigMockFns.mockAcquireLock.mockResolvedValue(true) + redisConfigMockFns.mockReleaseLock.mockResolvedValue(true) }) afterEach(() => { @@ -107,6 +115,7 @@ describe('OAuth Utils', () => { } mockRefreshOAuthToken.mockResolvedValueOnce({ + ok: true, accessToken: 'new-token', expiresIn: 3600, refreshToken: 'new-refresh-token', @@ -130,7 +139,11 @@ describe('OAuth Utils', () => { providerId: 'google', } - mockRefreshOAuthToken.mockResolvedValueOnce(null) + mockRefreshOAuthToken.mockResolvedValueOnce({ + ok: false, + errorCode: 'invalid_grant', + message: 'Failed', + }) await expect( refreshTokenIfNeeded('request-id', mockCredential, 'credential-id') @@ -198,6 +211,7 @@ describe('OAuth Utils', () => { mockUpdateChain() mockRefreshOAuthToken.mockResolvedValueOnce({ + ok: true, accessToken: 'new-token', expiresIn: 3600, refreshToken: 'new-refresh-token', @@ -237,7 +251,11 @@ describe('OAuth Utils', () => { mockSelectChain([mockResolvedCredential]) mockSelectChain([mockAccountRow]) - mockRefreshOAuthToken.mockResolvedValueOnce(null) + mockRefreshOAuthToken.mockResolvedValueOnce({ + ok: false, + errorCode: 'invalid_grant', + message: 'Failed', + }) const token = await refreshAccessTokenIfNeeded('credential-id', 'test-user-id', 'request-id') diff --git a/apps/sim/app/api/auth/oauth/utils.ts b/apps/sim/app/api/auth/oauth/utils.ts index 4109441528d..bbfdb0135be 100644 --- a/apps/sim/app/api/auth/oauth/utils.ts +++ b/apps/sim/app/api/auth/oauth/utils.ts @@ -4,6 +4,8 @@ import { account, credential, credentialSetMember } from '@sim/db/schema' import { createLogger } from '@sim/logger' import { toError } from '@sim/utils/errors' import { and, desc, eq, inArray } from 'drizzle-orm' +import { withLeaderLock } from '@/lib/concurrency/leader-lock' +import { coalesceLocally } from '@/lib/concurrency/singleflight' import { decryptSecret } from '@/lib/core/security/encryption' import { refreshOAuthToken } from '@/lib/oauth' import { @@ -11,6 +13,11 @@ import { isMicrosoftProvider, PROACTIVE_REFRESH_THRESHOLD_DAYS, } from '@/lib/oauth/microsoft' +import { + getRecentTerminalError, + isTerminalRefreshError, + markCredentialDead, +} from '@/lib/oauth/terminal-errors' import { ATLASSIAN_SERVICE_ACCOUNT_PROVIDER_ID, ATLASSIAN_SERVICE_ACCOUNT_SECRET_TYPE, @@ -318,6 +325,112 @@ export async function getCredential(requestId: string, credentialId: string, use return getCredentialByAccountId(requestId, resolved.accountId, userId) } +interface CoalescedRefreshOptions { + accountId: string + providerId: string + refreshToken: string + requestId?: string + userId?: string +} + +async function performCoalescedRefresh({ + accountId, + providerId, + refreshToken, + requestId, + userId, +}: CoalescedRefreshOptions): Promise { + const logContext = { + ...(requestId ? { requestId } : {}), + ...(userId ? { userId } : {}), + providerId, + accountId, + } + + const deadCode = await getRecentTerminalError(accountId) + if (deadCode) { + logger.warn('Skipping refresh: credential recently failed', { + ...logContext, + errorCode: deadCode, + }) + return null + } + + const lockKey = `oauth:refresh:${accountId}` + + return coalesceLocally(lockKey, () => + withLeaderLock({ + key: lockKey, + onLeader: async () => { + try { + const result = await refreshOAuthToken(providerId, refreshToken) + + if (!result.ok) { + logger.error('Failed to refresh token', { + ...logContext, + errorCode: result.errorCode, + }) + if (result.errorCode && isTerminalRefreshError(result.errorCode)) { + await markCredentialDead(accountId, result.errorCode) + } + return null + } + + const updateData: Record = { + accessToken: result.accessToken, + accessTokenExpiresAt: new Date(Date.now() + result.expiresIn * 1000), + updatedAt: new Date(), + } + if (result.refreshToken && result.refreshToken !== refreshToken) { + updateData.refreshToken = result.refreshToken + } + if (isMicrosoftProvider(providerId)) { + updateData.refreshTokenExpiresAt = getMicrosoftRefreshTokenExpiry() + } + + await db.update(account).set(updateData).where(eq(account.id, accountId)) + + logger.info('Successfully refreshed access token', logContext) + return result.accessToken + } catch (error) { + logger.error('Refresh failed inside leader path', { + ...logContext, + error: toError(error).message, + }) + return null + } + }, + onFollower: async () => { + try { + const [row] = await db + .select({ + accessToken: account.accessToken, + accessTokenExpiresAt: account.accessTokenExpiresAt, + }) + .from(account) + .where(eq(account.id, accountId)) + .limit(1) + if ( + row?.accessToken && + row.accessTokenExpiresAt && + row.accessTokenExpiresAt > new Date() + ) { + logger.info('Got fresh access token from coalesced refresh', logContext) + return row.accessToken + } + return null + } catch (error) { + logger.warn('Follower DB read failed during refresh poll', { + ...logContext, + error: toError(error).message, + }) + return null + } + }, + }) + ) +} + export async function getOAuthToken(userId: string, providerId: string): Promise { const connections = await db .select({ @@ -347,52 +460,12 @@ export async function getOAuthToken(userId: string, providerId: string): Promise !!credential.refreshToken && (!credential.accessToken || (tokenExpiry && tokenExpiry < now)) if (shouldAttemptRefresh) { - logger.info( - `Access token expired for user ${userId}, provider ${providerId}. Attempting to refresh.` - ) - - try { - // Use the existing refreshOAuthToken function - const refreshResult = await refreshOAuthToken(providerId, credential.refreshToken!) - - if (!refreshResult) { - logger.error(`Failed to refresh token for user ${userId}, provider ${providerId}`, { - providerId, - userId, - hasRefreshToken: !!credential.refreshToken, - }) - return null - } - - const { accessToken, expiresIn, refreshToken: newRefreshToken } = refreshResult - - // Update the database with new tokens - const updateData: any = { - accessToken, - accessTokenExpiresAt: new Date(Date.now() + expiresIn * 1000), // Convert seconds to milliseconds - updatedAt: new Date(), - } - - // If we received a new refresh token (some providers like Airtable rotate them), save it - if (newRefreshToken && newRefreshToken !== credential.refreshToken) { - logger.info(`Updating refresh token for user ${userId}, provider ${providerId}`) - updateData.refreshToken = newRefreshToken - } - - // Update the token in the database with the actual expiration time from the provider - await db.update(account).set(updateData).where(eq(account.id, credential.id)) - - logger.info(`Successfully refreshed token for user ${userId}, provider ${providerId}`) - return accessToken - } catch (error) { - logger.error(`Error refreshing token for user ${userId}, provider ${providerId}`, { - error: toError(error).message, - stack: error instanceof Error ? error.stack : undefined, - providerId, - userId, - }) - return null - } + return performCoalescedRefresh({ + accountId: credential.id, + providerId, + refreshToken: credential.refreshToken!, + userId, + }) } if (!credential.accessToken) { @@ -472,66 +545,27 @@ export async function refreshAccessTokenIfNeeded( const accessToken = credential.accessToken if (shouldRefresh) { - logger.info(`[${requestId}] Refreshing token for credential`) - try { - const refreshedToken = await refreshOAuthToken( - credential.providerId, - credential.refreshToken! - ) - - if (!refreshedToken) { - logger.error(`[${requestId}] Failed to refresh token for credential: ${credentialId}`, { - credentialId, - providerId: credential.providerId, - userId: credential.userId, - hasRefreshToken: !!credential.refreshToken, - }) - if (!accessTokenNeedsRefresh && accessToken) { - logger.info(`[${requestId}] Proactive refresh failed but access token still valid`) - return accessToken - } - return null - } + const resolvedCredentialId = + (credential as { resolvedCredentialId?: string }).resolvedCredentialId ?? credentialId - // Prepare update data - const updateData: Record = { - accessToken: refreshedToken.accessToken, - accessTokenExpiresAt: new Date(Date.now() + refreshedToken.expiresIn * 1000), - updatedAt: new Date(), - } - - // If we received a new refresh token, update it - if (refreshedToken.refreshToken && refreshedToken.refreshToken !== credential.refreshToken) { - logger.info(`[${requestId}] Updating refresh token for credential`) - updateData.refreshToken = refreshedToken.refreshToken - } - - if (isMicrosoftProvider(credential.providerId)) { - updateData.refreshTokenExpiresAt = getMicrosoftRefreshTokenExpiry() - } + const fresh = await performCoalescedRefresh({ + accountId: resolvedCredentialId, + providerId: credential.providerId, + refreshToken: credential.refreshToken!, + requestId, + userId: credential.userId, + }) + if (fresh) return fresh - // Update the token in the database - const resolvedCredentialId = - (credential as { resolvedCredentialId?: string }).resolvedCredentialId ?? credentialId - await db.update(account).set(updateData).where(eq(account.id, resolvedCredentialId)) - - logger.info(`[${requestId}] Successfully refreshed access token for credential`) - return refreshedToken.accessToken - } catch (error) { - logger.error(`[${requestId}] Error refreshing token for credential`, { - error: toError(error).message, - stack: error instanceof Error ? error.stack : undefined, - providerId: credential.providerId, - credentialId, - userId: credential.userId, - }) - if (!accessTokenNeedsRefresh && accessToken) { - logger.info(`[${requestId}] Proactive refresh failed but access token still valid`) - return accessToken - } - return null + // If refresh was only triggered proactively (Microsoft refresh-token aging), + // the still-valid access token is a fine fallback. + if (!accessTokenNeedsRefresh && accessToken) { + logger.info(`[${requestId}] Refresh unavailable; reusing still-valid access token`) + return accessToken } - } else if (!accessToken) { + return null + } + if (!accessToken) { // We have no access token and either no refresh token or not eligible to refresh logger.error(`[${requestId}] Missing access token for credential`) return null @@ -580,65 +614,20 @@ export async function refreshTokenIfNeeded( return { accessToken: credential.accessToken, refreshed: false } } - try { - const refreshResult = await refreshOAuthToken(credential.providerId, credential.refreshToken!) - - if (!refreshResult) { - logger.error(`[${requestId}] Failed to refresh token for credential`) - if (!accessTokenNeedsRefresh && credential.accessToken) { - logger.info(`[${requestId}] Proactive refresh failed but access token still valid`) - return { accessToken: credential.accessToken, refreshed: false } - } - throw new Error('Failed to refresh token') - } - - const { accessToken: refreshedToken, expiresIn, refreshToken: newRefreshToken } = refreshResult - - // Prepare update data - const updateData: Record = { - accessToken: refreshedToken, - accessTokenExpiresAt: new Date(Date.now() + expiresIn * 1000), // Use provider's expiry - updatedAt: new Date(), - } - - // If we received a new refresh token, update it - if (newRefreshToken && newRefreshToken !== credential.refreshToken) { - logger.info(`[${requestId}] Updating refresh token`) - updateData.refreshToken = newRefreshToken - } - - if (isMicrosoftProvider(credential.providerId)) { - updateData.refreshTokenExpiresAt = getMicrosoftRefreshTokenExpiry() - } - - await db.update(account).set(updateData).where(eq(account.id, resolvedCredentialId)) - - logger.info(`[${requestId}] Successfully refreshed access token`) - return { accessToken: refreshedToken, refreshed: true } - } catch (error) { - logger.warn( - `[${requestId}] Refresh attempt failed, checking if another concurrent request succeeded` - ) - - const freshCredential = await getCredential(requestId, resolvedCredentialId, credential.userId) - if (freshCredential?.accessToken) { - const freshExpiresAt = freshCredential.accessTokenExpiresAt - const stillValid = !freshExpiresAt || freshExpiresAt > new Date() - - if (stillValid) { - logger.info(`[${requestId}] Found valid token from concurrent refresh, using it`) - return { accessToken: freshCredential.accessToken, refreshed: true } - } - } - - if (!accessTokenNeedsRefresh && credential.accessToken) { - logger.info(`[${requestId}] Proactive refresh failed but access token still valid`) - return { accessToken: credential.accessToken, refreshed: false } - } + const fresh = await performCoalescedRefresh({ + accountId: resolvedCredentialId, + providerId: credential.providerId, + refreshToken: credential.refreshToken!, + requestId, + userId: credential.userId, + }) + if (fresh) return { accessToken: fresh, refreshed: true } - logger.error(`[${requestId}] Refresh failed and no valid token found in DB`, error) - throw error + if (!accessTokenNeedsRefresh && credential.accessToken) { + logger.info(`[${requestId}] Refresh unavailable; reusing still-valid access token`) + return { accessToken: credential.accessToken, refreshed: false } } + throw new Error('Failed to refresh token') } export interface CredentialSetCredential { @@ -701,32 +690,13 @@ export async function getCredentialsForCredentialSet( let accessToken = cred.accessToken if (shouldRefresh && cred.refreshToken) { - try { - const refreshResult = await refreshOAuthToken(providerId, cred.refreshToken) - - if (refreshResult) { - accessToken = refreshResult.accessToken - - const updateData: Record = { - accessToken: refreshResult.accessToken, - accessTokenExpiresAt: new Date(Date.now() + refreshResult.expiresIn * 1000), - updatedAt: new Date(), - } - - if (refreshResult.refreshToken && refreshResult.refreshToken !== cred.refreshToken) { - updateData.refreshToken = refreshResult.refreshToken - } - - await db.update(account).set(updateData).where(eq(account.id, cred.id)) - - logger.info(`Refreshed token for user ${cred.userId}, provider ${providerId}`) - } - } catch (error) { - logger.error(`Failed to refresh token for user ${cred.userId}, provider ${providerId}`, { - error: toError(error).message, - }) - continue - } + const fresh = await performCoalescedRefresh({ + accountId: cred.id, + providerId, + refreshToken: cred.refreshToken, + userId: cred.userId, + }) + if (fresh) accessToken = fresh } if (accessToken) { diff --git a/apps/sim/lib/concurrency/__tests__/leader-lock.test.ts b/apps/sim/lib/concurrency/__tests__/leader-lock.test.ts new file mode 100644 index 00000000000..8e0cbcc6cde --- /dev/null +++ b/apps/sim/lib/concurrency/__tests__/leader-lock.test.ts @@ -0,0 +1,168 @@ +/** + * @vitest-environment node + */ +import { redisConfigMock, redisConfigMockFns } from '@sim/testing' +import { sleep } from '@sim/utils/helpers' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +vi.mock('@/lib/core/config/redis', () => redisConfigMock) + +import { withLeaderLock } from '@/lib/concurrency/leader-lock' + +beforeEach(() => { + vi.clearAllMocks() + redisConfigMockFns.mockAcquireLock.mockResolvedValue(true) + redisConfigMockFns.mockReleaseLock.mockResolvedValue(true) +}) + +describe('withLeaderLock', () => { + it('runs onLeader exactly once when lock acquired', async () => { + const onLeader = vi.fn(async () => 'leader-result') + const onFollower = vi.fn(async () => null) + + const result = await withLeaderLock({ + key: 'k', + onLeader, + onFollower, + }) + + expect(result).toBe('leader-result') + expect(onLeader).toHaveBeenCalledTimes(1) + expect(onFollower).not.toHaveBeenCalled() + expect(redisConfigMockFns.mockReleaseLock).toHaveBeenCalledTimes(1) + }) + + it('passes a fresh owner token to acquireLock and releaseLock', async () => { + await withLeaderLock({ + key: 'k', + onLeader: async () => 'x', + onFollower: async () => null, + }) + + const [acquireKey, acquireValue] = redisConfigMockFns.mockAcquireLock.mock.calls[0]! + const [releaseKey, releaseValue] = redisConfigMockFns.mockReleaseLock.mock.calls[0]! + + expect(acquireKey).toBe('k') + expect(releaseKey).toBe('k') + expect(acquireValue).toBe(releaseValue) + expect(typeof acquireValue).toBe('string') + expect((acquireValue as string).length).toBeGreaterThan(0) + }) + + it('falls back to uncoordinated leader when acquireLock throws', async () => { + redisConfigMockFns.mockAcquireLock.mockRejectedValueOnce(new Error('redis down')) + + const onLeader = vi.fn(async () => 'fallback') + const onFollower = vi.fn(async () => null) + + const result = await withLeaderLock({ + key: 'k', + onLeader, + onFollower, + }) + + expect(result).toBe('fallback') + expect(onLeader).toHaveBeenCalledTimes(1) + expect(onFollower).not.toHaveBeenCalled() + expect(redisConfigMockFns.mockReleaseLock).not.toHaveBeenCalled() + }) + + it('does not propagate releaseLock errors out of the leader path', async () => { + redisConfigMockFns.mockReleaseLock.mockRejectedValueOnce(new Error('redis blip')) + + const result = await withLeaderLock({ + key: 'k', + onLeader: async () => 'leader-value', + onFollower: async () => null, + }) + + expect(result).toBe('leader-value') + }) + + it('releases the lock even when onLeader throws', async () => { + const onLeader = vi.fn(async () => { + throw new Error('boom') + }) + + await expect( + withLeaderLock({ + key: 'k', + onLeader, + onFollower: async () => null, + }) + ).rejects.toThrow('boom') + + expect(redisConfigMockFns.mockReleaseLock).toHaveBeenCalledTimes(1) + }) + + it('follower polls onFollower until it returns non-null', async () => { + redisConfigMockFns.mockAcquireLock.mockResolvedValueOnce(false) + + let polls = 0 + const onFollower = vi.fn(async () => { + polls += 1 + if (polls >= 2) return 'available' + return null + }) + + const result = await withLeaderLock({ + key: 'k', + pollIntervalMs: 5, + maxWaitMs: 1000, + onLeader: async () => 'should-not-run', + onFollower, + }) + + expect(result).toBe('available') + expect(onFollower.mock.calls.length).toBeGreaterThanOrEqual(2) + }) + + it('follower returns null after timeout', async () => { + redisConfigMockFns.mockAcquireLock.mockResolvedValueOnce(false) + + const result = await withLeaderLock({ + key: 'k', + pollIntervalMs: 5, + maxWaitMs: 20, + onLeader: async () => 'should-not-run', + onFollower: async () => null, + }) + + expect(result).toBeNull() + }) + + it('only one of N concurrent callers acquires the lock', async () => { + // Track which calls won the lock: first one returns true, rest return false. + let acquired = false + redisConfigMockFns.mockAcquireLock.mockImplementation(async () => { + if (acquired) return false + acquired = true + return true + }) + redisConfigMockFns.mockReleaseLock.mockImplementation(async () => { + acquired = false + return true + }) + + let leaderRuns = 0 + + const callers = Array.from({ length: 5 }, () => + withLeaderLock({ + key: 'shared', + pollIntervalMs: 5, + maxWaitMs: 200, + onLeader: async () => { + leaderRuns += 1 + await sleep(20) + return 'leader-value' + }, + onFollower: async () => (acquired ? null : 'follower-saw-released'), + }) + ) + + const results = await Promise.all(callers) + expect(leaderRuns).toBe(1) + expect(results.filter((r) => r === 'leader-value').length).toBe(1) + expect(results.filter((r) => r === 'follower-saw-released').length).toBeGreaterThan(0) + }) +}) diff --git a/apps/sim/lib/concurrency/__tests__/singleflight.test.ts b/apps/sim/lib/concurrency/__tests__/singleflight.test.ts new file mode 100644 index 00000000000..f911266be7b --- /dev/null +++ b/apps/sim/lib/concurrency/__tests__/singleflight.test.ts @@ -0,0 +1,65 @@ +/** + * @vitest-environment node + */ +import { sleep } from '@sim/utils/helpers' +import { afterEach, describe, expect, it, vi } from 'vitest' +import { __resetCoalesceLocallyForTests, coalesceLocally } from '@/lib/concurrency/singleflight' + +afterEach(() => { + __resetCoalesceLocallyForTests() + vi.restoreAllMocks() +}) + +describe('coalesceLocally', () => { + it('invokes fn once when N callers race on the same key', async () => { + const fn = vi.fn(async () => { + await sleep(5) + return 'value' + }) + + const results = await Promise.all( + Array.from({ length: 10 }, () => coalesceLocally('shared', fn)) + ) + + expect(fn).toHaveBeenCalledTimes(1) + expect(results).toEqual(Array.from({ length: 10 }, () => 'value')) + }) + + it('returns the same promise instance to concurrent callers', () => { + const fn = async () => { + await sleep(10) + return 1 + } + const a = coalesceLocally('same-key', fn) + const b = coalesceLocally('same-key', fn) + expect(a).toBe(b) + }) + + it('clears the cache after success so the next call invokes fn again', async () => { + let count = 0 + const fn = async () => { + count += 1 + return count + } + + expect(await coalesceLocally('once', fn)).toBe(1) + expect(await coalesceLocally('once', fn)).toBe(2) + }) + + it('clears the cache after rejection so the next call invokes fn again', async () => { + let count = 0 + const fn = async () => { + count += 1 + throw new Error(`fail ${count}`) + } + + await expect(coalesceLocally('rejection', fn)).rejects.toThrow('fail 1') + await expect(coalesceLocally('rejection', fn)).rejects.toThrow('fail 2') + }) + + it('does not coalesce across distinct keys', async () => { + const fn = vi.fn(async () => 'value') + await Promise.all([coalesceLocally('a', fn), coalesceLocally('b', fn)]) + expect(fn).toHaveBeenCalledTimes(2) + }) +}) diff --git a/apps/sim/lib/concurrency/leader-lock.ts b/apps/sim/lib/concurrency/leader-lock.ts new file mode 100644 index 00000000000..dd0ed0402a2 --- /dev/null +++ b/apps/sim/lib/concurrency/leader-lock.ts @@ -0,0 +1,69 @@ +import { createLogger } from '@sim/logger' +import { toError } from '@sim/utils/errors' +import { sleep } from '@sim/utils/helpers' +import { generateShortId } from '@sim/utils/id' +import { acquireLock, releaseLock } from '@/lib/core/config/redis' + +const logger = createLogger('LeaderLock') + +const DEFAULT_TTL_SEC = 10 +const DEFAULT_POLL_INTERVAL_MS = 100 +const DEFAULT_MAX_WAIT_MS = 3_000 + +export interface LeaderLockOptions { + key: string + ttlSec?: number + pollIntervalMs?: number + maxWaitMs?: number + onLeader: () => Promise + onFollower: () => Promise +} + +export async function withLeaderLock(opts: LeaderLockOptions): Promise { + const { + key, + ttlSec = DEFAULT_TTL_SEC, + pollIntervalMs = DEFAULT_POLL_INTERVAL_MS, + maxWaitMs = DEFAULT_MAX_WAIT_MS, + onLeader, + onFollower, + } = opts + + const ownerToken = generateShortId() + + let acquired = false + try { + acquired = await acquireLock(key, ownerToken, ttlSec) + } catch (error) { + logger.warn('Lock acquisition failed; running leader path uncoordinated', { + key, + error: toError(error).message, + }) + return onLeader() + } + + if (acquired) { + try { + return await onLeader() + } finally { + try { + await releaseLock(key, ownerToken) + } catch (error) { + logger.warn('Lock release failed (will expire via TTL)', { + key, + error: toError(error).message, + }) + } + } + } + + const deadline = Date.now() + maxWaitMs + while (Date.now() < deadline) { + await sleep(pollIntervalMs) + const value = await onFollower() + if (value !== null) return value + } + + logger.warn('Follower timed out waiting for leader', { key, maxWaitMs }) + return null +} diff --git a/apps/sim/lib/concurrency/singleflight.ts b/apps/sim/lib/concurrency/singleflight.ts new file mode 100644 index 00000000000..f15ae06f9da --- /dev/null +++ b/apps/sim/lib/concurrency/singleflight.ts @@ -0,0 +1,21 @@ +const inflight = new Map>() + +export function coalesceLocally(key: string, fn: () => Promise): Promise { + const existing = inflight.get(key) as Promise | undefined + if (existing) return existing + + const promise = (async () => { + try { + return await fn() + } finally { + inflight.delete(key) + } + })() + + inflight.set(key, promise) + return promise +} + +export function __resetCoalesceLocallyForTests(): void { + inflight.clear() +} diff --git a/apps/sim/lib/credentials/draft-hooks.ts b/apps/sim/lib/credentials/draft-hooks.ts index c5768c96aee..07361e58a70 100644 --- a/apps/sim/lib/credentials/draft-hooks.ts +++ b/apps/sim/lib/credentials/draft-hooks.ts @@ -4,6 +4,7 @@ import * as schema from '@sim/db/schema' import { createLogger } from '@sim/logger' import { generateId } from '@sim/utils/id' import { and, eq, sql } from 'drizzle-orm' +import { clearDeadFlag } from '@/lib/oauth/terminal-errors' const logger = createLogger('CredentialDraftHooks') @@ -53,6 +54,8 @@ export async function handleCreateCredentialFromDraft(params: { accountId, }) + await clearDeadFlag(accountId) + recordAudit({ workspaceId: draft.workspaceId, actorId: userId, @@ -147,6 +150,8 @@ export async function handleReconnectCredential(params: { newAccountId, }) + await clearDeadFlag(newAccountId) + recordAudit({ workspaceId, actorId: userId, diff --git a/apps/sim/lib/oauth/__tests__/terminal-errors.test.ts b/apps/sim/lib/oauth/__tests__/terminal-errors.test.ts new file mode 100644 index 00000000000..3fd2787a1ea --- /dev/null +++ b/apps/sim/lib/oauth/__tests__/terminal-errors.test.ts @@ -0,0 +1,105 @@ +/** + * @vitest-environment node + */ +import { redisConfigMock, redisConfigMockFns } from '@sim/testing' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +vi.mock('@/lib/core/config/redis', () => redisConfigMock) + +import { + clearDeadFlag, + getRecentTerminalError, + isTerminalRefreshError, + markCredentialDead, +} from '@/lib/oauth/terminal-errors' + +interface FakeRedis { + store: Map + set: ReturnType + get: ReturnType + del: ReturnType +} + +function createFakeRedis(): FakeRedis { + const store = new Map() + return { + store, + set: vi.fn(async (key: string, value: string) => { + store.set(key, value) + return 'OK' + }), + get: vi.fn(async (key: string) => store.get(key) ?? null), + del: vi.fn(async (key: string) => (store.delete(key) ? 1 : 0)), + } +} + +beforeEach(() => { + vi.clearAllMocks() + redisConfigMockFns.mockGetRedisClient.mockReturnValue(null) +}) + +describe('isTerminalRefreshError', () => { + it.each([ + 'invalid_refresh_token', + 'invalid_grant', + 'access_denied', + 'bad_client_secret', + 'invalid_client_id', + 'invalid_client', + 'bad_redirect_uri', + ])('returns true for %s', (code) => { + expect(isTerminalRefreshError(code)).toBe(true) + }) + + it.each(['ratelimited', 'internal_error', 'service_unavailable', undefined, null, ''])( + 'returns false for %s', + (code) => { + expect(isTerminalRefreshError(code as string | undefined | null)).toBe(false) + } + ) +}) + +describe('markCredentialDead / getRecentTerminalError / clearDeadFlag', () => { + it('roundtrips a code through Redis', async () => { + const redis = createFakeRedis() + redisConfigMockFns.mockGetRedisClient.mockReturnValue(redis as never) + + await markCredentialDead('acc-1', 'invalid_refresh_token') + expect(await getRecentTerminalError('acc-1')).toBe('invalid_refresh_token') + }) + + it('clearDeadFlag removes the entry', async () => { + const redis = createFakeRedis() + redisConfigMockFns.mockGetRedisClient.mockReturnValue(redis as never) + + await markCredentialDead('acc-1', 'invalid_refresh_token') + await clearDeadFlag('acc-1') + expect(await getRecentTerminalError('acc-1')).toBeNull() + }) + + it('all functions are no-ops when Redis is unavailable', async () => { + await expect(markCredentialDead('acc-1', 'code')).resolves.toBeUndefined() + await expect(getRecentTerminalError('acc-1')).resolves.toBeNull() + await expect(clearDeadFlag('acc-1')).resolves.toBeUndefined() + }) + + it('absorbs Redis errors without throwing', async () => { + const redis = createFakeRedis() + redis.set.mockRejectedValueOnce(new Error('boom')) + redis.get.mockRejectedValueOnce(new Error('boom')) + redis.del.mockRejectedValueOnce(new Error('boom')) + redisConfigMockFns.mockGetRedisClient.mockReturnValue(redis as never) + + await expect(markCredentialDead('acc-1', 'code')).resolves.toBeUndefined() + await expect(getRecentTerminalError('acc-1')).resolves.toBeNull() + await expect(clearDeadFlag('acc-1')).resolves.toBeUndefined() + }) + + it('uses a 1-hour TTL on the dead flag', async () => { + const redis = createFakeRedis() + redisConfigMockFns.mockGetRedisClient.mockReturnValue(redis as never) + + await markCredentialDead('acc-1', 'invalid_refresh_token') + expect(redis.set).toHaveBeenCalledWith('oauth:dead:acc-1', 'invalid_refresh_token', 'EX', 3600) + }) +}) diff --git a/apps/sim/lib/oauth/oauth.test.ts b/apps/sim/lib/oauth/oauth.test.ts index c677c7b2c84..a5ad87e2f2b 100644 --- a/apps/sim/lib/oauth/oauth.test.ts +++ b/apps/sim/lib/oauth/oauth.test.ts @@ -326,7 +326,7 @@ describe('OAuth Token Refresh', () => { }) describe('Error Handling', () => { - it.concurrent('should return null for unsupported provider', async () => { + it.concurrent('should return failure for unsupported provider', async () => { const mockFetch = createMockFetch(defaultOAuthResponse) const refreshToken = 'test_refresh_token' @@ -334,10 +334,10 @@ describe('OAuth Token Refresh', () => { refreshOAuthToken('unsupported', refreshToken) ) - expect(result).toBeNull() + expect(result.ok).toBe(false) }) - it.concurrent('should return null for API error responses', async () => { + it.concurrent('should return failure with errorCode for HTTP error responses', async () => { const mockFetch = vi.fn().mockResolvedValue({ ok: false, status: 400, @@ -351,16 +351,36 @@ describe('OAuth Token Refresh', () => { const result = await withMockFetch(mockFetch, () => refreshOAuthToken('google', refreshToken)) - expect(result).toBeNull() + expect(result.ok).toBe(false) + if (!result.ok) { + expect(result.errorCode).toBe('invalid_request') + } + }) + + it.concurrent('should return failure for Slack-style body errors', async () => { + const mockFetch = vi.fn().mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({ ok: false, error: 'invalid_refresh_token' }), + }) + const refreshToken = 'test_refresh_token' + + const result = await withMockFetch(mockFetch, () => refreshOAuthToken('slack', refreshToken)) + + expect(result.ok).toBe(false) + if (!result.ok) { + expect(result.errorCode).toBe('invalid_refresh_token') + } }) - it.concurrent('should return null for network errors', async () => { + it.concurrent('should return failure for network errors', async () => { const mockFetch = vi.fn().mockRejectedValue(new Error('Network error')) const refreshToken = 'test_refresh_token' const result = await withMockFetch(mockFetch, () => refreshOAuthToken('google', refreshToken)) - expect(result).toBeNull() + expect(result.ok).toBe(false) }) }) @@ -383,6 +403,7 @@ describe('OAuth Token Refresh', () => { ) expect(result).toEqual({ + ok: true, accessToken: 'new_access_token', expiresIn: 3600, refreshToken: newRefreshToken, @@ -421,6 +442,7 @@ describe('OAuth Token Refresh', () => { ) expect(result).toEqual({ + ok: true, accessToken: 'new_access_token', expiresIn: 3600, refreshToken: rotatedRefreshToken, @@ -443,13 +465,14 @@ describe('OAuth Token Refresh', () => { const result = await withMockFetch(mockFetch, () => refreshOAuthToken('google', refreshToken)) expect(result).toEqual({ + ok: true, accessToken: 'new_access_token', expiresIn: 3600, refreshToken: refreshToken, }) }) - it.concurrent('should return null when access token is missing', async () => { + it.concurrent('should return failure when access token is missing', async () => { const refreshToken = 'test_refresh_token' const mockFetch = vi.fn().mockResolvedValue({ @@ -461,7 +484,7 @@ describe('OAuth Token Refresh', () => { const result = await withMockFetch(mockFetch, () => refreshOAuthToken('google', refreshToken)) - expect(result).toBeNull() + expect(result.ok).toBe(false) }) it.concurrent('should use default expiration when not provided', async () => { @@ -477,6 +500,7 @@ describe('OAuth Token Refresh', () => { const result = await withMockFetch(mockFetch, () => refreshOAuthToken('google', refreshToken)) expect(result).toEqual({ + ok: true, accessToken: 'new_access_token', expiresIn: 3600, refreshToken: refreshToken, diff --git a/apps/sim/lib/oauth/oauth.ts b/apps/sim/lib/oauth/oauth.ts index 8702399024a..82c3ee7bb1e 100644 --- a/apps/sim/lib/oauth/oauth.ts +++ b/apps/sim/lib/oauth/oauth.ts @@ -1456,13 +1456,6 @@ function buildAuthRequest( return { headers, bodyParams, useJsonBody: config.useJsonBody } } -/** - * Refresh an OAuth token - * This is a server-side utility function to refresh OAuth tokens - * @param providerId The provider ID (e.g., 'google-drive') - * @param refreshToken The refresh token to use - * @returns Object containing the new access token and expiration time in seconds, or null if refresh failed - */ function getBaseProviderForService(providerId: string): string { if (providerId in OAUTH_PROVIDERS) { return providerId @@ -1479,10 +1472,33 @@ function getBaseProviderForService(providerId: string): string { throw new Error(`Unknown OAuth provider: ${providerId}`) } +export interface RefreshTokenSuccess { + ok: true + accessToken: string + expiresIn: number + refreshToken: string +} + +export interface RefreshTokenFailure { + ok: false + errorCode?: string + message?: string +} + +export type RefreshTokenResult = RefreshTokenSuccess | RefreshTokenFailure + +function extractErrorCode(value: unknown): string | undefined { + if (value && typeof value === 'object' && 'error' in value) { + const code = (value as { error: unknown }).error + if (typeof code === 'string') return code + } + return undefined +} + export async function refreshOAuthToken( providerId: string, refreshToken: string -): Promise<{ accessToken: string; expiresIn: number; refreshToken: string } | null> { +): Promise { try { const provider = getBaseProviderForService(providerId) @@ -1498,7 +1514,7 @@ export async function refreshOAuthToken( if (!response.ok) { const errorText = await response.text() - let errorData = errorText + let errorData: unknown = errorText try { errorData = JSON.parse(errorText) @@ -1518,11 +1534,34 @@ export async function refreshOAuthToken( hasRefreshToken: !!refreshToken, refreshTokenPrefix: refreshToken ? `${refreshToken.substring(0, 10)}...` : 'none', }) - throw new Error(`Failed to refresh token: ${response.status} ${errorText}`) + return { + ok: false, + errorCode: extractErrorCode(errorData), + message: `Failed to refresh token: ${response.status} ${errorText}`, + } } const data = await response.json() + if (data && typeof data === 'object' && data.ok === false) { + logger.error('Token refresh failed:', { + status: response.status, + statusText: response.statusText, + error: data.error, + parsedError: data, + providerId, + tokenEndpoint: config.tokenEndpoint, + hasClientId: !!config.clientId, + hasClientSecret: !!config.clientSecret, + hasRefreshToken: !!refreshToken, + }) + return { + ok: false, + errorCode: typeof data.error === 'string' ? data.error : undefined, + message: `Failed to refresh token: ${data.error ?? 'unknown'}`, + } + } + const accessToken = data.access_token let newRefreshToken = null @@ -1534,8 +1573,8 @@ export async function refreshOAuthToken( const expiresIn = data.expires_in || data.expiresIn || 3600 if (!accessToken) { - logger.warn('No access token found in refresh response', data) - return null + logger.warn('No access token found in refresh response', { providerId, response: data }) + return { ok: false, message: 'No access token in refresh response' } } logger.info('Token refreshed successfully with expiration', { @@ -1545,14 +1584,14 @@ export async function refreshOAuthToken( }) return { + ok: true, accessToken, expiresIn, refreshToken: newRefreshToken || refreshToken, // Return new refresh token if available } } catch (error) { - logger.error('Error refreshing token:', { - error: toError(error).message, - }) - return null + const message = toError(error).message + logger.error('Error refreshing token:', { error: message }) + return { ok: false, message } } } diff --git a/apps/sim/lib/oauth/terminal-errors.ts b/apps/sim/lib/oauth/terminal-errors.ts new file mode 100644 index 00000000000..25fba73205c --- /dev/null +++ b/apps/sim/lib/oauth/terminal-errors.ts @@ -0,0 +1,67 @@ +import { createLogger } from '@sim/logger' +import { toError } from '@sim/utils/errors' +import { getRedisClient } from '@/lib/core/config/redis' + +const logger = createLogger('OAuthTerminalErrors') + +const TERMINAL_ERRORS = new Set([ + 'invalid_refresh_token', + 'invalid_grant', + 'access_denied', + 'bad_client_secret', + 'invalid_client_id', + 'invalid_client', + 'bad_redirect_uri', +]) + +const DEAD_CACHE_TTL_SEC = 60 * 60 + +function deadKey(accountId: string): string { + return `oauth:dead:${accountId}` +} + +export function isTerminalRefreshError(code: string | undefined | null): boolean { + if (!code) return false + return TERMINAL_ERRORS.has(code) +} + +export async function markCredentialDead(accountId: string, code: string): Promise { + const redis = getRedisClient() + if (!redis) return + try { + await redis.set(deadKey(accountId), code, 'EX', DEAD_CACHE_TTL_SEC) + } catch (error) { + logger.warn('Failed to mark credential dead in Redis', { + accountId, + code, + error: toError(error).message, + }) + } +} + +export async function getRecentTerminalError(accountId: string): Promise { + const redis = getRedisClient() + if (!redis) return null + try { + return await redis.get(deadKey(accountId)) + } catch (error) { + logger.warn('Failed to read terminal error flag from Redis', { + accountId, + error: toError(error).message, + }) + return null + } +} + +export async function clearDeadFlag(accountId: string): Promise { + const redis = getRedisClient() + if (!redis) return + try { + await redis.del(deadKey(accountId)) + } catch (error) { + logger.warn('Failed to clear terminal error flag from Redis', { + accountId, + error: toError(error).message, + }) + } +}