Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix(knowledge): record embedding usage cost for KB document processing
Adds billing tracking to the KB embedding pipeline, which was previously
generating OpenAI API calls with no cost recorded. Token counts are now
captured from the actual API response and recorded via recordUsage after
successful embedding insertion. BYOK workspaces are excluded from billing.
Applies to all execution paths: direct, BullMQ, and Trigger.dev.
  • Loading branch information
waleedlatif1 committed Apr 4, 2026
commit 987411ca48d2458566ea65762178a6fc3e0fb35a
5 changes: 4 additions & 1 deletion apps/sim/app/api/knowledge/utils.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ vi.stubGlobal(
{ embedding: [0.1, 0.2], index: 0 },
{ embedding: [0.3, 0.4], index: 1 },
],
usage: { prompt_tokens: 2, total_tokens: 2 },
}),
})
)
Expand Down Expand Up @@ -294,7 +295,7 @@ describe('Knowledge Utils', () => {
it.concurrent('should return same length as input', async () => {
const result = await generateEmbeddings(['a', 'b'])

expect(result.length).toBe(2)
expect(result.embeddings.length).toBe(2)
})

it('should use Azure OpenAI when Azure config is provided', async () => {
Expand All @@ -313,6 +314,7 @@ describe('Knowledge Utils', () => {
ok: true,
json: async () => ({
data: [{ embedding: [0.1, 0.2], index: 0 }],
usage: { prompt_tokens: 1, total_tokens: 1 },
}),
} as any)

Expand Down Expand Up @@ -342,6 +344,7 @@ describe('Knowledge Utils', () => {
ok: true,
json: async () => ({
data: [{ embedding: [0.1, 0.2], index: 0 }],
usage: { prompt_tokens: 1, total_tokens: 1 },
}),
} as any)

Expand Down
1 change: 1 addition & 0 deletions apps/sim/lib/billing/core/usage-log.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ export type UsageLogSource =
| 'workspace-chat'
| 'mcp_copilot'
| 'mothership_block'
| 'knowledge-base'

/**
* Metadata for 'model' category charges
Expand Down
5 changes: 4 additions & 1 deletion apps/sim/lib/chunkers/docs-chunker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ export class DocsChunker {
const textChunks = await this.splitContent(markdownContent)

logger.info(`Generating embeddings for ${textChunks.length} chunks in ${relativePath}`)
const embeddings = textChunks.length > 0 ? await generateEmbeddings(textChunks) : []
const { embeddings } =
textChunks.length > 0
? await generateEmbeddings(textChunks)
: { embeddings: [] as number[][] }
const embeddingModel = 'text-embedding-3-small'

const chunks: DocChunk[] = []
Expand Down
4 changes: 2 additions & 2 deletions apps/sim/lib/knowledge/chunks/service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ export async function createChunk(
workspaceId?: string | null
): Promise<ChunkData> {
logger.info(`[${requestId}] Generating embedding for manual chunk`)
const embeddings = await generateEmbeddings([chunkData.content], undefined, workspaceId)
const { embeddings } = await generateEmbeddings([chunkData.content], undefined, workspaceId)

// Calculate accurate token count
const tokenCount = estimateTokenCount(chunkData.content, 'openai')
Expand Down Expand Up @@ -359,7 +359,7 @@ export async function updateChunk(
if (content !== currentChunk[0].content) {
logger.info(`[${requestId}] Content changed, regenerating embedding for chunk ${chunkId}`)

const embeddings = await generateEmbeddings([content], undefined, workspaceId)
const { embeddings } = await generateEmbeddings([content], undefined, workspaceId)

// Calculate accurate token count
const tokenCount = estimateTokenCount(content, 'openai')
Expand Down
43 changes: 41 additions & 2 deletions apps/sim/lib/knowledge/documents/service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ import {
type SQL,
sql,
} from 'drizzle-orm'
import { recordUsage } from '@/lib/billing/core/usage-log'
import { createBullMQJobData, isBullMQEnabled } from '@/lib/core/bullmq'
import { env } from '@/lib/core/config/env'
import { isTriggerDevEnabled } from '@/lib/core/config/feature-flags'
import { getCostMultiplier, isTriggerDevEnabled } from '@/lib/core/config/feature-flags'
import { enqueueWorkspaceDispatch } from '@/lib/core/workspace-dispatch'
import { processDocument } from '@/lib/knowledge/documents/document-processor'
import type { DocumentSortField, SortOrder } from '@/lib/knowledge/documents/types'
Expand All @@ -43,6 +44,7 @@ import type { ProcessedDocumentTags } from '@/lib/knowledge/types'
import { deleteFile } from '@/lib/uploads/core/storage-service'
import { extractStorageKey } from '@/lib/uploads/utils/file-utils'
import type { DocumentProcessingPayload } from '@/background/knowledge-processing'
import { getEmbeddingModelPricing } from '@/providers/models'

const logger = createLogger('DocumentService')

Expand Down Expand Up @@ -460,6 +462,9 @@ export async function processDocumentAsync(
overlap: rawConfig?.overlap ?? 200,
}

let totalEmbeddingTokens = 0
let embeddingIsBYOK = false

await withTimeout(
(async () => {
const processed = await processDocument(
Expand Down Expand Up @@ -500,10 +505,16 @@ export async function processDocumentAsync(
const batchNum = Math.floor(i / batchSize) + 1

logger.info(`[${documentId}] Processing embedding batch ${batchNum}/${totalBatches}`)
const batchEmbeddings = await generateEmbeddings(batch, undefined, kb[0].workspaceId)
const {
embeddings: batchEmbeddings,
totalTokens: batchTokens,
isBYOK,
} = await generateEmbeddings(batch, undefined, kb[0].workspaceId)
for (const emb of batchEmbeddings) {
embeddings.push(emb)
}
totalEmbeddingTokens += batchTokens
embeddingIsBYOK = isBYOK
}
}

Expand Down Expand Up @@ -638,6 +649,34 @@ export async function processDocumentAsync(

const processingTime = Date.now() - startTime
logger.info(`[${documentId}] Successfully processed document in ${processingTime}ms`)

if (!embeddingIsBYOK && totalEmbeddingTokens > 0 && kb[0].userId) {
try {
const embeddingModel = 'text-embedding-3-small'
const pricing = getEmbeddingModelPricing(embeddingModel)
if (pricing) {
const cost = (totalEmbeddingTokens / 1_000_000) * pricing.input * getCostMultiplier()
await recordUsage({
userId: kb[0].userId,
workspaceId: kb[0].workspaceId ?? undefined,
entries: [
{
category: 'model',
source: 'knowledge-base',
description: embeddingModel,
cost,
metadata: { inputTokens: totalEmbeddingTokens, outputTokens: 0 },
},
],
additionalStats: {
totalTokensUsed: sql`total_tokens_used + ${totalEmbeddingTokens}`,
},
})
}
} catch (billingError) {
logger.error(`[${documentId}] Failed to record embedding usage`, { error: billingError })
}
}
} catch (error) {
const processingTime = Date.now() - startTime
const errorMessage = error instanceof Error ? error.message : 'Unknown error'
Expand Down
35 changes: 28 additions & 7 deletions apps/sim/lib/knowledge/embeddings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ interface EmbeddingConfig {
apiUrl: string
headers: Record<string, string>
modelName: string
isBYOK: boolean
}

interface EmbeddingResponseItem {
Expand Down Expand Up @@ -71,16 +72,19 @@ async function getEmbeddingConfig(
'Content-Type': 'application/json',
},
modelName: kbModelName,
isBYOK: false,
}
}

let openaiApiKey = env.OPENAI_API_KEY
let isBYOK = false

if (workspaceId) {
const byokResult = await getBYOKKey(workspaceId, 'openai')
if (byokResult) {
logger.info('Using workspace BYOK key for OpenAI embeddings')
openaiApiKey = byokResult.apiKey
isBYOK = true
}
}

Expand All @@ -98,12 +102,16 @@ async function getEmbeddingConfig(
'Content-Type': 'application/json',
},
modelName: embeddingModel,
isBYOK,
}
}

const EMBEDDING_REQUEST_TIMEOUT_MS = 60_000

async function callEmbeddingAPI(inputs: string[], config: EmbeddingConfig): Promise<number[][]> {
async function callEmbeddingAPI(
inputs: string[],
config: EmbeddingConfig
): Promise<{ embeddings: number[][]; totalTokens: number }> {
return retryWithExponentialBackoff(
async () => {
const useDimensions = supportsCustomDimensions(config.modelName)
Expand Down Expand Up @@ -140,7 +148,10 @@ async function callEmbeddingAPI(inputs: string[], config: EmbeddingConfig): Prom
}

const data: EmbeddingAPIResponse = await response.json()
return data.data.map((item) => item.embedding)
return {
embeddings: data.data.map((item) => item.embedding),
totalTokens: data.usage.total_tokens,
}
Comment on lines 150 to +154
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Missing null guard on data.usage

The response is cast to EmbeddingAPIResponse at the TypeScript level, but there's no runtime validation. If an Azure deployment or OpenAI-compatible proxy omits the usage field (e.g. older API versions or proxies), data.usage.total_tokens throws TypeError: Cannot read properties of undefined (reading 'total_tokens'), which would propagate out of callEmbeddingAPI and abort the entire batch, preventing document processing from completing.

A safe fallback costs nothing in the happy path:

Suggested change
const data: EmbeddingAPIResponse = await response.json()
return data.data.map((item) => item.embedding)
return {
embeddings: data.data.map((item) => item.embedding),
totalTokens: data.usage.total_tokens,
}
const data: EmbeddingAPIResponse = await response.json()
return {
embeddings: data.data.map((item) => item.embedding),
totalTokens: data.usage?.total_tokens ?? 0,
}

},
{
maxRetries: 3,
Expand Down Expand Up @@ -178,14 +189,22 @@ async function processWithConcurrency<T, R>(
return results
}

export interface GenerateEmbeddingsResult {
embeddings: number[][]
totalTokens: number
isBYOK: boolean
}

/**
* Generate embeddings for multiple texts with token-aware batching and parallel processing
* Generate embeddings for multiple texts with token-aware batching and parallel processing.
* Returns embeddings alongside the actual token count from the API and whether a BYOK key was used.
* Callers should use `totalTokens` and `isBYOK` to record billing via `recordUsage`.
*/
export async function generateEmbeddings(
texts: string[],
embeddingModel = 'text-embedding-3-small',
workspaceId?: string | null
): Promise<number[][]> {
): Promise<GenerateEmbeddingsResult> {
const config = await getEmbeddingConfig(embeddingModel, workspaceId)

const batches = batchByTokenLimit(texts, MAX_TOKENS_PER_REQUEST, embeddingModel)
Expand All @@ -204,13 +223,15 @@ export async function generateEmbeddings(
)

const allEmbeddings: number[][] = []
let totalTokens = 0
for (const batch of batchResults) {
for (const emb of batch) {
for (const emb of batch.embeddings) {
allEmbeddings.push(emb)
}
totalTokens += batch.totalTokens
}

return allEmbeddings
return { embeddings: allEmbeddings, totalTokens, isBYOK: config.isBYOK }
}

/**
Expand All @@ -227,6 +248,6 @@ export async function generateSearchEmbedding(
`Using ${config.useAzure ? 'Azure OpenAI' : 'OpenAI'} for search embedding generation`
)

const embeddings = await callEmbeddingAPI([query], config)
const { embeddings } = await callEmbeddingAPI([query], config)
return embeddings[0]
}