Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 26 additions & 15 deletions apps/sim/lib/billing/core/billing.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import {
isOrgScopedSubscription,
} from '@/lib/billing/subscriptions/utils'
import { Decimal, toDecimal, toNumber } from '@/lib/billing/utils/decimal'
import type { DbOrTx } from '@/lib/db/types'
import type { DbClient } from '@/lib/db/types'

export { getPlanPricing }

Expand All @@ -32,6 +32,8 @@ const logger = createLogger('Billing')

interface GetOrganizationSubscriptionOptions {
onError?: 'return-null' | 'throw'
/** Read-routing client (primary or replica); defaults to the primary. */
executor?: DbClient
}

/**
Expand All @@ -42,14 +44,18 @@ interface GetOrganizationSubscriptionOptions {
* For product-access gating use `getOrganizationSubscriptionUsable`
* (from `core/subscription.ts`), which excludes `past_due`.
* Returns `null` when there is no entitled sub.
*
* `options.executor` exists for replica routing on display/summary read
* paths only. Enforcement and webhook callers must read the primary —
* omit the executor (or pass `db`).
*/
export async function getOrganizationSubscription(
organizationId: string,
options: GetOrganizationSubscriptionOptions = {}
) {
const { onError = 'return-null' } = options
const { onError = 'return-null', executor = db } = options
try {
const orgSubs = await db
const orgSubs = await executor
.select()
.from(subscription)
.where(
Expand Down Expand Up @@ -111,13 +117,16 @@ export async function isSubscriptionOrgScoped(sub: { referenceId: string }): Pro
* column is `NOT NULL DEFAULT '0'` and mixing scopes would break
* current-period billing math.
*/
async function aggregateOrgMemberStats(organizationId: string): Promise<{
async function aggregateOrgMemberStats(
organizationId: string,
executor: DbClient = db
): Promise<{
memberIds: string[]
currentPeriodCost: number
currentPeriodCopilotCost: number
lastPeriodCopilotCost: number
}> {
const rows = await db
const rows = await executor
.select({
userId: member.userId,
currentPeriodCost: userStats.currentPeriodCost,
Expand Down Expand Up @@ -386,7 +395,7 @@ export async function calculateSubscriptionOverage(sub: {
export async function getSimplifiedBillingSummary(
userId: string,
organizationId?: string,
executor: DbOrTx = db
executor: DbClient = db
): Promise<{
type: 'individual' | 'organization'
plan: string
Expand Down Expand Up @@ -432,8 +441,8 @@ export async function getSimplifiedBillingSummary(
// Get subscription and usage data upfront
const [subscription, usageData] = await Promise.all([
organizationId
? getOrganizationSubscription(organizationId)
: getHighestPrioritySubscription(userId),
? getOrganizationSubscription(organizationId, { executor })
: getHighestPrioritySubscription(userId, { executor }),
getUserUsageData(userId, executor),
])

Expand All @@ -455,7 +464,7 @@ export async function getSimplifiedBillingSummary(
// Pool usage/copilot across all members in one query. Must not use
// `getUserUsageData` per-member — it now returns the pool itself
// for org-scoped subs, which would N-times-count.
const pooled = await aggregateOrgMemberStats(organizationId)
const pooled = await aggregateOrgMemberStats(organizationId, executor)

const rawCurrentUsage = pooled.currentPeriodCost
const totalLastPeriodCopilotCost = pooled.lastPeriodCopilotCost
Expand Down Expand Up @@ -495,7 +504,8 @@ export async function getSimplifiedBillingSummary(
if (planDollars > 0) {
const userBounds = await getOrgMemberRefreshBounds(
organizationId,
subscription.periodStart
subscription.periodStart,
executor
)
refreshDeduction = await computeDailyRefreshConsumed(
{
Expand All @@ -516,7 +526,8 @@ export async function getSimplifiedBillingSummary(
const { limit: orgUsageLimit } = await getOrgUsageLimit(
organizationId,
plan,
subscription.seats ?? null
subscription.seats ?? null,
executor
)

const percentUsed =
Expand All @@ -532,7 +543,7 @@ export async function getSimplifiedBillingSummary(
)
: 0

const orgCredits = await getCreditBalance(userId)
const orgCredits = await getCreditBalance(userId, executor)
const orgBillingInterval = getBillingInterval(subscription.metadata as SubscriptionMetadata)

return {
Expand Down Expand Up @@ -576,7 +587,7 @@ export async function getSimplifiedBillingSummary(
}
}

const userStatsRows = await db
const userStatsRows = await executor
.select({
currentPeriodCopilotCost: userStats.currentPeriodCopilotCost,
lastPeriodCopilotCost: userStats.lastPeriodCopilotCost,
Expand All @@ -597,7 +608,7 @@ export async function getSimplifiedBillingSummary(
let totalCopilotCost = copilotCost
let totalLastPeriodCopilotCost = lastPeriodCopilotCost
if (orgScoped && subscription?.referenceId) {
const pooled = await aggregateOrgMemberStats(subscription.referenceId)
const pooled = await aggregateOrgMemberStats(subscription.referenceId, executor)
totalCopilotCost = pooled.currentPeriodCopilotCost
totalLastPeriodCopilotCost = pooled.lastPeriodCopilotCost
}
Expand Down Expand Up @@ -631,7 +642,7 @@ export async function getSimplifiedBillingSummary(
)
: 0

const userCredits = await getCreditBalance(userId)
const userCredits = await getCreditBalance(userId, executor)
const individualBillingInterval = getBillingInterval(
subscription?.metadata as SubscriptionMetadata
)
Expand Down
19 changes: 10 additions & 9 deletions apps/sim/lib/billing/core/organization.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import {
hasUsableSubscriptionStatus,
} from '@/lib/billing/subscriptions/utils'
import { toDecimal, toNumber } from '@/lib/billing/utils/decimal'
import type { DbOrTx } from '@/lib/db/types'
import type { DbClient } from '@/lib/db/types'

const logger = createLogger('OrganizationBilling')

Expand Down Expand Up @@ -66,11 +66,11 @@ interface MemberUsageData {
export async function getOrgMemberLedgerByUser(
organizationId: string,
period?: { start: Date; end: Date } | null,
executor: DbOrTx = db
executor: DbClient = db
): Promise<Map<string, number>> {
let billingPeriod = period ?? null
if (period === undefined) {
const subscription = await getOrganizationSubscription(organizationId)
const subscription = await getOrganizationSubscription(organizationId, { executor })
billingPeriod =
subscription?.periodStart && subscription?.periodEnd
? { start: subscription.periodStart, end: subscription.periodEnd }
Expand All @@ -90,11 +90,11 @@ export async function getOrgMemberLedgerByUser(
*/
export async function getOrganizationBillingData(
organizationId: string,
executor: DbOrTx = db
executor: DbClient = db
): Promise<OrganizationUsageData | null> {
try {
// Get organization info
const orgRecord = await db
const orgRecord = await executor
.select()
.from(organization)
.where(eq(organization.id, organizationId))
Expand All @@ -108,15 +108,15 @@ export async function getOrganizationBillingData(
const organizationData = orgRecord[0]

// Get organization subscription directly (referenceId = organizationId)
const subscription = await getOrganizationSubscription(organizationId)
const subscription = await getOrganizationSubscription(organizationId, { executor })

if (!subscription) {
logger.warn('No subscription found for organization', { organizationId })
return null
}

// Get all organization members with their usage data
const membersWithUsage = await db
const membersWithUsage = await executor
.select({
userId: member.userId,
userName: user.name,
Expand Down Expand Up @@ -185,7 +185,8 @@ export async function getOrganizationBillingData(
const memberIds = members.map((m) => m.userId)
const userBounds = await getOrgMemberRefreshBounds(
subscription.referenceId,
subscription.periodStart
subscription.periodStart,
executor
)
const refreshConsumed = await computeDailyRefreshConsumed(
{
Expand Down Expand Up @@ -233,7 +234,7 @@ export async function getOrganizationBillingData(

const averageUsagePerMember = members.length > 0 ? totalCurrentUsage / members.length : 0

const [pendingInvitationCount] = await db
const [pendingInvitationCount] = await executor
.select({ count: count() })
.from(invitation)
.where(
Expand Down
17 changes: 10 additions & 7 deletions apps/sim/lib/billing/core/plan.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@ import {
checkTeamPlan,
ENTITLED_SUBSCRIPTION_STATUSES,
} from '@/lib/billing/subscriptions/utils'
import type { DbClient } from '@/lib/db/types'

const logger = createLogger('PlanLookup')

export type HighestPrioritySubscription = Awaited<ReturnType<typeof getHighestPrioritySubscription>>

interface GetHighestPrioritySubscriptionOptions {
onError?: 'return-null' | 'throw'
/** Read-routing client (primary or replica); defaults to the primary. */
executor?: DbClient
}

function pickHighestPrioritySubscription<TSubscription>(
Expand All @@ -33,9 +36,9 @@ export async function getHighestPriorityPersonalSubscription(
userId: string,
options: GetHighestPrioritySubscriptionOptions = {}
) {
const { onError = 'return-null' } = options
const { onError = 'return-null', executor = db } = options
try {
const personalSubs = await db
const personalSubs = await executor
.select()
.from(subscription)
.where(
Expand Down Expand Up @@ -77,9 +80,9 @@ export async function getHighestPrioritySubscription(
userId: string,
options: GetHighestPrioritySubscriptionOptions = {}
) {
const { onError = 'return-null' } = options
const { onError = 'return-null', executor = db } = options
try {
const personalSubs = await db
const personalSubs = await executor
.select()
.from(subscription)
.where(
Expand All @@ -89,7 +92,7 @@ export async function getHighestPrioritySubscription(
)
)

const memberships = await db
const memberships = await executor
.select({ organizationId: member.organizationId })
.from(member)
.where(eq(member.userId, userId))
Expand All @@ -99,15 +102,15 @@ export async function getHighestPrioritySubscription(
let orgSubs: typeof personalSubs = []
if (orgIds.length > 0) {
// Verify orgs exist to filter out orphaned subscriptions
const existingOrgs = await db
const existingOrgs = await executor
.select({ id: organization.id })
.from(organization)
.where(inArray(organization.id, orgIds))

const validOrgIds = existingOrgs.map((o) => o.id)

if (validOrgIds.length > 0) {
orgSubs = await db
orgSubs = await executor
.select()
.from(subscription)
.where(
Expand Down
6 changes: 3 additions & 3 deletions apps/sim/lib/billing/core/usage-log.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import { and, desc, eq, gte, inArray, lt, lte, sql } from 'drizzle-orm'
import { defaultBillingPeriod } from '@/lib/billing/core/billing-period'
import { getHighestPrioritySubscription } from '@/lib/billing/core/plan'
import { isOrgScopedSubscription } from '@/lib/billing/subscriptions/utils'
import type { DbOrTx } from '@/lib/db/types'
import type { DbClient, DbOrTx } from '@/lib/db/types'

const logger = createLogger('UsageLog')

Expand Down Expand Up @@ -184,7 +184,7 @@ export async function getBillingPeriodUsageCost(
billingEntity: BillingEntity,
billingPeriod: { start: Date; end: Date },
source?: UsageLogSource | UsageLogSource[],
executor: DbOrTx = db
executor: DbClient = db
): Promise<number> {
const conditions = [
eq(usageLog.billingEntityType, billingEntity.type),
Expand Down Expand Up @@ -212,7 +212,7 @@ export async function getBillingPeriodUsageCostByUser(
billingEntity: BillingEntity,
billingPeriod: { start: Date; end: Date },
source?: UsageLogSource | UsageLogSource[],
executor: DbOrTx = db
executor: DbClient = db
): Promise<Map<string, number>> {
const conditions = [
eq(usageLog.billingEntityType, billingEntity.type),
Expand Down
Loading
Loading