Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix(billing): atomize usage_log and userStats writes via central reco…
…rdUsage()
  • Loading branch information
waleedlatif1 committed Mar 25, 2026
commit 95ea77f9d6ae9b1cf67155aa095129d38d1184d7
55 changes: 20 additions & 35 deletions apps/sim/app/api/billing/update-cost/route.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import { db } from '@sim/db'
import { userStats } from '@sim/db/schema'
import { createLogger } from '@sim/logger'
import { eq, sql } from 'drizzle-orm'
import { sql } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { logModelUsage } from '@/lib/billing/core/usage-log'
import { recordUsage } from '@/lib/billing/core/usage-log'
import { checkAndBillOverageThreshold } from '@/lib/billing/threshold-billing'
import { checkInternalApiKey } from '@/lib/copilot/utils'
import { isBillingEnabled } from '@/lib/core/config/feature-flags'
Expand Down Expand Up @@ -87,53 +85,40 @@ export async function POST(req: NextRequest) {
source,
})

// Check if user stats record exists (same as ExecutionLogger)
const userStatsRecords = await db.select().from(userStats).where(eq(userStats.userId, userId))

if (userStatsRecords.length === 0) {
logger.error(
`[${requestId}] User stats record not found - should be created during onboarding`,
{
userId,
}
)
return NextResponse.json({ error: 'User stats record not found' }, { status: 500 })
}

const totalTokens = inputTokens + outputTokens

const updateFields: Record<string, unknown> = {
totalCost: sql`total_cost + ${cost}`,
currentPeriodCost: sql`current_period_cost + ${cost}`,
const additionalStats: Record<string, ReturnType<typeof sql>> = {
totalCopilotCost: sql`total_copilot_cost + ${cost}`,
currentPeriodCopilotCost: sql`current_period_copilot_cost + ${cost}`,
totalCopilotCalls: sql`total_copilot_calls + 1`,
totalCopilotTokens: sql`total_copilot_tokens + ${totalTokens}`,
lastActive: new Date(),
}

if (isMcp) {
updateFields.totalMcpCopilotCost = sql`total_mcp_copilot_cost + ${cost}`
Comment thread
waleedlatif1 marked this conversation as resolved.
updateFields.currentPeriodMcpCopilotCost = sql`current_period_mcp_copilot_cost + ${cost}`
updateFields.totalMcpCopilotCalls = sql`total_mcp_copilot_calls + 1`
additionalStats.totalMcpCopilotCost = sql`total_mcp_copilot_cost + ${cost}`
additionalStats.currentPeriodMcpCopilotCost = sql`current_period_mcp_copilot_cost + ${cost}`
additionalStats.totalMcpCopilotCalls = sql`total_mcp_copilot_calls + 1`
}

await db.update(userStats).set(updateFields).where(eq(userStats.userId, userId))

logger.info(`[${requestId}] Updated user stats record`, {
// Atomic write: usage_log INSERT + userStats UPDATE in one transaction
await recordUsage({
userId,
addedCost: cost,
source,
entries: [
{
category: 'model',
source,
description: model,
cost,
metadata: { inputTokens, outputTokens },
},
],
additionalStats,
})
Comment thread
waleedlatif1 marked this conversation as resolved.
Outdated

// Log usage for complete audit trail with the original source for visibility
await logModelUsage({
logger.info(`[${requestId}] Recorded usage`, {
userId,
addedCost: cost,
source,
model,
inputTokens,
outputTokens,
cost,
})

// Check if user has hit overage threshold and bill incrementally
Expand Down
42 changes: 20 additions & 22 deletions apps/sim/app/api/wand/route.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import { db } from '@sim/db'
import { userStats, workflow } from '@sim/db/schema'
import { workflow } from '@sim/db/schema'
import { createLogger } from '@sim/logger'
import { eq, sql } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { getBYOKKey } from '@/lib/api-key/byok'
import { getSession } from '@/lib/auth'
import { logModelUsage } from '@/lib/billing/core/usage-log'
import { recordUsage } from '@/lib/billing/core/usage-log'
import { checkAndBillOverageThreshold } from '@/lib/billing/threshold-billing'
import { env } from '@/lib/core/config/env'
import { getCostMultiplier, isBillingEnabled } from '@/lib/core/config/feature-flags'
Expand Down Expand Up @@ -134,23 +134,21 @@ async function updateUserStatsForWand(
costToStore = modelCost * costMultiplier
}

await db
.update(userStats)
.set({
totalTokensUsed: sql`total_tokens_used + ${totalTokens}`,
totalCost: sql`total_cost + ${costToStore}`,
currentPeriodCost: sql`current_period_cost + ${costToStore}`,
lastActive: new Date(),
})
.where(eq(userStats.userId, userId))

await logModelUsage({
// Atomic write: usage_log INSERT + userStats UPDATE in one transaction
await recordUsage({
userId,
source: 'wand',
model: modelName,
inputTokens: promptTokens,
outputTokens: completionTokens,
cost: costToStore,
entries: [
{
category: 'model',
source: 'wand',
description: modelName,
cost: costToStore,
metadata: { inputTokens: promptTokens, outputTokens: completionTokens },
},
],
additionalStats: {
totalTokensUsed: sql`total_tokens_used + ${totalTokens}`,
},
})

await checkAndBillOverageThreshold(userId)
Expand Down Expand Up @@ -341,7 +339,7 @@ export async function POST(req: NextRequest) {
let finalUsage: any = null
let usageRecorded = false

const recordUsage = async () => {
const flushUsage = async () => {
if (usageRecorded || !finalUsage) {
return
}
Expand All @@ -360,7 +358,7 @@ export async function POST(req: NextRequest) {

if (done) {
logger.info(`[${requestId}] Stream completed. Total chunks: ${chunkCount}`)
await recordUsage()
await flushUsage()
controller.enqueue(encoder.encode(`data: ${JSON.stringify({ done: true })}\n\n`))
controller.close()
break
Expand Down Expand Up @@ -390,7 +388,7 @@ export async function POST(req: NextRequest) {
if (data === '[DONE]') {
logger.info(`[${requestId}] Received [DONE] signal`)

await recordUsage()
await flushUsage()

controller.enqueue(
encoder.encode(`data: ${JSON.stringify({ done: true })}\n\n`)
Expand Down Expand Up @@ -468,7 +466,7 @@ export async function POST(req: NextRequest) {
})

try {
await recordUsage()
await flushUsage()
} catch (usageError) {
logger.warn(`[${requestId}] Failed to record usage after stream error`, usageError)
}
Expand Down
Loading