diff --git a/apps/sim/app/api/auth/oauth/credentials/route.ts b/apps/sim/app/api/auth/oauth/credentials/route.ts index d7af4dea365..8cfe86e8b26 100644 --- a/apps/sim/app/api/auth/oauth/credentials/route.ts +++ b/apps/sim/app/api/auth/oauth/credentials/route.ts @@ -3,8 +3,8 @@ import { jwtDecode } from 'jwt-decode' import { type NextRequest, NextResponse } from 'next/server' import { getSession } from '@/lib/auth' import { createLogger } from '@/lib/logs/console-logger' -import type { OAuthService } from '@/lib/oauth' -import { parseProvider } from '@/lib/oauth' +import type { OAuthService } from '@/lib/oauth/oauth' +import { parseProvider } from '@/lib/oauth/oauth' import { db } from '@/db' import { account, user } from '@/db/schema' diff --git a/apps/sim/app/api/auth/oauth/utils.test.ts b/apps/sim/app/api/auth/oauth/utils.test.ts index 5f59e545d9e..6d06d21b973 100644 --- a/apps/sim/app/api/auth/oauth/utils.test.ts +++ b/apps/sim/app/api/auth/oauth/utils.test.ts @@ -35,7 +35,7 @@ describe('OAuth Utils', () => { db: mockDb, })) - vi.doMock('@/lib/oauth', () => ({ + vi.doMock('@/lib/oauth/oauth', () => ({ refreshOAuthToken: mockRefreshOAuthToken, })) @@ -181,13 +181,13 @@ describe('OAuth Utils', () => { providerId: 'google', } - mockRefreshOAuthToken.mockRejectedValueOnce(new Error('Refresh failed')) + mockRefreshOAuthToken.mockResolvedValueOnce(null) const { refreshTokenIfNeeded } = await import('./utils') await expect( refreshTokenIfNeeded('request-id', mockCredential, 'credential-id') - ).rejects.toThrow() + ).rejects.toThrow('Failed to refresh token') expect(mockLogger.error).toHaveBeenCalled() }) diff --git a/apps/sim/app/api/auth/oauth/utils.ts b/apps/sim/app/api/auth/oauth/utils.ts index 19e74e2c260..24ff8ad2a3f 100644 --- a/apps/sim/app/api/auth/oauth/utils.ts +++ b/apps/sim/app/api/auth/oauth/utils.ts @@ -1,7 +1,7 @@ import { and, eq } from 'drizzle-orm' import { getSession } from '@/lib/auth' import { createLogger } from '@/lib/logs/console-logger' -import { refreshOAuthToken } from '@/lib/oauth' +import { refreshOAuthToken } from '@/lib/oauth/oauth' import { db } from '@/db' import { account, workflow } from '@/db/schema' diff --git a/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/credential-selector/components/oauth-required-modal.tsx b/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/credential-selector/components/oauth-required-modal.tsx index 8c06924f53e..57d8f12d2e1 100644 --- a/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/credential-selector/components/oauth-required-modal.tsx +++ b/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/credential-selector/components/oauth-required-modal.tsx @@ -18,7 +18,7 @@ import { OAUTH_PROVIDERS, type OAuthProvider, parseProvider, -} from '@/lib/oauth' +} from '@/lib/oauth/oauth' import { saveToStorage } from '@/stores/workflows/persistence' const logger = createLogger('OAuthRequiredModal') diff --git a/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/credential-selector/credential-selector.tsx b/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/credential-selector/credential-selector.tsx index 717dbacd072..ad08ad9f887 100644 --- a/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/credential-selector/credential-selector.tsx +++ b/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/credential-selector/credential-selector.tsx @@ -20,7 +20,7 @@ import { OAUTH_PROVIDERS, type OAuthProvider, parseProvider, -} from '@/lib/oauth' +} from '@/lib/oauth/oauth' import { saveToStorage } from '@/stores/workflows/persistence' import { OAuthRequiredModal } from './components/oauth-required-modal' diff --git a/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/file-selector/components/confluence-file-selector.tsx b/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/file-selector/components/confluence-file-selector.tsx index a7f2e9e7bd2..f453137370c 100644 --- a/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/file-selector/components/confluence-file-selector.tsx +++ b/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/file-selector/components/confluence-file-selector.tsx @@ -18,7 +18,7 @@ import { getProviderIdFromServiceId, getServiceIdFromScopes, type OAuthProvider, -} from '@/lib/oauth' +} from '@/lib/oauth/oauth' import { saveToStorage } from '@/stores/workflows/persistence' import { OAuthRequiredModal } from '../../credential-selector/components/oauth-required-modal' diff --git a/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/file-selector/components/google-drive-picker.tsx b/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/file-selector/components/google-drive-picker.tsx index 2d0938bae5d..a151b14835e 100644 --- a/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/file-selector/components/google-drive-picker.tsx +++ b/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/file-selector/components/google-drive-picker.tsx @@ -23,7 +23,7 @@ import { OAUTH_PROVIDERS, type OAuthProvider, parseProvider, -} from '@/lib/oauth' +} from '@/lib/oauth/oauth' import { saveToStorage } from '@/stores/workflows/persistence' import { OAuthRequiredModal } from '../../credential-selector/components/oauth-required-modal' diff --git a/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/file-selector/components/jira-issue-selector.tsx b/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/file-selector/components/jira-issue-selector.tsx index 2a2c2922996..cc9af8850b5 100644 --- a/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/file-selector/components/jira-issue-selector.tsx +++ b/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/file-selector/components/jira-issue-selector.tsx @@ -19,7 +19,7 @@ import { getProviderIdFromServiceId, getServiceIdFromScopes, type OAuthProvider, -} from '@/lib/oauth' +} from '@/lib/oauth/oauth' import { saveToStorage } from '@/stores/workflows/persistence' import { OAuthRequiredModal } from '../../credential-selector/components/oauth-required-modal' diff --git a/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/file-selector/components/microsoft-file-selector.tsx b/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/file-selector/components/microsoft-file-selector.tsx index a04f7e01c6b..b13104085f7 100644 --- a/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/file-selector/components/microsoft-file-selector.tsx +++ b/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/file-selector/components/microsoft-file-selector.tsx @@ -22,7 +22,7 @@ import { OAUTH_PROVIDERS, type OAuthProvider, parseProvider, -} from '@/lib/oauth' +} from '@/lib/oauth/oauth' import { saveToStorage } from '@/stores/workflows/persistence' import { OAuthRequiredModal } from '../../credential-selector/components/oauth-required-modal' diff --git a/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/file-selector/components/teams-message-selector.tsx b/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/file-selector/components/teams-message-selector.tsx index f71abcfdce5..df2c3a25bf0 100644 --- a/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/file-selector/components/teams-message-selector.tsx +++ b/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/file-selector/components/teams-message-selector.tsx @@ -19,7 +19,7 @@ import { getProviderIdFromServiceId, getServiceIdFromScopes, type OAuthProvider, -} from '@/lib/oauth' +} from '@/lib/oauth/oauth' import { saveToStorage } from '@/stores/workflows/persistence' import { OAuthRequiredModal } from '../../credential-selector/components/oauth-required-modal' diff --git a/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/folder-selector/folder-selector.tsx b/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/folder-selector/folder-selector.tsx index 2a0d9990eab..a28983b8e82 100644 --- a/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/folder-selector/folder-selector.tsx +++ b/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/folder-selector/folder-selector.tsx @@ -14,7 +14,11 @@ import { } from '@/components/ui/command' import { Popover, PopoverContent, PopoverTrigger } from '@/components/ui/popover' import { createLogger } from '@/lib/logs/console-logger' -import { type Credential, getProviderIdFromServiceId, getServiceIdFromScopes } from '@/lib/oauth' +import { + type Credential, + getProviderIdFromServiceId, + getServiceIdFromScopes, +} from '@/lib/oauth/oauth' import { OAuthRequiredModal } from '@/app/w/[id]/components/workflow-block/components/sub-block/components/credential-selector/components/oauth-required-modal' import { saveToStorage } from '@/stores/workflows/persistence' diff --git a/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/project-selector/components/jira-project-selector.tsx b/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/project-selector/components/jira-project-selector.tsx index 5d562708a2a..d9f171edb72 100644 --- a/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/project-selector/components/jira-project-selector.tsx +++ b/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/project-selector/components/jira-project-selector.tsx @@ -19,7 +19,7 @@ import { getProviderIdFromServiceId, getServiceIdFromScopes, type OAuthProvider, -} from '@/lib/oauth' +} from '@/lib/oauth/oauth' import { saveToStorage } from '@/stores/workflows/persistence' import { OAuthRequiredModal } from '../../credential-selector/components/oauth-required-modal' diff --git a/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/tool-input/tool-input.tsx b/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/tool-input/tool-input.tsx index f2e5136b051..a38bff4e4f3 100644 --- a/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/tool-input/tool-input.tsx +++ b/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/tool-input/tool-input.tsx @@ -11,7 +11,7 @@ import { } from '@/components/ui/select' import { Toggle } from '@/components/ui/toggle' import { Tooltip, TooltipContent, TooltipTrigger } from '@/components/ui/tooltip' -import type { OAuthProvider } from '@/lib/oauth' +import type { OAuthProvider } from '@/lib/oauth/oauth' import { cn } from '@/lib/utils' import { getAllBlocks } from '@/blocks' import { supportsToolUsageControl } from '@/providers/model-capabilities' diff --git a/apps/sim/app/w/components/sidebar/components/settings-modal/components/credentials/credentials.tsx b/apps/sim/app/w/components/sidebar/components/settings-modal/components/credentials/credentials.tsx index dab2d04b6ea..c1d5637c8a7 100644 --- a/apps/sim/app/w/components/sidebar/components/settings-modal/components/credentials/credentials.tsx +++ b/apps/sim/app/w/components/sidebar/components/settings-modal/components/credentials/credentials.tsx @@ -9,7 +9,7 @@ import { Input } from '@/components/ui/input' import { Skeleton } from '@/components/ui/skeleton' import { client, useSession } from '@/lib/auth-client' import { createLogger } from '@/lib/logs/console-logger' -import { OAUTH_PROVIDERS, type OAuthServiceConfig } from '@/lib/oauth' +import { OAUTH_PROVIDERS, type OAuthServiceConfig } from '@/lib/oauth/oauth' import { cn } from '@/lib/utils' import { loadFromStorage, removeFromStorage, saveToStorage } from '@/stores/workflows/persistence' diff --git a/apps/sim/lib/oauth/oauth.test.ts b/apps/sim/lib/oauth/oauth.test.ts new file mode 100644 index 00000000000..16b37d428fa --- /dev/null +++ b/apps/sim/lib/oauth/oauth.test.ts @@ -0,0 +1,368 @@ +import { afterEach, beforeEach, describe, expect, it, type Mock, vi } from 'vitest' + +vi.mock('../env', () => ({ + env: { + GOOGLE_CLIENT_ID: 'google_client_id', + GOOGLE_CLIENT_SECRET: 'google_client_secret', + GITHUB_CLIENT_ID: 'github_client_id', + GITHUB_CLIENT_SECRET: 'github_client_secret', + X_CLIENT_ID: 'x_client_id', + X_CLIENT_SECRET: 'x_client_secret', + CONFLUENCE_CLIENT_ID: 'confluence_client_id', + CONFLUENCE_CLIENT_SECRET: 'confluence_client_secret', + JIRA_CLIENT_ID: 'jira_client_id', + JIRA_CLIENT_SECRET: 'jira_client_secret', + AIRTABLE_CLIENT_ID: 'airtable_client_id', + AIRTABLE_CLIENT_SECRET: 'airtable_client_secret', + SUPABASE_CLIENT_ID: 'supabase_client_id', + SUPABASE_CLIENT_SECRET: 'supabase_client_secret', + NOTION_CLIENT_ID: 'notion_client_id', + NOTION_CLIENT_SECRET: 'notion_client_secret', + DISCORD_CLIENT_ID: 'discord_client_id', + DISCORD_CLIENT_SECRET: 'discord_client_secret', + MICROSOFT_CLIENT_ID: 'microsoft_client_id', + MICROSOFT_CLIENT_SECRET: 'microsoft_client_secret', + LINEAR_CLIENT_ID: 'linear_client_id', + LINEAR_CLIENT_SECRET: 'linear_client_secret', + SLACK_CLIENT_ID: 'slack_client_id', + SLACK_CLIENT_SECRET: 'slack_client_secret', + }, +})) + +vi.mock('@/lib/logs/console-logger', () => ({ + createLogger: vi.fn().mockReturnValue({ + info: vi.fn(), + error: vi.fn(), + warn: vi.fn(), + debug: vi.fn(), + }), +})) + +const mockFetch = vi.fn() +global.fetch = mockFetch + +import { refreshOAuthToken } from './oauth' + +describe('OAuth Token Refresh', () => { + beforeEach(() => { + vi.clearAllMocks() + mockFetch.mockResolvedValue({ + ok: true, + json: async () => ({ + access_token: 'new_access_token', + expires_in: 3600, + refresh_token: 'new_refresh_token', + }), + }) + }) + + afterEach(() => { + vi.clearAllMocks() + }) + + describe('Basic Auth Providers', () => { + const basicAuthProviders = [ + { + name: 'Airtable', + providerId: 'airtable', + endpoint: 'https://airtable.com/oauth2/v1/token', + }, + { name: 'X (Twitter)', providerId: 'x', endpoint: 'https://api.x.com/2/oauth2/token' }, + { + name: 'Confluence', + providerId: 'confluence', + endpoint: 'https://auth.atlassian.com/oauth/token', + }, + { name: 'Jira', providerId: 'jira', endpoint: 'https://auth.atlassian.com/oauth/token' }, + { + name: 'Discord', + providerId: 'discord', + endpoint: 'https://discord.com/api/v10/oauth2/token', + }, + { name: 'Linear', providerId: 'linear', endpoint: 'https://api.linear.app/oauth/token' }, + ] + + basicAuthProviders.forEach(({ name, providerId, endpoint }) => { + it(`should send ${name} request with Basic Auth header and no credentials in body`, async () => { + const refreshToken = 'test_refresh_token' + + await refreshOAuthToken(providerId, refreshToken) + + expect(mockFetch).toHaveBeenCalledWith( + endpoint, + expect.objectContaining({ + method: 'POST', + headers: expect.objectContaining({ + 'Content-Type': 'application/x-www-form-urlencoded', + Authorization: expect.stringMatching(/^Basic /), + }), + body: expect.any(String), + }) + ) + + const [, requestOptions] = (mockFetch as Mock).mock.calls[0] + + // Verify Basic Auth header + const authHeader = requestOptions.headers.Authorization + expect(authHeader).toMatch(/^Basic /) + + // Decode and verify credentials + const base64Credentials = authHeader.replace('Basic ', '') + const credentials = Buffer.from(base64Credentials, 'base64').toString('utf-8') + const [clientId, clientSecret] = credentials.split(':') + + expect(clientId).toBe(`${providerId}_client_id`) + expect(clientSecret).toBe(`${providerId}_client_secret`) + + // Verify body contains only required parameters + const bodyParams = new URLSearchParams(requestOptions.body) + const bodyKeys = Array.from(bodyParams.keys()) + + expect(bodyKeys).toEqual(['grant_type', 'refresh_token']) + expect(bodyParams.get('grant_type')).toBe('refresh_token') + expect(bodyParams.get('refresh_token')).toBe(refreshToken) + + // Verify client credentials are NOT in the body + expect(bodyParams.get('client_id')).toBeNull() + expect(bodyParams.get('client_secret')).toBeNull() + }) + }) + }) + + describe('Body Credential Providers', () => { + const bodyCredentialProviders = [ + { name: 'Google', providerId: 'google', endpoint: 'https://oauth2.googleapis.com/token' }, + { + name: 'GitHub', + providerId: 'github', + endpoint: 'https://github.com/login/oauth/access_token', + }, + { + name: 'Microsoft', + providerId: 'microsoft', + endpoint: 'https://login.microsoftonline.com/common/oauth2/v2.0/token', + }, + { + name: 'Outlook', + providerId: 'outlook', + endpoint: 'https://login.microsoftonline.com/common/oauth2/v2.0/token', + }, + { + name: 'Supabase', + providerId: 'supabase', + endpoint: 'https://api.supabase.com/v1/oauth/token', + }, + { name: 'Notion', providerId: 'notion', endpoint: 'https://api.notion.com/v1/oauth/token' }, + { name: 'Slack', providerId: 'slack', endpoint: 'https://slack.com/api/oauth.v2.access' }, + ] + + bodyCredentialProviders.forEach(({ name, providerId, endpoint }) => { + it(`should send ${name} request with credentials in body and no Basic Auth`, async () => { + const refreshToken = 'test_refresh_token' + + await refreshOAuthToken(providerId, refreshToken) + + expect(mockFetch).toHaveBeenCalledWith( + endpoint, + expect.objectContaining({ + method: 'POST', + headers: expect.objectContaining({ + 'Content-Type': 'application/x-www-form-urlencoded', + }), + body: expect.any(String), + }) + ) + + const [, requestOptions] = (mockFetch as Mock).mock.calls[0] + + // Verify no Basic Auth header + expect(requestOptions.headers.Authorization).toBeUndefined() + + // Verify body contains all required parameters + const bodyParams = new URLSearchParams(requestOptions.body) + const bodyKeys = Array.from(bodyParams.keys()).sort() + + expect(bodyKeys).toEqual(['client_id', 'client_secret', 'grant_type', 'refresh_token']) + expect(bodyParams.get('grant_type')).toBe('refresh_token') + expect(bodyParams.get('refresh_token')).toBe(refreshToken) + + // Verify client credentials are in the body + const expectedClientId = + providerId === 'outlook' ? 'microsoft_client_id' : `${providerId}_client_id` + const expectedClientSecret = + providerId === 'outlook' ? 'microsoft_client_secret' : `${providerId}_client_secret` + + expect(bodyParams.get('client_id')).toBe(expectedClientId) + expect(bodyParams.get('client_secret')).toBe(expectedClientSecret) + }) + }) + + it('should include Accept header for GitHub requests', async () => { + const refreshToken = 'test_refresh_token' + + await refreshOAuthToken('github', refreshToken) + + const [, requestOptions] = (mockFetch as Mock).mock.calls[0] + expect(requestOptions.headers.Accept).toBe('application/json') + }) + }) + + describe('Error Handling', () => { + it('should return null for unsupported provider', async () => { + const refreshToken = 'test_refresh_token' + + const result = await refreshOAuthToken('unsupported', refreshToken) + + expect(result).toBeNull() + }) + + it('should return null for API error responses', async () => { + const refreshToken = 'test_refresh_token' + + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 400, + text: async () => + JSON.stringify({ + error: 'invalid_request', + error_description: 'Invalid refresh token', + }), + }) + + const result = await refreshOAuthToken('google', refreshToken) + + expect(result).toBeNull() + }) + + it('should return null for network errors', async () => { + const refreshToken = 'test_refresh_token' + + mockFetch.mockRejectedValueOnce(new Error('Network error')) + + const result = await refreshOAuthToken('google', refreshToken) + + expect(result).toBeNull() + }) + }) + + describe('Token Response Handling', () => { + it('should handle providers that return new refresh tokens', async () => { + const refreshToken = 'old_refresh_token' + const newRefreshToken = 'new_refresh_token' + + mockFetch.mockResolvedValueOnce({ + ok: true, + json: async () => ({ + access_token: 'new_access_token', + expires_in: 3600, + refresh_token: newRefreshToken, + }), + }) + + const result = await refreshOAuthToken('airtable', refreshToken) + + expect(result).toEqual({ + accessToken: 'new_access_token', + expiresIn: 3600, + refreshToken: newRefreshToken, + }) + }) + + it('should use original refresh token when new one is not provided', async () => { + const refreshToken = 'original_refresh_token' + + mockFetch.mockResolvedValueOnce({ + ok: true, + json: async () => ({ + access_token: 'new_access_token', + expires_in: 3600, + // No refresh_token in response + }), + }) + + const result = await refreshOAuthToken('google', refreshToken) + + expect(result).toEqual({ + accessToken: 'new_access_token', + expiresIn: 3600, + refreshToken: refreshToken, + }) + }) + + it('should return null when access token is missing', async () => { + const refreshToken = 'test_refresh_token' + + mockFetch.mockResolvedValueOnce({ + ok: true, + json: async () => ({ + expires_in: 3600, + // No access_token in response + }), + }) + + const result = await refreshOAuthToken('google', refreshToken) + + expect(result).toBeNull() + }) + + it('should use default expiration when not provided', async () => { + const refreshToken = 'test_refresh_token' + + mockFetch.mockResolvedValueOnce({ + ok: true, + json: async () => ({ + access_token: 'new_access_token', + // No expires_in in response + }), + }) + + const result = await refreshOAuthToken('google', refreshToken) + + expect(result).toEqual({ + accessToken: 'new_access_token', + expiresIn: 3600, + refreshToken: refreshToken, + }) + }) + }) + + describe('Airtable Tests', () => { + it('should not have duplicate client ID issue', async () => { + const refreshToken = 'test_refresh_token' + + await refreshOAuthToken('airtable', refreshToken) + + const [, requestOptions] = (mockFetch as Mock).mock.calls[0] + + // Verify Authorization header is present and correct + expect(requestOptions.headers.Authorization).toMatch(/^Basic /) + + // Parse body and verify client credentials are NOT present + const bodyParams = new URLSearchParams(requestOptions.body) + expect(bodyParams.get('client_id')).toBeNull() + expect(bodyParams.get('client_secret')).toBeNull() + + // Verify only expected parameters are present + const bodyKeys = Array.from(bodyParams.keys()) + expect(bodyKeys).toEqual(['grant_type', 'refresh_token']) + }) + + it('should handle Airtable refresh token rotation', async () => { + const refreshToken = 'old_refresh_token' + const newRefreshToken = 'rotated_refresh_token' + + mockFetch.mockResolvedValueOnce({ + ok: true, + json: async () => ({ + access_token: 'new_access_token', + expires_in: 3600, + refresh_token: newRefreshToken, + }), + }) + + const result = await refreshOAuthToken('airtable', refreshToken) + + expect(result?.refreshToken).toBe(newRefreshToken) + }) + }) +}) diff --git a/apps/sim/lib/oauth.ts b/apps/sim/lib/oauth/oauth.ts similarity index 73% rename from apps/sim/lib/oauth.ts rename to apps/sim/lib/oauth/oauth.ts index 1f6c0c5803f..1513aa9369e 100644 --- a/apps/sim/lib/oauth.ts +++ b/apps/sim/lib/oauth/oauth.ts @@ -22,7 +22,7 @@ import { xIcon, } from '@/components/icons' import { createLogger } from '@/lib/logs/console-logger' -import { env } from './env' +import { env } from '../env' const logger = createLogger('OAuth') @@ -520,6 +520,216 @@ export function parseProvider(provider: OAuthProvider): ProviderConfig { } } +interface ProviderAuthConfig { + tokenEndpoint: string + clientId: string + clientSecret: string + useBasicAuth: boolean + additionalHeaders?: Record + supportsRefreshTokenRotation?: boolean +} + +/** + * Get OAuth provider configuration for token refresh + */ +function getProviderAuthConfig(provider: string): ProviderAuthConfig { + const getCredentials = (clientId: string | undefined, clientSecret: string | undefined) => { + if (!clientId || !clientSecret) { + throw new Error(`Missing client credentials for provider: ${provider}`) + } + return { clientId, clientSecret } + } + + switch (provider) { + case 'google': { + const { clientId, clientSecret } = getCredentials( + env.GOOGLE_CLIENT_ID, + env.GOOGLE_CLIENT_SECRET + ) + return { + tokenEndpoint: 'https://oauth2.googleapis.com/token', + clientId, + clientSecret, + useBasicAuth: false, + } + } + case 'github': { + const { clientId, clientSecret } = getCredentials( + env.GITHUB_CLIENT_ID, + env.GITHUB_CLIENT_SECRET + ) + return { + tokenEndpoint: 'https://github.com/login/oauth/access_token', + clientId, + clientSecret, + useBasicAuth: false, + additionalHeaders: { Accept: 'application/json' }, + } + } + case 'x': { + const { clientId, clientSecret } = getCredentials(env.X_CLIENT_ID, env.X_CLIENT_SECRET) + return { + tokenEndpoint: 'https://api.x.com/2/oauth2/token', + clientId, + clientSecret, + useBasicAuth: true, + } + } + case 'confluence': { + const { clientId, clientSecret } = getCredentials( + env.CONFLUENCE_CLIENT_ID, + env.CONFLUENCE_CLIENT_SECRET + ) + return { + tokenEndpoint: 'https://auth.atlassian.com/oauth/token', + clientId, + clientSecret, + useBasicAuth: true, + supportsRefreshTokenRotation: true, + } + } + case 'jira': { + const { clientId, clientSecret } = getCredentials(env.JIRA_CLIENT_ID, env.JIRA_CLIENT_SECRET) + return { + tokenEndpoint: 'https://auth.atlassian.com/oauth/token', + clientId, + clientSecret, + useBasicAuth: true, + supportsRefreshTokenRotation: true, + } + } + case 'airtable': { + const { clientId, clientSecret } = getCredentials( + env.AIRTABLE_CLIENT_ID, + env.AIRTABLE_CLIENT_SECRET + ) + return { + tokenEndpoint: 'https://airtable.com/oauth2/v1/token', + clientId, + clientSecret, + useBasicAuth: true, + supportsRefreshTokenRotation: true, + } + } + case 'supabase': { + const { clientId, clientSecret } = getCredentials( + env.SUPABASE_CLIENT_ID, + env.SUPABASE_CLIENT_SECRET + ) + return { + tokenEndpoint: 'https://api.supabase.com/v1/oauth/token', + clientId, + clientSecret, + useBasicAuth: false, + } + } + case 'notion': { + const { clientId, clientSecret } = getCredentials( + env.NOTION_CLIENT_ID, + env.NOTION_CLIENT_SECRET + ) + return { + tokenEndpoint: 'https://api.notion.com/v1/oauth/token', + clientId, + clientSecret, + useBasicAuth: false, + } + } + case 'discord': { + const { clientId, clientSecret } = getCredentials( + env.DISCORD_CLIENT_ID, + env.DISCORD_CLIENT_SECRET + ) + return { + tokenEndpoint: 'https://discord.com/api/v10/oauth2/token', + clientId, + clientSecret, + useBasicAuth: true, + } + } + case 'microsoft': { + const { clientId, clientSecret } = getCredentials( + env.MICROSOFT_CLIENT_ID, + env.MICROSOFT_CLIENT_SECRET + ) + return { + tokenEndpoint: 'https://login.microsoftonline.com/common/oauth2/v2.0/token', + clientId, + clientSecret, + useBasicAuth: false, + } + } + case 'outlook': { + const { clientId, clientSecret } = getCredentials( + env.MICROSOFT_CLIENT_ID, + env.MICROSOFT_CLIENT_SECRET + ) + return { + tokenEndpoint: 'https://login.microsoftonline.com/common/oauth2/v2.0/token', + clientId, + clientSecret, + useBasicAuth: false, + } + } + case 'linear': { + const { clientId, clientSecret } = getCredentials( + env.LINEAR_CLIENT_ID, + env.LINEAR_CLIENT_SECRET + ) + return { + tokenEndpoint: 'https://api.linear.app/oauth/token', + clientId, + clientSecret, + useBasicAuth: true, + } + } + case 'slack': { + const { clientId, clientSecret } = getCredentials( + env.SLACK_CLIENT_ID, + env.SLACK_CLIENT_SECRET + ) + return { + tokenEndpoint: 'https://slack.com/api/oauth.v2.access', + clientId, + clientSecret, + useBasicAuth: false, + } + } + default: + throw new Error(`Unsupported provider: ${provider}`) + } +} + +/** + * Build the authentication request headers and body for OAuth token refresh + */ +function buildAuthRequest( + config: ProviderAuthConfig, + refreshToken: string +): { headers: Record; bodyParams: Record } { + const headers: Record = { + 'Content-Type': 'application/x-www-form-urlencoded', + ...config.additionalHeaders, + } + + const bodyParams: Record = { + grant_type: 'refresh_token', + refresh_token: refreshToken, + } + + if (config.useBasicAuth) { + // Use Basic Authentication - credentials in Authorization header only + const basicAuth = Buffer.from(`${config.clientId}:${config.clientSecret}`).toString('base64') + headers.Authorization = `Basic ${basicAuth}` + } else { + // Use body credentials - include client credentials in request body + bodyParams.client_id = config.clientId + bodyParams.client_secret = config.clientSecret + } + + return { headers, bodyParams } +} + /** * Refresh an OAuth token * This is a server-side utility function to refresh OAuth tokens @@ -535,152 +745,17 @@ export async function refreshOAuthToken( // Get the provider from the providerId (e.g., 'google-drive' -> 'google') const provider = providerId.split('-')[0] - // Determine the token endpoint based on the provider - let tokenEndpoint: string - let clientId: string | undefined - let clientSecret: string | undefined - let useBasicAuth = false - - switch (provider) { - case 'google': - tokenEndpoint = 'https://oauth2.googleapis.com/token' - clientId = env.GOOGLE_CLIENT_ID - clientSecret = env.GOOGLE_CLIENT_SECRET - break - case 'github': - tokenEndpoint = 'https://github.com/login/oauth/access_token' - clientId = env.GITHUB_CLIENT_ID - clientSecret = env.GITHUB_CLIENT_SECRET - break - case 'x': - tokenEndpoint = 'https://api.x.com/2/oauth2/token' - clientId = env.X_CLIENT_ID - clientSecret = env.X_CLIENT_SECRET - useBasicAuth = true - break - case 'confluence': - tokenEndpoint = 'https://auth.atlassian.com/oauth/token' - clientId = env.CONFLUENCE_CLIENT_ID - clientSecret = env.CONFLUENCE_CLIENT_SECRET - useBasicAuth = true - break - case 'jira': - tokenEndpoint = 'https://auth.atlassian.com/oauth/token' - clientId = env.JIRA_CLIENT_ID - clientSecret = env.JIRA_CLIENT_SECRET - useBasicAuth = true - break - case 'airtable': - tokenEndpoint = 'https://airtable.com/oauth2/v1/token' - clientId = env.AIRTABLE_CLIENT_ID - clientSecret = env.AIRTABLE_CLIENT_SECRET - useBasicAuth = true - break - case 'supabase': - tokenEndpoint = 'https://api.supabase.com/v1/oauth/token' - clientId = env.SUPABASE_CLIENT_ID - clientSecret = env.SUPABASE_CLIENT_SECRET - break - case 'notion': - tokenEndpoint = 'https://api.notion.com/v1/oauth/token' - clientId = env.NOTION_CLIENT_ID - clientSecret = env.NOTION_CLIENT_SECRET - break - case 'discord': - tokenEndpoint = 'https://discord.com/api/v10/oauth2/token' - clientId = env.DISCORD_CLIENT_ID - clientSecret = env.DISCORD_CLIENT_SECRET - useBasicAuth = true - break - case 'microsoft': - tokenEndpoint = 'https://login.microsoftonline.com/common/oauth2/v2.0/token' - clientId = env.MICROSOFT_CLIENT_ID - clientSecret = env.MICROSOFT_CLIENT_SECRET - break - case 'outlook': - tokenEndpoint = 'https://login.microsoftonline.com/common/oauth2/v2.0/token' - clientId = env.MICROSOFT_CLIENT_ID - clientSecret = env.MICROSOFT_CLIENT_SECRET - break - case 'linear': - tokenEndpoint = 'https://api.linear.app/oauth/token' - clientId = env.LINEAR_CLIENT_ID - clientSecret = env.LINEAR_CLIENT_SECRET - useBasicAuth = true - break - case 'slack': - tokenEndpoint = 'https://slack.com/api/oauth.v2.access' - clientId = env.SLACK_CLIENT_ID - clientSecret = env.SLACK_CLIENT_SECRET - break - default: - throw new Error(`Unsupported provider: ${provider}`) - } + // Get provider configuration + const config = getProviderAuthConfig(provider) - if (!clientId || !clientSecret) { - throw new Error(`Missing client credentials for provider: ${provider}`) - } - - // Prepare request headers and body - const headers: Record = { - 'Content-Type': 'application/x-www-form-urlencoded', - ...(provider === 'github' && { - Accept: 'application/json', - }), - } - - // Prepare request body - const bodyParams: Record = { - grant_type: 'refresh_token', - refresh_token: refreshToken, - } - - // For Airtable, check if we have both client ID and secret - if (provider === 'airtable') { - // Airtable requires Basic Auth with client ID and secret in the Authorization header - // Do not include client_id or client_secret in the body when using Basic Auth - if (clientId && clientSecret) { - const basicAuth = Buffer.from(`${clientId}:${clientSecret}`).toString('base64') - headers.Authorization = `Basic ${basicAuth}` - - // Make sure to include refresh_token in body params but not client_id/client_secret - // This ensures we're not sending credentials in both header and body - bodyParams.client_id = undefined - bodyParams.client_secret = undefined - } else { - throw new Error('Both client ID and client secret are required for Airtable OAuth') - } - } else if ( - provider === 'x' || - provider === 'confluence' || - provider === 'jira' || - provider === 'discord' - ) { - const authString = `${clientId}:${clientSecret}` - const basicAuth = Buffer.from(authString).toString('base64') - headers.Authorization = `Basic ${basicAuth}` - - // When using Basic Auth, don't include client_id in body - bodyParams.client_id = undefined - bodyParams.client_secret = undefined - } else { - // For other providers, use the general approach - if (useBasicAuth) { - const basicAuth = Buffer.from(`${clientId}:${clientSecret}`).toString('base64') - headers.Authorization = `Basic ${basicAuth}` - } - - if (!useBasicAuth) { - bodyParams.client_id = clientId - bodyParams.client_secret = clientSecret - } - } + // Build authentication request + const { headers, bodyParams } = buildAuthRequest(config, refreshToken) // Refresh the token - const response = await fetch(tokenEndpoint, { + const response = await fetch(config.tokenEndpoint, { method: 'POST', headers, - body: new URLSearchParams(bodyParams as Record).toString(), + body: new URLSearchParams(bodyParams).toString(), }) if (!response.ok) { @@ -708,16 +783,9 @@ export async function refreshOAuthToken( // Extract token and expiration (different providers may use different field names) const accessToken = data.access_token - // For Airtable, also capture the new refresh token if provided - // Airtable may rotate refresh tokens + // Handle refresh token rotation for providers that support it let newRefreshToken = null - if (provider === 'airtable' && data.refresh_token) { - newRefreshToken = data.refresh_token - logger.info('Received new refresh token from Airtable') - } - - // For Confluence and Jira, check if we got a new refresh token - if ((provider === 'confluence' || provider === 'jira') && data.refresh_token) { + if (config.supportsRefreshTokenRotation && data.refresh_token) { newRefreshToken = data.refresh_token logger.info(`Received new refresh token from ${provider}`) } diff --git a/apps/sim/tools/types.ts b/apps/sim/tools/types.ts index a3a777793ba..1a90514e378 100644 --- a/apps/sim/tools/types.ts +++ b/apps/sim/tools/types.ts @@ -1,4 +1,4 @@ -import type { OAuthService } from '@/lib/oauth' +import type { OAuthService } from '@/lib/oauth/oauth' export type HttpMethod = 'GET' | 'POST' | 'PUT' | 'DELETE' | 'PATCH'