Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions apps/sim/lib/core/security/input-validation.server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import type { LookupFunction } from 'net'
import { createLogger } from '@sim/logger'
import { toError } from '@sim/utils/errors'
import * as ipaddr from 'ipaddr.js'
import { Agent, type RequestInit as UndiciRequestInit, fetch as undiciFetch } from 'undici'
import { isHosted } from '@/lib/core/config/feature-flags'
import { type ValidationResult, validateExternalUrl } from '@/lib/core/security/input-validation'
import { PayloadSizeLimitError } from '@/lib/core/utils/stream-limits'
Expand Down Expand Up @@ -400,6 +401,40 @@ export function createPinnedLookup(resolvedIP: string): LookupFunction {
}
}

/**
* Builds a standard `fetch`-compatible function that pins every outbound
* connection to `resolvedIP`, preventing DNS-rebinding (TOCTOU) between URL
* validation and connection. The original hostname is preserved for TLS SNI and
* the `Host` header so it still matches the certificate. This is the single
* source of truth for pinned outbound fetches — both the LLM providers and the
* MCP transport consume it.
*
* Pass the returned function as the `fetch` option to the OpenAI/Anthropic SDKs
* (or call it directly) after validating the URL with {@link validateUrlWithDNS}
* and capturing `resolvedIP`. Because the pinned lookup always returns
* `resolvedIP` regardless of hostname, any redirect the server returns also
* connects to the validated IP — an attacker cannot rebind a redirect target to
* an internal address.
*
* The `Agent` is captured for the lifetime of the returned function, so repeated
* calls (e.g. a provider tool loop) reuse its keep-alive connections.
*/
export function createPinnedFetch(resolvedIP: string): typeof fetch {
const dispatcher = new Agent({ connect: { lookup: createPinnedLookup(resolvedIP) } })

const pinned = async (input: RequestInfo | URL, init?: RequestInit): Promise<Response> => {
// double-cast-allowed: DOM RequestInfo/URL and undici fetch input types differ but are structurally compatible at runtime (Node's global fetch IS undici)
const undiciInput = input as unknown as Parameters<typeof undiciFetch>[0]
// double-cast-allowed: DOM RequestInit and undici RequestInit are structurally compatible at runtime but the TS types differ
const undiciInit: UndiciRequestInit = { ...(init as unknown as UndiciRequestInit), dispatcher }
const response = await undiciFetch(undiciInput, undiciInit)
// double-cast-allowed: undici Response and DOM Response are structurally compatible at runtime
return response as unknown as Response
}

return pinned
}

/**
* Performs a fetch with IP pinning to prevent DNS rebinding attacks.
* Uses the pre-resolved IP address while preserving the original hostname for TLS SNI.
Expand Down
126 changes: 126 additions & 0 deletions apps/sim/lib/core/security/pinned-fetch.server.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/**
* @vitest-environment node
*/
import { featureFlagsMock } from '@sim/testing'
import { beforeEach, describe, expect, it, vi } from 'vitest'

const { mockAgent, mockUndiciFetch, capturedAgentOptions, agentCloses } = vi.hoisted(() => {
const capturedAgentOptions: unknown[] = []
const agentCloses: unknown[] = []
class MockAgent {
constructor(options: unknown) {
capturedAgentOptions.push(options)
}
close() {
agentCloses.push(this)
return Promise.resolve()
}
}
return {
mockAgent: MockAgent,
mockUndiciFetch: vi.fn(),
capturedAgentOptions,
agentCloses,
}
})

vi.mock('undici', () => ({ Agent: mockAgent, fetch: mockUndiciFetch }))
vi.mock('@/lib/core/config/feature-flags', () => featureFlagsMock)

import { createPinnedFetch } from '@/lib/core/security/input-validation.server'

type LookupCallback = (err: Error | null, address: string, family: number) => void
type PinnedLookup = (hostname: string, options: { all?: boolean }, callback: LookupCallback) => void

