diff --git a/apps/sim/lib/billing/core/billing.ts b/apps/sim/lib/billing/core/billing.ts index aca33c6e3d..f1b3ae2636 100644 --- a/apps/sim/lib/billing/core/billing.ts +++ b/apps/sim/lib/billing/core/billing.ts @@ -22,7 +22,7 @@ import { isOrgScopedSubscription, } from '@/lib/billing/subscriptions/utils' import { Decimal, toDecimal, toNumber } from '@/lib/billing/utils/decimal' -import type { DbOrTx } from '@/lib/db/types' +import type { DbClient } from '@/lib/db/types' export { getPlanPricing } @@ -32,6 +32,8 @@ const logger = createLogger('Billing') interface GetOrganizationSubscriptionOptions { onError?: 'return-null' | 'throw' + /** Read-routing client (primary or replica); defaults to the primary. */ + executor?: DbClient } /** @@ -42,14 +44,18 @@ interface GetOrganizationSubscriptionOptions { * For product-access gating use `getOrganizationSubscriptionUsable` * (from `core/subscription.ts`), which excludes `past_due`. * Returns `null` when there is no entitled sub. + * + * `options.executor` exists for replica routing on display/summary read + * paths only. Enforcement and webhook callers must read the primary — + * omit the executor (or pass `db`). */ export async function getOrganizationSubscription( organizationId: string, options: GetOrganizationSubscriptionOptions = {} ) { - const { onError = 'return-null' } = options + const { onError = 'return-null', executor = db } = options try { - const orgSubs = await db + const orgSubs = await executor .select() .from(subscription) .where( @@ -111,13 +117,16 @@ export async function isSubscriptionOrgScoped(sub: { referenceId: string }): Pro * column is `NOT NULL DEFAULT '0'` and mixing scopes would break * current-period billing math. */ -async function aggregateOrgMemberStats(organizationId: string): Promise<{ +async function aggregateOrgMemberStats( + organizationId: string, + executor: DbClient = db +): Promise<{ memberIds: string[] currentPeriodCost: number currentPeriodCopilotCost: number lastPeriodCopilotCost: number }> { - const rows = await db + const rows = await executor .select({ userId: member.userId, currentPeriodCost: userStats.currentPeriodCost, @@ -386,7 +395,7 @@ export async function calculateSubscriptionOverage(sub: { export async function getSimplifiedBillingSummary( userId: string, organizationId?: string, - executor: DbOrTx = db + executor: DbClient = db ): Promise<{ type: 'individual' | 'organization' plan: string @@ -432,8 +441,8 @@ export async function getSimplifiedBillingSummary( // Get subscription and usage data upfront const [subscription, usageData] = await Promise.all([ organizationId - ? getOrganizationSubscription(organizationId) - : getHighestPrioritySubscription(userId), + ? getOrganizationSubscription(organizationId, { executor }) + : getHighestPrioritySubscription(userId, { executor }), getUserUsageData(userId, executor), ]) @@ -455,7 +464,7 @@ export async function getSimplifiedBillingSummary( // Pool usage/copilot across all members in one query. Must not use // `getUserUsageData` per-member — it now returns the pool itself // for org-scoped subs, which would N-times-count. - const pooled = await aggregateOrgMemberStats(organizationId) + const pooled = await aggregateOrgMemberStats(organizationId, executor) const rawCurrentUsage = pooled.currentPeriodCost const totalLastPeriodCopilotCost = pooled.lastPeriodCopilotCost @@ -495,7 +504,8 @@ export async function getSimplifiedBillingSummary( if (planDollars > 0) { const userBounds = await getOrgMemberRefreshBounds( organizationId, - subscription.periodStart + subscription.periodStart, + executor ) refreshDeduction = await computeDailyRefreshConsumed( { @@ -516,7 +526,8 @@ export async function getSimplifiedBillingSummary( const { limit: orgUsageLimit } = await getOrgUsageLimit( organizationId, plan, - subscription.seats ?? null + subscription.seats ?? null, + executor ) const percentUsed = @@ -532,7 +543,7 @@ export async function getSimplifiedBillingSummary( ) : 0 - const orgCredits = await getCreditBalance(userId) + const orgCredits = await getCreditBalance(userId, executor) const orgBillingInterval = getBillingInterval(subscription.metadata as SubscriptionMetadata) return { @@ -576,7 +587,7 @@ export async function getSimplifiedBillingSummary( } } - const userStatsRows = await db + const userStatsRows = await executor .select({ currentPeriodCopilotCost: userStats.currentPeriodCopilotCost, lastPeriodCopilotCost: userStats.lastPeriodCopilotCost, @@ -597,7 +608,7 @@ export async function getSimplifiedBillingSummary( let totalCopilotCost = copilotCost let totalLastPeriodCopilotCost = lastPeriodCopilotCost if (orgScoped && subscription?.referenceId) { - const pooled = await aggregateOrgMemberStats(subscription.referenceId) + const pooled = await aggregateOrgMemberStats(subscription.referenceId, executor) totalCopilotCost = pooled.currentPeriodCopilotCost totalLastPeriodCopilotCost = pooled.lastPeriodCopilotCost } @@ -631,7 +642,7 @@ export async function getSimplifiedBillingSummary( ) : 0 - const userCredits = await getCreditBalance(userId) + const userCredits = await getCreditBalance(userId, executor) const individualBillingInterval = getBillingInterval( subscription?.metadata as SubscriptionMetadata ) diff --git a/apps/sim/lib/billing/core/organization.ts b/apps/sim/lib/billing/core/organization.ts index f7240550eb..2000edd898 100644 --- a/apps/sim/lib/billing/core/organization.ts +++ b/apps/sim/lib/billing/core/organization.ts @@ -19,7 +19,7 @@ import { hasUsableSubscriptionStatus, } from '@/lib/billing/subscriptions/utils' import { toDecimal, toNumber } from '@/lib/billing/utils/decimal' -import type { DbOrTx } from '@/lib/db/types' +import type { DbClient } from '@/lib/db/types' const logger = createLogger('OrganizationBilling') @@ -66,11 +66,11 @@ interface MemberUsageData { export async function getOrgMemberLedgerByUser( organizationId: string, period?: { start: Date; end: Date } | null, - executor: DbOrTx = db + executor: DbClient = db ): Promise> { let billingPeriod = period ?? null if (period === undefined) { - const subscription = await getOrganizationSubscription(organizationId) + const subscription = await getOrganizationSubscription(organizationId, { executor }) billingPeriod = subscription?.periodStart && subscription?.periodEnd ? { start: subscription.periodStart, end: subscription.periodEnd } @@ -90,11 +90,11 @@ export async function getOrgMemberLedgerByUser( */ export async function getOrganizationBillingData( organizationId: string, - executor: DbOrTx = db + executor: DbClient = db ): Promise { try { // Get organization info - const orgRecord = await db + const orgRecord = await executor .select() .from(organization) .where(eq(organization.id, organizationId)) @@ -108,7 +108,7 @@ export async function getOrganizationBillingData( const organizationData = orgRecord[0] // Get organization subscription directly (referenceId = organizationId) - const subscription = await getOrganizationSubscription(organizationId) + const subscription = await getOrganizationSubscription(organizationId, { executor }) if (!subscription) { logger.warn('No subscription found for organization', { organizationId }) @@ -116,7 +116,7 @@ export async function getOrganizationBillingData( } // Get all organization members with their usage data - const membersWithUsage = await db + const membersWithUsage = await executor .select({ userId: member.userId, userName: user.name, @@ -185,7 +185,8 @@ export async function getOrganizationBillingData( const memberIds = members.map((m) => m.userId) const userBounds = await getOrgMemberRefreshBounds( subscription.referenceId, - subscription.periodStart + subscription.periodStart, + executor ) const refreshConsumed = await computeDailyRefreshConsumed( { @@ -233,7 +234,7 @@ export async function getOrganizationBillingData( const averageUsagePerMember = members.length > 0 ? totalCurrentUsage / members.length : 0 - const [pendingInvitationCount] = await db + const [pendingInvitationCount] = await executor .select({ count: count() }) .from(invitation) .where( diff --git a/apps/sim/lib/billing/core/plan.ts b/apps/sim/lib/billing/core/plan.ts index 8bbd516ee1..b4a56dab13 100644 --- a/apps/sim/lib/billing/core/plan.ts +++ b/apps/sim/lib/billing/core/plan.ts @@ -8,6 +8,7 @@ import { checkTeamPlan, ENTITLED_SUBSCRIPTION_STATUSES, } from '@/lib/billing/subscriptions/utils' +import type { DbClient } from '@/lib/db/types' const logger = createLogger('PlanLookup') @@ -15,6 +16,8 @@ export type HighestPrioritySubscription = Awaited( @@ -33,9 +36,9 @@ export async function getHighestPriorityPersonalSubscription( userId: string, options: GetHighestPrioritySubscriptionOptions = {} ) { - const { onError = 'return-null' } = options + const { onError = 'return-null', executor = db } = options try { - const personalSubs = await db + const personalSubs = await executor .select() .from(subscription) .where( @@ -77,9 +80,9 @@ export async function getHighestPrioritySubscription( userId: string, options: GetHighestPrioritySubscriptionOptions = {} ) { - const { onError = 'return-null' } = options + const { onError = 'return-null', executor = db } = options try { - const personalSubs = await db + const personalSubs = await executor .select() .from(subscription) .where( @@ -89,7 +92,7 @@ export async function getHighestPrioritySubscription( ) ) - const memberships = await db + const memberships = await executor .select({ organizationId: member.organizationId }) .from(member) .where(eq(member.userId, userId)) @@ -99,7 +102,7 @@ export async function getHighestPrioritySubscription( let orgSubs: typeof personalSubs = [] if (orgIds.length > 0) { // Verify orgs exist to filter out orphaned subscriptions - const existingOrgs = await db + const existingOrgs = await executor .select({ id: organization.id }) .from(organization) .where(inArray(organization.id, orgIds)) @@ -107,7 +110,7 @@ export async function getHighestPrioritySubscription( const validOrgIds = existingOrgs.map((o) => o.id) if (validOrgIds.length > 0) { - orgSubs = await db + orgSubs = await executor .select() .from(subscription) .where( diff --git a/apps/sim/lib/billing/core/usage-log.ts b/apps/sim/lib/billing/core/usage-log.ts index 985713bebc..034609a65f 100644 --- a/apps/sim/lib/billing/core/usage-log.ts +++ b/apps/sim/lib/billing/core/usage-log.ts @@ -8,7 +8,7 @@ import { and, desc, eq, gte, inArray, lt, lte, sql } from 'drizzle-orm' import { defaultBillingPeriod } from '@/lib/billing/core/billing-period' import { getHighestPrioritySubscription } from '@/lib/billing/core/plan' import { isOrgScopedSubscription } from '@/lib/billing/subscriptions/utils' -import type { DbOrTx } from '@/lib/db/types' +import type { DbClient, DbOrTx } from '@/lib/db/types' const logger = createLogger('UsageLog') @@ -184,7 +184,7 @@ export async function getBillingPeriodUsageCost( billingEntity: BillingEntity, billingPeriod: { start: Date; end: Date }, source?: UsageLogSource | UsageLogSource[], - executor: DbOrTx = db + executor: DbClient = db ): Promise { const conditions = [ eq(usageLog.billingEntityType, billingEntity.type), @@ -212,7 +212,7 @@ export async function getBillingPeriodUsageCostByUser( billingEntity: BillingEntity, billingPeriod: { start: Date; end: Date }, source?: UsageLogSource | UsageLogSource[], - executor: DbOrTx = db + executor: DbClient = db ): Promise> { const conditions = [ eq(usageLog.billingEntityType, billingEntity.type), diff --git a/apps/sim/lib/billing/core/usage.ts b/apps/sim/lib/billing/core/usage.ts index 7845312d31..a39d3d1309 100644 --- a/apps/sim/lib/billing/core/usage.ts +++ b/apps/sim/lib/billing/core/usage.ts @@ -34,7 +34,7 @@ import type { BillingData, UsageData, UsageLimitInfo } from '@/lib/billing/types import { Decimal, toDecimal, toNumber } from '@/lib/billing/utils/decimal' import { isBillingEnabled } from '@/lib/core/config/feature-flags' import { getBaseUrl } from '@/lib/core/utils/urls' -import type { DbOrTx } from '@/lib/db/types' +import type { DbClient } from '@/lib/db/types' import { sendEmail } from '@/lib/messaging/email/mailer' import { getEmailPreferences } from '@/lib/messaging/email/unsubscribe' @@ -58,9 +58,10 @@ export interface OrgUsageLimitResult { * downstream refresh / bounds computations. */ export async function getPooledOrgCurrentPeriodCost( - organizationId: string + organizationId: string, + executor: DbClient = db ): Promise<{ memberIds: string[]; currentPeriodCost: number; lastPeriodCost: number }> { - const rows = await db + const rows = await executor .select({ userId: member.userId, currentPeriodCost: userStats.currentPeriodCost, @@ -95,9 +96,10 @@ export async function getPooledOrgCurrentPeriodCost( export async function getOrgUsageLimit( organizationId: string, plan: string, - seats: number | null + seats: number | null, + executor: DbClient = db ): Promise { - const orgData = await db + const orgData = await executor .select({ orgUsageLimit: organization.orgUsageLimit }) .from(organization) .where(eq(organization.id, organizationId)) @@ -164,6 +166,7 @@ export async function handleNewUser(userId: string): Promise { * This is a fallback for cases where the user.create.after hook didn't fire * (e.g., OAuth account linking to existing users). * + * Always writes to the primary — never takes a read-routing executor. */ export async function ensureUserStatsExists(userId: string): Promise { await db @@ -180,13 +183,24 @@ export async function ensureUserStatsExists(userId: string): Promise { /** * Get comprehensive usage data for a user */ -export async function getUserUsageData(userId: string, executor: DbOrTx = db): Promise { +export async function getUserUsageData( + userId: string, + executor: DbClient = db +): Promise { try { + // Write — always on the primary regardless of executor routing. await ensureUserStatsExists(userId) const [userStatsData, subscription] = await Promise.all([ - db.select().from(userStats).where(eq(userStats.userId, userId)).limit(1), - getHighestPrioritySubscription(userId), + // Read-your-write: must see the row ensureUserStatsExists may have just + // inserted, which a lagging replica can miss (this path throws on a + // missing row). Stays on the primary deliberately. + db + .select() + .from(userStats) + .where(eq(userStats.userId, userId)) + .limit(1), + getHighestPrioritySubscription(userId, { executor }), ]) if (userStatsData.length === 0) { @@ -239,11 +253,12 @@ export async function getUserUsageData(userId: string, executor: DbOrTx = db): P const orgLimit = await getOrgUsageLimit( subscription.referenceId, subscription.plan, - subscription.seats + subscription.seats, + executor ) limit = orgLimit.limit - const pooled = await getPooledOrgCurrentPeriodCost(subscription.referenceId) + const pooled = await getPooledOrgCurrentPeriodCost(subscription.referenceId, executor) orgMemberIds = pooled.memberIds lastPeriodCost = pooled.lastPeriodCost const ledgerUsage = await getBillingPeriodUsageCost( @@ -270,7 +285,8 @@ export async function getUserUsageData(userId: string, executor: DbOrTx = db): P if (orgMemberIds.length > 0) { const userBounds = await getOrgMemberRefreshBounds( subscription.referenceId, - billingPeriodStart + billingPeriodStart, + executor ) dailyRefreshConsumed = await computeDailyRefreshConsumed( { @@ -628,16 +644,16 @@ export async function syncUsageLimitsFromSubscription(userId: string): Promise { - const subscription = await getHighestPrioritySubscription(userId) + const subscription = await getHighestPrioritySubscription(userId, { executor }) const orgScoped = isOrgScopedSubscription(subscription, userId) let rawCost: number let refreshUserIds: string[] = [userId] if (orgScoped && subscription) { - const pooled = await getPooledOrgCurrentPeriodCost(subscription.referenceId) + const pooled = await getPooledOrgCurrentPeriodCost(subscription.referenceId, executor) if (pooled.memberIds.length === 0) return 0 refreshUserIds = pooled.memberIds const billingPeriod = @@ -653,7 +669,7 @@ export async function getEffectiveCurrentPeriodCost( executor )) } else { - const rows = await db + const rows = await executor .select({ current: userStats.currentPeriodCost }) .from(userStats) .where(eq(userStats.userId, userId)) @@ -683,7 +699,11 @@ export async function getEffectiveCurrentPeriodCost( const userBounds = orgScoped && subscription.periodStart - ? await getOrgMemberRefreshBounds(subscription.referenceId, subscription.periodStart) + ? await getOrgMemberRefreshBounds( + subscription.referenceId, + subscription.periodStart, + executor + ) : {} const refreshConsumed = await computeDailyRefreshConsumed( diff --git a/apps/sim/lib/billing/credits/balance.ts b/apps/sim/lib/billing/credits/balance.ts index 7a5bcea061..cdf3dc784f 100644 --- a/apps/sim/lib/billing/credits/balance.ts +++ b/apps/sim/lib/billing/credits/balance.ts @@ -10,6 +10,7 @@ import { isOrgScopedSubscription, } from '@/lib/billing/subscriptions/utils' import { Decimal, toDecimal, toFixedString, toNumber } from '@/lib/billing/utils/decimal' +import type { DbClient } from '@/lib/db/types' const logger = createLogger('CreditBalance') @@ -28,10 +29,11 @@ export interface CreditBalanceInfo { */ export async function getCreditBalanceForEntity( entityType: 'user' | 'organization', - entityId: string + entityId: string, + executor: DbClient = db ): Promise { if (entityType === 'organization') { - const rows = await db + const rows = await executor .select({ creditBalance: organization.creditBalance }) .from(organization) .where(eq(organization.id, entityId)) @@ -39,7 +41,7 @@ export async function getCreditBalanceForEntity( return rows.length > 0 ? toNumber(toDecimal(rows[0].creditBalance)) : 0 } - const rows = await db + const rows = await executor .select({ creditBalance: userStats.creditBalance }) .from(userStats) .where(eq(userStats.userId, entityId)) @@ -47,19 +49,22 @@ export async function getCreditBalanceForEntity( return rows.length > 0 ? toNumber(toDecimal(rows[0].creditBalance)) : 0 } -export async function getCreditBalance(userId: string): Promise { - const subscription = await getHighestPrioritySubscription(userId) +export async function getCreditBalance( + userId: string, + executor: DbClient = db +): Promise { + const subscription = await getHighestPrioritySubscription(userId, { executor }) if (isOrgScopedSubscription(subscription, userId) && subscription) { return { - balance: await getCreditBalanceForEntity('organization', subscription.referenceId), + balance: await getCreditBalanceForEntity('organization', subscription.referenceId, executor), entityType: 'organization', entityId: subscription.referenceId, } } return { - balance: await getCreditBalanceForEntity('user', userId), + balance: await getCreditBalanceForEntity('user', userId, executor), entityType: 'user', entityId: userId, } diff --git a/apps/sim/lib/billing/credits/daily-refresh.ts b/apps/sim/lib/billing/credits/daily-refresh.ts index 7391c74192..b39351c204 100644 --- a/apps/sim/lib/billing/credits/daily-refresh.ts +++ b/apps/sim/lib/billing/credits/daily-refresh.ts @@ -16,7 +16,7 @@ import { member, usageLog, userStats } from '@sim/db/schema' import { createLogger } from '@sim/logger' import { and, eq, gte, inArray, lt, or, sql, sum } from 'drizzle-orm' import { DAILY_REFRESH_RATE } from '@/lib/billing/constants' -import type { DbOrTx } from '@/lib/db/types' +import type { DbClient } from '@/lib/db/types' const logger = createLogger('DailyRefresh') @@ -52,7 +52,7 @@ export async function computeDailyRefreshConsumed( userBounds?: Record billingEntity?: { type: 'user' | 'organization'; id: string } }, - executor: DbOrTx = db + executor: DbClient = db ): Promise { const { userIds, @@ -157,9 +157,10 @@ export function getDailyRefreshDollars(planDollars: number): number { export async function getOrgMemberRefreshBounds( organizationId: string, - periodStart: Date + periodStart: Date, + executor: DbClient = db ): Promise> { - const rows = await db + const rows = await executor .select({ userId: member.userId, snapshotAt: userStats.proPeriodCostSnapshotAt, diff --git a/apps/sim/lib/db/types.ts b/apps/sim/lib/db/types.ts index 8039e2b78b..8b1129deaf 100644 --- a/apps/sim/lib/db/types.ts +++ b/apps/sim/lib/db/types.ts @@ -15,3 +15,15 @@ export type DbOrTx = typeof schema, ExtractTablesWithRelations > + +/** + * Read-routing client: the primary `db` or the read replica `dbReplica`. + * + * For read-path helpers (billing summaries, dashboard aggregations) whose + * executor param exists to route SELECT fan-outs to a replica. Deliberately + * excludes transaction handles — these helpers issue multi-step query fans + * that must never run while a transaction holds a pooled connection. Use + * `DbOrTx` only for helpers genuinely designed to join a caller's + * transaction. + */ +export type DbClient = typeof db diff --git a/apps/sim/lib/mcp/workflow-mcp-sync.ts b/apps/sim/lib/mcp/workflow-mcp-sync.ts index 7a6a25d4eb..8c5d42a67d 100644 --- a/apps/sim/lib/mcp/workflow-mcp-sync.ts +++ b/apps/sim/lib/mcp/workflow-mcp-sync.ts @@ -143,7 +143,6 @@ interface SyncOptionsBase { requestId: string /** Context for logging (e.g., 'deploy', 'revert', 'activate') */ context?: string - notify?: boolean throwOnError?: boolean } @@ -151,11 +150,17 @@ interface SyncOptionsBase { * Callers running inside a transaction must preload the workflow state: * loading it lazily would issue queries on the global pool while the * transaction already holds a pooled connection. + * + * Server notification is strictly post-commit. The standalone arm notifies + * after its own transaction commits (`notify` defaults to true); the `tx` arm + * never notifies — publishing before the caller's transaction commits would + * announce state that may still roll back, so the transaction owner notifies + * after commit (see deployment-outbox). */ type SyncOptions = SyncOptionsBase & ( - | { tx: DbOrTx; state: { blocks?: Record } } - | { tx?: undefined; state?: { blocks?: Record } } + | { tx: DbOrTx; state: { blocks?: Record }; notify?: false } + | { tx?: undefined; state?: { blocks?: Record }; notify?: boolean } ) /** @@ -193,21 +198,11 @@ export async function syncMcpToolsForWorkflow( return tools } - const { - workflowId, - requestId, - state, - context = 'sync', - tx, - notify = true, - throwOnError = false, - } = options + const { workflowId, requestId, state, context = 'sync', tx, throwOnError = false } = options try { if (!hasValidStartBlockInState(state as WorkflowState | null)) { - const affectedTools = await removeMcpToolsForWorkflow(workflowId, requestId, tx, false, true) - if (notify) notifyMcpToolServers(affectedTools) - return affectedTools + return await removeMcpToolsForWorkflow(workflowId, requestId, tx, true) } const generatedParameterSchema = state.blocks @@ -324,9 +319,7 @@ export async function syncMcpToolsForWorkflow( `[${requestId}] Synced ${syncedToolCount} MCP tool(s) for workflow (${context}): ${workflowId}` ) - const affectedTools = [...affectedServerIds].map((serverId) => ({ serverId })) - if (notify) notifyMcpToolServers(affectedTools) - return affectedTools + return [...affectedServerIds].map((serverId) => ({ serverId })) } catch (error) { logger.error(`[${requestId}] Error syncing MCP tools (${context}):`, error) if (throwOnError) throw error @@ -336,20 +329,23 @@ export async function syncMcpToolsForWorkflow( /** * Remove all MCP tools for a workflow (used when undeploying). - * Queries affected tools before deleting so we can notify their servers. + * Queries affected tools before deleting so their servers can be notified. + * + * Server notification is strictly post-commit: the standalone path notifies + * after the transaction opened here commits; when `tx` is provided the + * transaction owner notifies after commit using the returned server ids. */ export async function removeMcpToolsForWorkflow( workflowId: string, requestId: string, tx?: DbOrTx, - notify = true, throwOnError = false ): Promise> { if (!tx) { const tools = await db.transaction((transaction) => - removeMcpToolsForWorkflow(workflowId, requestId, transaction, false, throwOnError) + removeMcpToolsForWorkflow(workflowId, requestId, transaction, throwOnError) ) - if (notify) notifyMcpToolServers(tools) + notifyMcpToolServers(tools) return tools } @@ -365,7 +361,6 @@ export async function removeMcpToolsForWorkflow( await tx.delete(workflowMcpTool).where(eq(workflowMcpTool.workflowId, workflowId)) logger.info(`[${requestId}] Removed MCP tools for workflow: ${workflowId}`) - if (notify) notifyMcpToolServers(tools) return tools } catch (error) { logger.error(`[${requestId}] Error removing MCP tools:`, error) diff --git a/apps/sim/lib/workflows/deployment-outbox.ts b/apps/sim/lib/workflows/deployment-outbox.ts index b33003f7a4..468e621f10 100644 --- a/apps/sim/lib/workflows/deployment-outbox.ts +++ b/apps/sim/lib/workflows/deployment-outbox.ts @@ -461,7 +461,7 @@ async function removeMcpToolsIfStillUndeployed( .limit(1) if (!workflowRecord || workflowRecord.isDeployed) return [] - return removeMcpToolsForWorkflow(workflowId, requestId, tx, false, true) + return removeMcpToolsForWorkflow(workflowId, requestId, tx, true) }) notifyMcpToolServers(tools) } diff --git a/apps/sim/lib/workflows/schedules/deploy.test.ts b/apps/sim/lib/workflows/schedules/deploy.test.ts index 8e6e5a3318..9bbac181f7 100644 --- a/apps/sim/lib/workflows/schedules/deploy.test.ts +++ b/apps/sim/lib/workflows/schedules/deploy.test.ts @@ -721,7 +721,7 @@ describe('Schedule Deploy Utilities', () => { setupMockTransaction() - const result = await createSchedulesForDeploy('workflow-1', blocks, {} as any) + const result = await createSchedulesForDeploy('workflow-1', blocks) expect(result.success).toBe(true) expect(mockTransaction).not.toHaveBeenCalled() @@ -742,7 +742,7 @@ describe('Schedule Deploy Utilities', () => { setupMockTransaction() - const result = await createSchedulesForDeploy('workflow-1', blocks, {} as any) + const result = await createSchedulesForDeploy('workflow-1', blocks) expect(result.success).toBe(true) expect(result.scheduleId).toBe('test-uuid') @@ -767,13 +767,37 @@ describe('Schedule Deploy Utilities', () => { setupMockTransaction() - const result = await createSchedulesForDeploy('workflow-1', blocks, {} as any) + const result = await createSchedulesForDeploy('workflow-1', blocks) expect(result.success).toBe(false) expect(result.error).toBe('Time is required for daily schedules') expect(mockTransaction).not.toHaveBeenCalled() }) + it('should write through a provided transaction without opening a new one', async () => { + const blocks: Record = { + 'block-1': { + id: 'block-1', + type: 'schedule', + subBlocks: { + scheduleType: { value: 'daily' }, + dailyTime: { value: '09:00' }, + timezone: { value: 'UTC' }, + }, + } as BlockState, + } + + setupMockTransaction() + const callerTx = { insert: mockInsert, delete: mockDelete, select: mockSelect } as any + + const result = await createSchedulesForDeploy('workflow-1', blocks, callerTx) + + expect(result.success).toBe(true) + expect(mockTransaction).not.toHaveBeenCalled() + expect(mockInsert).toHaveBeenCalled() + expect(mockOnConflictDoUpdate).toHaveBeenCalled() + }) + it('should use onConflictDoUpdate for existing schedules', async () => { const blocks: Record = { 'block-1': { @@ -789,7 +813,7 @@ describe('Schedule Deploy Utilities', () => { setupMockTransaction() - await createSchedulesForDeploy('workflow-1', blocks, {} as any) + await createSchedulesForDeploy('workflow-1', blocks) expect(mockOnConflictDoUpdate).toHaveBeenCalledWith({ target: expect.any(Array), @@ -818,7 +842,7 @@ describe('Schedule Deploy Utilities', () => { mockTransaction.mockRejectedValueOnce(new Error('Database error')) - const result = await createSchedulesForDeploy('workflow-1', blocks, {} as any) + const result = await createSchedulesForDeploy('workflow-1', blocks) expect(result.success).toBe(false) expect(result.error).toBe('Database error') diff --git a/apps/sim/lib/workflows/schedules/deploy.ts b/apps/sim/lib/workflows/schedules/deploy.ts index 1973a2cb8b..4a6d7893d9 100644 --- a/apps/sim/lib/workflows/schedules/deploy.ts +++ b/apps/sim/lib/workflows/schedules/deploy.ts @@ -24,12 +24,13 @@ export interface ScheduleDeployResult { /** * Create or update schedule records for a workflow during deployment. - * Uses a transaction to ensure atomicity - all schedules are created or none are. + * Atomic either way: writes run inside the caller's transaction when `tx` + * is provided, otherwise inside a transaction opened here. */ export async function createSchedulesForDeploy( workflowId: string, blocks: Record, - dbCtx: DbOrTx, + tx?: DbOrTx, deploymentVersionId?: string ): Promise { const scheduleBlocks = findScheduleBlocks(blocks) @@ -73,10 +74,10 @@ export async function createSchedulesForDeploy( } | null = null try { - const writeSchedules = async (tx: DbOrTx) => { + const writeSchedules = async (trx: DbOrTx) => { const currentBlockIds = new Set(validatedBlocks.map((b) => b.blockId)) - const existingSchedules = await tx + const existingSchedules = await trx .select({ id: workflowSchedule.id, blockId: workflowSchedule.blockId }) .from(workflowSchedule) .where( @@ -97,7 +98,7 @@ export async function createSchedulesForDeploy( logger.info( `Deleting ${orphanedScheduleIds.length} orphaned schedule(s) for workflow ${workflowId}` ) - await tx.delete(workflowSchedule).where(inArray(workflowSchedule.id, orphanedScheduleIds)) + await trx.delete(workflowSchedule).where(inArray(workflowSchedule.id, orphanedScheduleIds)) } for (const validated of validatedBlocks) { @@ -133,7 +134,7 @@ export async function createSchedulesForDeploy( infraRetryCount: 0, } - await tx + await trx .insert(workflowSchedule) .values(values) .onConflictDoUpdate({ @@ -156,11 +157,9 @@ export async function createSchedulesForDeploy( } } - if (dbCtx === db || !hasScheduleWriteMethods(dbCtx)) { - await db.transaction(writeSchedules) - } else { - await writeSchedules(dbCtx) - } + // The global client is not a transaction — wrap it so the atomicity + // contract holds even if a caller passes `db` explicitly. + await (tx && tx !== db ? writeSchedules(tx) : db.transaction(writeSchedules)) } catch (error) { logger.error(`Failed to create schedules for workflow ${workflowId}`, error) return { @@ -175,15 +174,6 @@ export async function createSchedulesForDeploy( } } -function hasScheduleWriteMethods(value: DbOrTx): boolean { - const candidate = value as Partial> - return ( - typeof candidate.select === 'function' && - typeof candidate.insert === 'function' && - typeof candidate.delete === 'function' - ) -} - /** * Delete all schedules for a workflow * This should be called within a database transaction during undeploy