diff --git a/apps/sim/app/api/__test-utils__/utils.ts b/apps/sim/app/api/__test-utils__/utils.ts index f6e867ad78b..e2581ca469c 100644 --- a/apps/sim/app/api/__test-utils__/utils.ts +++ b/apps/sim/app/api/__test-utils__/utils.ts @@ -1,7 +1,6 @@ import { NextRequest } from 'next/server' import { vi } from 'vitest' -// Add type definitions for better type safety export interface MockUser { id: string email: string @@ -14,7 +13,6 @@ export interface MockAuthResult { mockUnauthenticated: () => void } -// Database result types export interface DatabaseSelectResult { id: string [key: string]: any @@ -234,7 +232,6 @@ export function createMockRequest( ): NextRequest { const url = 'http://localhost:3000/api/test' - // Use the URL constructor to create a proper URL object return new NextRequest(new URL(url), { method, headers: new Headers(headers), @@ -248,7 +245,6 @@ export function mockExecutionDependencies() { return { ...(actual as any), decryptSecret: vi.fn().mockImplementation((encrypted: string) => { - // Map from encrypted to decrypted const entries = Object.entries(mockEnvironmentVars) const found = entries.find(([_, val]) => val === encrypted) const key = found ? found[0] : null @@ -570,6 +566,7 @@ export function mockDrizzleOrm() { asc: vi.fn((field) => ({ field, type: 'asc' })), desc: vi.fn((field) => ({ field, type: 'desc' })), isNull: vi.fn((field) => ({ field, type: 'isNull' })), + count: vi.fn((field) => ({ field, type: 'count' })), sql: vi.fn((strings, ...values) => ({ type: 'sql', sql: strings, @@ -578,6 +575,57 @@ export function mockDrizzleOrm() { })) } +/** + * Mock knowledge-related database schemas + */ +export function mockKnowledgeSchemas() { + vi.doMock('@/db/schema', () => ({ + knowledgeBase: { + id: 'kb_id', + userId: 'user_id', + name: 'kb_name', + description: 'description', + tokenCount: 'token_count', + embeddingModel: 'embedding_model', + embeddingDimension: 'embedding_dimension', + chunkingConfig: 'chunking_config', + workspaceId: 'workspace_id', + createdAt: 'created_at', + updatedAt: 'updated_at', + deletedAt: 'deleted_at', + }, + document: { + id: 'doc_id', + knowledgeBaseId: 'kb_id', + filename: 'filename', + fileUrl: 'file_url', + fileSize: 'file_size', + mimeType: 'mime_type', + chunkCount: 'chunk_count', + tokenCount: 'token_count', + characterCount: 'character_count', + processingStatus: 'processing_status', + processingStartedAt: 'processing_started_at', + processingCompletedAt: 'processing_completed_at', + processingError: 'processing_error', + enabled: 'enabled', + uploadedAt: 'uploaded_at', + deletedAt: 'deleted_at', + }, + embedding: { + id: 'embedding_id', + documentId: 'doc_id', + knowledgeBaseId: 'kb_id', + chunkIndex: 'chunk_index', + content: 'content', + embedding: 'embedding', + tokenCount: 'token_count', + characterCount: 'character_count', + createdAt: 'created_at', + }, + })) +} + /** * Mock console logger */ diff --git a/apps/sim/app/api/knowledge/[id]/documents/[documentId]/chunks/[chunkId]/route.ts b/apps/sim/app/api/knowledge/[id]/documents/[documentId]/chunks/[chunkId]/route.ts index a5f81b793b3..ba011e5173f 100644 --- a/apps/sim/app/api/knowledge/[id]/documents/[documentId]/chunks/[chunkId]/route.ts +++ b/apps/sim/app/api/knowledge/[id]/documents/[documentId]/chunks/[chunkId]/route.ts @@ -1,3 +1,4 @@ +import crypto from 'node:crypto' import { eq, sql } from 'drizzle-orm' import { type NextRequest, NextResponse } from 'next/server' import { z } from 'zod' @@ -12,8 +13,6 @@ const logger = createLogger('ChunkByIdAPI') const UpdateChunkSchema = z.object({ content: z.string().min(1, 'Content is required').optional(), enabled: z.boolean().optional(), - searchRank: z.number().min(0).optional(), - qualityScore: z.number().min(0).max(1).optional(), }) export async function GET( @@ -103,21 +102,27 @@ export async function PUT( try { const validatedData = UpdateChunkSchema.parse(body) - const updateData: any = { - updatedAt: new Date(), - } + const updateData: Partial<{ + content: string + contentLength: number + tokenCount: number + chunkHash: string + enabled: boolean + updatedAt: Date + }> = {} - if (validatedData.content !== undefined) { + if (validatedData.content) { updateData.content = validatedData.content updateData.contentLength = validatedData.content.length // Update token count estimation (rough approximation: 4 chars per token) updateData.tokenCount = Math.ceil(validatedData.content.length / 4) + updateData.chunkHash = crypto + .createHash('sha256') + .update(validatedData.content) + .digest('hex') } + if (validatedData.enabled !== undefined) updateData.enabled = validatedData.enabled - if (validatedData.searchRank !== undefined) - updateData.searchRank = validatedData.searchRank.toString() - if (validatedData.qualityScore !== undefined) - updateData.qualityScore = validatedData.qualityScore.toString() await db.update(embedding).set(updateData).where(eq(embedding.id, chunkId)) diff --git a/apps/sim/app/api/knowledge/[id]/documents/[documentId]/chunks/route.ts b/apps/sim/app/api/knowledge/[id]/documents/[documentId]/chunks/route.ts index 715fbad6496..f35ec4d7b88 100644 --- a/apps/sim/app/api/knowledge/[id]/documents/[documentId]/chunks/route.ts +++ b/apps/sim/app/api/knowledge/[id]/documents/[documentId]/chunks/route.ts @@ -1,5 +1,5 @@ import crypto from 'crypto' -import { and, asc, eq, ilike, sql } from 'drizzle-orm' +import { and, asc, eq, ilike, inArray, sql } from 'drizzle-orm' import { type NextRequest, NextResponse } from 'next/server' import { z } from 'zod' import { getSession } from '@/lib/auth' @@ -11,7 +11,6 @@ import { checkDocumentAccess, generateEmbeddings } from '../../../../utils' const logger = createLogger('DocumentChunksAPI') -// Schema for query parameters const GetChunksQuerySchema = z.object({ search: z.string().optional(), enabled: z.enum(['true', 'false', 'all']).optional().default('all'), @@ -19,12 +18,19 @@ const GetChunksQuerySchema = z.object({ offset: z.coerce.number().min(0).optional().default(0), }) -// Schema for creating manual chunks const CreateChunkSchema = z.object({ content: z.string().min(1, 'Content is required').max(10000, 'Content too long'), enabled: z.boolean().optional().default(true), }) +const BatchOperationSchema = z.object({ + operation: z.enum(['enable', 'disable', 'delete']), + chunkIds: z + .array(z.string()) + .min(1, 'At least one chunk ID is required') + .max(100, 'Cannot operate on more than 100 chunks at once'), +}) + export async function GET( req: NextRequest, { params }: { params: Promise<{ id: string; documentId: string }> } @@ -112,10 +118,7 @@ export async function GET( enabled: embedding.enabled, startOffset: embedding.startOffset, endOffset: embedding.endOffset, - overlapTokens: embedding.overlapTokens, metadata: embedding.metadata, - searchRank: embedding.searchRank, - qualityScore: embedding.qualityScore, createdAt: embedding.createdAt, updatedAt: embedding.updatedAt, }) @@ -236,12 +239,7 @@ export async function POST( embeddingModel: 'text-embedding-3-small', startOffset: 0, // Manual chunks don't have document offsets endOffset: validatedData.content.length, - overlapTokens: 0, metadata: { manual: true }, // Mark as manually created - searchRank: '1.0', - accessCount: 0, - lastAccessedAt: null, - qualityScore: null, enabled: validatedData.enabled, createdAt: now, updatedAt: now, @@ -286,3 +284,144 @@ export async function POST( return NextResponse.json({ error: 'Failed to create chunk' }, { status: 500 }) } } + +export async function PATCH( + req: NextRequest, + { params }: { params: Promise<{ id: string; documentId: string }> } +) { + const requestId = crypto.randomUUID().slice(0, 8) + const { id: knowledgeBaseId, documentId } = await params + + try { + const session = await getSession() + if (!session?.user?.id) { + logger.warn(`[${requestId}] Unauthorized batch chunk operation attempt`) + return NextResponse.json({ error: 'Unauthorized' }, { status: 401 }) + } + + const accessCheck = await checkDocumentAccess(knowledgeBaseId, documentId, session.user.id) + + if (!accessCheck.hasAccess) { + if (accessCheck.notFound) { + logger.warn( + `[${requestId}] ${accessCheck.reason}: KB=${knowledgeBaseId}, Doc=${documentId}` + ) + return NextResponse.json({ error: accessCheck.reason }, { status: 404 }) + } + logger.warn( + `[${requestId}] User ${session.user.id} attempted unauthorized batch chunk operation: ${accessCheck.reason}` + ) + return NextResponse.json({ error: 'Unauthorized' }, { status: 401 }) + } + + const body = await req.json() + + try { + const validatedData = BatchOperationSchema.parse(body) + const { operation, chunkIds } = validatedData + + logger.info( + `[${requestId}] Starting batch ${operation} operation on ${chunkIds.length} chunks for document ${documentId}` + ) + + const results = [] + let successCount = 0 + const errorCount = 0 + + if (operation === 'delete') { + // Handle batch delete with transaction for consistency + await db.transaction(async (tx) => { + // Get chunks to delete for statistics update + const chunksToDelete = await tx + .select({ + id: embedding.id, + tokenCount: embedding.tokenCount, + contentLength: embedding.contentLength, + }) + .from(embedding) + .where(and(eq(embedding.documentId, documentId), inArray(embedding.id, chunkIds))) + + if (chunksToDelete.length === 0) { + throw new Error('No valid chunks found to delete') + } + + // Delete chunks + await tx + .delete(embedding) + .where(and(eq(embedding.documentId, documentId), inArray(embedding.id, chunkIds))) + + // Update document statistics + const totalTokens = chunksToDelete.reduce((sum, chunk) => sum + chunk.tokenCount, 0) + const totalCharacters = chunksToDelete.reduce( + (sum, chunk) => sum + chunk.contentLength, + 0 + ) + + await tx + .update(document) + .set({ + chunkCount: sql`${document.chunkCount} - ${chunksToDelete.length}`, + tokenCount: sql`${document.tokenCount} - ${totalTokens}`, + characterCount: sql`${document.characterCount} - ${totalCharacters}`, + }) + .where(eq(document.id, documentId)) + + successCount = chunksToDelete.length + results.push({ + operation: 'delete', + deletedCount: chunksToDelete.length, + chunkIds: chunksToDelete.map((c) => c.id), + }) + }) + } else { + // Handle batch enable/disable + const enabled = operation === 'enable' + + // Update chunks in a single query + const updateResult = await db + .update(embedding) + .set({ + enabled, + updatedAt: new Date(), + }) + .where(and(eq(embedding.documentId, documentId), inArray(embedding.id, chunkIds))) + .returning({ id: embedding.id }) + + successCount = updateResult.length + results.push({ + operation, + updatedCount: updateResult.length, + chunkIds: updateResult.map((r) => r.id), + }) + } + + logger.info( + `[${requestId}] Batch ${operation} operation completed: ${successCount} successful, ${errorCount} errors` + ) + + return NextResponse.json({ + success: true, + data: { + operation, + successCount, + errorCount, + results, + }, + }) + } catch (validationError) { + if (validationError instanceof z.ZodError) { + logger.warn(`[${requestId}] Invalid batch operation data`, { + errors: validationError.errors, + }) + return NextResponse.json( + { error: 'Invalid request data', details: validationError.errors }, + { status: 400 } + ) + } + throw validationError + } + } catch (error) { + logger.error(`[${requestId}] Error in batch chunk operation`, error) + return NextResponse.json({ error: 'Failed to perform batch operation' }, { status: 500 }) + } +} diff --git a/apps/sim/app/api/knowledge/[id]/documents/[documentId]/retry/route.ts b/apps/sim/app/api/knowledge/[id]/documents/[documentId]/retry/route.ts deleted file mode 100644 index df113acb26d..00000000000 --- a/apps/sim/app/api/knowledge/[id]/documents/[documentId]/retry/route.ts +++ /dev/null @@ -1,101 +0,0 @@ -import { eq } from 'drizzle-orm' -import { type NextRequest, NextResponse } from 'next/server' -import { getSession } from '@/lib/auth' -import { createLogger } from '@/lib/logs/console-logger' -import { db } from '@/db' -import { document, embedding } from '@/db/schema' -import { checkDocumentAccess, processDocumentAsync } from '../../../../utils' - -const logger = createLogger('DocumentRetryAPI') - -export async function POST( - req: NextRequest, - { params }: { params: Promise<{ id: string; documentId: string }> } -) { - const requestId = crypto.randomUUID().slice(0, 8) - const { id: knowledgeBaseId, documentId } = await params - - try { - const session = await getSession() - if (!session?.user?.id) { - logger.warn(`[${requestId}] Unauthorized document retry attempt`) - return NextResponse.json({ error: 'Unauthorized' }, { status: 401 }) - } - - const accessCheck = await checkDocumentAccess(knowledgeBaseId, documentId, session.user.id) - - if (!accessCheck.hasAccess) { - if (accessCheck.notFound) { - logger.warn( - `[${requestId}] ${accessCheck.reason}: KB=${knowledgeBaseId}, Doc=${documentId}` - ) - return NextResponse.json({ error: accessCheck.reason }, { status: 404 }) - } - logger.warn( - `[${requestId}] User ${session.user.id} attempted unauthorized document retry: ${accessCheck.reason}` - ) - return NextResponse.json({ error: 'Unauthorized' }, { status: 401 }) - } - - const doc = accessCheck.document - - if (doc.processingStatus !== 'failed') { - logger.warn( - `[${requestId}] Document ${documentId} is not in failed state (current: ${doc.processingStatus})` - ) - return NextResponse.json({ error: 'Document is not in failed state' }, { status: 400 }) - } - - await db.transaction(async (tx) => { - await tx.delete(embedding).where(eq(embedding.documentId, documentId)) - - await tx - .update(document) - .set({ - processingStatus: 'pending', - processingStartedAt: null, - processingCompletedAt: null, - processingError: null, - chunkCount: 0, - tokenCount: 0, - characterCount: 0, - }) - .where(eq(document.id, documentId)) - }) - - const processingOptions = { - chunkSize: 1024, - minCharactersPerChunk: 24, - recipe: 'default', - lang: 'en', - } - - const docData = { - filename: doc.filename, - fileUrl: doc.fileUrl, - fileSize: doc.fileSize, - mimeType: doc.mimeType, - fileHash: doc.fileHash, - } - - processDocumentAsync(knowledgeBaseId, documentId, docData, processingOptions).catch( - (error: unknown) => { - logger.error(`[${requestId}] Background retry processing error:`, error) - } - ) - - logger.info(`[${requestId}] Document retry initiated: ${documentId}`) - - return NextResponse.json({ - success: true, - data: { - documentId, - status: 'pending', - message: 'Document retry processing started', - }, - }) - } catch (error) { - logger.error(`[${requestId}] Error retrying document processing`, error) - return NextResponse.json({ error: 'Failed to retry document processing' }, { status: 500 }) - } -} diff --git a/apps/sim/app/api/knowledge/[id]/documents/[documentId]/route.test.ts b/apps/sim/app/api/knowledge/[id]/documents/[documentId]/route.test.ts new file mode 100644 index 00000000000..1d5c68170ae --- /dev/null +++ b/apps/sim/app/api/knowledge/[id]/documents/[documentId]/route.test.ts @@ -0,0 +1,550 @@ +/** + * Tests for document by ID API route + * + * @vitest-environment node + */ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { + createMockRequest, + mockAuth, + mockConsoleLogger, + mockDrizzleOrm, + mockKnowledgeSchemas, +} from '@/app/api/__test-utils__/utils' + +mockKnowledgeSchemas() + +vi.mock('../../../utils', () => ({ + checkDocumentAccess: vi.fn(), + processDocumentAsync: vi.fn(), +})) + +// Setup common mocks +mockDrizzleOrm() +mockConsoleLogger() + +describe('Document By ID API Route', () => { + const mockAuth$ = mockAuth() + + const mockDbChain = { + select: vi.fn().mockReturnThis(), + from: vi.fn().mockReturnThis(), + where: vi.fn().mockReturnThis(), + limit: vi.fn().mockReturnThis(), + update: vi.fn().mockReturnThis(), + set: vi.fn().mockReturnThis(), + delete: vi.fn().mockReturnThis(), + transaction: vi.fn(), + } + + const mockCheckDocumentAccess = vi.fn() + const mockProcessDocumentAsync = vi.fn() + + const mockDocument = { + id: 'doc-123', + knowledgeBaseId: 'kb-123', + filename: 'test-document.pdf', + fileUrl: 'https://example.com/test-document.pdf', + fileSize: 1024, + mimeType: 'application/pdf', + chunkCount: 5, + tokenCount: 100, + characterCount: 500, + processingStatus: 'completed', + processingStartedAt: new Date('2023-01-01T10:00:00Z'), + processingCompletedAt: new Date('2023-01-01T10:05:00Z'), + processingError: null, + enabled: true, + uploadedAt: new Date('2023-01-01T09:00:00Z'), + deletedAt: null, + } + + const resetMocks = () => { + vi.clearAllMocks() + Object.values(mockDbChain).forEach((fn) => { + if (typeof fn === 'function') { + fn.mockClear().mockReset() + if (fn !== mockDbChain.transaction) { + fn.mockReturnThis() + } + } + }) + mockCheckDocumentAccess.mockClear().mockReset() + mockProcessDocumentAsync.mockClear().mockReset() + } + + beforeEach(async () => { + resetMocks() + + vi.doMock('@/db', () => ({ + db: mockDbChain, + })) + + vi.doMock('../../../utils', () => ({ + checkDocumentAccess: mockCheckDocumentAccess, + processDocumentAsync: mockProcessDocumentAsync, + })) + + vi.stubGlobal('crypto', { + randomUUID: vi.fn().mockReturnValue('mock-uuid-1234-5678'), + }) + }) + + afterEach(() => { + vi.clearAllMocks() + }) + + describe('GET /api/knowledge/[id]/documents/[documentId]', () => { + const mockParams = Promise.resolve({ id: 'kb-123', documentId: 'doc-123' }) + + it('should retrieve document successfully for authenticated user', async () => { + mockAuth$.mockAuthenticatedUser() + mockCheckDocumentAccess.mockResolvedValue({ + hasAccess: true, + document: mockDocument, + }) + + const req = createMockRequest('GET') + const { GET } = await import('./route') + const response = await GET(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(200) + expect(data.success).toBe(true) + expect(data.data.id).toBe('doc-123') + expect(data.data.filename).toBe('test-document.pdf') + expect(mockCheckDocumentAccess).toHaveBeenCalledWith('kb-123', 'doc-123', 'user-123') + }) + + it('should return unauthorized for unauthenticated user', async () => { + mockAuth$.mockUnauthenticated() + + const req = createMockRequest('GET') + const { GET } = await import('./route') + const response = await GET(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(401) + expect(data.error).toBe('Unauthorized') + }) + + it('should return not found for non-existent document', async () => { + mockAuth$.mockAuthenticatedUser() + mockCheckDocumentAccess.mockResolvedValue({ + hasAccess: false, + notFound: true, + reason: 'Document not found', + }) + + const req = createMockRequest('GET') + const { GET } = await import('./route') + const response = await GET(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(404) + expect(data.error).toBe('Document not found') + }) + + it('should return unauthorized for document without access', async () => { + mockAuth$.mockAuthenticatedUser() + mockCheckDocumentAccess.mockResolvedValue({ + hasAccess: false, + reason: 'Access denied', + }) + + const req = createMockRequest('GET') + const { GET } = await import('./route') + const response = await GET(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(401) + expect(data.error).toBe('Unauthorized') + }) + + it('should handle database errors', async () => { + mockAuth$.mockAuthenticatedUser() + mockCheckDocumentAccess.mockRejectedValue(new Error('Database error')) + + const req = createMockRequest('GET') + const { GET } = await import('./route') + const response = await GET(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(500) + expect(data.error).toBe('Failed to fetch document') + }) + }) + + describe('PUT /api/knowledge/[id]/documents/[documentId] - Regular Updates', () => { + const mockParams = Promise.resolve({ id: 'kb-123', documentId: 'doc-123' }) + const validUpdateData = { + filename: 'updated-document.pdf', + enabled: false, + chunkCount: 10, + tokenCount: 200, + } + + it('should update document successfully', async () => { + mockAuth$.mockAuthenticatedUser() + mockCheckDocumentAccess.mockResolvedValue({ + hasAccess: true, + document: mockDocument, + }) + + // Create a sequence of mocks for the database operations + const updateChain = { + set: vi.fn().mockReturnValue({ + where: vi.fn().mockResolvedValue(undefined), // Update operation completes + }), + } + + const selectChain = { + from: vi.fn().mockReturnValue({ + where: vi.fn().mockReturnValue({ + limit: vi.fn().mockResolvedValue([{ ...mockDocument, ...validUpdateData }]), + }), + }), + } + + // Mock db operations in sequence + mockDbChain.update.mockReturnValue(updateChain) + mockDbChain.select.mockReturnValue(selectChain) + + const req = createMockRequest('PUT', validUpdateData) + const { PUT } = await import('./route') + const response = await PUT(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(200) + expect(data.success).toBe(true) + expect(data.data.filename).toBe('updated-document.pdf') + expect(data.data.enabled).toBe(false) + expect(mockDbChain.update).toHaveBeenCalled() + expect(mockDbChain.select).toHaveBeenCalled() + }) + + it('should validate update data', async () => { + mockAuth$.mockAuthenticatedUser() + mockCheckDocumentAccess.mockResolvedValue({ + hasAccess: true, + document: mockDocument, + }) + + const invalidData = { + filename: '', // Invalid: empty filename + chunkCount: -1, // Invalid: negative count + processingStatus: 'invalid', // Invalid: not in enum + } + + const req = createMockRequest('PUT', invalidData) + const { PUT } = await import('./route') + const response = await PUT(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(400) + expect(data.error).toBe('Invalid request data') + expect(data.details).toBeDefined() + }) + }) + + describe('PUT /api/knowledge/[id]/documents/[documentId] - Mark Failed Due to Timeout', () => { + const mockParams = Promise.resolve({ id: 'kb-123', documentId: 'doc-123' }) + + it('should mark document as failed due to timeout successfully', async () => { + const processingDocument = { + ...mockDocument, + processingStatus: 'processing', + processingStartedAt: new Date(Date.now() - 200000), // 200 seconds ago + } + + mockAuth$.mockAuthenticatedUser() + mockCheckDocumentAccess.mockResolvedValue({ + hasAccess: true, + document: processingDocument, + }) + + // Create a sequence of mocks for the database operations + const updateChain = { + set: vi.fn().mockReturnValue({ + where: vi.fn().mockResolvedValue(undefined), // Update operation completes + }), + } + + const selectChain = { + from: vi.fn().mockReturnValue({ + where: vi.fn().mockReturnValue({ + limit: vi + .fn() + .mockResolvedValue([{ ...processingDocument, processingStatus: 'failed' }]), + }), + }), + } + + // Mock db operations in sequence + mockDbChain.update.mockReturnValue(updateChain) + mockDbChain.select.mockReturnValue(selectChain) + + const req = createMockRequest('PUT', { markFailedDueToTimeout: true }) + const { PUT } = await import('./route') + const response = await PUT(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(200) + expect(data.success).toBe(true) + expect(mockDbChain.update).toHaveBeenCalled() + expect(updateChain.set).toHaveBeenCalledWith( + expect.objectContaining({ + processingStatus: 'failed', + processingError: 'Processing timed out - background process may have been terminated', + processingCompletedAt: expect.any(Date), + }) + ) + }) + + it('should reject marking failed for non-processing document', async () => { + mockAuth$.mockAuthenticatedUser() + mockCheckDocumentAccess.mockResolvedValue({ + hasAccess: true, + document: { ...mockDocument, processingStatus: 'completed' }, + }) + + const req = createMockRequest('PUT', { markFailedDueToTimeout: true }) + const { PUT } = await import('./route') + const response = await PUT(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(400) + expect(data.error).toContain('Document is not in processing state') + }) + + it('should reject marking failed for recently started processing', async () => { + const recentProcessingDocument = { + ...mockDocument, + processingStatus: 'processing', + processingStartedAt: new Date(Date.now() - 60000), // 60 seconds ago + } + + mockAuth$.mockAuthenticatedUser() + mockCheckDocumentAccess.mockResolvedValue({ + hasAccess: true, + document: recentProcessingDocument, + }) + + const req = createMockRequest('PUT', { markFailedDueToTimeout: true }) + const { PUT } = await import('./route') + const response = await PUT(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(400) + expect(data.error).toContain('Document has not been processing long enough') + }) + }) + + describe('PUT /api/knowledge/[id]/documents/[documentId] - Retry Processing', () => { + const mockParams = Promise.resolve({ id: 'kb-123', documentId: 'doc-123' }) + + it('should retry processing successfully', async () => { + const failedDocument = { + ...mockDocument, + processingStatus: 'failed', + processingError: 'Previous processing failed', + } + + mockAuth$.mockAuthenticatedUser() + mockCheckDocumentAccess.mockResolvedValue({ + hasAccess: true, + document: failedDocument, + }) + + // Mock transaction + mockDbChain.transaction.mockImplementation(async (callback) => { + const mockTx = { + delete: vi.fn().mockReturnValue({ + where: vi.fn().mockResolvedValue(undefined), + }), + update: vi.fn().mockReturnValue({ + set: vi.fn().mockReturnValue({ + where: vi.fn().mockResolvedValue(undefined), + }), + }), + } + return await callback(mockTx) + }) + + mockProcessDocumentAsync.mockResolvedValue(undefined) + + const req = createMockRequest('PUT', { retryProcessing: true }) + const { PUT } = await import('./route') + const response = await PUT(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(200) + expect(data.success).toBe(true) + expect(data.data.status).toBe('pending') + expect(data.data.message).toBe('Document retry processing started') + expect(mockDbChain.transaction).toHaveBeenCalled() + expect(mockProcessDocumentAsync).toHaveBeenCalled() + }) + + it('should reject retry for non-failed document', async () => { + mockAuth$.mockAuthenticatedUser() + mockCheckDocumentAccess.mockResolvedValue({ + hasAccess: true, + document: { ...mockDocument, processingStatus: 'completed' }, + }) + + const req = createMockRequest('PUT', { retryProcessing: true }) + const { PUT } = await import('./route') + const response = await PUT(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(400) + expect(data.error).toBe('Document is not in failed state') + }) + }) + + describe('PUT /api/knowledge/[id]/documents/[documentId] - Authentication & Authorization', () => { + const mockParams = Promise.resolve({ id: 'kb-123', documentId: 'doc-123' }) + const validUpdateData = { filename: 'updated-document.pdf' } + + it('should return unauthorized for unauthenticated user', async () => { + mockAuth$.mockUnauthenticated() + + const req = createMockRequest('PUT', validUpdateData) + const { PUT } = await import('./route') + const response = await PUT(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(401) + expect(data.error).toBe('Unauthorized') + }) + + it('should return not found for non-existent document', async () => { + mockAuth$.mockAuthenticatedUser() + mockCheckDocumentAccess.mockResolvedValue({ + hasAccess: false, + notFound: true, + reason: 'Document not found', + }) + + const req = createMockRequest('PUT', validUpdateData) + const { PUT } = await import('./route') + const response = await PUT(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(404) + expect(data.error).toBe('Document not found') + }) + + it('should handle database errors during update', async () => { + mockAuth$.mockAuthenticatedUser() + mockCheckDocumentAccess.mockResolvedValue({ + hasAccess: true, + document: mockDocument, + }) + mockDbChain.set.mockRejectedValue(new Error('Database error')) + + const req = createMockRequest('PUT', validUpdateData) + const { PUT } = await import('./route') + const response = await PUT(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(500) + expect(data.error).toBe('Failed to update document') + }) + }) + + describe('DELETE /api/knowledge/[id]/documents/[documentId]', () => { + const mockParams = Promise.resolve({ id: 'kb-123', documentId: 'doc-123' }) + + it('should delete document successfully', async () => { + mockAuth$.mockAuthenticatedUser() + mockCheckDocumentAccess.mockResolvedValue({ + hasAccess: true, + document: mockDocument, + }) + + // Properly chain the mock database operations for soft delete + mockDbChain.update.mockReturnValue(mockDbChain) + mockDbChain.set.mockReturnValue(mockDbChain) + mockDbChain.where.mockResolvedValue(undefined) // Update operation resolves + + const req = createMockRequest('DELETE') + const { DELETE } = await import('./route') + const response = await DELETE(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(200) + expect(data.success).toBe(true) + expect(data.data.message).toBe('Document deleted successfully') + expect(mockDbChain.update).toHaveBeenCalled() + expect(mockDbChain.set).toHaveBeenCalledWith( + expect.objectContaining({ + deletedAt: expect.any(Date), + }) + ) + }) + + it('should return unauthorized for unauthenticated user', async () => { + mockAuth$.mockUnauthenticated() + + const req = createMockRequest('DELETE') + const { DELETE } = await import('./route') + const response = await DELETE(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(401) + expect(data.error).toBe('Unauthorized') + }) + + it('should return not found for non-existent document', async () => { + mockAuth$.mockAuthenticatedUser() + mockCheckDocumentAccess.mockResolvedValue({ + hasAccess: false, + notFound: true, + reason: 'Document not found', + }) + + const req = createMockRequest('DELETE') + const { DELETE } = await import('./route') + const response = await DELETE(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(404) + expect(data.error).toBe('Document not found') + }) + + it('should return unauthorized for document without access', async () => { + mockAuth$.mockAuthenticatedUser() + mockCheckDocumentAccess.mockResolvedValue({ + hasAccess: false, + reason: 'Access denied', + }) + + const req = createMockRequest('DELETE') + const { DELETE } = await import('./route') + const response = await DELETE(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(401) + expect(data.error).toBe('Unauthorized') + }) + + it('should handle database errors during deletion', async () => { + mockAuth$.mockAuthenticatedUser() + mockCheckDocumentAccess.mockResolvedValue({ + hasAccess: true, + document: mockDocument, + }) + mockDbChain.set.mockRejectedValue(new Error('Database error')) + + const req = createMockRequest('DELETE') + const { DELETE } = await import('./route') + const response = await DELETE(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(500) + expect(data.error).toBe('Failed to delete document') + }) + }) +}) diff --git a/apps/sim/app/api/knowledge/[id]/documents/[documentId]/route.ts b/apps/sim/app/api/knowledge/[id]/documents/[documentId]/route.ts index a709edaafa7..1e466d68881 100644 --- a/apps/sim/app/api/knowledge/[id]/documents/[documentId]/route.ts +++ b/apps/sim/app/api/knowledge/[id]/documents/[documentId]/route.ts @@ -4,8 +4,8 @@ import { z } from 'zod' import { getSession } from '@/lib/auth' import { createLogger } from '@/lib/logs/console-logger' import { db } from '@/db' -import { document } from '@/db/schema' -import { checkDocumentAccess } from '../../../utils' +import { document, embedding } from '@/db/schema' +import { checkDocumentAccess, processDocumentAsync } from '../../../utils' const logger = createLogger('DocumentByIdAPI') @@ -15,6 +15,10 @@ const UpdateDocumentSchema = z.object({ chunkCount: z.number().min(0).optional(), tokenCount: z.number().min(0).optional(), characterCount: z.number().min(0).optional(), + processingStatus: z.enum(['pending', 'processing', 'completed', 'failed']).optional(), + processingError: z.string().optional(), + markFailedDueToTimeout: z.boolean().optional(), + retryProcessing: z.boolean().optional(), }) export async function GET( @@ -96,12 +100,113 @@ export async function PUT( const updateData: any = {} - if (validatedData.filename !== undefined) updateData.filename = validatedData.filename - if (validatedData.enabled !== undefined) updateData.enabled = validatedData.enabled - if (validatedData.chunkCount !== undefined) updateData.chunkCount = validatedData.chunkCount - if (validatedData.tokenCount !== undefined) updateData.tokenCount = validatedData.tokenCount - if (validatedData.characterCount !== undefined) - updateData.characterCount = validatedData.characterCount + // Handle special operations first + if (validatedData.markFailedDueToTimeout) { + // Mark document as failed due to timeout (replaces mark-failed endpoint) + const doc = accessCheck.document + + if (doc.processingStatus !== 'processing') { + return NextResponse.json( + { error: `Document is not in processing state (current: ${doc.processingStatus})` }, + { status: 400 } + ) + } + + if (!doc.processingStartedAt) { + return NextResponse.json( + { error: 'Document has no processing start time' }, + { status: 400 } + ) + } + + const now = new Date() + const processingDuration = now.getTime() - new Date(doc.processingStartedAt).getTime() + const DEAD_PROCESS_THRESHOLD_MS = 150 * 1000 + + if (processingDuration <= DEAD_PROCESS_THRESHOLD_MS) { + return NextResponse.json( + { error: 'Document has not been processing long enough to be considered dead' }, + { status: 400 } + ) + } + + updateData.processingStatus = 'failed' + updateData.processingError = + 'Processing timed out - background process may have been terminated' + updateData.processingCompletedAt = now + + logger.info( + `[${requestId}] Marked document ${documentId} as failed due to dead process (processing time: ${Math.round(processingDuration / 1000)}s)` + ) + } else if (validatedData.retryProcessing) { + // Retry processing (replaces retry endpoint) + const doc = accessCheck.document + + if (doc.processingStatus !== 'failed') { + return NextResponse.json({ error: 'Document is not in failed state' }, { status: 400 }) + } + + // Clear existing embeddings and reset document state + await db.transaction(async (tx) => { + await tx.delete(embedding).where(eq(embedding.documentId, documentId)) + + await tx + .update(document) + .set({ + processingStatus: 'pending', + processingStartedAt: null, + processingCompletedAt: null, + processingError: null, + chunkCount: 0, + tokenCount: 0, + characterCount: 0, + }) + .where(eq(document.id, documentId)) + }) + + const processingOptions = { + chunkSize: 1024, + minCharactersPerChunk: 24, + recipe: 'default', + lang: 'en', + } + + const docData = { + filename: doc.filename, + fileUrl: doc.fileUrl, + fileSize: doc.fileSize, + mimeType: doc.mimeType, + } + + processDocumentAsync(knowledgeBaseId, documentId, docData, processingOptions).catch( + (error: unknown) => { + logger.error(`[${requestId}] Background retry processing error:`, error) + } + ) + + logger.info(`[${requestId}] Document retry initiated: ${documentId}`) + + return NextResponse.json({ + success: true, + data: { + documentId, + status: 'pending', + message: 'Document retry processing started', + }, + }) + } else { + // Regular field updates + if (validatedData.filename !== undefined) updateData.filename = validatedData.filename + if (validatedData.enabled !== undefined) updateData.enabled = validatedData.enabled + if (validatedData.chunkCount !== undefined) updateData.chunkCount = validatedData.chunkCount + if (validatedData.tokenCount !== undefined) updateData.tokenCount = validatedData.tokenCount + if (validatedData.characterCount !== undefined) + updateData.characterCount = validatedData.characterCount + if (validatedData.processingStatus !== undefined) + updateData.processingStatus = validatedData.processingStatus + if (validatedData.processingError !== undefined) + updateData.processingError = validatedData.processingError + } await db.update(document).set(updateData).where(eq(document.id, documentId)) diff --git a/apps/sim/app/api/knowledge/[id]/documents/route.test.ts b/apps/sim/app/api/knowledge/[id]/documents/route.test.ts new file mode 100644 index 00000000000..9ba218cbe71 --- /dev/null +++ b/apps/sim/app/api/knowledge/[id]/documents/route.test.ts @@ -0,0 +1,424 @@ +/** + * Tests for knowledge base documents API route + * + * @vitest-environment node + */ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { + createMockRequest, + mockAuth, + mockConsoleLogger, + mockDrizzleOrm, + mockKnowledgeSchemas, +} from '@/app/api/__test-utils__/utils' + +mockKnowledgeSchemas() + +vi.mock('../../utils', () => ({ + checkKnowledgeBaseAccess: vi.fn(), + processDocumentAsync: vi.fn(), +})) + +mockDrizzleOrm() +mockConsoleLogger() + +describe('Knowledge Base Documents API Route', () => { + const mockAuth$ = mockAuth() + + const mockDbChain = { + select: vi.fn().mockReturnThis(), + from: vi.fn().mockReturnThis(), + where: vi.fn().mockReturnThis(), + orderBy: vi.fn().mockReturnThis(), + insert: vi.fn().mockReturnThis(), + values: vi.fn().mockReturnThis(), + update: vi.fn().mockReturnThis(), + set: vi.fn().mockReturnThis(), + transaction: vi.fn(), + } + + const mockCheckKnowledgeBaseAccess = vi.fn() + const mockProcessDocumentAsync = vi.fn() + + const mockDocument = { + id: 'doc-123', + knowledgeBaseId: 'kb-123', + filename: 'test-document.pdf', + fileUrl: 'https://example.com/test-document.pdf', + fileSize: 1024, + mimeType: 'application/pdf', + chunkCount: 5, + tokenCount: 100, + characterCount: 500, + processingStatus: 'completed', + processingStartedAt: new Date(), + processingCompletedAt: new Date(), + processingError: null, + enabled: true, + uploadedAt: new Date(), + } + + const resetMocks = () => { + vi.clearAllMocks() + Object.values(mockDbChain).forEach((fn) => { + if (typeof fn === 'function') { + fn.mockClear().mockReset() + if (fn !== mockDbChain.transaction) { + fn.mockReturnThis() + } + } + }) + mockCheckKnowledgeBaseAccess.mockClear().mockReset() + mockProcessDocumentAsync.mockClear().mockReset() + } + + beforeEach(async () => { + resetMocks() + + vi.doMock('@/db', () => ({ + db: mockDbChain, + })) + + vi.doMock('../../utils', () => ({ + checkKnowledgeBaseAccess: mockCheckKnowledgeBaseAccess, + processDocumentAsync: mockProcessDocumentAsync, + })) + + vi.stubGlobal('crypto', { + randomUUID: vi.fn().mockReturnValue('mock-uuid-1234-5678'), + }) + }) + + afterEach(() => { + vi.clearAllMocks() + }) + + describe('GET /api/knowledge/[id]/documents', () => { + const mockParams = Promise.resolve({ id: 'kb-123' }) + + it('should retrieve documents successfully for authenticated user', async () => { + mockAuth$.mockAuthenticatedUser() + mockCheckKnowledgeBaseAccess.mockResolvedValue({ hasAccess: true }) + mockDbChain.orderBy.mockResolvedValue([mockDocument]) + + const req = createMockRequest('GET') + const { GET } = await import('./route') + const response = await GET(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(200) + expect(data.success).toBe(true) + expect(data.data).toHaveLength(1) + expect(data.data[0].id).toBe('doc-123') + expect(mockDbChain.select).toHaveBeenCalled() + expect(mockCheckKnowledgeBaseAccess).toHaveBeenCalledWith('kb-123', 'user-123') + }) + + it('should filter disabled documents by default', async () => { + mockAuth$.mockAuthenticatedUser() + mockCheckKnowledgeBaseAccess.mockResolvedValue({ hasAccess: true }) + mockDbChain.orderBy.mockResolvedValue([mockDocument]) + + const req = createMockRequest('GET') + const { GET } = await import('./route') + const response = await GET(req, { params: mockParams }) + + expect(response.status).toBe(200) + expect(mockDbChain.where).toHaveBeenCalled() + }) + + it('should include disabled documents when requested', async () => { + mockAuth$.mockAuthenticatedUser() + mockCheckKnowledgeBaseAccess.mockResolvedValue({ hasAccess: true }) + mockDbChain.orderBy.mockResolvedValue([mockDocument]) + + const url = 'http://localhost:3000/api/knowledge/kb-123/documents?includeDisabled=true' + const req = new Request(url, { method: 'GET' }) as any + + const { GET } = await import('./route') + const response = await GET(req, { params: mockParams }) + + expect(response.status).toBe(200) + }) + + it('should return unauthorized for unauthenticated user', async () => { + mockAuth$.mockUnauthenticated() + + const req = createMockRequest('GET') + const { GET } = await import('./route') + const response = await GET(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(401) + expect(data.error).toBe('Unauthorized') + }) + + it('should return not found for non-existent knowledge base', async () => { + mockAuth$.mockAuthenticatedUser() + mockCheckKnowledgeBaseAccess.mockResolvedValue({ hasAccess: false, notFound: true }) + + const req = createMockRequest('GET') + const { GET } = await import('./route') + const response = await GET(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(404) + expect(data.error).toBe('Knowledge base not found') + }) + + it('should return unauthorized for knowledge base without access', async () => { + mockAuth$.mockAuthenticatedUser() + mockCheckKnowledgeBaseAccess.mockResolvedValue({ hasAccess: false }) + + const req = createMockRequest('GET') + const { GET } = await import('./route') + const response = await GET(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(401) + expect(data.error).toBe('Unauthorized') + }) + + it('should handle database errors', async () => { + mockAuth$.mockAuthenticatedUser() + mockCheckKnowledgeBaseAccess.mockResolvedValue({ hasAccess: true }) + mockDbChain.orderBy.mockRejectedValue(new Error('Database error')) + + const req = createMockRequest('GET') + const { GET } = await import('./route') + const response = await GET(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(500) + expect(data.error).toBe('Failed to fetch documents') + }) + }) + + describe('POST /api/knowledge/[id]/documents - Single Document', () => { + const mockParams = Promise.resolve({ id: 'kb-123' }) + const validDocumentData = { + filename: 'test-document.pdf', + fileUrl: 'https://example.com/test-document.pdf', + fileSize: 1024, + mimeType: 'application/pdf', + } + + it('should create single document successfully', async () => { + mockAuth$.mockAuthenticatedUser() + mockCheckKnowledgeBaseAccess.mockResolvedValue({ hasAccess: true }) + mockDbChain.values.mockResolvedValue(undefined) + + const req = createMockRequest('POST', validDocumentData) + const { POST } = await import('./route') + const response = await POST(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(200) + expect(data.success).toBe(true) + expect(data.data.filename).toBe(validDocumentData.filename) + expect(data.data.fileUrl).toBe(validDocumentData.fileUrl) + expect(mockDbChain.insert).toHaveBeenCalled() + }) + + it('should validate single document data', async () => { + mockAuth$.mockAuthenticatedUser() + mockCheckKnowledgeBaseAccess.mockResolvedValue({ hasAccess: true }) + + const invalidData = { + filename: '', // Invalid: empty filename + fileUrl: 'invalid-url', // Invalid: not a valid URL + fileSize: 0, // Invalid: size must be > 0 + mimeType: '', // Invalid: empty mime type + } + + const req = createMockRequest('POST', invalidData) + const { POST } = await import('./route') + const response = await POST(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(400) + expect(data.error).toBe('Invalid request data') + expect(data.details).toBeDefined() + }) + }) + + describe('POST /api/knowledge/[id]/documents - Bulk Documents', () => { + const mockParams = Promise.resolve({ id: 'kb-123' }) + const validBulkData = { + bulk: true, + documents: [ + { + filename: 'doc1.pdf', + fileUrl: 'https://example.com/doc1.pdf', + fileSize: 1024, + mimeType: 'application/pdf', + }, + { + filename: 'doc2.pdf', + fileUrl: 'https://example.com/doc2.pdf', + fileSize: 2048, + mimeType: 'application/pdf', + }, + ], + processingOptions: { + chunkSize: 1024, + minCharactersPerChunk: 100, + recipe: 'default', + lang: 'en', + chunkOverlap: 200, + }, + } + + it('should create bulk documents successfully', async () => { + mockAuth$.mockAuthenticatedUser() + mockCheckKnowledgeBaseAccess.mockResolvedValue({ hasAccess: true }) + + // Mock transaction to return the created documents + mockDbChain.transaction.mockImplementation(async (callback) => { + const mockTx = { + insert: vi.fn().mockReturnValue({ + values: vi.fn().mockResolvedValue(undefined), + }), + } + return await callback(mockTx) + }) + + mockProcessDocumentAsync.mockResolvedValue(undefined) + + const req = createMockRequest('POST', validBulkData) + const { POST } = await import('./route') + const response = await POST(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(200) + expect(data.success).toBe(true) + expect(data.data.total).toBe(2) + expect(data.data.documentsCreated).toHaveLength(2) + expect(data.data.processingMethod).toBe('background') + expect(mockDbChain.transaction).toHaveBeenCalled() + }) + + it('should validate bulk document data', async () => { + mockAuth$.mockAuthenticatedUser() + mockCheckKnowledgeBaseAccess.mockResolvedValue({ hasAccess: true }) + + const invalidBulkData = { + bulk: true, + documents: [ + { + filename: '', // Invalid: empty filename + fileUrl: 'invalid-url', + fileSize: 0, + mimeType: '', + }, + ], + processingOptions: { + chunkSize: 50, // Invalid: too small + minCharactersPerChunk: 10, // Invalid: too small + recipe: 'default', + lang: 'en', + chunkOverlap: 1000, // Invalid: too large + }, + } + + const req = createMockRequest('POST', invalidBulkData) + const { POST } = await import('./route') + const response = await POST(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(400) + expect(data.error).toBe('Invalid request data') + expect(data.details).toBeDefined() + }) + + it('should handle processing errors gracefully', async () => { + mockAuth$.mockAuthenticatedUser() + mockCheckKnowledgeBaseAccess.mockResolvedValue({ hasAccess: true }) + + // Mock transaction to succeed but processing to fail + mockDbChain.transaction.mockImplementation(async (callback) => { + const mockTx = { + insert: vi.fn().mockReturnValue({ + values: vi.fn().mockResolvedValue(undefined), + }), + } + return await callback(mockTx) + }) + + // Don't reject the promise - the processing is async and catches errors internally + mockProcessDocumentAsync.mockResolvedValue(undefined) + + const req = createMockRequest('POST', validBulkData) + const { POST } = await import('./route') + const response = await POST(req, { params: mockParams }) + const data = await response.json() + + // The endpoint should still return success since documents are created + // and processing happens asynchronously + expect(response.status).toBe(200) + expect(data.success).toBe(true) + }) + }) + + describe('POST /api/knowledge/[id]/documents - Authentication & Authorization', () => { + const mockParams = Promise.resolve({ id: 'kb-123' }) + const validDocumentData = { + filename: 'test-document.pdf', + fileUrl: 'https://example.com/test-document.pdf', + fileSize: 1024, + mimeType: 'application/pdf', + } + + it('should return unauthorized for unauthenticated user', async () => { + mockAuth$.mockUnauthenticated() + + const req = createMockRequest('POST', validDocumentData) + const { POST } = await import('./route') + const response = await POST(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(401) + expect(data.error).toBe('Unauthorized') + }) + + it('should return not found for non-existent knowledge base', async () => { + mockAuth$.mockAuthenticatedUser() + mockCheckKnowledgeBaseAccess.mockResolvedValue({ hasAccess: false, notFound: true }) + + const req = createMockRequest('POST', validDocumentData) + const { POST } = await import('./route') + const response = await POST(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(404) + expect(data.error).toBe('Knowledge base not found') + }) + + it('should return unauthorized for knowledge base without access', async () => { + mockAuth$.mockAuthenticatedUser() + mockCheckKnowledgeBaseAccess.mockResolvedValue({ hasAccess: false }) + + const req = createMockRequest('POST', validDocumentData) + const { POST } = await import('./route') + const response = await POST(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(401) + expect(data.error).toBe('Unauthorized') + }) + + it('should handle database errors during creation', async () => { + mockAuth$.mockAuthenticatedUser() + mockCheckKnowledgeBaseAccess.mockResolvedValue({ hasAccess: true }) + mockDbChain.values.mockRejectedValue(new Error('Database error')) + + const req = createMockRequest('POST', validDocumentData) + const { POST } = await import('./route') + const response = await POST(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(500) + expect(data.error).toBe('Failed to create document') + }) + }) +}) diff --git a/apps/sim/app/api/knowledge/[id]/documents/route.ts b/apps/sim/app/api/knowledge/[id]/documents/route.ts index 75eee293a73..c7568f29df3 100644 --- a/apps/sim/app/api/knowledge/[id]/documents/route.ts +++ b/apps/sim/app/api/knowledge/[id]/documents/route.ts @@ -1,20 +1,178 @@ -import { and, eq, isNull } from 'drizzle-orm' +import crypto from 'node:crypto' +import { and, desc, eq, inArray, isNull } from 'drizzle-orm' import { type NextRequest, NextResponse } from 'next/server' import { z } from 'zod' import { getSession } from '@/lib/auth' import { createLogger } from '@/lib/logs/console-logger' import { db } from '@/db' import { document } from '@/db/schema' -import { checkKnowledgeBaseAccess } from '../../utils' +import { checkKnowledgeBaseAccess, processDocumentAsync } from '../../utils' const logger = createLogger('DocumentsAPI') +const PROCESSING_CONFIG = { + maxConcurrentDocuments: 3, + batchSize: 5, + delayBetweenBatches: 1000, + delayBetweenDocuments: 500, +} + +async function processDocumentsWithConcurrencyControl( + createdDocuments: Array<{ + documentId: string + filename: string + fileUrl: string + fileSize: number + mimeType: string + }>, + knowledgeBaseId: string, + processingOptions: { + chunkSize: number + minCharactersPerChunk: number + recipe: string + lang: string + chunkOverlap: number + }, + requestId: string +): Promise { + const totalDocuments = createdDocuments.length + const batches = [] + + for (let i = 0; i < totalDocuments; i += PROCESSING_CONFIG.batchSize) { + batches.push(createdDocuments.slice(i, i + PROCESSING_CONFIG.batchSize)) + } + + logger.info(`[${requestId}] Processing ${totalDocuments} documents in ${batches.length} batches`) + + for (const [batchIndex, batch] of batches.entries()) { + logger.info( + `[${requestId}] Starting batch ${batchIndex + 1}/${batches.length} with ${batch.length} documents` + ) + + await processBatchWithConcurrency(batch, knowledgeBaseId, processingOptions, requestId) + + if (batchIndex < batches.length - 1) { + await new Promise((resolve) => setTimeout(resolve, PROCESSING_CONFIG.delayBetweenBatches)) + } + } + + logger.info(`[${requestId}] Completed processing initiation for all ${totalDocuments} documents`) +} + +async function processBatchWithConcurrency( + batch: Array<{ + documentId: string + filename: string + fileUrl: string + fileSize: number + mimeType: string + }>, + knowledgeBaseId: string, + processingOptions: { + chunkSize: number + minCharactersPerChunk: number + recipe: string + lang: string + chunkOverlap: number + }, + requestId: string +): Promise { + const semaphore = new Array(PROCESSING_CONFIG.maxConcurrentDocuments).fill(0) + const processingPromises = batch.map(async (doc, index) => { + if (index > 0) { + await new Promise((resolve) => + setTimeout(resolve, index * PROCESSING_CONFIG.delayBetweenDocuments) + ) + } + + await new Promise((resolve) => { + const checkSlot = () => { + const availableIndex = semaphore.findIndex((slot) => slot === 0) + if (availableIndex !== -1) { + semaphore[availableIndex] = 1 + resolve() + } else { + setTimeout(checkSlot, 100) + } + } + checkSlot() + }) + + try { + logger.info(`[${requestId}] Starting processing for document: ${doc.filename}`) + + await processDocumentAsync( + knowledgeBaseId, + doc.documentId, + { + filename: doc.filename, + fileUrl: doc.fileUrl, + fileSize: doc.fileSize, + mimeType: doc.mimeType, + }, + processingOptions + ) + + logger.info(`[${requestId}] Successfully initiated processing for document: ${doc.filename}`) + } catch (error: unknown) { + logger.error(`[${requestId}] Failed to process document: ${doc.filename}`, { + documentId: doc.documentId, + filename: doc.filename, + error: error instanceof Error ? error.message : 'Unknown error', + }) + + try { + await db + .update(document) + .set({ + processingStatus: 'failed', + processingError: + error instanceof Error ? error.message : 'Failed to initiate processing', + processingCompletedAt: new Date(), + }) + .where(eq(document.id, doc.documentId)) + } catch (dbError: unknown) { + logger.error( + `[${requestId}] Failed to update document status for failed document: ${doc.documentId}`, + dbError + ) + } + } finally { + const slotIndex = semaphore.findIndex((slot) => slot === 1) + if (slotIndex !== -1) { + semaphore[slotIndex] = 0 + } + } + }) + + await Promise.allSettled(processingPromises) +} + const CreateDocumentSchema = z.object({ filename: z.string().min(1, 'Filename is required'), fileUrl: z.string().url('File URL must be valid'), fileSize: z.number().min(1, 'File size must be greater than 0'), mimeType: z.string().min(1, 'MIME type is required'), - fileHash: z.string().optional(), +}) + +const BulkCreateDocumentsSchema = z.object({ + documents: z.array(CreateDocumentSchema), + processingOptions: z.object({ + chunkSize: z.number().min(100).max(4000), + minCharactersPerChunk: z.number().min(50).max(2000), + recipe: z.string(), + lang: z.string(), + chunkOverlap: z.number().min(0).max(500), + }), + bulk: z.literal(true), +}) + +const BulkUpdateDocumentsSchema = z.object({ + operation: z.enum(['enable', 'disable', 'delete']), + documentIds: z + .array(z.string()) + .min(1, 'At least one document ID is required') + .max(100, 'Cannot operate on more than 100 documents at once'), }) export async function GET(req: NextRequest, { params }: { params: Promise<{ id: string }> }) { @@ -58,12 +216,10 @@ export async function GET(req: NextRequest, { params }: { params: Promise<{ id: const documents = await db .select({ id: document.id, - knowledgeBaseId: document.knowledgeBaseId, filename: document.filename, fileUrl: document.fileUrl, fileSize: document.fileSize, mimeType: document.mimeType, - fileHash: document.fileHash, chunkCount: document.chunkCount, tokenCount: document.tokenCount, characterCount: document.characterCount, @@ -76,7 +232,7 @@ export async function GET(req: NextRequest, { params }: { params: Promise<{ id: }) .from(document) .where(and(...whereConditions)) - .orderBy(document.uploadedAt) + .orderBy(desc(document.uploadedAt)) logger.info( `[${requestId}] Retrieved ${documents.length} documents for knowledge base ${knowledgeBaseId}` @@ -118,63 +274,251 @@ export async function POST(req: NextRequest, { params }: { params: Promise<{ id: const body = await req.json() - try { - const validatedData = CreateDocumentSchema.parse(body) + // Check if this is a bulk operation + if (body.bulk === true) { + // Handle bulk processing (replaces process-documents endpoint) + try { + const validatedData = BulkCreateDocumentsSchema.parse(body) - // Check for duplicate file hash if provided - if (validatedData.fileHash) { - const existingDocument = await db - .select({ id: document.id }) - .from(document) - .where( - and( - eq(document.knowledgeBaseId, knowledgeBaseId), - eq(document.fileHash, validatedData.fileHash), - isNull(document.deletedAt) + const createdDocuments = await db.transaction(async (tx) => { + const documentPromises = validatedData.documents.map(async (docData) => { + const documentId = crypto.randomUUID() + const now = new Date() + + const newDocument = { + id: documentId, + knowledgeBaseId, + filename: docData.filename, + fileUrl: docData.fileUrl, + fileSize: docData.fileSize, + mimeType: docData.mimeType, + chunkCount: 0, + tokenCount: 0, + characterCount: 0, + processingStatus: 'pending' as const, + enabled: true, + uploadedAt: now, + } + + await tx.insert(document).values(newDocument) + logger.info( + `[${requestId}] Document record created: ${documentId} for file: ${docData.filename}` ) + return { documentId, ...docData } + }) + + return await Promise.all(documentPromises) + }) + + logger.info( + `[${requestId}] Starting controlled async processing of ${createdDocuments.length} documents` + ) + + processDocumentsWithConcurrencyControl( + createdDocuments, + knowledgeBaseId, + validatedData.processingOptions, + requestId + ).catch((error: unknown) => { + logger.error(`[${requestId}] Critical error in document processing pipeline:`, error) + }) + + return NextResponse.json({ + success: true, + data: { + total: createdDocuments.length, + documentsCreated: createdDocuments.map((doc) => ({ + documentId: doc.documentId, + filename: doc.filename, + status: 'pending', + })), + processingMethod: 'background', + processingConfig: { + maxConcurrentDocuments: PROCESSING_CONFIG.maxConcurrentDocuments, + batchSize: PROCESSING_CONFIG.batchSize, + totalBatches: Math.ceil(createdDocuments.length / PROCESSING_CONFIG.batchSize), + }, + }, + }) + } catch (validationError) { + if (validationError instanceof z.ZodError) { + logger.warn(`[${requestId}] Invalid bulk processing request data`, { + errors: validationError.errors, + }) + return NextResponse.json( + { error: 'Invalid request data', details: validationError.errors }, + { status: 400 } ) - .limit(1) + } + throw validationError + } + } else { + // Handle single document creation + try { + const validatedData = CreateDocumentSchema.parse(body) - if (existingDocument.length > 0) { - logger.warn(`[${requestId}] Duplicate file hash detected: ${validatedData.fileHash}`) + const documentId = crypto.randomUUID() + const now = new Date() + + const newDocument = { + id: documentId, + knowledgeBaseId, + filename: validatedData.filename, + fileUrl: validatedData.fileUrl, + fileSize: validatedData.fileSize, + mimeType: validatedData.mimeType, + chunkCount: 0, + tokenCount: 0, + characterCount: 0, + enabled: true, + uploadedAt: now, + } + + await db.insert(document).values(newDocument) + + logger.info( + `[${requestId}] Document created: ${documentId} in knowledge base ${knowledgeBaseId}` + ) + + return NextResponse.json({ + success: true, + data: newDocument, + }) + } catch (validationError) { + if (validationError instanceof z.ZodError) { + logger.warn(`[${requestId}] Invalid document data`, { + errors: validationError.errors, + }) return NextResponse.json( - { error: 'Document with this file hash already exists' }, - { status: 409 } + { error: 'Invalid request data', details: validationError.errors }, + { status: 400 } ) } + throw validationError } + } + } catch (error) { + logger.error(`[${requestId}] Error creating document`, error) + return NextResponse.json({ error: 'Failed to create document' }, { status: 500 }) + } +} - const documentId = crypto.randomUUID() - const now = new Date() +export async function PATCH(req: NextRequest, { params }: { params: Promise<{ id: string }> }) { + const requestId = crypto.randomUUID().slice(0, 8) + const { id: knowledgeBaseId } = await params - const newDocument = { - id: documentId, - knowledgeBaseId, - filename: validatedData.filename, - fileUrl: validatedData.fileUrl, - fileSize: validatedData.fileSize, - mimeType: validatedData.mimeType, - fileHash: validatedData.fileHash || null, - chunkCount: 0, - tokenCount: 0, - characterCount: 0, - enabled: true, - uploadedAt: now, + try { + const session = await getSession() + if (!session?.user?.id) { + logger.warn(`[${requestId}] Unauthorized bulk document operation attempt`) + return NextResponse.json({ error: 'Unauthorized' }, { status: 401 }) + } + + const accessCheck = await checkKnowledgeBaseAccess(knowledgeBaseId, session.user.id) + + if (!accessCheck.hasAccess) { + if ('notFound' in accessCheck && accessCheck.notFound) { + logger.warn(`[${requestId}] Knowledge base not found: ${knowledgeBaseId}`) + return NextResponse.json({ error: 'Knowledge base not found' }, { status: 404 }) + } + logger.warn( + `[${requestId}] User ${session.user.id} attempted to perform bulk operation on unauthorized knowledge base ${knowledgeBaseId}` + ) + return NextResponse.json({ error: 'Unauthorized' }, { status: 401 }) + } + + const body = await req.json() + + try { + const validatedData = BulkUpdateDocumentsSchema.parse(body) + const { operation, documentIds } = validatedData + + logger.info( + `[${requestId}] Starting bulk ${operation} operation on ${documentIds.length} documents in knowledge base ${knowledgeBaseId}` + ) + + // Verify all documents belong to this knowledge base and user has access + const documentsToUpdate = await db + .select({ + id: document.id, + enabled: document.enabled, + }) + .from(document) + .where( + and( + eq(document.knowledgeBaseId, knowledgeBaseId), + inArray(document.id, documentIds), + isNull(document.deletedAt) + ) + ) + + if (documentsToUpdate.length === 0) { + return NextResponse.json({ error: 'No valid documents found to update' }, { status: 404 }) + } + + if (documentsToUpdate.length !== documentIds.length) { + logger.warn( + `[${requestId}] Some documents not found or don't belong to knowledge base. Requested: ${documentIds.length}, Found: ${documentsToUpdate.length}` + ) } - await db.insert(document).values(newDocument) + // Perform the bulk operation + let updateResult: Array<{ id: string; enabled?: boolean; deletedAt?: Date | null }> + let successCount: number + + if (operation === 'delete') { + // Handle bulk soft delete + updateResult = await db + .update(document) + .set({ + deletedAt: new Date(), + }) + .where( + and( + eq(document.knowledgeBaseId, knowledgeBaseId), + inArray(document.id, documentIds), + isNull(document.deletedAt) + ) + ) + .returning({ id: document.id, deletedAt: document.deletedAt }) + + successCount = updateResult.length + } else { + // Handle bulk enable/disable + const enabled = operation === 'enable' + + updateResult = await db + .update(document) + .set({ + enabled, + }) + .where( + and( + eq(document.knowledgeBaseId, knowledgeBaseId), + inArray(document.id, documentIds), + isNull(document.deletedAt) + ) + ) + .returning({ id: document.id, enabled: document.enabled }) + + successCount = updateResult.length + } logger.info( - `[${requestId}] Document created: ${documentId} in knowledge base ${knowledgeBaseId}` + `[${requestId}] Bulk ${operation} operation completed: ${successCount} documents updated in knowledge base ${knowledgeBaseId}` ) return NextResponse.json({ success: true, - data: newDocument, + data: { + operation, + successCount, + updatedDocuments: updateResult, + }, }) } catch (validationError) { if (validationError instanceof z.ZodError) { - logger.warn(`[${requestId}] Invalid document data`, { + logger.warn(`[${requestId}] Invalid bulk operation data`, { errors: validationError.errors, }) return NextResponse.json( @@ -185,7 +529,7 @@ export async function POST(req: NextRequest, { params }: { params: Promise<{ id: throw validationError } } catch (error) { - logger.error(`[${requestId}] Error creating document`, error) - return NextResponse.json({ error: 'Failed to create document' }, { status: 500 }) + logger.error(`[${requestId}] Error in bulk document operation`, error) + return NextResponse.json({ error: 'Failed to perform bulk operation' }, { status: 500 }) } } diff --git a/apps/sim/app/api/knowledge/[id]/process-documents/route.ts b/apps/sim/app/api/knowledge/[id]/process-documents/route.ts deleted file mode 100644 index 2d636b0c569..00000000000 --- a/apps/sim/app/api/knowledge/[id]/process-documents/route.ts +++ /dev/null @@ -1,299 +0,0 @@ -import { eq } from 'drizzle-orm' -import { type NextRequest, NextResponse } from 'next/server' -import { z } from 'zod' -import { getSession } from '@/lib/auth' -import { createLogger } from '@/lib/logs/console-logger' -import { db } from '@/db' -import { document } from '@/db/schema' -import { checkKnowledgeBaseAccess, processDocumentAsync } from '../../utils' - -const logger = createLogger('ProcessDocumentsAPI') - -const ProcessDocumentsSchema = z.object({ - documents: z.array( - z.object({ - filename: z.string().min(1, 'Filename is required'), - fileUrl: z.string().url('File URL must be valid'), - fileSize: z.number().min(1, 'File size must be greater than 0'), - mimeType: z.string().min(1, 'MIME type is required'), - fileHash: z.string().optional(), - }) - ), - processingOptions: z.object({ - chunkSize: z.number(), - minCharactersPerChunk: z.number(), - recipe: z.string(), - lang: z.string(), - chunkOverlap: z.number().optional(), - }), -}) - -const PROCESSING_CONFIG = { - maxConcurrentDocuments: 3, // Limit concurrent processing to prevent resource exhaustion - batchSize: 5, // Process documents in batches - delayBetweenBatches: 1000, // 1 second delay between batches - delayBetweenDocuments: 500, // 500ms delay between individual documents in a batch -} - -/** - * Process documents with concurrency control and batching - */ -async function processDocumentsWithConcurrencyControl( - createdDocuments: Array<{ - documentId: string - filename: string - fileUrl: string - fileSize: number - mimeType: string - fileHash?: string - }>, - knowledgeBaseId: string, - processingOptions: any, - requestId: string -): Promise { - const totalDocuments = createdDocuments.length - const batches = [] - - // Create batches - for (let i = 0; i < totalDocuments; i += PROCESSING_CONFIG.batchSize) { - batches.push(createdDocuments.slice(i, i + PROCESSING_CONFIG.batchSize)) - } - - logger.info(`[${requestId}] Processing ${totalDocuments} documents in ${batches.length} batches`) - - for (const [batchIndex, batch] of batches.entries()) { - logger.info( - `[${requestId}] Starting batch ${batchIndex + 1}/${batches.length} with ${batch.length} documents` - ) - - // Process batch with limited concurrency - await processBatchWithConcurrency(batch, knowledgeBaseId, processingOptions, requestId) - - // Add delay between batches (except for the last batch) - if (batchIndex < batches.length - 1) { - await new Promise((resolve) => setTimeout(resolve, PROCESSING_CONFIG.delayBetweenBatches)) - } - } - - logger.info(`[${requestId}] Completed processing initiation for all ${totalDocuments} documents`) -} - -/** - * Process a batch of documents with controlled concurrency - */ -async function processBatchWithConcurrency( - batch: Array<{ - documentId: string - filename: string - fileUrl: string - fileSize: number - mimeType: string - fileHash?: string - }>, - knowledgeBaseId: string, - processingOptions: any, - requestId: string -): Promise { - const semaphore = new Array(PROCESSING_CONFIG.maxConcurrentDocuments).fill(0) - const processingPromises = batch.map(async (doc, index) => { - // Add staggered delay to prevent overwhelming the system - if (index > 0) { - await new Promise((resolve) => - setTimeout(resolve, index * PROCESSING_CONFIG.delayBetweenDocuments) - ) - } - - // Wait for available slot - await new Promise((resolve) => { - const checkSlot = () => { - const availableIndex = semaphore.findIndex((slot) => slot === 0) - if (availableIndex !== -1) { - semaphore[availableIndex] = 1 - resolve() - } else { - setTimeout(checkSlot, 100) - } - } - checkSlot() - }) - - try { - logger.info(`[${requestId}] Starting processing for document: ${doc.filename}`) - - await processDocumentAsync( - knowledgeBaseId, - doc.documentId, - { - filename: doc.filename, - fileUrl: doc.fileUrl, - fileSize: doc.fileSize, - mimeType: doc.mimeType, - fileHash: doc.fileHash, - }, - processingOptions - ) - - logger.info(`[${requestId}] Successfully initiated processing for document: ${doc.filename}`) - } catch (error: unknown) { - logger.error(`[${requestId}] Failed to process document: ${doc.filename}`, { - documentId: doc.documentId, - filename: doc.filename, - fileSize: doc.fileSize, - mimeType: doc.mimeType, - error: error instanceof Error ? error.message : 'Unknown error', - stack: error instanceof Error ? error.stack : undefined, - }) - - try { - await db - .update(document) - .set({ - processingStatus: 'failed', - processingError: - error instanceof Error ? error.message : 'Failed to initiate processing', - processingCompletedAt: new Date(), - }) - .where(eq(document.id, doc.documentId)) - } catch (dbError: unknown) { - logger.error( - `[${requestId}] Failed to update document status for failed document: ${doc.documentId}`, - dbError - ) - } - } finally { - const slotIndex = semaphore.findIndex((slot) => slot === 1) - if (slotIndex !== -1) { - semaphore[slotIndex] = 0 - } - } - }) - - await Promise.allSettled(processingPromises) -} - -export async function POST(req: NextRequest, { params }: { params: Promise<{ id: string }> }) { - const requestId = crypto.randomUUID().slice(0, 8) - const { id: knowledgeBaseId } = await params - - try { - const session = await getSession() - if (!session?.user?.id) { - logger.warn(`[${requestId}] Unauthorized document processing attempt`) - return NextResponse.json({ error: 'Unauthorized' }, { status: 401 }) - } - - const accessCheck = await checkKnowledgeBaseAccess(knowledgeBaseId, session.user.id) - - if (!accessCheck.hasAccess) { - if ('notFound' in accessCheck && accessCheck.notFound) { - logger.warn(`[${requestId}] Knowledge base not found: ${knowledgeBaseId}`) - return NextResponse.json({ error: 'Knowledge base not found' }, { status: 404 }) - } - logger.warn( - `[${requestId}] User ${session.user.id} attempted to process documents in unauthorized knowledge base ${knowledgeBaseId}` - ) - return NextResponse.json({ error: 'Unauthorized' }, { status: 401 }) - } - - const body = await req.json() - - try { - const validatedData = ProcessDocumentsSchema.parse(body) - - const createdDocuments = await db.transaction(async (tx) => { - const documentPromises = validatedData.documents.map(async (docData, index) => { - const documentId = crypto.randomUUID() - const now = new Date() - - const newDocument = { - id: documentId, - knowledgeBaseId, - filename: docData.filename, - fileUrl: docData.fileUrl, - fileSize: docData.fileSize, - mimeType: docData.mimeType, - fileHash: docData.fileHash || null, - chunkCount: 0, - tokenCount: 0, - characterCount: 0, - processingStatus: 'pending' as const, - enabled: true, - uploadedAt: now, - } - - try { - await tx.insert(document).values(newDocument) - logger.info( - `[${requestId}] Document record created: ${documentId} for file: ${docData.filename}` - ) - return { documentId, ...docData } - } catch (dbError) { - logger.error( - `[${requestId}] Failed to create document record for ${docData.filename}:`, - dbError - ) - throw new Error( - `Failed to create document record for ${docData.filename}: ${dbError instanceof Error ? dbError.message : 'Unknown database error'}` - ) - } - }) - - const results = await Promise.all(documentPromises) - - // Validate that all documents were created successfully - const invalidResults = results.filter((result) => !result.documentId || !result.filename) - if (invalidResults.length > 0) { - logger.error(`[${requestId}] Some documents failed to create properly:`, invalidResults) - throw new Error(`Failed to create ${invalidResults.length} document records`) - } - - return results - }) - - logger.info( - `[${requestId}] Starting controlled async processing of ${createdDocuments.length} documents` - ) - - processDocumentsWithConcurrencyControl( - createdDocuments, - knowledgeBaseId, - validatedData.processingOptions, - requestId - ).catch((error: unknown) => { - logger.error(`[${requestId}] Critical error in document processing pipeline:`, error) - }) - - return NextResponse.json({ - success: true, - data: { - total: createdDocuments.length, - documentsCreated: createdDocuments.map((doc) => ({ - documentId: doc.documentId, - filename: doc.filename, - status: 'pending', - })), - processingMethod: 'background', - processingConfig: { - maxConcurrentDocuments: PROCESSING_CONFIG.maxConcurrentDocuments, - batchSize: PROCESSING_CONFIG.batchSize, - totalBatches: Math.ceil(createdDocuments.length / PROCESSING_CONFIG.batchSize), - }, - }, - }) - } catch (validationError) { - if (validationError instanceof z.ZodError) { - logger.warn(`[${requestId}] Invalid processing request data`, { - errors: validationError.errors, - }) - return NextResponse.json( - { error: 'Invalid request data', details: validationError.errors }, - { status: 400 } - ) - } - throw validationError - } - } catch (error) { - logger.error(`[${requestId}] Error processing documents`, error) - return NextResponse.json({ error: 'Failed to process documents' }, { status: 500 }) - } -} diff --git a/apps/sim/app/api/knowledge/[id]/route.test.ts b/apps/sim/app/api/knowledge/[id]/route.test.ts new file mode 100644 index 00000000000..97cf4c23953 --- /dev/null +++ b/apps/sim/app/api/knowledge/[id]/route.test.ts @@ -0,0 +1,332 @@ +/** + * Tests for knowledge base by ID API route + * + * @vitest-environment node + */ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { + createMockRequest, + mockAuth, + mockConsoleLogger, + mockDrizzleOrm, + mockKnowledgeSchemas, +} from '@/app/api/__test-utils__/utils' + +mockKnowledgeSchemas() +mockDrizzleOrm() +mockConsoleLogger() + +describe('Knowledge Base By ID API Route', () => { + const mockAuth$ = mockAuth() + + const mockDbChain = { + select: vi.fn().mockReturnThis(), + from: vi.fn().mockReturnThis(), + where: vi.fn().mockReturnThis(), + limit: vi.fn().mockReturnThis(), + update: vi.fn().mockReturnThis(), + set: vi.fn().mockReturnThis(), + } + + const mockKnowledgeBase = { + id: 'kb-123', + userId: 'user-123', + name: 'Test Knowledge Base', + description: 'Test description', + tokenCount: 100, + embeddingModel: 'text-embedding-3-small', + embeddingDimension: 1536, + chunkingConfig: { maxSize: 1024, minSize: 100, overlap: 200 }, + createdAt: new Date(), + updatedAt: new Date(), + workspaceId: null, + deletedAt: null, + } + + const resetMocks = () => { + vi.clearAllMocks() + Object.values(mockDbChain).forEach((fn) => { + if (typeof fn === 'function') { + fn.mockClear().mockReset().mockReturnThis() + } + }) + } + + beforeEach(async () => { + vi.clearAllMocks() + + vi.doMock('@/db', () => ({ + db: mockDbChain, + })) + + vi.stubGlobal('crypto', { + randomUUID: vi.fn().mockReturnValue('mock-uuid-1234-5678'), + }) + }) + + afterEach(() => { + vi.clearAllMocks() + }) + + describe('GET /api/knowledge/[id]', () => { + const mockParams = Promise.resolve({ id: 'kb-123' }) + + it('should retrieve knowledge base successfully for authenticated user', async () => { + mockAuth$.mockAuthenticatedUser() + + mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'user-123' }]) + + mockDbChain.limit.mockResolvedValueOnce([mockKnowledgeBase]) + + const req = createMockRequest('GET') + const { GET } = await import('./route') + const response = await GET(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(200) + expect(data.success).toBe(true) + expect(data.data.id).toBe('kb-123') + expect(data.data.name).toBe('Test Knowledge Base') + expect(mockDbChain.select).toHaveBeenCalled() + }) + + it('should return unauthorized for unauthenticated user', async () => { + mockAuth$.mockUnauthenticated() + + const req = createMockRequest('GET') + const { GET } = await import('./route') + const response = await GET(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(401) + expect(data.error).toBe('Unauthorized') + }) + + it('should return not found for non-existent knowledge base', async () => { + mockAuth$.mockAuthenticatedUser() + + mockDbChain.limit.mockResolvedValueOnce([]) + + const req = createMockRequest('GET') + const { GET } = await import('./route') + const response = await GET(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(404) + expect(data.error).toBe('Knowledge base not found') + }) + + it('should return unauthorized for knowledge base owned by different user', async () => { + mockAuth$.mockAuthenticatedUser() + + mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'different-user' }]) + + const req = createMockRequest('GET') + const { GET } = await import('./route') + const response = await GET(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(401) + expect(data.error).toBe('Unauthorized') + }) + + it('should handle database errors', async () => { + mockAuth$.mockAuthenticatedUser() + mockDbChain.limit.mockRejectedValueOnce(new Error('Database error')) + + const req = createMockRequest('GET') + const { GET } = await import('./route') + const response = await GET(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(500) + expect(data.error).toBe('Failed to fetch knowledge base') + }) + }) + + describe('PUT /api/knowledge/[id]', () => { + const mockParams = Promise.resolve({ id: 'kb-123' }) + const validUpdateData = { + name: 'Updated Knowledge Base', + description: 'Updated description', + } + + it('should update knowledge base successfully', async () => { + mockAuth$.mockAuthenticatedUser() + + resetMocks() + + mockDbChain.where.mockReturnValueOnce(mockDbChain) // Return this to continue chain + mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'user-123' }]) + + mockDbChain.where.mockResolvedValueOnce(undefined) + + mockDbChain.where.mockReturnValueOnce(mockDbChain) // Return this to continue chain + mockDbChain.limit.mockResolvedValueOnce([{ ...mockKnowledgeBase, ...validUpdateData }]) + + const req = createMockRequest('PUT', validUpdateData) + const { PUT } = await import('./route') + const response = await PUT(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(200) + expect(data.success).toBe(true) + expect(data.data.name).toBe('Updated Knowledge Base') + expect(mockDbChain.update).toHaveBeenCalled() + }) + + it('should return unauthorized for unauthenticated user', async () => { + mockAuth$.mockUnauthenticated() + + const req = createMockRequest('PUT', validUpdateData) + const { PUT } = await import('./route') + const response = await PUT(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(401) + expect(data.error).toBe('Unauthorized') + }) + + it('should return not found for non-existent knowledge base', async () => { + mockAuth$.mockAuthenticatedUser() + + resetMocks() + + mockDbChain.where.mockReturnValueOnce(mockDbChain) // Return this to continue chain + mockDbChain.limit.mockResolvedValueOnce([]) + + const req = createMockRequest('PUT', validUpdateData) + const { PUT } = await import('./route') + const response = await PUT(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(404) + expect(data.error).toBe('Knowledge base not found') + }) + + it('should validate update data', async () => { + mockAuth$.mockAuthenticatedUser() + + resetMocks() + + mockDbChain.where.mockReturnValueOnce(mockDbChain) // Return this to continue chain + mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'user-123' }]) + + const invalidData = { + name: '', + } + + const req = createMockRequest('PUT', invalidData) + const { PUT } = await import('./route') + const response = await PUT(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(400) + expect(data.error).toBe('Invalid request data') + expect(data.details).toBeDefined() + }) + + it('should handle database errors during update', async () => { + mockAuth$.mockAuthenticatedUser() + + mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'user-123' }]) + + mockDbChain.where.mockRejectedValueOnce(new Error('Database error')) + + const req = createMockRequest('PUT', validUpdateData) + const { PUT } = await import('./route') + const response = await PUT(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(500) + expect(data.error).toBe('Failed to update knowledge base') + }) + }) + + describe('DELETE /api/knowledge/[id]', () => { + const mockParams = Promise.resolve({ id: 'kb-123' }) + + it('should delete knowledge base successfully', async () => { + mockAuth$.mockAuthenticatedUser() + + resetMocks() + + mockDbChain.where.mockReturnValueOnce(mockDbChain) // Return this to continue chain + mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'user-123' }]) + + mockDbChain.where.mockResolvedValueOnce(undefined) + + const req = createMockRequest('DELETE') + const { DELETE } = await import('./route') + const response = await DELETE(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(200) + expect(data.success).toBe(true) + expect(data.data.message).toBe('Knowledge base deleted successfully') + expect(mockDbChain.update).toHaveBeenCalled() + }) + + it('should return unauthorized for unauthenticated user', async () => { + mockAuth$.mockUnauthenticated() + + const req = createMockRequest('DELETE') + const { DELETE } = await import('./route') + const response = await DELETE(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(401) + expect(data.error).toBe('Unauthorized') + }) + + it('should return not found for non-existent knowledge base', async () => { + mockAuth$.mockAuthenticatedUser() + + resetMocks() + + mockDbChain.where.mockReturnValueOnce(mockDbChain) // Return this to continue chain + mockDbChain.limit.mockResolvedValueOnce([]) + + const req = createMockRequest('DELETE') + const { DELETE } = await import('./route') + const response = await DELETE(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(404) + expect(data.error).toBe('Knowledge base not found') + }) + + it('should return unauthorized for knowledge base owned by different user', async () => { + mockAuth$.mockAuthenticatedUser() + + resetMocks() + + mockDbChain.where.mockReturnValueOnce(mockDbChain) // Return this to continue chain + mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'different-user' }]) + + const req = createMockRequest('DELETE') + const { DELETE } = await import('./route') + const response = await DELETE(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(401) + expect(data.error).toBe('Unauthorized') + }) + + it('should handle database errors during delete', async () => { + mockAuth$.mockAuthenticatedUser() + + mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'user-123' }]) + + mockDbChain.where.mockRejectedValueOnce(new Error('Database error')) + + const req = createMockRequest('DELETE') + const { DELETE } = await import('./route') + const response = await DELETE(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(500) + expect(data.error).toBe('Failed to delete knowledge base') + }) + }) +}) diff --git a/apps/sim/app/api/knowledge/[id]/route.ts b/apps/sim/app/api/knowledge/[id]/route.ts index 51ca1e20919..04d34fd5720 100644 --- a/apps/sim/app/api/knowledge/[id]/route.ts +++ b/apps/sim/app/api/knowledge/[id]/route.ts @@ -8,7 +8,6 @@ import { knowledgeBase } from '@/db/schema' const logger = createLogger('KnowledgeBaseByIdAPI') -// Schema for knowledge base updates const UpdateKnowledgeBaseSchema = z.object({ name: z.string().min(1, 'Name is required').optional(), description: z.string().optional(), diff --git a/apps/sim/app/api/knowledge/route.test.ts b/apps/sim/app/api/knowledge/route.test.ts new file mode 100644 index 00000000000..94db6a8358b --- /dev/null +++ b/apps/sim/app/api/knowledge/route.test.ts @@ -0,0 +1,220 @@ +/** + * Tests for knowledge base API route + * + * @vitest-environment node + */ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { + createMockRequest, + mockAuth, + mockConsoleLogger, + mockDrizzleOrm, + mockKnowledgeSchemas, +} from '@/app/api/__test-utils__/utils' + +mockKnowledgeSchemas() +mockDrizzleOrm() +mockConsoleLogger() + +describe('Knowledge Base API Route', () => { + const mockAuth$ = mockAuth() + + const mockDbChain = { + select: vi.fn().mockReturnThis(), + from: vi.fn().mockReturnThis(), + leftJoin: vi.fn().mockReturnThis(), + where: vi.fn().mockReturnThis(), + groupBy: vi.fn().mockReturnThis(), + orderBy: vi.fn().mockResolvedValue([]), + insert: vi.fn().mockReturnThis(), + values: vi.fn().mockResolvedValue(undefined), + } + + beforeEach(async () => { + vi.clearAllMocks() + + vi.doMock('@/db', () => ({ + db: mockDbChain, + })) + + Object.values(mockDbChain).forEach((fn) => { + if (typeof fn === 'function') { + fn.mockClear() + if (fn !== mockDbChain.orderBy && fn !== mockDbChain.values) { + fn.mockReturnThis() + } + } + }) + + vi.stubGlobal('crypto', { + randomUUID: vi.fn().mockReturnValue('mock-uuid-1234-5678'), + }) + }) + + afterEach(() => { + vi.clearAllMocks() + }) + + describe('GET /api/knowledge', () => { + it('should return knowledge bases with document counts for authenticated user', async () => { + const mockKnowledgeBases = [ + { + id: 'kb-1', + name: 'Test KB 1', + description: 'Test description', + tokenCount: 100, + embeddingModel: 'text-embedding-3-small', + embeddingDimension: 1536, + chunkingConfig: { maxSize: 1024, minSize: 100, overlap: 200 }, + createdAt: new Date().toISOString(), + updatedAt: new Date().toISOString(), + workspaceId: null, + docCount: 5, + }, + ] + + mockAuth$.mockAuthenticatedUser() + mockDbChain.orderBy.mockResolvedValue(mockKnowledgeBases) + + const req = createMockRequest('GET') + const { GET } = await import('./route') + const response = await GET(req) + const data = await response.json() + + expect(response.status).toBe(200) + expect(data.success).toBe(true) + expect(data.data).toEqual(mockKnowledgeBases) + expect(mockDbChain.select).toHaveBeenCalled() + }) + + it('should return unauthorized for unauthenticated user', async () => { + mockAuth$.mockUnauthenticated() + + const req = createMockRequest('GET') + const { GET } = await import('./route') + const response = await GET(req) + const data = await response.json() + + expect(response.status).toBe(401) + expect(data.error).toBe('Unauthorized') + }) + + it('should handle database errors', async () => { + mockAuth$.mockAuthenticatedUser() + mockDbChain.orderBy.mockRejectedValue(new Error('Database error')) + + const req = createMockRequest('GET') + const { GET } = await import('./route') + const response = await GET(req) + const data = await response.json() + + expect(response.status).toBe(500) + expect(data.error).toBe('Failed to fetch knowledge bases') + }) + }) + + describe('POST /api/knowledge', () => { + const validKnowledgeBaseData = { + name: 'Test Knowledge Base', + description: 'Test description', + chunkingConfig: { + maxSize: 1024, + minSize: 100, + overlap: 200, + }, + } + + it('should create knowledge base successfully', async () => { + mockAuth$.mockAuthenticatedUser() + + const req = createMockRequest('POST', validKnowledgeBaseData) + const { POST } = await import('./route') + const response = await POST(req) + const data = await response.json() + + expect(response.status).toBe(200) + expect(data.success).toBe(true) + expect(data.data.name).toBe(validKnowledgeBaseData.name) + expect(data.data.description).toBe(validKnowledgeBaseData.description) + expect(mockDbChain.insert).toHaveBeenCalled() + }) + + it('should return unauthorized for unauthenticated user', async () => { + mockAuth$.mockUnauthenticated() + + const req = createMockRequest('POST', validKnowledgeBaseData) + const { POST } = await import('./route') + const response = await POST(req) + const data = await response.json() + + expect(response.status).toBe(401) + expect(data.error).toBe('Unauthorized') + }) + + it('should validate required fields', async () => { + mockAuth$.mockAuthenticatedUser() + + const req = createMockRequest('POST', { description: 'Missing name' }) + const { POST } = await import('./route') + const response = await POST(req) + const data = await response.json() + + expect(response.status).toBe(400) + expect(data.error).toBe('Invalid request data') + expect(data.details).toBeDefined() + }) + + it('should validate chunking config constraints', async () => { + mockAuth$.mockAuthenticatedUser() + + const invalidData = { + name: 'Test KB', + chunkingConfig: { + maxSize: 100, + minSize: 200, // Invalid: minSize > maxSize + overlap: 50, + }, + } + + const req = createMockRequest('POST', invalidData) + const { POST } = await import('./route') + const response = await POST(req) + const data = await response.json() + + expect(response.status).toBe(400) + expect(data.error).toBe('Invalid request data') + }) + + it('should use default values for optional fields', async () => { + mockAuth$.mockAuthenticatedUser() + + const minimalData = { name: 'Test KB' } + const req = createMockRequest('POST', minimalData) + const { POST } = await import('./route') + const response = await POST(req) + const data = await response.json() + + expect(response.status).toBe(200) + expect(data.data.embeddingModel).toBe('text-embedding-3-small') + expect(data.data.embeddingDimension).toBe(1536) + expect(data.data.chunkingConfig).toEqual({ + maxSize: 1024, + minSize: 100, + overlap: 200, + }) + }) + + it('should handle database errors during creation', async () => { + mockAuth$.mockAuthenticatedUser() + mockDbChain.values.mockRejectedValue(new Error('Database error')) + + const req = createMockRequest('POST', validKnowledgeBaseData) + const { POST } = await import('./route') + const response = await POST(req) + const data = await response.json() + + expect(response.status).toBe(500) + expect(data.error).toBe('Failed to create knowledge base') + }) + }) +}) diff --git a/apps/sim/app/api/knowledge/route.ts b/apps/sim/app/api/knowledge/route.ts index 09b6a395700..8d57c3a710a 100644 --- a/apps/sim/app/api/knowledge/route.ts +++ b/apps/sim/app/api/knowledge/route.ts @@ -17,11 +17,18 @@ const CreateKnowledgeBaseSchema = z.object({ embeddingDimension: z.literal(1536).default(1536), chunkingConfig: z .object({ - maxSize: z.number().default(1024), - minSize: z.number().default(100), - overlap: z.number().default(200), + maxSize: z.number().min(100).max(4000).default(1024), + minSize: z.number().min(50).max(2000).default(100), + overlap: z.number().min(0).max(500).default(200), }) - .default({}), + .default({ + maxSize: 1024, + minSize: 100, + overlap: 200, + }) + .refine((data) => data.minSize < data.maxSize, { + message: 'Min chunk size must be less than max chunk size', + }), }) export async function GET(req: NextRequest) { @@ -101,7 +108,11 @@ export async function POST(req: NextRequest) { tokenCount: 0, embeddingModel: validatedData.embeddingModel, embeddingDimension: validatedData.embeddingDimension, - chunkingConfig: validatedData.chunkingConfig, + chunkingConfig: validatedData.chunkingConfig || { + maxSize: 1024, + minSize: 100, + overlap: 200, + }, docCount: 0, createdAt: now, updatedAt: now, diff --git a/apps/sim/app/api/knowledge/search/route.test.ts b/apps/sim/app/api/knowledge/search/route.test.ts new file mode 100644 index 00000000000..8cf86c202ca --- /dev/null +++ b/apps/sim/app/api/knowledge/search/route.test.ts @@ -0,0 +1,399 @@ +/** + * Tests for knowledge search API route + * + * @vitest-environment node + */ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { + createMockRequest, + mockConsoleLogger, + mockKnowledgeSchemas, +} from '@/app/api/__test-utils__/utils' + +vi.mock('drizzle-orm', () => ({ + and: vi.fn().mockImplementation((...args) => ({ and: args })), + eq: vi.fn().mockImplementation((a, b) => ({ eq: [a, b] })), + inArray: vi.fn().mockImplementation((field, values) => ({ inArray: [field, values] })), + isNull: vi.fn().mockImplementation((arg) => ({ isNull: arg })), + sql: vi.fn().mockImplementation((strings, ...values) => ({ + sql: strings, + values, + as: vi.fn().mockReturnValue({ sql: strings, values, alias: 'mocked_alias' }), + })), +})) + +mockKnowledgeSchemas() + +vi.mock('@/lib/env', () => ({ + env: { + OPENAI_API_KEY: 'test-api-key', + }, +})) + +vi.mock('@/lib/documents/utils', () => ({ + retryWithExponentialBackoff: vi.fn().mockImplementation((fn) => fn()), +})) + +mockConsoleLogger() + +describe('Knowledge Search API Route', () => { + const mockDbChain = { + select: vi.fn().mockReturnThis(), + from: vi.fn().mockReturnThis(), + where: vi.fn().mockReturnThis(), + orderBy: vi.fn().mockReturnThis(), + limit: vi.fn().mockReturnThis(), + } + + const mockGetUserId = vi.fn() + const mockFetch = vi.fn() + + const mockEmbedding = [0.1, 0.2, 0.3, 0.4, 0.5] + const mockSearchResults = [ + { + id: 'chunk-1', + content: 'This is a test chunk', + documentId: 'doc-1', + chunkIndex: 0, + metadata: { title: 'Test Document' }, + distance: 0.2, + }, + { + id: 'chunk-2', + content: 'Another test chunk', + documentId: 'doc-2', + chunkIndex: 1, + metadata: { title: 'Another Document' }, + distance: 0.3, + }, + ] + + beforeEach(async () => { + vi.clearAllMocks() + + vi.doMock('@/db', () => ({ + db: mockDbChain, + })) + + vi.doMock('@/app/api/auth/oauth/utils', () => ({ + getUserId: mockGetUserId, + })) + + Object.values(mockDbChain).forEach((fn) => { + if (typeof fn === 'function') { + fn.mockClear().mockReturnThis() + } + }) + + vi.stubGlobal('crypto', { + randomUUID: vi.fn().mockReturnValue('mock-uuid-1234-5678'), + }) + + vi.stubGlobal('fetch', mockFetch) + }) + + afterEach(() => { + vi.clearAllMocks() + }) + + describe('POST /api/knowledge/search', () => { + const validSearchData = { + knowledgeBaseIds: 'kb-123', + query: 'test search query', + topK: 10, + } + + const mockKnowledgeBases = [ + { + id: 'kb-123', + userId: 'user-123', + name: 'Test KB', + deletedAt: null, + }, + ] + + it('should perform search successfully with single knowledge base', async () => { + mockGetUserId.mockResolvedValue('user-123') + + mockDbChain.where.mockResolvedValueOnce(mockKnowledgeBases) + + mockDbChain.limit.mockResolvedValueOnce(mockSearchResults) + + mockFetch.mockResolvedValue({ + ok: true, + json: () => + Promise.resolve({ + data: [{ embedding: mockEmbedding }], + }), + }) + + const req = createMockRequest('POST', validSearchData) + const { POST } = await import('./route') + const response = await POST(req) + const data = await response.json() + + expect(response.status).toBe(200) + expect(data.success).toBe(true) + expect(data.data.results).toHaveLength(2) + expect(data.data.results[0].similarity).toBe(0.8) // 1 - 0.2 + expect(data.data.query).toBe(validSearchData.query) + expect(data.data.knowledgeBaseIds).toEqual(['kb-123']) + expect(mockDbChain.select).toHaveBeenCalled() + }) + + it('should perform search successfully with multiple knowledge bases', async () => { + const multiKbData = { + ...validSearchData, + knowledgeBaseIds: ['kb-123', 'kb-456'], + } + + const multiKbs = [ + ...mockKnowledgeBases, + { id: 'kb-456', userId: 'user-123', name: 'Test KB 2', deletedAt: null }, + ] + + mockGetUserId.mockResolvedValue('user-123') + + mockDbChain.where.mockResolvedValueOnce(multiKbs) + + mockDbChain.limit.mockResolvedValueOnce(mockSearchResults) + + mockFetch.mockResolvedValue({ + ok: true, + json: () => + Promise.resolve({ + data: [{ embedding: mockEmbedding }], + }), + }) + + const req = createMockRequest('POST', multiKbData) + const { POST } = await import('./route') + const response = await POST(req) + const data = await response.json() + + expect(response.status).toBe(200) + expect(data.success).toBe(true) + expect(data.data.knowledgeBaseIds).toEqual(['kb-123', 'kb-456']) + }) + + it('should handle workflow-based authentication', async () => { + const workflowData = { + ...validSearchData, + workflowId: 'workflow-123', + } + + mockGetUserId.mockResolvedValue('user-123') + + mockDbChain.where.mockResolvedValueOnce(mockKnowledgeBases) // First call: get knowledge bases + + mockDbChain.limit.mockResolvedValueOnce(mockSearchResults) // Second call: search results + + mockFetch.mockResolvedValue({ + ok: true, + json: () => + Promise.resolve({ + data: [{ embedding: mockEmbedding }], + }), + }) + + const req = createMockRequest('POST', workflowData) + const { POST } = await import('./route') + const response = await POST(req) + const data = await response.json() + + expect(response.status).toBe(200) + expect(data.success).toBe(true) + expect(mockGetUserId).toHaveBeenCalledWith(expect.any(String), 'workflow-123') + }) + + it('should return unauthorized for unauthenticated request', async () => { + mockGetUserId.mockResolvedValue(null) + + const req = createMockRequest('POST', validSearchData) + const { POST } = await import('./route') + const response = await POST(req) + const data = await response.json() + + expect(response.status).toBe(401) + expect(data.error).toBe('Unauthorized') + }) + + it('should return not found for workflow that does not exist', async () => { + const workflowData = { + ...validSearchData, + workflowId: 'nonexistent-workflow', + } + + mockGetUserId.mockResolvedValue(null) + + const req = createMockRequest('POST', workflowData) + const { POST } = await import('./route') + const response = await POST(req) + const data = await response.json() + + expect(response.status).toBe(404) + expect(data.error).toBe('Workflow not found') + }) + + it('should return not found for non-existent knowledge base', async () => { + mockGetUserId.mockResolvedValue('user-123') + + mockDbChain.where.mockResolvedValueOnce([]) // No knowledge bases found + + const req = createMockRequest('POST', validSearchData) + const { POST } = await import('./route') + const response = await POST(req) + const data = await response.json() + + expect(response.status).toBe(404) + expect(data.error).toBe('Knowledge base not found or access denied') + }) + + it('should return not found for some missing knowledge bases', async () => { + const multiKbData = { + ...validSearchData, + knowledgeBaseIds: ['kb-123', 'kb-missing'], + } + + mockGetUserId.mockResolvedValue('user-123') + + mockDbChain.where.mockResolvedValueOnce(mockKnowledgeBases) // Only kb-123 found + + const req = createMockRequest('POST', multiKbData) + const { POST } = await import('./route') + const response = await POST(req) + const data = await response.json() + + expect(response.status).toBe(404) + expect(data.error).toBe('Knowledge bases not found: kb-missing') + }) + + it('should validate search parameters', async () => { + const invalidData = { + knowledgeBaseIds: '', // Empty string + query: '', // Empty query + topK: 150, // Too high + } + + const req = createMockRequest('POST', invalidData) + const { POST } = await import('./route') + const response = await POST(req) + const data = await response.json() + + expect(response.status).toBe(400) + expect(data.error).toBe('Invalid request data') + expect(data.details).toBeDefined() + }) + + it('should use default topK value when not provided', async () => { + const dataWithoutTopK = { + knowledgeBaseIds: 'kb-123', + query: 'test search query', + } + + mockGetUserId.mockResolvedValue('user-123') + + mockDbChain.where.mockResolvedValueOnce(mockKnowledgeBases) // First call: get knowledge bases + + mockDbChain.limit.mockResolvedValueOnce(mockSearchResults) // Second call: search results + + mockFetch.mockResolvedValue({ + ok: true, + json: () => + Promise.resolve({ + data: [{ embedding: mockEmbedding }], + }), + }) + + const req = createMockRequest('POST', dataWithoutTopK) + const { POST } = await import('./route') + const response = await POST(req) + const data = await response.json() + + expect(response.status).toBe(200) + expect(data.data.topK).toBe(10) // Default value + }) + + it('should handle OpenAI API errors', async () => { + mockGetUserId.mockResolvedValue('user-123') + mockDbChain.limit.mockResolvedValueOnce(mockKnowledgeBases) + + mockFetch.mockResolvedValue({ + ok: false, + status: 401, + statusText: 'Unauthorized', + text: () => Promise.resolve('Invalid API key'), + }) + + const req = createMockRequest('POST', validSearchData) + const { POST } = await import('./route') + const response = await POST(req) + const data = await response.json() + + expect(response.status).toBe(500) + expect(data.error).toBe('Failed to perform vector search') + }) + + it('should handle missing OpenAI API key', async () => { + vi.doMock('@/lib/env', () => ({ + env: { + OPENAI_API_KEY: undefined, + }, + })) + + mockGetUserId.mockResolvedValue('user-123') + mockDbChain.limit.mockResolvedValueOnce(mockKnowledgeBases) + + const req = createMockRequest('POST', validSearchData) + const { POST } = await import('./route') + const response = await POST(req) + const data = await response.json() + + expect(response.status).toBe(500) + expect(data.error).toBe('Failed to perform vector search') + }) + + it('should handle database errors during search', async () => { + mockGetUserId.mockResolvedValue('user-123') + mockDbChain.limit.mockResolvedValueOnce(mockKnowledgeBases) + mockDbChain.limit.mockRejectedValueOnce(new Error('Database error')) + + mockFetch.mockResolvedValue({ + ok: true, + json: () => + Promise.resolve({ + data: [{ embedding: mockEmbedding }], + }), + }) + + const req = createMockRequest('POST', validSearchData) + const { POST } = await import('./route') + const response = await POST(req) + const data = await response.json() + + expect(response.status).toBe(500) + expect(data.error).toBe('Failed to perform vector search') + }) + + it('should handle invalid OpenAI response format', async () => { + mockGetUserId.mockResolvedValue('user-123') + mockDbChain.limit.mockResolvedValueOnce(mockKnowledgeBases) + + mockFetch.mockResolvedValue({ + ok: true, + json: () => + Promise.resolve({ + data: [], // Empty data array + }), + }) + + const req = createMockRequest('POST', validSearchData) + const { POST } = await import('./route') + const response = await POST(req) + const data = await response.json() + + expect(response.status).toBe(500) + expect(data.error).toBe('Failed to perform vector search') + }) + }) +}) diff --git a/apps/sim/app/api/knowledge/search/route.ts b/apps/sim/app/api/knowledge/search/route.ts index 7b6b69b3f42..381809ed59f 100644 --- a/apps/sim/app/api/knowledge/search/route.ts +++ b/apps/sim/app/api/knowledge/search/route.ts @@ -1,4 +1,4 @@ -import { and, eq, isNull, sql } from 'drizzle-orm' +import { and, eq, inArray, isNull, sql } from 'drizzle-orm' import { type NextRequest, NextResponse } from 'next/server' import { z } from 'zod' import { retryWithExponentialBackoff } from '@/lib/documents/utils' @@ -20,9 +20,11 @@ class APIError extends Error { } } -// Schema for vector search request const VectorSearchSchema = z.object({ - knowledgeBaseId: z.string().min(1, 'Knowledge base ID is required'), + knowledgeBaseIds: z.union([ + z.string().min(1, 'Knowledge base ID is required'), + z.array(z.string().min(1)).min(1, 'At least one knowledge base ID is required'), + ]), query: z.string().min(1, 'Search query is required'), topK: z.number().min(1).max(100).default(10), }) @@ -34,7 +36,7 @@ async function generateSearchEmbedding(query: string): Promise { } try { - return await retryWithExponentialBackoff( + const embedding = await retryWithExponentialBackoff( async () => { const response = await fetch('https://api.openai.com/v1/embeddings', { method: 'POST', @@ -69,10 +71,12 @@ async function generateSearchEmbedding(query: string): Promise { { maxRetries: 5, initialDelayMs: 1000, - maxDelayMs: 30000, // Max 30 seconds delay for search queries + maxDelayMs: 30000, backoffMultiplier: 2, } ) + + return embedding } catch (error) { logger.error('Failed to generate search embedding:', error) throw new Error( @@ -81,12 +85,91 @@ async function generateSearchEmbedding(query: string): Promise { } } +function getQueryStrategy(kbCount: number, topK: number) { + const useParallel = kbCount > 4 || (kbCount > 2 && topK > 50) + const distanceThreshold = kbCount > 3 ? 0.8 : 1.0 + const parallelLimit = Math.ceil(topK / kbCount) + 5 + + return { + useParallel, + distanceThreshold, + parallelLimit, + singleQueryOptimized: kbCount <= 2, + } +} + +async function executeParallelQueries( + knowledgeBaseIds: string[], + queryVector: string, + topK: number, + distanceThreshold: number +) { + const parallelLimit = Math.ceil(topK / knowledgeBaseIds.length) + 5 + + const queryPromises = knowledgeBaseIds.map(async (kbId) => { + const results = await db + .select({ + id: embedding.id, + content: embedding.content, + documentId: embedding.documentId, + chunkIndex: embedding.chunkIndex, + metadata: embedding.metadata, + distance: sql`${embedding.embedding} <=> ${queryVector}::vector`.as('distance'), + knowledgeBaseId: embedding.knowledgeBaseId, + }) + .from(embedding) + .where( + and( + eq(embedding.knowledgeBaseId, kbId), + eq(embedding.enabled, true), + sql`${embedding.embedding} <=> ${queryVector}::vector < ${distanceThreshold}` + ) + ) + .orderBy(sql`${embedding.embedding} <=> ${queryVector}::vector`) + .limit(parallelLimit) + + return results + }) + + const parallelResults = await Promise.all(queryPromises) + return parallelResults.flat() +} + +async function executeSingleQuery( + knowledgeBaseIds: string[], + queryVector: string, + topK: number, + distanceThreshold: number +) { + return await db + .select({ + id: embedding.id, + content: embedding.content, + documentId: embedding.documentId, + chunkIndex: embedding.chunkIndex, + metadata: embedding.metadata, + distance: sql`${embedding.embedding} <=> ${queryVector}::vector`.as('distance'), + }) + .from(embedding) + .where( + and( + inArray(embedding.knowledgeBaseId, knowledgeBaseIds), + eq(embedding.enabled, true), + sql`${embedding.embedding} <=> ${queryVector}::vector < ${distanceThreshold}` + ) + ) + .orderBy(sql`${embedding.embedding} <=> ${queryVector}::vector`) + .limit(topK) +} + +function mergeAndRankResults(results: any[], topK: number) { + return results.sort((a, b) => a.distance - b.distance).slice(0, topK) +} + export async function POST(request: NextRequest) { const requestId = crypto.randomUUID().slice(0, 8) try { - logger.info(`[${requestId}] Processing vector search request`) - const body = await request.json() const { workflowId, ...searchParams } = body @@ -95,63 +178,71 @@ export async function POST(request: NextRequest) { if (!userId) { const errorMessage = workflowId ? 'Workflow not found' : 'Unauthorized' const statusCode = workflowId ? 404 : 401 - logger.warn(`[${requestId}] Authentication failed: ${errorMessage}`) return NextResponse.json({ error: errorMessage }, { status: statusCode }) } try { const validatedData = VectorSearchSchema.parse(searchParams) - // Verify the knowledge base exists and user has access - const kb = await db - .select() - .from(knowledgeBase) - .where( - and( - eq(knowledgeBase.id, validatedData.knowledgeBaseId), - eq(knowledgeBase.userId, userId), - isNull(knowledgeBase.deletedAt) - ) - ) - .limit(1) + const knowledgeBaseIds = Array.isArray(validatedData.knowledgeBaseIds) + ? validatedData.knowledgeBaseIds + : [validatedData.knowledgeBaseIds] + + const [kb, queryEmbedding] = await Promise.all([ + db + .select() + .from(knowledgeBase) + .where( + and( + inArray(knowledgeBase.id, knowledgeBaseIds), + eq(knowledgeBase.userId, userId), + isNull(knowledgeBase.deletedAt) + ) + ), + generateSearchEmbedding(validatedData.query), + ]) if (kb.length === 0) { - logger.warn( - `[${requestId}] Knowledge base not found or access denied: ${validatedData.knowledgeBaseId}` - ) return NextResponse.json( { error: 'Knowledge base not found or access denied' }, { status: 404 } ) } - // Generate embedding for the search query - logger.info(`[${requestId}] Generating embedding for search query`) - const queryEmbedding = await generateSearchEmbedding(validatedData.query) - - // Perform vector similarity search using pgvector cosine similarity - logger.info(`[${requestId}] Performing vector search with topK=${validatedData.topK}`) - - const results = await db - .select({ - id: embedding.id, - content: embedding.content, - documentId: embedding.documentId, - chunkIndex: embedding.chunkIndex, - metadata: embedding.metadata, - similarity: sql`1 - (${embedding.embedding} <=> ${JSON.stringify(queryEmbedding)}::vector)`, - }) - .from(embedding) - .where( - and( - eq(embedding.knowledgeBaseId, validatedData.knowledgeBaseId), - eq(embedding.enabled, true) - ) + const foundKbIds = kb.map((k) => k.id) + const missingKbIds = knowledgeBaseIds.filter((id) => !foundKbIds.includes(id)) + + if (missingKbIds.length > 0) { + return NextResponse.json( + { error: `Knowledge bases not found: ${missingKbIds.join(', ')}` }, + { status: 404 } ) - .orderBy(sql`${embedding.embedding} <=> ${JSON.stringify(queryEmbedding)}::vector`) - .limit(validatedData.topK) + } + + // Adaptive query strategy based on KB count and parameters + const strategy = getQueryStrategy(foundKbIds.length, validatedData.topK) + const queryVector = JSON.stringify(queryEmbedding) + + let results: any[] - logger.info(`[${requestId}] Vector search completed. Found ${results.length} results`) + if (strategy.useParallel) { + // Execute parallel queries for better performance with many KBs + const parallelResults = await executeParallelQueries( + foundKbIds, + queryVector, + validatedData.topK, + strategy.distanceThreshold + ) + results = mergeAndRankResults(parallelResults, validatedData.topK) + } else { + // Execute single optimized query for fewer KBs + results = await executeSingleQuery( + foundKbIds, + queryVector, + validatedData.topK, + strategy.distanceThreshold + ) + } return NextResponse.json({ success: true, @@ -162,19 +253,17 @@ export async function POST(request: NextRequest) { documentId: result.documentId, chunkIndex: result.chunkIndex, metadata: result.metadata, - similarity: result.similarity, + similarity: 1 - result.distance, })), query: validatedData.query, - knowledgeBaseId: validatedData.knowledgeBaseId, + knowledgeBaseIds: foundKbIds, + knowledgeBaseId: foundKbIds[0], topK: validatedData.topK, totalResults: results.length, }, }) } catch (validationError) { if (validationError instanceof z.ZodError) { - logger.warn(`[${requestId}] Invalid vector search data`, { - errors: validationError.errors, - }) return NextResponse.json( { error: 'Invalid request data', details: validationError.errors }, { status: 400 } @@ -183,7 +272,6 @@ export async function POST(request: NextRequest) { throw validationError } } catch (error) { - logger.error(`[${requestId}] Error performing vector search`, error) return NextResponse.json( { error: 'Failed to perform vector search', diff --git a/apps/sim/app/api/knowledge/utils.ts b/apps/sim/app/api/knowledge/utils.ts index 4848bbcb48a..6a061e0820b 100644 --- a/apps/sim/app/api/knowledge/utils.ts +++ b/apps/sim/app/api/knowledge/utils.ts @@ -9,6 +9,12 @@ import { document, embedding, knowledgeBase } from '@/db/schema' const logger = createLogger('KnowledgeUtils') +// Timeout constants (in milliseconds) +const TIMEOUTS = { + OVERALL_PROCESSING: 150000, // 150 seconds (2.5 minutes) + EMBEDDINGS_API: 60000, // 60 seconds per batch +} as const + class APIError extends Error { public status: number @@ -19,6 +25,22 @@ class APIError extends Error { } } +/** + * Create a timeout wrapper for async operations + */ +function withTimeout( + promise: Promise, + timeoutMs: number, + operation = 'Operation' +): Promise { + return Promise.race([ + promise, + new Promise((_, reject) => + setTimeout(() => reject(new Error(`${operation} timed out after ${timeoutMs}ms`)), timeoutMs) + ), + ]) +} + export interface KnowledgeBaseData { id: string userId: string @@ -41,7 +63,6 @@ export interface DocumentData { fileUrl: string fileSize: number mimeType: string - fileHash?: string | null chunkCount: number tokenCount: number characterCount: number @@ -67,12 +88,7 @@ export interface EmbeddingData { embeddingModel: string startOffset: number endOffset: number - overlapTokens: number metadata: unknown - searchRank?: string | null - accessCount: number - lastAccessedAt?: Date | null - qualityScore?: string | null enabled: boolean createdAt: Date updatedAt: Date @@ -316,30 +332,44 @@ export async function generateEmbeddings( const batchEmbeddings = await retryWithExponentialBackoff( async () => { - const response = await fetch('https://api.openai.com/v1/embeddings', { - method: 'POST', - headers: { - Authorization: `Bearer ${openaiApiKey}`, - 'Content-Type': 'application/json', - }, - body: JSON.stringify({ - input: batch, - model: embeddingModel, - encoding_format: 'float', - }), - }) - - if (!response.ok) { - const errorText = await response.text() - const error = new APIError( - `OpenAI API error: ${response.status} ${response.statusText} - ${errorText}`, - response.status - ) + const controller = new AbortController() + const timeoutId = setTimeout(() => controller.abort(), TIMEOUTS.EMBEDDINGS_API) + + try { + const response = await fetch('https://api.openai.com/v1/embeddings', { + method: 'POST', + headers: { + Authorization: `Bearer ${openaiApiKey}`, + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + input: batch, + model: embeddingModel, + encoding_format: 'float', + }), + signal: controller.signal, + }) + + clearTimeout(timeoutId) + + if (!response.ok) { + const errorText = await response.text() + const error = new APIError( + `OpenAI API error: ${response.status} ${response.statusText} - ${errorText}`, + response.status + ) + throw error + } + + const data: OpenAIEmbeddingResponse = await response.json() + return data.data.map((item) => item.embedding) + } catch (error) { + clearTimeout(timeoutId) + if (error instanceof Error && error.name === 'AbortError') { + throw new Error('OpenAI API request timed out') + } throw error } - - const data: OpenAIEmbeddingResponse = await response.json() - return data.data.map((item) => item.embedding) }, { maxRetries: 5, @@ -370,7 +400,6 @@ export async function processDocumentAsync( fileUrl: string fileSize: number mimeType: string - fileHash?: string | null }, processingOptions: { chunkSize?: number @@ -396,78 +425,78 @@ export async function processDocumentAsync( logger.info(`[${documentId}] Status updated to 'processing', starting document processor`) - const processed = await processDocument( - docData.fileUrl, - docData.filename, - docData.mimeType, - processingOptions.chunkSize || 1000, - processingOptions.chunkOverlap || 200 - ) - - const now = new Date() - - logger.info( - `[${documentId}] Document parsed successfully, generating embeddings for ${processed.chunks.length} chunks` - ) + // Wrap the entire processing operation with a 5-minute timeout + await withTimeout( + (async () => { + const processed = await processDocument( + docData.fileUrl, + docData.filename, + docData.mimeType, + processingOptions.chunkSize || 1000, + processingOptions.chunkOverlap || 200 + ) + + const now = new Date() + + logger.info( + `[${documentId}] Document parsed successfully, generating embeddings for ${processed.chunks.length} chunks` + ) + + const chunkTexts = processed.chunks.map((chunk) => chunk.text) + const embeddings = chunkTexts.length > 0 ? await generateEmbeddings(chunkTexts) : [] + + logger.info(`[${documentId}] Embeddings generated, updating document record`) + + const embeddingRecords = processed.chunks.map((chunk, chunkIndex) => ({ + id: crypto.randomUUID(), + knowledgeBaseId, + documentId, + chunkIndex, + chunkHash: crypto.createHash('sha256').update(chunk.text).digest('hex'), + content: chunk.text, + contentLength: chunk.text.length, + tokenCount: Math.ceil(chunk.text.length / 4), + embedding: embeddings[chunkIndex] || null, + embeddingModel: 'text-embedding-3-small', + startOffset: chunk.metadata.startIndex, + endOffset: chunk.metadata.endIndex, + metadata: {}, + createdAt: now, + updatedAt: now, + })) - const chunkTexts = processed.chunks.map((chunk) => chunk.text) - const embeddings = chunkTexts.length > 0 ? await generateEmbeddings(chunkTexts) : [] - - logger.info(`[${documentId}] Embeddings generated, updating document record`) - - const embeddingRecords = processed.chunks.map((chunk, chunkIndex) => ({ - id: crypto.randomUUID(), - knowledgeBaseId, - documentId, - chunkIndex, - chunkHash: crypto.createHash('sha256').update(chunk.text).digest('hex'), - content: chunk.text, - contentLength: chunk.text.length, - tokenCount: Math.ceil(chunk.text.length / 4), - embedding: embeddings[chunkIndex] || null, - embeddingModel: 'text-embedding-3-small', - startOffset: chunk.metadata.startIndex, - endOffset: chunk.metadata.endIndex, - overlapTokens: 0, - metadata: {}, - searchRank: '1.0', - accessCount: 0, - lastAccessedAt: null, - qualityScore: null, - createdAt: now, - updatedAt: now, - })) - - await db.transaction(async (tx) => { - if (embeddingRecords.length > 0) { - await tx.insert(embedding).values(embeddingRecords) - } - - await tx - .update(document) - .set({ - chunkCount: processed.metadata.chunkCount, - tokenCount: processed.metadata.tokenCount, - characterCount: processed.metadata.characterCount, - processingStatus: 'completed', - processingCompletedAt: now, - processingError: null, - }) - .where(eq(document.id, documentId)) + await db.transaction(async (tx) => { + if (embeddingRecords.length > 0) { + await tx.insert(embedding).values(embeddingRecords) + } - await tx - .update(knowledgeBase) - .set({ - tokenCount: sql`${knowledgeBase.tokenCount} + ${processed.metadata.tokenCount}`, - updatedAt: now, + await tx + .update(document) + .set({ + chunkCount: processed.metadata.chunkCount, + tokenCount: processed.metadata.tokenCount, + characterCount: processed.metadata.characterCount, + processingStatus: 'completed', + processingCompletedAt: now, + processingError: null, + }) + .where(eq(document.id, documentId)) + + await tx + .update(knowledgeBase) + .set({ + tokenCount: sql`${knowledgeBase.tokenCount} + ${processed.metadata.tokenCount}`, + updatedAt: now, + }) + .where(eq(knowledgeBase.id, knowledgeBaseId)) }) - .where(eq(knowledgeBase.id, knowledgeBaseId)) - }) + })(), + TIMEOUTS.OVERALL_PROCESSING, + 'Document processing' + ) const processingTime = Date.now() - startTime - logger.info( - `[${documentId}] Successfully processed document with ${processed.metadata.chunkCount} chunks in ${processingTime}ms` - ) + logger.info(`[${documentId}] Successfully processed document in ${processingTime}ms`) } catch (error) { const processingTime = Date.now() - startTime logger.error(`[${documentId}] Failed to process document after ${processingTime}ms:`, { diff --git a/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/document-selector/document-selector.tsx b/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/document-selector/document-selector.tsx index 60475ab9510..d1a3ab06f4d 100644 --- a/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/document-selector/document-selector.tsx +++ b/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/document-selector/document-selector.tsx @@ -22,7 +22,6 @@ interface DocumentData { fileUrl: string fileSize: number mimeType: string - fileHash: string | null chunkCount: number tokenCount: number characterCount: number diff --git a/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/knowledge-base-selector/knowledge-base-selector.tsx b/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/knowledge-base-selector/knowledge-base-selector.tsx index 01e78aeec1f..8012fb32fc3 100644 --- a/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/knowledge-base-selector/knowledge-base-selector.tsx +++ b/apps/sim/app/w/[id]/components/workflow-block/components/sub-block/components/knowledge-base-selector/knowledge-base-selector.tsx @@ -1,7 +1,7 @@ 'use client' import { useCallback, useEffect, useState } from 'react' -import { Check, ChevronDown, RefreshCw } from 'lucide-react' +import { Check, ChevronDown, RefreshCw, X } from 'lucide-react' import { PackageSearchIcon } from '@/components/icons' import { Button } from '@/components/ui/button' import { @@ -21,7 +21,7 @@ interface KnowledgeBaseSelectorProps { blockId: string subBlock: SubBlockConfig disabled?: boolean - onKnowledgeBaseSelect?: (knowledgeBaseId: string) => void + onKnowledgeBaseSelect?: (knowledgeBaseId: string | string[]) => void isPreview?: boolean previewValue?: string | null } @@ -42,9 +42,8 @@ export function KnowledgeBaseSelector({ const [loading, setLoading] = useState(false) const [error, setError] = useState(null) const [open, setOpen] = useState(false) - const [selectedKnowledgeBase, setSelectedKnowledgeBase] = useState(null) + const [selectedKnowledgeBases, setSelectedKnowledgeBases] = useState([]) const [initialFetchDone, setInitialFetchDone] = useState(false) - const [knowledgeBaseInfo, setKnowledgeBaseInfo] = useState(null) // Get the current value from the store const storeValue = getValue(blockId, subBlock.id) @@ -52,6 +51,8 @@ export function KnowledgeBaseSelector({ // Use preview value when in preview mode, otherwise use store value const value = isPreview ? previewValue : storeValue + const isMultiSelect = subBlock.multiSelect === true + // Fetch knowledge bases const fetchKnowledgeBases = useCallback(async () => { setLoading(true) @@ -82,12 +83,11 @@ export function KnowledgeBaseSelector({ } } - // Handle knowledge base selection - const handleSelectKnowledgeBase = (knowledgeBase: KnowledgeBaseData) => { + // Handle single knowledge base selection (for backward compatibility) + const handleSelectSingleKnowledgeBase = (knowledgeBase: KnowledgeBaseData) => { if (isPreview) return - setSelectedKnowledgeBase(knowledgeBase) - setKnowledgeBaseInfo(knowledgeBase) + setSelectedKnowledgeBases([knowledgeBase]) if (!isPreview) { setValue(blockId, subBlock.id, knowledgeBase.id) @@ -97,20 +97,65 @@ export function KnowledgeBaseSelector({ setOpen(false) } - // Sync selected knowledge base with value prop + // Handle multi-select knowledge base selection + const handleToggleKnowledgeBase = (knowledgeBase: KnowledgeBaseData) => { + if (isPreview) return + + const isCurrentlySelected = selectedKnowledgeBases.some((kb) => kb.id === knowledgeBase.id) + let newSelected: KnowledgeBaseData[] + + if (isCurrentlySelected) { + // Remove from selection + newSelected = selectedKnowledgeBases.filter((kb) => kb.id !== knowledgeBase.id) + } else { + // Add to selection + newSelected = [...selectedKnowledgeBases, knowledgeBase] + } + + setSelectedKnowledgeBases(newSelected) + + if (!isPreview) { + const selectedIds = newSelected.map((kb) => kb.id) + const valueToStore = selectedIds.length === 1 ? selectedIds[0] : selectedIds.join(',') + setValue(blockId, subBlock.id, valueToStore) + } + + onKnowledgeBaseSelect?.(newSelected.map((kb) => kb.id)) + } + + // Remove selected knowledge base (for multi-select tags) + const handleRemoveKnowledgeBase = (knowledgeBaseId: string) => { + if (isPreview) return + + const newSelected = selectedKnowledgeBases.filter((kb) => kb.id !== knowledgeBaseId) + setSelectedKnowledgeBases(newSelected) + + if (!isPreview) { + const selectedIds = newSelected.map((kb) => kb.id) + const valueToStore = selectedIds.length === 1 ? selectedIds[0] : selectedIds.join(',') + setValue(blockId, subBlock.id, valueToStore) + } + + onKnowledgeBaseSelect?.(newSelected.map((kb) => kb.id)) + } + + // Sync selected knowledge bases with value prop useEffect(() => { if (value && knowledgeBases.length > 0) { - const kbInfo = knowledgeBases.find((kb) => kb.id === value) - if (kbInfo) { - setSelectedKnowledgeBase(kbInfo) - setKnowledgeBaseInfo(kbInfo) - } else { - setSelectedKnowledgeBase(null) - setKnowledgeBaseInfo(null) - } + const selectedIds = + typeof value === 'string' + ? value.includes(',') + ? value + .split(',') + .map((id) => id.trim()) + .filter((id) => id.length > 0) + : [value] + : [] + + const selectedKbs = knowledgeBases.filter((kb) => selectedIds.includes(kb.id)) + setSelectedKnowledgeBases(selectedKbs) } else if (!value) { - setSelectedKnowledgeBase(null) - setKnowledgeBaseInfo(null) + setSelectedKnowledgeBases([]) } }, [value, knowledgeBases]) @@ -124,10 +169,23 @@ export function KnowledgeBaseSelector({ // If we have a value but no knowledge base info and haven't fetched yet, fetch useEffect(() => { - if (value && !selectedKnowledgeBase && !loading && !initialFetchDone && !isPreview) { + if ( + value && + selectedKnowledgeBases.length === 0 && + !loading && + !initialFetchDone && + !isPreview + ) { fetchKnowledgeBases() } - }, [value, selectedKnowledgeBase, loading, initialFetchDone, fetchKnowledgeBases, isPreview]) + }, [ + value, + selectedKnowledgeBases.length, + loading, + initialFetchDone, + fetchKnowledgeBases, + isPreview, + ]) const formatKnowledgeBaseName = (knowledgeBase: KnowledgeBaseData) => { return knowledgeBase.name @@ -141,10 +199,38 @@ export function KnowledgeBaseSelector({ return knowledgeBase.description || 'No description' } - const label = subBlock.placeholder || 'Select knowledge base' + const isKnowledgeBaseSelected = (knowledgeBaseId: string) => { + return selectedKnowledgeBases.some((kb) => kb.id === knowledgeBaseId) + } + + const label = + subBlock.placeholder || (isMultiSelect ? 'Select knowledge bases' : 'Select knowledge base') return (
+ {/* Selected knowledge bases display (for multi-select) */} + {isMultiSelect && selectedKnowledgeBases.length > 0 && ( +
+ {selectedKnowledgeBases.map((kb) => ( +
+ + {formatKnowledgeBaseName(kb)} + {!disabled && !isPreview && ( + + )} +
+ ))} +
+ )} +