describe('createPinnedFetch', () => {
beforeEach(() => {
vi.clearAllMocks()
capturedAgentOptions.length = 0
agentCloses.length = 0
mockUndiciFetch.mockResolvedValue(new Response('ok'))
})

it('builds an undici Agent whose pinned lookup always resolves to the validated IP', async () => {
createPinnedFetch('203.0.113.10')

expect(capturedAgentOptions).toHaveLength(1)
const { connect } = capturedAgentOptions[0] as { connect: { lookup: PinnedLookup } }
expect(typeof connect.lookup).toBe('function')

const resolved = await new Promise<{ address: string; family: number }>((resolve) => {
connect.lookup('rebind.attacker.tld', {}, (_err, address, family) =>
resolve({ address, family })
)
})
expect(resolved).toEqual({ address: '203.0.113.10', family: 4 })
})

it('uses IPv6 family when the validated IP is IPv6', async () => {
createPinnedFetch('2606:4700:4700::1111')
const { connect } = capturedAgentOptions[0] as { connect: { lookup: PinnedLookup } }
const resolved = await new Promise<{ address: string; family: number }>((resolve) => {
connect.lookup('example.com', {}, (_err, address, family) => resolve({ address, family }))
})
expect(resolved).toEqual({ address: '2606:4700:4700::1111', family: 6 })
})

it('forwards the pinned dispatcher on every call while preserving init options', async () => {
const pinned = createPinnedFetch('203.0.113.10')
const controller = new AbortController()

await pinned('https://myresource.openai.azure.com/openai/v1/responses', {
method: 'POST',
headers: { 'api-key': 'secret' },
body: '{}',
signal: controller.signal,
})

expect(mockUndiciFetch).toHaveBeenCalledTimes(1)
const [url, init] = mockUndiciFetch.mock.calls[0]
expect(url).toBe('https://myresource.openai.azure.com/openai/v1/responses')
const typedInit = init as RequestInit & { dispatcher?: unknown }
expect(typedInit.dispatcher).toBeInstanceOf(mockAgent)
expect(typedInit.method).toBe('POST')
expect(typedInit.headers).toEqual({ 'api-key': 'secret' })
expect(typedInit.body).toBe('{}')
expect(typedInit.signal).toBe(controller.signal)
})

it('handles an undefined init by still attaching the dispatcher', async () => {
const pinned = createPinnedFetch('203.0.113.10')
await pinned('https://example.com')
const init = mockUndiciFetch.mock.calls[0][1] as { dispatcher?: unknown }
expect(init.dispatcher).toBeInstanceOf(mockAgent)
})

it('reuses one captured dispatcher across all calls of a single instance', async () => {
const pinned = createPinnedFetch('203.0.113.10')
await pinned('https://example.com/a')
await pinned('https://example.com/b')

expect(capturedAgentOptions).toHaveLength(1)
const d1 = (mockUndiciFetch.mock.calls[0][1] as { dispatcher: unknown }).dispatcher
const d2 = (mockUndiciFetch.mock.calls[1][1] as { dispatcher: unknown }).dispatcher
expect(d1).toBe(d2)
})

it('creates an independent dispatcher per instance', async () => {
const a = createPinnedFetch('203.0.113.10')
const b = createPinnedFetch('203.0.113.10')
await a('https://example.com/a')
await b('https://example.com/b')

expect(capturedAgentOptions).toHaveLength(2)
const d1 = (mockUndiciFetch.mock.calls[0][1] as { dispatcher: unknown }).dispatcher
const d2 = (mockUndiciFetch.mock.calls[1][1] as { dispatcher: unknown }).dispatcher
expect(d1).not.toBe(d2)
})

it('returns the response produced by undici fetch', async () => {
mockUndiciFetch.mockResolvedValueOnce(new Response('pong', { status: 201 }))
const pinned = createPinnedFetch('203.0.113.10')
const response = await pinned('https://example.com')
expect(response.status).toBe(201)
expect(await response.text()).toBe('pong')
})
})
4 changes: 2 additions & 2 deletions apps/sim/lib/mcp/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ import {
import { createLogger } from '@sim/logger'
import { getErrorMessage } from '@sim/utils/errors'
import { getMaxExecutionTimeout } from '@/lib/core/execution-limits'
import { createPinnedFetch } from '@/lib/core/security/input-validation.server'
import { McpOauthRedirectRequired } from '@/lib/mcp/oauth'
import { createMcpPinnedFetch } from '@/lib/mcp/pinned-fetch'
import {
type McpClientOptions,
McpConnectionError,
Expand Down Expand Up @@ -70,7 +70,7 @@ export class McpClient {
this.transport = new StreamableHTTPClientTransport(new URL(this.config.url), {
authProvider: useOauth ? this.authProvider : undefined,
requestInit: { headers: this.config.headers },
...(resolvedIP ? { fetch: createMcpPinnedFetch(resolvedIP) } : {}),
...(resolvedIP ? { fetch: createPinnedFetch(resolvedIP) } : {}),
})

this.client = new Client(
Expand Down
36 changes: 17 additions & 19 deletions apps/sim/lib/mcp/oauth/probe.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,22 @@
*/
import { beforeEach, describe, expect, it, vi } from 'vitest'

const {
mockCreateMcpPinnedFetch,
mockCreateSsrfGuardedMcpFetch,
mockPinnedFetch,
mockGuardedFetch,
} = vi.hoisted(() => {
const mockPinnedFetch = vi.fn()
const mockGuardedFetch = vi.fn()
return {
mockPinnedFetch,
mockGuardedFetch,
mockCreateMcpPinnedFetch: vi.fn(() => mockPinnedFetch),
mockCreateSsrfGuardedMcpFetch: vi.fn(() => mockGuardedFetch),
}
})
const { mockCreatePinnedFetch, mockCreateSsrfGuardedMcpFetch, mockPinnedFetch, mockGuardedFetch } =
vi.hoisted(() => {
const mockPinnedFetch = vi.fn()
const mockGuardedFetch = vi.fn()
return {
mockPinnedFetch,
mockGuardedFetch,
mockCreatePinnedFetch: vi.fn(() => mockPinnedFetch),
mockCreateSsrfGuardedMcpFetch: vi.fn(() => mockGuardedFetch),
}
})

vi.mock('@/lib/core/security/input-validation.server', () => ({
createPinnedFetch: mockCreatePinnedFetch,
}))
vi.mock('@/lib/mcp/pinned-fetch', () => ({
createMcpPinnedFetch: mockCreateMcpPinnedFetch,
createSsrfGuardedMcpFetch: mockCreateSsrfGuardedMcpFetch,
}))

Expand Down Expand Up @@ -50,7 +48,7 @@ describe('detectMcpAuthType — connection pinning (SSRF / DNS-rebinding)', () =
const authType = await detectMcpAuthType('https://rebind.example.com/mcp', '203.0.113.10')

expect(authType).toBe('none')
expect(mockCreateMcpPinnedFetch).toHaveBeenCalledWith('203.0.113.10')
expect(mockCreatePinnedFetch).toHaveBeenCalledWith('203.0.113.10')
expect(mockCreateSsrfGuardedMcpFetch).not.toHaveBeenCalled()
expect(mockPinnedFetch).toHaveBeenCalledTimes(1)
// The unpinned global fetch must never be used — that was the SSRF sink.
Expand All @@ -64,7 +62,7 @@ describe('detectMcpAuthType — connection pinning (SSRF / DNS-rebinding)', () =

expect(authType).toBe('none')
expect(mockCreateSsrfGuardedMcpFetch).toHaveBeenCalledTimes(1)
expect(mockCreateMcpPinnedFetch).not.toHaveBeenCalled()
expect(mockCreatePinnedFetch).not.toHaveBeenCalled()
expect(mockGuardedFetch).toHaveBeenCalledTimes(1)
expect(globalFetchSpy).not.toHaveBeenCalled()
})
Expand All @@ -90,7 +88,7 @@ describe('detectMcpAuthType — connection pinning (SSRF / DNS-rebinding)', () =
const authType = await detectMcpAuthType('http://example.com/mcp', '203.0.113.10')

expect(authType).toBe('headers')
expect(mockCreateMcpPinnedFetch).not.toHaveBeenCalled()
expect(mockCreatePinnedFetch).not.toHaveBeenCalled()
expect(mockCreateSsrfGuardedMcpFetch).not.toHaveBeenCalled()
expect(globalFetchSpy).not.toHaveBeenCalled()
})
Expand Down
5 changes: 3 additions & 2 deletions apps/sim/lib/mcp/oauth/probe.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import { extractWWWAuthenticateParams } from '@modelcontextprotocol/sdk/client/auth.js'
import type { FetchLike } from '@modelcontextprotocol/sdk/shared/transport.js'
import { createLogger } from '@sim/logger'
import { createPinnedFetch } from '@/lib/core/security/input-validation.server'
import { isLoopbackHostname } from '@/lib/core/utils/urls'
import { createMcpPinnedFetch, createSsrfGuardedMcpFetch } from '@/lib/mcp/pinned-fetch'
import { createSsrfGuardedMcpFetch } from '@/lib/mcp/pinned-fetch'
import type { McpAuthType } from '@/lib/mcp/types'

const logger = createLogger('McpOauthProbe')
Expand Down Expand Up @@ -33,7 +34,7 @@ export async function detectMcpAuthType(
}

const probeFetch: FetchLike = resolvedIP
? createMcpPinnedFetch(resolvedIP)
? createPinnedFetch(resolvedIP)
: createSsrfGuardedMcpFetch()

const controller = new AbortController()
Expand Down
30 changes: 9 additions & 21 deletions apps/sim/lib/mcp/oauth/revoke.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,33 +15,23 @@ const PUBLIC_SERVER_URL = 'https://mcp.attacker.com'
const PUBLIC_SERVER_IP = '203.0.113.10'

const {
MockAgent,
mockUndiciFetch,
mockValidateMcpServerSsrf,
mockDiscoverOAuthServerInfo,
mockLoadOauthRow,
mockDecryptSecret,
mockDbSelect,
} = vi.hoisted(() => {
class MockAgent {
close() {
return Promise.resolve()
}
}
return {
MockAgent,
mockUndiciFetch: vi.fn(),
mockValidateMcpServerSsrf: vi.fn(),
mockDiscoverOAuthServerInfo: vi.fn(),
mockLoadOauthRow: vi.fn(),
mockDecryptSecret: vi.fn(),
mockDbSelect: vi.fn(),
}
})
} = vi.hoisted(() => ({
mockUndiciFetch: vi.fn(),
mockValidateMcpServerSsrf: vi.fn(),
mockDiscoverOAuthServerInfo: vi.fn(),
mockLoadOauthRow: vi.fn(),
mockDecryptSecret: vi.fn(),
mockDbSelect: vi.fn(),
}))

vi.mock('undici', () => ({ Agent: MockAgent, fetch: mockUndiciFetch }))
vi.mock('@/lib/core/security/input-validation.server', () => ({
createPinnedLookup: vi.fn(() => 'pinned-lookup-fn'),
createPinnedFetch: vi.fn(() => mockUndiciFetch),
}))
vi.mock('@/lib/mcp/domain-check', () => ({
validateMcpServerSsrf: mockValidateMcpServerSsrf,
Expand All @@ -59,7 +49,6 @@ vi.mock('@sim/db', () => ({
db: { select: mockDbSelect },
}))

import { __resetPinnedAgentsForTests } from '@/lib/mcp/pinned-fetch'
import { revokeMcpOauthTokens } from './revoke'

function wireServerRow(row: Record<string, unknown>) {
Expand All @@ -74,7 +63,6 @@ function wireServerRow(row: Record<string, unknown>) {
describe('revokeMcpOauthTokens — SSRF guard', () => {
beforeEach(() => {
vi.clearAllMocks()
__resetPinnedAgentsForTests()

mockLoadOauthRow.mockResolvedValue({
tokens: { access_token: 'access-secret', refresh_token: 'refresh-secret' },
Expand Down
Loading
Loading