Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion apps/sim/app/api/a2a/agents/[agentId]/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@ export async function GET(request: NextRequest, { params }: { params: Promise<Ro

if (!agent.agent.isPublished) {
const auth = await checkSessionOrInternalAuth(request, { requireWorkflowId: false })
if (!auth.success) {
if (!auth.success || !auth.userId) {
return NextResponse.json({ error: 'Agent not published' }, { status: 404 })
}

const workspaceAccess = await checkWorkspaceAccess(agent.agent.workspaceId, auth.userId)
if (!workspaceAccess.exists || !workspaceAccess.hasAccess) {
return NextResponse.json({ error: 'Agent not published' }, { status: 404 })
}
}
Expand Down
17 changes: 14 additions & 3 deletions apps/sim/app/api/a2a/agents/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import { sanitizeAgentName } from '@/lib/a2a/utils'
import { checkSessionOrInternalAuth } from '@/lib/auth/hybrid'
import { loadWorkflowFromNormalizedTables } from '@/lib/workflows/persistence/utils'
import { hasValidStartBlockInState } from '@/lib/workflows/triggers/trigger-utils'
import { getWorkspaceById } from '@/lib/workspaces/permissions/utils'
import { checkWorkspaceAccess } from '@/lib/workspaces/permissions/utils'

const logger = createLogger('A2AAgentsAPI')

Expand All @@ -39,10 +39,13 @@ export async function GET(request: NextRequest) {
return NextResponse.json({ error: 'workspaceId is required' }, { status: 400 })
}

const ws = await getWorkspaceById(workspaceId)
if (!ws) {
const workspaceAccess = await checkWorkspaceAccess(workspaceId, auth.userId)
if (!workspaceAccess.exists) {
return NextResponse.json({ error: 'Workspace not found' }, { status: 404 })
}
if (!workspaceAccess.hasAccess) {
return NextResponse.json({ error: 'Forbidden' }, { status: 403 })
}

const agents = await db
.select({
Expand Down Expand Up @@ -103,6 +106,14 @@ export async function POST(request: NextRequest) {
)
}

const workspaceAccess = await checkWorkspaceAccess(workspaceId, auth.userId)
if (!workspaceAccess.exists) {
return NextResponse.json({ error: 'Workspace not found' }, { status: 404 })
}
if (!workspaceAccess.canWrite) {
return NextResponse.json({ error: 'Forbidden' }, { status: 403 })
}

const [wf] = await db
.select({
id: workflow.id,
Expand Down
119 changes: 100 additions & 19 deletions apps/sim/app/api/a2a/serve/[agentId]/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import { validateUrlWithDNS } from '@/lib/core/security/input-validation.server'
import { SSE_HEADERS } from '@/lib/core/utils/sse'
import { getBaseUrl } from '@/lib/core/utils/urls'
import { markExecutionCancelled } from '@/lib/execution/cancellation'
import { checkWorkspaceAccess } from '@/lib/workspaces/permissions/utils'
import { getWorkspaceBilledAccountUserId } from '@/lib/workspaces/utils'
import {
A2A_ERROR_CODES,
A2A_METHODS,
Expand Down Expand Up @@ -191,6 +193,9 @@ export async function POST(request: NextRequest, { params }: { params: Promise<R

const authSchemes = (agent.authentication as { schemes?: string[] })?.schemes || []
const requiresAuth = !authSchemes.includes('none')
let authenticatedUserId: string | null = null
let authenticatedAuthType: 'session' | 'api_key' | 'internal_jwt' | undefined
let authenticatedApiKeyType: 'personal' | 'workspace' | undefined

if (requiresAuth) {
const auth = await checkHybridAuth(request, { requireWorkflowId: false })
Expand All @@ -200,6 +205,17 @@ export async function POST(request: NextRequest, { params }: { params: Promise<R
{ status: 401 }
)
}
authenticatedUserId = auth.userId
authenticatedAuthType = auth.authType
authenticatedApiKeyType = auth.apiKeyType

const workspaceAccess = await checkWorkspaceAccess(agent.workspaceId, authenticatedUserId)
if (!workspaceAccess.exists || !workspaceAccess.hasAccess) {
return NextResponse.json(
createError(null, A2A_ERROR_CODES.AUTHENTICATION_REQUIRED, 'Access denied'),
{ status: 403 }
)
}
}

const [wf] = await db
Expand All @@ -225,34 +241,61 @@ export async function POST(request: NextRequest, { params }: { params: Promise<R
}

const { id, method, params: rpcParams } = body
const apiKey = request.headers.get('X-API-Key')
const requestApiKey = request.headers.get('X-API-Key')
const apiKey = authenticatedAuthType === 'api_key' ? requestApiKey : null
const isPersonalApiKeyCaller =
authenticatedAuthType === 'api_key' && authenticatedApiKeyType === 'personal'
const billedUserId = await getWorkspaceBilledAccountUserId(agent.workspaceId)
if (!billedUserId) {
logger.error('Unable to resolve workspace billed account for A2A execution', {
agentId: agent.id,
workspaceId: agent.workspaceId,
})
return NextResponse.json(
createError(
id,
A2A_ERROR_CODES.INTERNAL_ERROR,
'Unable to resolve billing account for this workspace'
),
{ status: 500 }
)
}
const executionUserId =
isPersonalApiKeyCaller && authenticatedUserId ? authenticatedUserId : billedUserId
Comment thread
icecrasher321 marked this conversation as resolved.

logger.info(`A2A request: ${method} for agent ${agentId}`)

switch (method) {
case A2A_METHODS.MESSAGE_SEND:
return handleMessageSend(id, agent, rpcParams as MessageSendParams, apiKey)
return handleMessageSend(id, agent, rpcParams as MessageSendParams, apiKey, executionUserId)

case A2A_METHODS.MESSAGE_STREAM:
return handleMessageStream(request, id, agent, rpcParams as MessageSendParams, apiKey)
return handleMessageStream(
request,
id,
agent,
rpcParams as MessageSendParams,
apiKey,
executionUserId
)

case A2A_METHODS.TASKS_GET:
return handleTaskGet(id, rpcParams as TaskIdParams)
return handleTaskGet(id, agent.id, rpcParams as TaskIdParams)

case A2A_METHODS.TASKS_CANCEL:
return handleTaskCancel(id, rpcParams as TaskIdParams)
return handleTaskCancel(id, agent.id, rpcParams as TaskIdParams)

case A2A_METHODS.TASKS_RESUBSCRIBE:
return handleTaskResubscribe(request, id, rpcParams as TaskIdParams)
return handleTaskResubscribe(request, id, agent.id, rpcParams as TaskIdParams)

case A2A_METHODS.PUSH_NOTIFICATION_SET:
return handlePushNotificationSet(id, rpcParams as PushNotificationSetParams)
return handlePushNotificationSet(id, agent.id, rpcParams as PushNotificationSetParams)

case A2A_METHODS.PUSH_NOTIFICATION_GET:
return handlePushNotificationGet(id, rpcParams as TaskIdParams)
return handlePushNotificationGet(id, agent.id, rpcParams as TaskIdParams)

case A2A_METHODS.PUSH_NOTIFICATION_DELETE:
return handlePushNotificationDelete(id, rpcParams as TaskIdParams)
return handlePushNotificationDelete(id, agent.id, rpcParams as TaskIdParams)

default:
return NextResponse.json(
Expand All @@ -268,6 +311,14 @@ export async function POST(request: NextRequest, { params }: { params: Promise<R
}
}

async function getTaskForAgent(taskId: string, agentId: string) {
const [task] = await db.select().from(a2aTask).where(eq(a2aTask.id, taskId)).limit(1)
if (!task || task.agentId !== agentId) {
return null
}
return task
}

/**
* Handle message/send - Send a message (v0.3)
*/
Expand All @@ -280,7 +331,8 @@ async function handleMessageSend(
workspaceId: string
},
params: MessageSendParams,
apiKey?: string | null
apiKey?: string | null,
executionUserId?: string
): Promise<NextResponse> {
if (!params?.message) {
return NextResponse.json(
Expand Down Expand Up @@ -318,6 +370,13 @@ async function handleMessageSend(
)
}

if (existingTask.agentId !== agent.id) {
return NextResponse.json(
createError(id, A2A_ERROR_CODES.TASK_NOT_FOUND, 'Task not found'),
{ status: 404 }
)
}

if (isTerminalState(existingTask.status as TaskState)) {
return NextResponse.json(
createError(id, A2A_ERROR_CODES.TASK_ALREADY_COMPLETE, 'Task already in terminal state'),
Expand Down Expand Up @@ -363,6 +422,7 @@ async function handleMessageSend(
} = await buildExecuteRequest({
workflowId: agent.workflowId,
apiKey,
userId: executionUserId,
})

logger.info(`Executing workflow ${agent.workflowId} for A2A task ${taskId}`)
Expand Down Expand Up @@ -475,7 +535,8 @@ async function handleMessageStream(
workspaceId: string
},
params: MessageSendParams,
apiKey?: string | null
apiKey?: string | null,
executionUserId?: string
): Promise<NextResponse> {
if (!params?.message) {
return NextResponse.json(
Expand Down Expand Up @@ -522,6 +583,13 @@ async function handleMessageStream(
})
}

if (existingTask.agentId !== agent.id) {
await releaseLock(lockKey, lockValue)
return NextResponse.json(createError(id, A2A_ERROR_CODES.TASK_NOT_FOUND, 'Task not found'), {
status: 404,
})
}

if (isTerminalState(existingTask.status as TaskState)) {
await releaseLock(lockKey, lockValue)
return NextResponse.json(
Expand Down Expand Up @@ -595,6 +663,7 @@ async function handleMessageStream(
} = await buildExecuteRequest({
workflowId: agent.workflowId,
apiKey,
userId: executionUserId,
stream: true,
})

Expand Down Expand Up @@ -788,7 +857,11 @@ async function handleMessageStream(
/**
* Handle tasks/get - Query task status
*/
async function handleTaskGet(id: string | number, params: TaskIdParams): Promise<NextResponse> {
async function handleTaskGet(
id: string | number,
agentId: string,
params: TaskIdParams
): Promise<NextResponse> {
if (!params?.id) {
return NextResponse.json(
createError(id, A2A_ERROR_CODES.INVALID_PARAMS, 'Task ID is required'),
Expand All @@ -801,7 +874,7 @@ async function handleTaskGet(id: string | number, params: TaskIdParams): Promise
? params.historyLength
: undefined

const [task] = await db.select().from(a2aTask).where(eq(a2aTask.id, params.id)).limit(1)
const task = await getTaskForAgent(params.id, agentId)

if (!task) {
return NextResponse.json(createError(id, A2A_ERROR_CODES.TASK_NOT_FOUND, 'Task not found'), {
Expand All @@ -825,15 +898,19 @@ async function handleTaskGet(id: string | number, params: TaskIdParams): Promise
/**
* Handle tasks/cancel - Cancel a running task
*/
async function handleTaskCancel(id: string | number, params: TaskIdParams): Promise<NextResponse> {
async function handleTaskCancel(
id: string | number,
agentId: string,
params: TaskIdParams
): Promise<NextResponse> {
if (!params?.id) {
return NextResponse.json(
createError(id, A2A_ERROR_CODES.INVALID_PARAMS, 'Task ID is required'),
{ status: 400 }
)
}

const [task] = await db.select().from(a2aTask).where(eq(a2aTask.id, params.id)).limit(1)
const task = await getTaskForAgent(params.id, agentId)

if (!task) {
return NextResponse.json(createError(id, A2A_ERROR_CODES.TASK_NOT_FOUND, 'Task not found'), {
Expand Down Expand Up @@ -897,6 +974,7 @@ async function handleTaskCancel(id: string | number, params: TaskIdParams): Prom
async function handleTaskResubscribe(
request: NextRequest,
id: string | number,
agentId: string,
params: TaskIdParams
): Promise<NextResponse> {
if (!params?.id) {
Expand All @@ -906,7 +984,7 @@ async function handleTaskResubscribe(
)
}

const [task] = await db.select().from(a2aTask).where(eq(a2aTask.id, params.id)).limit(1)
const task = await getTaskForAgent(params.id, agentId)

if (!task) {
return NextResponse.json(createError(id, A2A_ERROR_CODES.TASK_NOT_FOUND, 'Task not found'), {
Expand Down Expand Up @@ -1103,6 +1181,7 @@ async function handleTaskResubscribe(
*/
async function handlePushNotificationSet(
id: string | number,
agentId: string,
params: PushNotificationSetParams
): Promise<NextResponse> {
if (!params?.id) {
Expand Down Expand Up @@ -1130,7 +1209,7 @@ async function handlePushNotificationSet(
)
}

const [task] = await db.select().from(a2aTask).where(eq(a2aTask.id, params.id)).limit(1)
const task = await getTaskForAgent(params.id, agentId)

if (!task) {
return NextResponse.json(createError(id, A2A_ERROR_CODES.TASK_NOT_FOUND, 'Task not found'), {
Expand Down Expand Up @@ -1181,6 +1260,7 @@ async function handlePushNotificationSet(
*/
async function handlePushNotificationGet(
id: string | number,
agentId: string,
params: TaskIdParams
): Promise<NextResponse> {
if (!params?.id) {
Expand All @@ -1190,7 +1270,7 @@ async function handlePushNotificationGet(
)
}

const [task] = await db.select().from(a2aTask).where(eq(a2aTask.id, params.id)).limit(1)
const task = await getTaskForAgent(params.id, agentId)

if (!task) {
return NextResponse.json(createError(id, A2A_ERROR_CODES.TASK_NOT_FOUND, 'Task not found'), {
Expand Down Expand Up @@ -1224,6 +1304,7 @@ async function handlePushNotificationGet(
*/
async function handlePushNotificationDelete(
id: string | number,
agentId: string,
params: TaskIdParams
): Promise<NextResponse> {
if (!params?.id) {
Expand All @@ -1233,7 +1314,7 @@ async function handlePushNotificationDelete(
)
}

const [task] = await db.select().from(a2aTask).where(eq(a2aTask.id, params.id)).limit(1)
const task = await getTaskForAgent(params.id, agentId)

if (!task) {
return NextResponse.json(createError(id, A2A_ERROR_CODES.TASK_NOT_FOUND, 'Task not found'), {
Expand Down
3 changes: 2 additions & 1 deletion apps/sim/app/api/a2a/serve/[agentId]/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ export function formatTaskResponse(task: Task, historyLength?: number): Task {
export interface ExecuteRequestConfig {
workflowId: string
apiKey?: string | null
userId?: string
stream?: boolean
}

Expand All @@ -124,7 +125,7 @@ export async function buildExecuteRequest(
if (config.apiKey) {
headers['X-API-Key'] = config.apiKey
} else {
const internalToken = await generateInternalToken()
const internalToken = await generateInternalToken(config.userId)
headers.Authorization = `Bearer ${internalToken}`
useInternalAuth = true
}
Expand Down
Loading