Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 21 additions & 11 deletions apps/sim/app/api/copilot/chat/update-messages/route.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,25 @@ import { authMockFns } from '@sim/testing'
import { NextRequest } from 'next/server'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'

const { mockSelect, mockFrom, mockWhere, mockLimit, mockUpdate, mockSet, mockUpdateWhere } =
vi.hoisted(() => ({
mockSelect: vi.fn(),
mockFrom: vi.fn(),
mockWhere: vi.fn(),
mockLimit: vi.fn(),
mockUpdate: vi.fn(),
mockSet: vi.fn(),
mockUpdateWhere: vi.fn(),
}))
const {
mockSelect,
mockFrom,
mockWhere,
mockLimit,
mockUpdate,
mockSet,
mockUpdateWhere,
mockReturning,
} = vi.hoisted(() => ({
mockSelect: vi.fn(),
mockFrom: vi.fn(),
mockWhere: vi.fn(),
mockLimit: vi.fn(),
mockUpdate: vi.fn(),
mockSet: vi.fn(),
mockUpdateWhere: vi.fn(),
mockReturning: vi.fn(),
}))

vi.mock('@sim/db', () => ({
db: {
Expand Down Expand Up @@ -51,8 +60,9 @@ describe('Copilot Chat Update Messages API Route', () => {
mockWhere.mockReturnValue({ limit: mockLimit })
mockLimit.mockResolvedValue([])
mockUpdate.mockReturnValue({ set: mockSet })
mockUpdateWhere.mockResolvedValue(undefined)
mockSet.mockReturnValue({ where: mockUpdateWhere })
mockUpdateWhere.mockReturnValue({ returning: mockReturning })
mockReturning.mockResolvedValue([{ model: 'gpt-4' }])
})

afterEach(() => {
Expand Down
12 changes: 11 additions & 1 deletion apps/sim/app/api/copilot/chat/update-messages/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { type NextRequest, NextResponse } from 'next/server'
import { updateCopilotMessagesContract } from '@/lib/api/contracts/copilot'
import { parseRequest } from '@/lib/api/server'
import { getAccessibleCopilotChatAuth } from '@/lib/copilot/chat/lifecycle'
import { replaceCopilotChatMessages } from '@/lib/copilot/chat/messages-dual-write'
import { normalizeMessage, type PersistedMessage } from '@/lib/copilot/chat/persisted-message'
import {
authenticateCopilotRequestSessionOnly,
Expand Down Expand Up @@ -86,7 +87,16 @@ export const POST = withRouteHandler(async (req: NextRequest) => {
updateData.config = config
}

await db.update(copilotChats).set(updateData).where(eq(copilotChats.id, chatId))
const [updated] = await db
.update(copilotChats)
.set(updateData)
.where(eq(copilotChats.id, chatId))
.returning({ model: copilotChats.model })
if (updated) {
await replaceCopilotChatMessages(chatId, normalizedMessages, {
chatModel: updated.model ?? null,
})
}

logger.info(`[${tracker.requestId}] Successfully updated chat`, {
chatId,
Expand Down
3 changes: 3 additions & 0 deletions apps/sim/app/api/mothership/chats/[chatId]/fork/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { eq } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { forkMothershipChatContract } from '@/lib/api/contracts/mothership-tasks'
import { parseRequest } from '@/lib/api/server'
import { appendCopilotChatMessages } from '@/lib/copilot/chat/messages-dual-write'
import type { PersistedMessage } from '@/lib/copilot/chat/persisted-message'
import { fetchGo } from '@/lib/copilot/request/go/fetch'
import {
Expand Down Expand Up @@ -102,6 +103,8 @@ export const POST = withRouteHandler(
return createInternalServerErrorResponse('Failed to create forked chat')
}

await appendCopilotChatMessages(newId, forkedMessages, { chatModel: parent.model })

// Clone copilot-service conversation state (messages, active_messages, memory files).
// Best-effort: if the copilot service doesn't have a row for the source chat yet, skip.
try {
Expand Down
36 changes: 23 additions & 13 deletions apps/sim/app/api/superuser/import-workflow/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import { type NextRequest, NextResponse } from 'next/server'
import { importWorkflowAsSuperuserContract } from '@/lib/api/contracts/workflows'
import { parseRequest } from '@/lib/api/server'
import { getSession } from '@/lib/auth'
import { appendCopilotChatMessages } from '@/lib/copilot/chat/messages-dual-write'
import type { PersistedMessage } from '@/lib/copilot/chat/persisted-message'
import { withRouteHandler } from '@/lib/core/utils/with-route-handler'
import { verifyEffectiveSuperUser } from '@/lib/templates/permissions'
import { parseWorkflowJson } from '@/lib/workflows/operations/import-export'
Expand Down Expand Up @@ -172,19 +174,27 @@ export const POST = withRouteHandler(async (request: NextRequest) => {
let copilotChatsImported = 0

for (const chat of sourceCopilotChats) {
await db.insert(copilotChats).values({
userId: session.user.id,
workflowId: newWorkflowId,
title: chat.title ? `[Import] ${chat.title}` : null,
messages: chat.messages,
model: chat.model,
conversationId: null, // Don't copy conversation ID
previewYaml: chat.previewYaml,
planArtifact: chat.planArtifact,
config: chat.config,
createdAt: new Date(),
updatedAt: new Date(),
})
const [imported] = await db
.insert(copilotChats)
.values({
userId: session.user.id,
workflowId: newWorkflowId,
title: chat.title ? `[Import] ${chat.title}` : null,
messages: chat.messages,
model: chat.model,
conversationId: null, // Don't copy conversation ID
previewYaml: chat.previewYaml,
planArtifact: chat.planArtifact,
config: chat.config,
createdAt: new Date(),
updatedAt: new Date(),
})
.returning({ id: copilotChats.id })
if (imported && Array.isArray(chat.messages) && chat.messages.length > 0) {
await appendCopilotChatMessages(imported.id, chat.messages as PersistedMessage[], {
chatModel: chat.model,
})
}
copilotChatsImported++
}

Expand Down
151 changes: 151 additions & 0 deletions apps/sim/lib/copilot/chat/messages-dual-write.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
/**
* @vitest-environment node
*/
import { dbChainMock, dbChainMockFns, resetDbChainMock } from '@sim/testing'
import { beforeEach, describe, expect, it, vi } from 'vitest'

vi.mock('@sim/db', () => dbChainMock)

import {
appendCopilotChatMessages,
replaceCopilotChatMessages,
} from '@/lib/copilot/chat/messages-dual-write'
import type { PersistedMessage } from '@/lib/copilot/chat/persisted-message'

const userMsg: PersistedMessage = {
id: 'msg-user-1',
role: 'user',
content: 'Hello',
timestamp: '2026-01-01T00:00:00.000Z',
}

const assistantMsg: PersistedMessage = {
id: 'msg-asst-1',
role: 'assistant',
content: 'Hi back',
timestamp: '2026-01-01T00:00:01.000Z',
}

describe('messages-dual-write', () => {
beforeEach(() => {
vi.clearAllMocks()
resetDbChainMock()
})

describe('appendCopilotChatMessages', () => {
it('is a no-op on empty array', async () => {
await appendCopilotChatMessages('chat-1', [])
expect(dbChainMockFns.insert).not.toHaveBeenCalled()
})

it('inserts rows built from PersistedMessage shape', async () => {
await appendCopilotChatMessages('chat-1', [userMsg, assistantMsg])

expect(dbChainMockFns.insert).toHaveBeenCalledTimes(1)
expect(dbChainMockFns.values).toHaveBeenCalledTimes(1)
const rows = dbChainMockFns.values.mock.calls[0][0]
expect(rows).toHaveLength(2)

expect(rows[0]).toMatchObject({
chatId: 'chat-1',
messageId: 'msg-user-1',
role: 'user',
content: userMsg,
model: null,
streamId: null,
})
expect(rows[0].createdAt).toEqual(new Date(userMsg.timestamp))
expect(rows[0].updatedAt).toEqual(new Date(userMsg.timestamp))

expect(rows[1]).toMatchObject({
chatId: 'chat-1',
messageId: 'msg-asst-1',
role: 'assistant',
content: assistantMsg,
})
expect(rows[1].createdAt).toEqual(new Date(assistantMsg.timestamp))
})

it('preserves per-message ordering via timestamp', async () => {
await appendCopilotChatMessages('chat-1', [userMsg, assistantMsg])
const rows = dbChainMockFns.values.mock.calls[0][0]
expect(rows[0].createdAt.getTime()).toBeLessThan(rows[1].createdAt.getTime())
})

it('passes chatModel and streamId options to every row', async () => {
await appendCopilotChatMessages('chat-1', [userMsg, assistantMsg], {
chatModel: 'claude-sonnet-4-5',
streamId: 'stream-xyz',
})

const rows = dbChainMockFns.values.mock.calls[0][0]
expect(rows[0].model).toBe('claude-sonnet-4-5')
expect(rows[0].streamId).toBe('stream-xyz')
expect(rows[1].model).toBe('claude-sonnet-4-5')
expect(rows[1].streamId).toBe('stream-xyz')
})

it('uses ON CONFLICT DO UPDATE with chat_id + message_id target', async () => {
await appendCopilotChatMessages('chat-1', [userMsg])

expect(dbChainMockFns.onConflictDoUpdate).toHaveBeenCalledTimes(1)
const conflictArg = dbChainMockFns.onConflictDoUpdate.mock.calls[0][0]
expect(conflictArg.target).toHaveLength(2)
expect(conflictArg.set).toHaveProperty('content')
expect(conflictArg.set).toHaveProperty('role')
expect(conflictArg.set).toHaveProperty('model')
expect(conflictArg.set).toHaveProperty('streamId')
expect(conflictArg.set).toHaveProperty('updatedAt')
})

it('swallows DB errors so the legacy JSONB write stays canonical', async () => {
dbChainMockFns.onConflictDoUpdate.mockRejectedValueOnce(new Error('connection lost'))

await expect(appendCopilotChatMessages('chat-1', [userMsg])).resolves.toBeUndefined()
})
})

describe('replaceCopilotChatMessages', () => {
it('deletes all chat rows when given an empty snapshot', async () => {
await replaceCopilotChatMessages('chat-1', [])

expect(dbChainMockFns.transaction).toHaveBeenCalledTimes(1)
expect(dbChainMockFns.delete).toHaveBeenCalledTimes(1)
expect(dbChainMockFns.insert).not.toHaveBeenCalled()
})

it('deletes only rows whose message_id is not in the new snapshot, then upserts', async () => {
await replaceCopilotChatMessages('chat-1', [userMsg, assistantMsg])

expect(dbChainMockFns.delete).toHaveBeenCalledTimes(1)
expect(dbChainMockFns.insert).toHaveBeenCalledTimes(1)

const rows = dbChainMockFns.values.mock.calls[0][0]
expect(rows).toHaveLength(2)
expect(rows.map((r: { messageId: string }) => r.messageId)).toEqual([
'msg-user-1',
'msg-asst-1',
])

expect(dbChainMockFns.onConflictDoUpdate).toHaveBeenCalledTimes(1)
const conflictArg = dbChainMockFns.onConflictDoUpdate.mock.calls[0][0]
expect(conflictArg.set).toHaveProperty('streamId')
expect(conflictArg.set).toHaveProperty('model')
})

it('passes chatModel to every row in the snapshot', async () => {
await replaceCopilotChatMessages('chat-1', [userMsg], {
chatModel: 'gpt-4o-mini',
})

const rows = dbChainMockFns.values.mock.calls[0][0]
expect(rows[0].model).toBe('gpt-4o-mini')
})

it('swallows DB errors so the legacy JSONB write stays canonical', async () => {
dbChainMockFns.transaction.mockRejectedValueOnce(new Error('tx aborted'))

await expect(replaceCopilotChatMessages('chat-1', [userMsg])).resolves.toBeUndefined()
})
})
})
Loading