Skip to content

Commit 1228ebd

Browse files
fix(db): close optional-executor contract traps (#4989)
1 parent ebf434f commit 1228ebd

12 files changed

Lines changed: 172 additions & 110 deletions

File tree

apps/sim/lib/billing/core/billing.ts

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import {
2222
isOrgScopedSubscription,
2323
} from '@/lib/billing/subscriptions/utils'
2424
import { Decimal, toDecimal, toNumber } from '@/lib/billing/utils/decimal'
25-
import type { DbOrTx } from '@/lib/db/types'
25+
import type { DbClient } from '@/lib/db/types'
2626

2727
export { getPlanPricing }
2828

@@ -32,6 +32,8 @@ const logger = createLogger('Billing')
3232

3333
interface GetOrganizationSubscriptionOptions {
3434
onError?: 'return-null' | 'throw'
35+
/** Read-routing client (primary or replica); defaults to the primary. */
36+
executor?: DbClient
3537
}
3638

3739
/**
@@ -42,14 +44,18 @@ interface GetOrganizationSubscriptionOptions {
4244
* For product-access gating use `getOrganizationSubscriptionUsable`
4345
* (from `core/subscription.ts`), which excludes `past_due`.
4446
* Returns `null` when there is no entitled sub.
47+
*
48+
* `options.executor` exists for replica routing on display/summary read
49+
* paths only. Enforcement and webhook callers must read the primary —
50+
* omit the executor (or pass `db`).
4551
*/
4652
export async function getOrganizationSubscription(
4753
organizationId: string,
4854
options: GetOrganizationSubscriptionOptions = {}
4955
) {
50-
const { onError = 'return-null' } = options
56+
const { onError = 'return-null', executor = db } = options
5157
try {
52-
const orgSubs = await db
58+
const orgSubs = await executor
5359
.select()
5460
.from(subscription)
5561
.where(
@@ -111,13 +117,16 @@ export async function isSubscriptionOrgScoped(sub: { referenceId: string }): Pro
111117
* column is `NOT NULL DEFAULT '0'` and mixing scopes would break
112118
* current-period billing math.
113119
*/
114-
async function aggregateOrgMemberStats(organizationId: string): Promise<{
120+
async function aggregateOrgMemberStats(
121+
organizationId: string,
122+
executor: DbClient = db
123+
): Promise<{
115124
memberIds: string[]
116125
currentPeriodCost: number
117126
currentPeriodCopilotCost: number
118127
lastPeriodCopilotCost: number
119128
}> {
120-
const rows = await db
129+
const rows = await executor
121130
.select({
122131
userId: member.userId,
123132
currentPeriodCost: userStats.currentPeriodCost,
@@ -386,7 +395,7 @@ export async function calculateSubscriptionOverage(sub: {
386395
export async function getSimplifiedBillingSummary(
387396
userId: string,
388397
organizationId?: string,
389-
executor: DbOrTx = db
398+
executor: DbClient = db
390399
): Promise<{
391400
type: 'individual' | 'organization'
392401
plan: string
@@ -432,8 +441,8 @@ export async function getSimplifiedBillingSummary(
432441
// Get subscription and usage data upfront
433442
const [subscription, usageData] = await Promise.all([
434443
organizationId
435-
? getOrganizationSubscription(organizationId)
436-
: getHighestPrioritySubscription(userId),
444+
? getOrganizationSubscription(organizationId, { executor })
445+
: getHighestPrioritySubscription(userId, { executor }),
437446
getUserUsageData(userId, executor),
438447
])
439448

@@ -455,7 +464,7 @@ export async function getSimplifiedBillingSummary(
455464
// Pool usage/copilot across all members in one query. Must not use
456465
// `getUserUsageData` per-member — it now returns the pool itself
457466
// for org-scoped subs, which would N-times-count.
458-
const pooled = await aggregateOrgMemberStats(organizationId)
467+
const pooled = await aggregateOrgMemberStats(organizationId, executor)
459468

460469
const rawCurrentUsage = pooled.currentPeriodCost
461470
const totalLastPeriodCopilotCost = pooled.lastPeriodCopilotCost
@@ -495,7 +504,8 @@ export async function getSimplifiedBillingSummary(
495504
if (planDollars > 0) {
496505
const userBounds = await getOrgMemberRefreshBounds(
497506
organizationId,
498-
subscription.periodStart
507+
subscription.periodStart,
508+
executor
499509
)
500510
refreshDeduction = await computeDailyRefreshConsumed(
501511
{
@@ -516,7 +526,8 @@ export async function getSimplifiedBillingSummary(
516526
const { limit: orgUsageLimit } = await getOrgUsageLimit(
517527
organizationId,
518528
plan,
519-
subscription.seats ?? null
529+
subscription.seats ?? null,
530+
executor
520531
)
521532

522533
const percentUsed =
@@ -532,7 +543,7 @@ export async function getSimplifiedBillingSummary(
532543
)
533544
: 0
534545

535-
const orgCredits = await getCreditBalance(userId)
546+
const orgCredits = await getCreditBalance(userId, executor)
536547
const orgBillingInterval = getBillingInterval(subscription.metadata as SubscriptionMetadata)
537548

538549
return {
@@ -576,7 +587,7 @@ export async function getSimplifiedBillingSummary(
576587
}
577588
}
578589

579-
const userStatsRows = await db
590+
const userStatsRows = await executor
580591
.select({
581592
currentPeriodCopilotCost: userStats.currentPeriodCopilotCost,
582593
lastPeriodCopilotCost: userStats.lastPeriodCopilotCost,
@@ -597,7 +608,7 @@ export async function getSimplifiedBillingSummary(
597608
let totalCopilotCost = copilotCost
598609
let totalLastPeriodCopilotCost = lastPeriodCopilotCost
599610
if (orgScoped && subscription?.referenceId) {
600-
const pooled = await aggregateOrgMemberStats(subscription.referenceId)
611+
const pooled = await aggregateOrgMemberStats(subscription.referenceId, executor)
601612
totalCopilotCost = pooled.currentPeriodCopilotCost
602613
totalLastPeriodCopilotCost = pooled.lastPeriodCopilotCost
603614
}
@@ -631,7 +642,7 @@ export async function getSimplifiedBillingSummary(
631642
)
632643
: 0
633644

634-
const userCredits = await getCreditBalance(userId)
645+
const userCredits = await getCreditBalance(userId, executor)
635646
const individualBillingInterval = getBillingInterval(
636647
subscription?.metadata as SubscriptionMetadata
637648
)

apps/sim/lib/billing/core/organization.ts

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import {
1919
hasUsableSubscriptionStatus,
2020
} from '@/lib/billing/subscriptions/utils'
2121
import { toDecimal, toNumber } from '@/lib/billing/utils/decimal'
22-
import type { DbOrTx } from '@/lib/db/types'
22+
import type { DbClient } from '@/lib/db/types'
2323

2424
const logger = createLogger('OrganizationBilling')
2525

@@ -66,11 +66,11 @@ interface MemberUsageData {
6666
export async function getOrgMemberLedgerByUser(
6767
organizationId: string,
6868
period?: { start: Date; end: Date } | null,
69-
executor: DbOrTx = db
69+
executor: DbClient = db
7070
): Promise<Map<string, number>> {
7171
let billingPeriod = period ?? null
7272
if (period === undefined) {
73-
const subscription = await getOrganizationSubscription(organizationId)
73+
const subscription = await getOrganizationSubscription(organizationId, { executor })
7474
billingPeriod =
7575
subscription?.periodStart && subscription?.periodEnd
7676
? { start: subscription.periodStart, end: subscription.periodEnd }
@@ -90,11 +90,11 @@ export async function getOrgMemberLedgerByUser(
9090
*/
9191
export async function getOrganizationBillingData(
9292
organizationId: string,
93-
executor: DbOrTx = db
93+
executor: DbClient = db
9494
): Promise<OrganizationUsageData | null> {
9595
try {
9696
// Get organization info
97-
const orgRecord = await db
97+
const orgRecord = await executor
9898
.select()
9999
.from(organization)
100100
.where(eq(organization.id, organizationId))
@@ -108,15 +108,15 @@ export async function getOrganizationBillingData(
108108
const organizationData = orgRecord[0]
109109

110110
// Get organization subscription directly (referenceId = organizationId)
111-
const subscription = await getOrganizationSubscription(organizationId)
111+
const subscription = await getOrganizationSubscription(organizationId, { executor })
112112

113113
if (!subscription) {
114114
logger.warn('No subscription found for organization', { organizationId })
115115
return null
116116
}
117117

118118
// Get all organization members with their usage data
119-
const membersWithUsage = await db
119+
const membersWithUsage = await executor
120120
.select({
121121
userId: member.userId,
122122
userName: user.name,
@@ -185,7 +185,8 @@ export async function getOrganizationBillingData(
185185
const memberIds = members.map((m) => m.userId)
186186
const userBounds = await getOrgMemberRefreshBounds(
187187
subscription.referenceId,
188-
subscription.periodStart
188+
subscription.periodStart,
189+
executor
189190
)
190191
const refreshConsumed = await computeDailyRefreshConsumed(
191192
{
@@ -233,7 +234,7 @@ export async function getOrganizationBillingData(
233234

234235
const averageUsagePerMember = members.length > 0 ? totalCurrentUsage / members.length : 0
235236

236-
const [pendingInvitationCount] = await db
237+
const [pendingInvitationCount] = await executor
237238
.select({ count: count() })
238239
.from(invitation)
239240
.where(

apps/sim/lib/billing/core/plan.ts

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,16 @@ import {
88
checkTeamPlan,
99
ENTITLED_SUBSCRIPTION_STATUSES,
1010
} from '@/lib/billing/subscriptions/utils'
11+
import type { DbClient } from '@/lib/db/types'
1112

1213
const logger = createLogger('PlanLookup')
1314

1415
export type HighestPrioritySubscription = Awaited<ReturnType<typeof getHighestPrioritySubscription>>
1516

1617
interface GetHighestPrioritySubscriptionOptions {
1718
onError?: 'return-null' | 'throw'
19+
/** Read-routing client (primary or replica); defaults to the primary. */
20+
executor?: DbClient
1821
}
1922

2023
function pickHighestPrioritySubscription<TSubscription>(
@@ -33,9 +36,9 @@ export async function getHighestPriorityPersonalSubscription(
3336
userId: string,
3437
options: GetHighestPrioritySubscriptionOptions = {}
3538
) {
36-
const { onError = 'return-null' } = options
39+
const { onError = 'return-null', executor = db } = options
3740
try {
38-
const personalSubs = await db
41+
const personalSubs = await executor
3942
.select()
4043
.from(subscription)
4144
.where(
@@ -77,9 +80,9 @@ export async function getHighestPrioritySubscription(
7780
userId: string,
7881
options: GetHighestPrioritySubscriptionOptions = {}
7982
) {
80-
const { onError = 'return-null' } = options
83+
const { onError = 'return-null', executor = db } = options
8184
try {
82-
const personalSubs = await db
85+
const personalSubs = await executor
8386
.select()
8487
.from(subscription)
8588
.where(
@@ -89,7 +92,7 @@ export async function getHighestPrioritySubscription(
8992
)
9093
)
9194

92-
const memberships = await db
95+
const memberships = await executor
9396
.select({ organizationId: member.organizationId })
9497
.from(member)
9598
.where(eq(member.userId, userId))
@@ -99,15 +102,15 @@ export async function getHighestPrioritySubscription(
99102
let orgSubs: typeof personalSubs = []
100103
if (orgIds.length > 0) {
101104
// Verify orgs exist to filter out orphaned subscriptions
102-
const existingOrgs = await db
105+
const existingOrgs = await executor
103106
.select({ id: organization.id })
104107
.from(organization)
105108
.where(inArray(organization.id, orgIds))
106109

107110
const validOrgIds = existingOrgs.map((o) => o.id)
108111

109112
if (validOrgIds.length > 0) {
110-
orgSubs = await db
113+
orgSubs = await executor
111114
.select()
112115
.from(subscription)
113116
.where(

apps/sim/lib/billing/core/usage-log.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import { and, desc, eq, gte, inArray, lt, lte, sql } from 'drizzle-orm'
88
import { defaultBillingPeriod } from '@/lib/billing/core/billing-period'
99
import { getHighestPrioritySubscription } from '@/lib/billing/core/plan'
1010
import { isOrgScopedSubscription } from '@/lib/billing/subscriptions/utils'
11-
import type { DbOrTx } from '@/lib/db/types'
11+
import type { DbClient, DbOrTx } from '@/lib/db/types'
1212

1313
const logger = createLogger('UsageLog')
1414

@@ -184,7 +184,7 @@ export async function getBillingPeriodUsageCost(
184184
billingEntity: BillingEntity,
185185
billingPeriod: { start: Date; end: Date },
186186
source?: UsageLogSource | UsageLogSource[],
187-
executor: DbOrTx = db
187+
executor: DbClient = db
188188
): Promise<number> {
189189
const conditions = [
190190
eq(usageLog.billingEntityType, billingEntity.type),
@@ -212,7 +212,7 @@ export async function getBillingPeriodUsageCostByUser(
212212
billingEntity: BillingEntity,
213213
billingPeriod: { start: Date; end: Date },
214214
source?: UsageLogSource | UsageLogSource[],
215-
executor: DbOrTx = db
215+
executor: DbClient = db
216216
): Promise<Map<string, number>> {
217217
const conditions = [
218218
eq(usageLog.billingEntityType, billingEntity.type),

0 commit comments

Comments
 (0)