diff --git a/apps/sim/app/api/mcp/serve/[serverId]/route.test.ts b/apps/sim/app/api/mcp/serve/[serverId]/route.test.ts index bd9b10d4d3d..cd9c5523231 100644 --- a/apps/sim/app/api/mcp/serve/[serverId]/route.test.ts +++ b/apps/sim/app/api/mcp/serve/[serverId]/route.test.ts @@ -197,4 +197,51 @@ describe('MCP Serve Route', () => { expect(headers['X-API-Key']).toBeUndefined() expect(mockGenerateInternalToken).toHaveBeenCalledWith('user-1') }) + + describe('initialize protocol version negotiation', () => { + async function callInitialize(protocolVersion?: string) { + dbChainMockFns.limit.mockResolvedValueOnce([ + { + id: 'server-1', + name: 'Public Server', + workspaceId: 'ws-1', + isPublic: true, + createdBy: 'owner-1', + }, + ]) + const params: Record = { + capabilities: {}, + clientInfo: { name: 'test', version: '1.0.0' }, + } + if (protocolVersion !== undefined) params.protocolVersion = protocolVersion + const req = new NextRequest('http://localhost:3000/api/mcp/serve/server-1', { + method: 'POST', + body: JSON.stringify({ jsonrpc: '2.0', id: 1, method: 'initialize', params }), + }) + const res = await POST(req, { params: Promise.resolve({ serverId: 'server-1' }) }) + return res.json() as Promise<{ result: { protocolVersion: string } }> + } + + it('echoes a supported client protocolVersion (2025-06-18)', async () => { + const body = await callInitialize('2025-06-18') + expect(body.result.protocolVersion).toBe('2025-06-18') + }) + + it('echoes a supported client protocolVersion (2024-11-05)', async () => { + const body = await callInitialize('2024-11-05') + expect(body.result.protocolVersion).toBe('2024-11-05') + }) + + it('falls back to SDK latest when client requests unknown version', async () => { + const { LATEST_PROTOCOL_VERSION } = await import('@modelcontextprotocol/sdk/types.js') + const body = await callInitialize('2099-01-01') + expect(body.result.protocolVersion).toBe(LATEST_PROTOCOL_VERSION) + }) + + it('falls back to SDK latest when client omits protocolVersion', async () => { + const { LATEST_PROTOCOL_VERSION } = await import('@modelcontextprotocol/sdk/types.js') + const body = await callInitialize(undefined) + expect(body.result.protocolVersion).toBe(LATEST_PROTOCOL_VERSION) + }) + }) }) diff --git a/apps/sim/app/api/mcp/serve/[serverId]/route.ts b/apps/sim/app/api/mcp/serve/[serverId]/route.ts index 702c9a57cf4..d876dcd0ef2 100644 --- a/apps/sim/app/api/mcp/serve/[serverId]/route.ts +++ b/apps/sim/app/api/mcp/serve/[serverId]/route.ts @@ -11,8 +11,10 @@ import { type JSONRPCError, type JSONRPCMessage, type JSONRPCResultResponse, + LATEST_PROTOCOL_VERSION, type ListToolsResult, type RequestId, + SUPPORTED_PROTOCOL_VERSIONS, type Tool, } from '@modelcontextprotocol/sdk/types.js' import { db } from '@sim/db' @@ -36,6 +38,17 @@ import { getUserEntityPermissions } from '@/lib/workspaces/permissions/utils' const logger = createLogger('WorkflowMcpServeAPI') +function negotiateProtocolVersion(rpcParams: unknown): string { + const requested = + rpcParams && typeof rpcParams === 'object' && 'protocolVersion' in rpcParams + ? (rpcParams as { protocolVersion?: unknown }).protocolVersion + : undefined + if (typeof requested === 'string' && SUPPORTED_PROTOCOL_VERSIONS.includes(requested)) { + return requested + } + return LATEST_PROTOCOL_VERSION +} + export const dynamic = 'force-dynamic' interface RouteParams { @@ -214,7 +227,7 @@ export const POST = withRouteHandler( switch (method) { case 'initialize': { const result: InitializeResult = { - protocolVersion: '2024-11-05', + protocolVersion: negotiateProtocolVersion(rpcParams), capabilities: { tools: {} }, serverInfo: { name: server.name, version: '1.0.0' }, } diff --git a/apps/sim/hooks/queries/mcp.ts b/apps/sim/hooks/queries/mcp.ts index 9f483a4fef7..b87ec642ee0 100644 --- a/apps/sim/hooks/queries/mcp.ts +++ b/apps/sim/hooks/queries/mcp.ts @@ -57,9 +57,7 @@ export const mcpKeys = { export type { McpServer } -/** - * Input for creating/updating an MCP server (distinct from McpServerConfig in types.ts) - */ +/** Wire shape for create/update; distinct from runtime McpServerConfig. */ export interface McpServerInput { name: string transport: McpTransport @@ -265,11 +263,7 @@ export function useCreateMcpServer() { }) } -/** - * Result of `useStartMcpOauth`. When `popup` is set, the caller should wait - * for it to close (or for the `mcp-oauth` postMessage) before clearing any - * "connecting" UI state. - */ +/** On `redirect`, the caller must wait for `popup.closed` or the `mcp-oauth` postMessage. */ export type StartMcpOauthMutationResult = | { status: 'redirect'; popup: Window } | { status: 'already_authorized' } @@ -464,13 +458,7 @@ const sseConnections: Map = ((globalThis as Record)[SSE_KEY] as Map) ?? ((globalThis as Record)[SSE_KEY] = new Map()) -/** - * Subscribe to MCP tool-change SSE events for a workspace. - * On each `tools_changed` event, invalidates the relevant React Query caches - * so the UI refreshes automatically. - * - * Invalidates both external MCP server keys and workflow MCP server keys. - */ +/** Subscribes to `tools_changed` SSE events and invalidates the affected query keys. */ export function useMcpToolsEvents(workspaceId: string) { const queryClient = useQueryClient() @@ -598,17 +586,11 @@ export function useMcpServerTest() { } } -/** - * Fetch allowed MCP domains (admin-configured allowlist) - */ async function fetchAllowedMcpDomains(signal?: AbortSignal): Promise { const data = await requestJson(getAllowedMcpDomainsContract, { signal }) return data.allowedMcpDomains ?? null } -/** - * Hook to fetch allowed MCP domains - */ export function useAllowedMcpDomains() { return useQuery({ queryKey: mcpKeys.allowedDomains(), diff --git a/apps/sim/lib/mcp/client.ts b/apps/sim/lib/mcp/client.ts index 9f3c36d00a5..ca2b26724fa 100644 --- a/apps/sim/lib/mcp/client.ts +++ b/apps/sim/lib/mcp/client.ts @@ -1,18 +1,10 @@ -/** - * MCP (Model Context Protocol) Client - * - * Implements the client side of MCP protocol with support for: - * - Streamable HTTP transport (MCP 2025-06-18) - * - Tool execution and discovery - * - Session management and protocol version negotiation - * - Custom security/consent layer - */ - import { UnauthorizedError } from '@modelcontextprotocol/sdk/client/auth.js' import { Client } from '@modelcontextprotocol/sdk/client/index.js' import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js' import { + LATEST_PROTOCOL_VERSION, type ListToolsResult, + SUPPORTED_PROTOCOL_VERSIONS, type Tool, ToolListChangedNotificationSchema, } from '@modelcontextprotocol/sdk/types.js' @@ -50,12 +42,6 @@ export class McpClient { private authProvider?: McpClientOptions['authProvider'] private isConnected = false - private static readonly SUPPORTED_VERSIONS = [ - '2025-06-18', // Latest stable with elicitation and OAuth 2.1 - '2025-03-26', // Streamable HTTP support - '2024-11-05', // Initial stable release - ] - constructor(options: McpClientOptions) { this.config = options.config this.securityPolicy = options.securityPolicy ?? { @@ -135,9 +121,6 @@ export class McpClient { } } - /** - * Disconnect from MCP server - */ async disconnect(): Promise { logger.info(`Disconnecting from MCP server: ${this.config.name}`) @@ -152,16 +135,10 @@ export class McpClient { logger.info(`Disconnected from MCP server: ${this.config.name}`) } - /** - * Get current connection status - */ getStatus(): McpConnectionStatus { return { ...this.connectionStatus } } - /** - * List all available tools from the server - */ async listTools(): Promise { if (!this.isConnected) { throw new McpConnectionError('Not connected to server', this.config.name) @@ -190,9 +167,6 @@ export class McpClient { } } - /** - * Execute a tool on the MCP server - */ async callTool(toolCall: McpToolCall): Promise { if (!this.isConnected) { throw new McpConnectionError('Not connected to server', this.config.name) @@ -237,10 +211,6 @@ export class McpClient { } } - /** - * Ping the server to check if it's still alive and responsive - * Per MCP spec: servers should respond to ping requests - */ async ping(): Promise<{ _meta?: Record }> { if (!this.isConnected) { throw new McpConnectionError('Not connected to server', this.config.name) @@ -257,18 +227,11 @@ export class McpClient { } } - /** - * Check if the server declared `capabilities.tools.listChanged: true` during initialization. - */ hasListChangedCapability(): boolean { return !!this.client.getServerCapabilities()?.tools?.listChanged } - /** - * Register a callback to be invoked when the underlying transport closes. - * Used by the connection manager for reconnection logic. - * Chains with the SDK's internal onclose handler so it still performs its cleanup. - */ + /** Chains with the SDK's internal onclose handler so its cleanup still runs. */ onClose(callback: () => void): void { const existingHandler = this.transport.onclose this.transport.onclose = () => { @@ -277,26 +240,17 @@ export class McpClient { } } - /** - * Get server configuration - */ getConfig(): McpServerConfig { return { ...this.config } } - /** - * Get version information for this client - */ static getVersionInfo(): McpVersionInfo { return { - supported: [...McpClient.SUPPORTED_VERSIONS], - preferred: McpClient.SUPPORTED_VERSIONS[0], + supported: [...SUPPORTED_PROTOCOL_VERSIONS], + preferred: LATEST_PROTOCOL_VERSION, } } - /** - * Get the negotiated protocol version for this connection - */ getNegotiatedVersion(): string | undefined { const serverVersion = this.client.getServerVersion() return typeof serverVersion === 'string' ? serverVersion : undefined @@ -306,9 +260,6 @@ export class McpClient { return this.transport.sessionId } - /** - * Request user consent for tool execution - */ async requestConsent(consentRequest: McpConsentRequest): Promise { if (!this.securityPolicy.requireConsent) { return { granted: true, auditId: `audit-${Date.now()}` } diff --git a/apps/sim/lib/mcp/oauth/storage.test.ts b/apps/sim/lib/mcp/oauth/storage.test.ts index 95c7ae853c0..61455b36135 100644 --- a/apps/sim/lib/mcp/oauth/storage.test.ts +++ b/apps/sim/lib/mcp/oauth/storage.test.ts @@ -11,11 +11,27 @@ import { } from '@sim/testing' import { beforeEach, describe, expect, it, vi } from 'vitest' +const { mockAcquireLock, mockReleaseLock, mockExtendLock } = vi.hoisted(() => ({ + mockAcquireLock: vi.fn(), + mockReleaseLock: vi.fn(), + mockExtendLock: vi.fn(), +})) + vi.mock('@sim/db', () => dbChainMock) vi.mock('@sim/db/schema', () => schemaMock) vi.mock('@/lib/core/security/encryption', () => encryptionMock) +vi.mock('@/lib/core/config/redis', () => ({ + acquireLock: mockAcquireLock, + releaseLock: mockReleaseLock, + extendLock: mockExtendLock, +})) -import { getOrCreateOauthRow, loadOauthRow, setOauthRowUser } from './storage' +import { + getOrCreateOauthRow, + loadOauthRow, + setOauthRowUser, + withMcpOauthRefreshLock, +} from './storage' describe('MCP OAuth storage', () => { beforeEach(() => { @@ -92,3 +108,149 @@ describe('MCP OAuth storage', () => { ) }) }) + +describe('withMcpOauthRefreshLock', () => { + beforeEach(() => { + vi.clearAllMocks() + mockAcquireLock.mockReset() + mockReleaseLock.mockReset() + mockExtendLock.mockReset() + mockReleaseLock.mockResolvedValue(true) + mockExtendLock.mockResolvedValue(true) + }) + + it('serializes concurrent in-process callers, each running its own fn()', async () => { + mockAcquireLock.mockResolvedValue(true) + let active = 0 + let maxActive = 0 + const fn = vi.fn(async () => { + active++ + maxActive = Math.max(maxActive, active) + await new Promise((r) => setTimeout(r, 1)) + active-- + return 'tokens' + }) + + const results = await Promise.all([ + withMcpOauthRefreshLock('row-serial', fn), + withMcpOauthRefreshLock('row-serial', fn), + withMcpOauthRefreshLock('row-serial', fn), + ]) + + expect(results).toEqual(['tokens', 'tokens', 'tokens']) + // Each caller gets its own fn() invocation — critical because fn() returns + // a stateful McpClient that can't be shared across consumers. + expect(fn).toHaveBeenCalledTimes(3) + // But never two at the same time within a process. + expect(maxActive).toBe(1) + expect(mockAcquireLock).toHaveBeenCalledTimes(3) + expect(mockReleaseLock).toHaveBeenCalledTimes(3) + }) + + it('serializes cross-process callers: follower polls until leader releases', async () => { + // First acquire fails (another process holds it), second succeeds. + mockAcquireLock.mockResolvedValueOnce(false).mockResolvedValueOnce(true) + const fn = vi.fn(async () => 'fresh') + + const result = await withMcpOauthRefreshLock('row-mutex', fn) + + expect(result).toBe('fresh') + expect(mockAcquireLock).toHaveBeenCalledTimes(2) + expect(fn).toHaveBeenCalledTimes(1) + }) + + it('falls open when Redis is unavailable on acquire', async () => { + mockAcquireLock.mockRejectedValueOnce(new Error('Redis connection refused')) + const fn = vi.fn(async () => 'uncoordinated') + + const result = await withMcpOauthRefreshLock('row-redis-down', fn) + + expect(result).toBe('uncoordinated') + expect(fn).toHaveBeenCalledTimes(1) + expect(mockReleaseLock).not.toHaveBeenCalled() + }) + + it('releases the lock even when fn throws', async () => { + mockAcquireLock.mockResolvedValue(true) + const fn = vi.fn(async () => { + throw new Error('refresh failed') + }) + + await expect(withMcpOauthRefreshLock('row-throws', fn)).rejects.toThrow('refresh failed') + + expect(mockReleaseLock).toHaveBeenCalledTimes(1) + }) + + it('does not surface releaseLock failures to the caller', async () => { + mockAcquireLock.mockResolvedValue(true) + mockReleaseLock.mockRejectedValueOnce(new Error('release failed')) + const fn = vi.fn(async () => 'value') + + const result = await withMcpOauthRefreshLock('row-release-fail', fn) + expect(result).toBe('value') + }) + + it('uses per-row lock keys so different rows do not serialize', async () => { + mockAcquireLock.mockResolvedValue(true) + const fn = vi.fn(async () => 'ok') + + await Promise.all([withMcpOauthRefreshLock('row-a', fn), withMcpOauthRefreshLock('row-b', fn)]) + + expect(mockAcquireLock).toHaveBeenCalledTimes(2) + const keys = mockAcquireLock.mock.calls.map((c) => c[0]) + expect(keys).toContain('mcp:oauth:refresh:row-a') + expect(keys).toContain('mcp:oauth:refresh:row-b') + }) + + it('throws when the lock is held longer than the max wait (does not race)', async () => { + vi.useFakeTimers() + try { + // Acquire always fails — another process holds the lock with watchdog extension. + mockAcquireLock.mockResolvedValue(false) + const fn = vi.fn(async () => 'should-not-run') + + const pending = withMcpOauthRefreshLock('row-deadline', fn) + // Attach the rejection expectation before draining so Vitest doesn't see + // an unhandled rejection while timers advance. + const assertion = expect(pending).rejects.toThrow(/held longer than/) + await vi.advanceTimersByTimeAsync(31_000) + await assertion + expect(fn).not.toHaveBeenCalled() + } finally { + vi.useRealTimers() + } + }) + + it('extends the lock TTL while fn() is running so long refreshes do not lose the lock', async () => { + vi.useFakeTimers() + try { + mockAcquireLock.mockResolvedValue(true) + let resolveFn: (v: string) => void + const fn = vi.fn( + () => + new Promise((resolve) => { + resolveFn = resolve + }) + ) + + const pending = withMcpOauthRefreshLock('row-watchdog', fn) + + // Advance time past two extend intervals (5s + 5s = 10s). + await vi.advanceTimersByTimeAsync(11_000) + expect(mockExtendLock.mock.calls.length).toBeGreaterThanOrEqual(2) + for (const call of mockExtendLock.mock.calls) { + expect(call[0]).toBe('mcp:oauth:refresh:row-watchdog') + } + + resolveFn!('done') + await expect(pending).resolves.toBe('done') + + // Watchdog must stop once fn() settles — no more extend calls. + const extendCallsAtFinish = mockExtendLock.mock.calls.length + await vi.advanceTimersByTimeAsync(20_000) + expect(mockExtendLock.mock.calls.length).toBe(extendCallsAtFinish) + } finally { + vi.useRealTimers() + } + }) +}) diff --git a/apps/sim/lib/mcp/oauth/storage.ts b/apps/sim/lib/mcp/oauth/storage.ts index ee6ae0143ff..aca0fbf5ec6 100644 --- a/apps/sim/lib/mcp/oauth/storage.ts +++ b/apps/sim/lib/mcp/oauth/storage.ts @@ -7,8 +7,10 @@ import { db } from '@sim/db' import { mcpServerOauth } from '@sim/db/schema' import { createLogger } from '@sim/logger' import { toError } from '@sim/utils/errors' -import { generateId } from '@sim/utils/id' +import { sleep } from '@sim/utils/helpers' +import { generateId, generateShortId } from '@sim/utils/id' import { and, eq, gt } from 'drizzle-orm' +import { acquireLock, extendLock, releaseLock } from '@/lib/core/config/redis' import { decryptSecret, encryptSecret } from '@/lib/core/security/encryption' const logger = createLogger('McpOauthStorage') @@ -227,23 +229,94 @@ export async function clearState(rowId: string): Promise { } /** - * Per-process serialization for an OAuth row. Refresh tokens rotate (RFC 6749 §6, - * MCP §2.3.3), so two concurrent refreshes against the same row would race and one - * would receive `invalid_grant`, wiping the credentials. We serialize SDK calls - * that may trigger a refresh on a per-row basis. + * Serialize OAuth row access across all callers, in-process AND across + * processes. Refresh tokens rotate (RFC 6749 §6, MCP §2.3.3), so two concurrent + * refreshes against the same row would race and one would receive + * `invalid_grant`, wiping credentials. + * + * Two-tier serialization (each caller runs its OWN `fn()` — callers consume + * `McpClient` instances that can't be shared, unlike a scalar access token): + * 1) In-process: per-row Promise chain. Concurrent callers queue; each + * runs `fn()` after the previous settles. + * 2) Cross-process: Redis mutex (`acquireLock` / `releaseLock`) with a TTL + * watchdog that periodically extends the lock while `fn()` runs, so + * long-running refreshes don't drop the lock and let another process + * race onto the same refresh. + * + * Falls open if Redis is unavailable — `acquireLock` no-ops, but in-process + * serialization still holds within a single Node process. */ -const refreshLocks = new Map>() +const REFRESH_LOCK_TTL_SEC = 15 +const REFRESH_LOCK_EXTEND_INTERVAL_MS = 5_000 +const REFRESH_POLL_INTERVAL_MS = 100 +const REFRESH_MAX_WAIT_MS = 30_000 + +const inflightChains = new Map>() export async function withMcpOauthRefreshLock(rowId: string, fn: () => Promise): Promise { - const prev = refreshLocks.get(rowId) ?? Promise.resolve() - // Wait for the predecessor to settle (success or failure), discard its - // value/error, then run fn. Each caller awaits its own fn's outcome — errors - // do not propagate across callers in the chain. - const next = prev.catch(() => undefined).then(() => fn()) - refreshLocks.set(rowId, next) + const lockKey = `mcp:oauth:refresh:${rowId}` + const prev = inflightChains.get(lockKey) ?? Promise.resolve() + const next = prev.catch(() => undefined).then(() => runWithRedisMutex(lockKey, rowId, fn)) + inflightChains.set(lockKey, next) const cleanup = () => { - if (refreshLocks.get(rowId) === next) refreshLocks.delete(rowId) + if (inflightChains.get(lockKey) === next) inflightChains.delete(lockKey) } next.then(cleanup, cleanup) - return next + return next as Promise +} + +async function runWithRedisMutex( + lockKey: string, + rowId: string, + fn: () => Promise +): Promise { + const ownerToken = generateShortId() + const deadline = Date.now() + REFRESH_MAX_WAIT_MS + + while (true) { + let acquired = false + try { + acquired = await acquireLock(lockKey, ownerToken, REFRESH_LOCK_TTL_SEC) + } catch (error) { + logger.warn('Redis unavailable, running OAuth flow uncoordinated', { + rowId, + error: toError(error).message, + }) + return fn() + } + + if (acquired) { + const watchdog = setInterval(() => { + extendLock(lockKey, ownerToken, REFRESH_LOCK_TTL_SEC).catch((error) => { + logger.warn('Refresh lock extend failed', { + rowId, + error: toError(error).message, + }) + }) + }, REFRESH_LOCK_EXTEND_INTERVAL_MS) + try { + return await fn() + } finally { + clearInterval(watchdog) + await releaseLock(lockKey, ownerToken).catch((error) => { + logger.warn('Refresh lock release failed (will expire via TTL)', { + rowId, + error: toError(error).message, + }) + }) + } + } + + if (Date.now() >= deadline) { + // Lock still held by another process AND its watchdog is keeping it + // alive — falling open would let us refresh concurrently and race the + // rotating refresh token. Throw and let the caller decide whether to + // retry; the Redis-down path remains the only branch that runs `fn()` + // uncoordinated (no coordination available there). + throw new Error( + `MCP OAuth refresh lock for ${rowId} held longer than ${REFRESH_MAX_WAIT_MS}ms` + ) + } + await sleep(REFRESH_POLL_INTERVAL_MS) + } } diff --git a/apps/sim/lib/mcp/pinned-fetch.test.ts b/apps/sim/lib/mcp/pinned-fetch.test.ts index 3237ae4fe44..8a4c27be0df 100644 --- a/apps/sim/lib/mcp/pinned-fetch.test.ts +++ b/apps/sim/lib/mcp/pinned-fetch.test.ts @@ -3,34 +3,41 @@ */ import { beforeEach, describe, expect, it, vi } from 'vitest' -const { mockAgent, mockCreatePinnedLookup, mockUndiciFetch, capturedAgentOptions } = vi.hoisted( - () => { +const { mockAgent, mockCreatePinnedLookup, mockUndiciFetch, capturedAgentOptions, agentCloses } = + vi.hoisted(() => { const capturedAgentOptions: unknown[] = [] + const agentCloses: unknown[] = [] class MockAgent { constructor(options: unknown) { capturedAgentOptions.push(options) } + close() { + agentCloses.push(this) + return Promise.resolve() + } } return { mockAgent: MockAgent, mockCreatePinnedLookup: vi.fn(), mockUndiciFetch: vi.fn(), capturedAgentOptions, + agentCloses, } - } -) + }) vi.mock('undici', () => ({ Agent: mockAgent, fetch: mockUndiciFetch })) vi.mock('@/lib/core/security/input-validation.server', () => ({ createPinnedLookup: mockCreatePinnedLookup, })) -import { createMcpPinnedFetch } from '@/lib/mcp/pinned-fetch' +import { __resetPinnedAgentsForTests, createMcpPinnedFetch } from '@/lib/mcp/pinned-fetch' describe('createMcpPinnedFetch', () => { beforeEach(() => { vi.clearAllMocks() capturedAgentOptions.length = 0 + agentCloses.length = 0 + __resetPinnedAgentsForTests() mockCreatePinnedLookup.mockReturnValue('pinned-lookup-fn') mockUndiciFetch.mockResolvedValue(new Response('ok')) }) @@ -73,7 +80,7 @@ describe('createMcpPinnedFetch', () => { expect(init.dispatcher).toBeInstanceOf(mockAgent) }) - it('reuses the same dispatcher across calls (one Agent per fetch instance)', async () => { + it('reuses the same dispatcher across calls within a fetch instance', async () => { const fetchLike = createMcpPinnedFetch('203.0.113.10') await fetchLike('https://example.com/a') await fetchLike('https://example.com/b') @@ -82,4 +89,39 @@ describe('createMcpPinnedFetch', () => { const d2 = (mockUndiciFetch.mock.calls[1][1] as { dispatcher: unknown }).dispatcher expect(d1).toBe(d2) }) + + it('pools agents by resolvedIP across createMcpPinnedFetch calls', async () => { + const a = createMcpPinnedFetch('203.0.113.10') + const b = createMcpPinnedFetch('203.0.113.10') + await a('https://example.com/a') + await b('https://example.com/b') + expect(capturedAgentOptions).toHaveLength(1) + const d1 = (mockUndiciFetch.mock.calls[0][1] as { dispatcher: unknown }).dispatcher + const d2 = (mockUndiciFetch.mock.calls[1][1] as { dispatcher: unknown }).dispatcher + expect(d1).toBe(d2) + }) + + it('creates separate agents for different resolved IPs', async () => { + const a = createMcpPinnedFetch('203.0.113.10') + const b = createMcpPinnedFetch('198.51.100.20') + await a('https://example.com/a') + await b('https://example.com/b') + expect(capturedAgentOptions).toHaveLength(2) + const d1 = (mockUndiciFetch.mock.calls[0][1] as { dispatcher: unknown }).dispatcher + const d2 = (mockUndiciFetch.mock.calls[1][1] as { dispatcher: unknown }).dispatcher + expect(d1).not.toBe(d2) + }) + + it('does not close evicted agents — captured closures keep working', async () => { + // Build an early closure whose agent will get evicted by later IPs. + const earlyClient = createMcpPinnedFetch('10.0.0.1') + // Fill the cache past its 64-entry limit so the early entry is evicted. + for (let i = 0; i < 64; i++) createMcpPinnedFetch(`10.1.${Math.floor(i / 256)}.${i % 256}`) + + // Eviction must NOT have closed any agents. + expect(agentCloses).toHaveLength(0) + // The early closure's captured dispatcher is still callable. + await earlyClient('https://example.com/still-works') + expect(mockUndiciFetch).toHaveBeenCalledTimes(1) + }) }) diff --git a/apps/sim/lib/mcp/pinned-fetch.ts b/apps/sim/lib/mcp/pinned-fetch.ts index 798de5710e6..236518d13ec 100644 --- a/apps/sim/lib/mcp/pinned-fetch.ts +++ b/apps/sim/lib/mcp/pinned-fetch.ts @@ -3,29 +3,47 @@ import { Agent, type RequestInit as UndiciRequestInit, fetch as undiciFetch } fr import { createPinnedLookup } from '@/lib/core/security/input-validation.server' /** - * Creates a FetchLike that pins all outbound HTTP connections to a pre-resolved - * IP address. Used by the MCP transport to prevent DNS-rebinding (TOCTOU) - * attacks: validation performs DNS once and confirms the IP is allowed; this - * fetch then forces every subsequent request (initial POST, SSE GET, redirects) - * to use that same IP, regardless of what the hostname now resolves to. + * Pins outbound HTTP connections to a pre-resolved IP to prevent DNS-rebinding + * between URL validation and connection. Hostname is preserved so TLS SNI and + * the Host header still match the certificate. * - * Uses undici's `fetch` directly so the `dispatcher` option is part of the - * real type contract — not a cast that would silently break if a future - * runtime swapped out the implementation. - * - * The original hostname is preserved on the request so TLS SNI and the Host - * header continue to match the certificate. + * Agents are pooled by `resolvedIP` so back-to-back calls to the same server + * reuse the same keep-alive connection pool instead of opening a fresh TCP + + * TLS connection per McpClient instance. */ +const MAX_POOLED_AGENTS = 64 +const pinnedAgents = new Map() + +function getPinnedAgent(resolvedIP: string): Agent { + const existing = pinnedAgents.get(resolvedIP) + if (existing) { + // LRU touch — re-insert to mark as most recently used. + pinnedAgents.delete(resolvedIP) + pinnedAgents.set(resolvedIP, existing) + return existing + } + if (pinnedAgents.size >= MAX_POOLED_AGENTS) { + // Drop the oldest entry WITHOUT closing it — existing `createMcpPinnedFetch` + // closures may still hold a reference and have in-flight requests. The + // dispatcher is GC'd (and its sockets cleaned up) when the last closure + // releases it; undici closes idle keep-alive connections after its own + // timeout (default 4s). + const oldestKey = pinnedAgents.keys().next().value + if (oldestKey !== undefined) pinnedAgents.delete(oldestKey) + } + const agent = new Agent({ connect: { lookup: createPinnedLookup(resolvedIP) } }) + pinnedAgents.set(resolvedIP, agent) + return agent +} + +export function __resetPinnedAgentsForTests(): void { + pinnedAgents.clear() +} + export function createMcpPinnedFetch(resolvedIP: string): FetchLike { - const dispatcher = new Agent({ - connect: { lookup: createPinnedLookup(resolvedIP) }, - }) + const dispatcher = getPinnedAgent(resolvedIP) return (async (url, init) => { - // DOM `RequestInit` and undici's `RequestInit` are structurally compatible - // at runtime (Node's global fetch IS undici) but differ in TS types. - // Cast the init through unknown to bridge the typing without losing the - // critical `dispatcher` typing on the call itself. const undiciInit: UndiciRequestInit = { // double-cast-allowed: DOM RequestInit and undici RequestInit are structurally compatible at runtime (Node's global fetch IS undici) but the TS types differ ...(init as unknown as UndiciRequestInit), diff --git a/apps/sim/lib/mcp/service.ts b/apps/sim/lib/mcp/service.ts index 113998c61aa..4ef53382483 100644 --- a/apps/sim/lib/mcp/service.ts +++ b/apps/sim/lib/mcp/service.ts @@ -1,7 +1,3 @@ -/** - * MCP Service - Clean stateless service for MCP operations - */ - import { UnauthorizedError } from '@modelcontextprotocol/sdk/client/auth.js' import { StreamableHTTPError } from '@modelcontextprotocol/sdk/client/streamableHttp.js' import { db } from '@sim/db' @@ -100,19 +96,12 @@ class McpService { } } - /** - * Dispose of the service and cleanup resources - */ dispose(): void { this.unsubscribeConnectionManager?.() this.cacheAdapter.dispose() logger.info('MCP Service disposed') } - /** - * Resolve environment variables in server config. - * Uses shared utility with strict mode (throws on missing vars). - */ private async resolveConfigEnvVars( config: McpServerConfig, userId: string, @@ -126,9 +115,6 @@ class McpService { return { config: resolvedConfig, resolvedIP } } - /** - * Get server configuration from database - */ private async getServerConfig( serverId: string, workspaceId: string @@ -171,9 +157,6 @@ class McpService { } } - /** - * Get all enabled servers for a workspace - */ private async getWorkspaceServers(workspaceId: string): Promise { const whereConditions = [ eq(mcpServers.workspaceId, workspaceId), @@ -205,9 +188,6 @@ class McpService { .filter((config) => isMcpDomainAllowed(config.url)) } - /** - * Create and connect to an MCP client - */ private async createClient( config: McpServerConfig, resolvedIP: string | null, @@ -320,12 +300,7 @@ class McpService { throw new Error(`Failed to execute tool ${toolCall.name} after ${maxRetries} attempts`) } - /** - * Detects an expired or unknown `Mcp-Session-Id` so the caller can retry. - * Per MCP spec, the server returns HTTP 404 for an unknown session id and - * may return 400 when the session header is malformed; the SDK surfaces - * both as `StreamableHTTPError` with a typed numeric `code` field. - */ + /** MCP spec: server returns 404 for unknown session id, 400 for malformed header. */ private isSessionError(error: unknown): boolean { if (error instanceof StreamableHTTPError) { return error.code === 404 || error.code === 400 @@ -333,9 +308,6 @@ class McpService { return false } - /** - * Update server connection status after discovery attempt - */ private async updateServerStatus( serverId: string, workspaceId: string, @@ -448,9 +420,6 @@ class McpService { } } - /** - * Discover tools from all workspace servers - */ async discoverTools( userId: string, workspaceId: string, @@ -601,9 +570,9 @@ class McpService { // Await cache writes so a follow-up discoverTools sees consistent state. await Promise.allSettled(cacheWrites) - Promise.allSettled(deferredSideEffects).catch((err) => { - logger.error(`[${requestId}] Error in deferred discovery work:`, err) - }) + // Each deferred side-effect self-logs failures, so we just mark the + // promises as handled to avoid unhandled-rejection warnings. + for (const p of deferredSideEffects) p.catch(() => {}) if (mcpConnectionManager) { for (const conn of liveConnections) { @@ -744,9 +713,6 @@ class McpService { throw new Error(`Failed to discover tools from server ${serverId} after ${maxRetries} attempts`) } - /** - * Get server summaries for a user - */ async getServerSummaries(userId: string, workspaceId: string): Promise { const requestId = generateRequestId() diff --git a/apps/sim/lib/mcp/types.ts b/apps/sim/lib/mcp/types.ts index 1cb4ad3e782..be506fd5b07 100644 --- a/apps/sim/lib/mcp/types.ts +++ b/apps/sim/lib/mcp/types.ts @@ -1,15 +1,8 @@ -/** - * MCP Types - for connecting to external MCP servers - */ +import type { Tool } from '@modelcontextprotocol/sdk/types.js' export type McpTransport = 'streamable-http' -/** - * Auth mode for an outbound MCP server connection. - * - `none` — server requires no auth. - * - `headers` — static header map (legacy / API-token / bearer). - * - `oauth` — OAuth 2.1 + PKCE via the SDK's authProvider, persisted per workspace server. - */ +/** `oauth` uses the SDK's authProvider; `headers` is a static map; `none` is unauthenticated. */ export type McpAuthType = 'none' | 'headers' | 'oauth' export interface McpServerStatusConfig { @@ -72,10 +65,6 @@ export interface McpSecurityPolicy { auditLevel: 'none' | 'basic' | 'detailed' } -/** - * JSON Schema property definition for tool parameters. - * Follows JSON Schema specification with description support. - */ export interface McpToolSchemaProperty { type?: string | string[] description?: string @@ -87,10 +76,7 @@ export interface McpToolSchemaProperty { [key: string]: unknown } -/** - * JSON Schema for tool input parameters. - * Aligns with MCP SDK's Tool.inputSchema structure. - */ +/** Typed view of the SDK's `Tool.inputSchema` (which is `Record`). */ export interface McpToolSchema { type: 'object' properties?: Record @@ -99,13 +85,8 @@ export interface McpToolSchema { [key: string]: unknown } -/** - * MCP Tool with server context. - * Extends the SDK's Tool type with app-specific server tracking. - */ -export interface McpTool { - name: string - description?: string +/** SDK `Tool` plus the server context Sim tracks. */ +export interface McpTool extends Pick { inputSchema: McpToolSchema serverId: string serverName: string @@ -209,9 +190,6 @@ export interface McpClientOptions { authProvider?: import('@modelcontextprotocol/sdk/client/auth.js').OAuthClientProvider } -/** - * Event emitted by the connection manager when a server's tools change. - */ export interface ToolsChangedEvent { serverId: string serverName: string @@ -219,9 +197,6 @@ export interface ToolsChangedEvent { timestamp: number } -/** - * State of a managed persistent connection. - */ export interface ManagedConnectionState { serverId: string serverName: string @@ -233,9 +208,6 @@ export interface ManagedConnectionState { lastActivity: number } -/** - * Event emitted when workflow CRUD modifies a workflow MCP server's tools. - */ export interface WorkflowToolsChangedEvent { serverId: string workspaceId: string diff --git a/apps/sim/lib/mcp/utils.test.ts b/apps/sim/lib/mcp/utils.test.ts index 29aded1358e..30990f62d4a 100644 --- a/apps/sim/lib/mcp/utils.test.ts +++ b/apps/sim/lib/mcp/utils.test.ts @@ -1,5 +1,7 @@ +import { UnauthorizedError } from '@modelcontextprotocol/sdk/client/auth.js' import { describe, expect, it } from 'vitest' import { DEFAULT_EXECUTION_TIMEOUT_MS } from '@/lib/core/execution-limits' +import { McpConnectionError, McpOauthAuthorizationRequiredError } from '@/lib/mcp/types' import { categorizeError, createMcpToolId, @@ -304,6 +306,32 @@ describe('categorizeError', () => { expect(result.message).toBe('Unknown error occurred') }) + it.concurrent('returns 401 for McpOauthAuthorizationRequiredError via instanceof', () => { + const error = new McpOauthAuthorizationRequiredError('mcp-a', 'A') + const result = categorizeError(error) + expect(result.status).toBe(401) + expect(result.message).toBe('Authentication required') + }) + + it.concurrent('returns 401 for SDK UnauthorizedError via instanceof', () => { + const error = new UnauthorizedError('token expired') + const result = categorizeError(error) + expect(result.status).toBe(401) + }) + + it.concurrent('returns 503 for McpConnectionError with cooldown message', () => { + const error = new McpConnectionError('Server in cooldown — try again shortly.', 'mcp-a') + const result = categorizeError(error) + expect(result.status).toBe(503) + }) + + it.concurrent('returns 502 for other McpConnectionError', () => { + const error = new McpConnectionError('connect ECONNREFUSED', 'mcp-a') + const result = categorizeError(error) + expect(result.status).toBe(502) + expect(result.message).toBe('Connection failed') + }) + it.concurrent('returns 500 for null', () => { const result = categorizeError(null) expect(result.status).toBe(500) diff --git a/apps/sim/lib/mcp/utils.ts b/apps/sim/lib/mcp/utils.ts index 6364cafb111..e5c2f9db22f 100644 --- a/apps/sim/lib/mcp/utils.ts +++ b/apps/sim/lib/mcp/utils.ts @@ -1,6 +1,11 @@ +import { UnauthorizedError } from '@modelcontextprotocol/sdk/client/auth.js' import { NextResponse } from 'next/server' import { DEFAULT_EXECUTION_TIMEOUT_MS } from '@/lib/core/execution-limits' -import type { McpApiResponse } from '@/lib/mcp/types' +import { + type McpApiResponse, + McpConnectionError, + McpOauthAuthorizationRequiredError, +} from '@/lib/mcp/types' import { isMcpTool, MCP } from '@/executor/constants' export const MCP_CONSTANTS = { @@ -137,28 +142,36 @@ export function categorizeError(error: unknown): { message: string; status: numb return { message: 'Unknown error occurred', status: 500 } } + // Typed dispatch first — our own classes carry definitive intent. + if (error instanceof McpOauthAuthorizationRequiredError || error instanceof UnauthorizedError) { + return { message: 'Authentication required', status: 401 } + } + if (error instanceof McpConnectionError) { + if (error.message.toLowerCase().includes('cooldown')) { + return { message: 'Server temporarily unavailable', status: 503 } + } + return { message: 'Connection failed', status: 502 } + } + + // Fall back to substring matching for SDK / third-party errors we don't + // own a typed class for. const msg = error.message.toLowerCase() if (msg.includes('timeout')) { return { message: 'Request timed out', status: 408 } } - if (msg.includes('cooldown')) { return { message: 'Server temporarily unavailable', status: 503 } } - if (msg.includes('not found') || msg.includes('not accessible')) { return { message: 'Resource not found', status: 404 } } - if (msg.includes('authentication') || msg.includes('unauthorized')) { return { message: 'Authentication required', status: 401 } } - if (msg.includes('invalid') || msg.includes('missing required') || msg.includes('validation')) { return { message: 'Invalid request parameters', status: 400 } } - return { message: 'Internal server error', status: 500 } }