Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix(knowledge): fix search embedding test mocks, parallelize billing …
…lookups
  • Loading branch information
waleedlatif1 committed Apr 4, 2026
commit 4e6cb8f46a709ad12f24277a08e44486b44dfa5c
103 changes: 43 additions & 60 deletions apps/sim/app/api/knowledge/search/utils.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* @vitest-environment node
*/
import { createEnvMock, databaseMock, loggerMock } from '@sim/testing'
import { mockNextFetchResponse } from '@sim/testing/mocks'
import { beforeEach, describe, expect, it, vi } from 'vitest'

vi.mock('drizzle-orm')
Expand All @@ -14,16 +15,6 @@ vi.mock('@/lib/knowledge/documents/utils', () => ({
retryWithExponentialBackoff: (fn: any) => fn(),
}))

vi.stubGlobal(
'fetch',
vi.fn().mockResolvedValue({
ok: true,
json: async () => ({
data: [{ embedding: [0.1, 0.2, 0.3] }],
}),
})
)

vi.mock('@/lib/core/config/env', () => createEnvMock())

import {
Expand Down Expand Up @@ -178,17 +169,16 @@ describe('Knowledge Search Utils', () => {
OPENAI_API_KEY: 'test-openai-key',
})

const fetchSpy = vi.mocked(fetch)
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({
mockNextFetchResponse({
json: {
data: [{ embedding: [0.1, 0.2, 0.3] }],
}),
} as any)
usage: { prompt_tokens: 1, total_tokens: 1 },
},
})

const result = await generateSearchEmbedding('test query')

expect(fetchSpy).toHaveBeenCalledWith(
expect(vi.mocked(fetch)).toHaveBeenCalledWith(
'https://test.openai.azure.com/openai/deployments/text-embedding-ada-002/embeddings?api-version=2024-12-01-preview',
expect.objectContaining({
headers: expect.objectContaining({
Expand All @@ -209,17 +199,16 @@ describe('Knowledge Search Utils', () => {
OPENAI_API_KEY: 'test-openai-key',
})

const fetchSpy = vi.mocked(fetch)
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({
mockNextFetchResponse({
json: {
data: [{ embedding: [0.1, 0.2, 0.3] }],
}),
} as any)
usage: { prompt_tokens: 1, total_tokens: 1 },
},
})

const result = await generateSearchEmbedding('test query')

expect(fetchSpy).toHaveBeenCalledWith(
expect(vi.mocked(fetch)).toHaveBeenCalledWith(
'https://api.openai.com/v1/embeddings',
expect.objectContaining({
headers: expect.objectContaining({
Expand All @@ -243,17 +232,16 @@ describe('Knowledge Search Utils', () => {
OPENAI_API_KEY: 'test-openai-key',
})

const fetchSpy = vi.mocked(fetch)
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({
mockNextFetchResponse({
json: {
data: [{ embedding: [0.1, 0.2, 0.3] }],
}),
} as any)
usage: { prompt_tokens: 1, total_tokens: 1 },
},
})

await generateSearchEmbedding('test query')

expect(fetchSpy).toHaveBeenCalledWith(
expect(vi.mocked(fetch)).toHaveBeenCalledWith(
expect.stringContaining('api-version='),
expect.any(Object)
)
Expand All @@ -273,17 +261,16 @@ describe('Knowledge Search Utils', () => {
OPENAI_API_KEY: 'test-openai-key',
})

const fetchSpy = vi.mocked(fetch)
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({
mockNextFetchResponse({
json: {
data: [{ embedding: [0.1, 0.2, 0.3] }],
}),
} as any)
usage: { prompt_tokens: 1, total_tokens: 1 },
},
})

await generateSearchEmbedding('test query', 'text-embedding-3-small')

expect(fetchSpy).toHaveBeenCalledWith(
expect(vi.mocked(fetch)).toHaveBeenCalledWith(
'https://test.openai.azure.com/openai/deployments/custom-embedding-model/embeddings?api-version=2024-12-01-preview',
expect.any(Object)
)
Expand Down Expand Up @@ -311,13 +298,12 @@ describe('Knowledge Search Utils', () => {
KB_OPENAI_MODEL_NAME: 'text-embedding-ada-002',
})

const fetchSpy = vi.mocked(fetch)
fetchSpy.mockResolvedValueOnce({
mockNextFetchResponse({
ok: false,
status: 404,
statusText: 'Not Found',
text: async () => 'Deployment not found',
} as any)
text: 'Deployment not found',
})

await expect(generateSearchEmbedding('test query')).rejects.toThrow('Embedding API failed')

Expand All @@ -332,13 +318,12 @@ describe('Knowledge Search Utils', () => {
OPENAI_API_KEY: 'test-openai-key',
})

const fetchSpy = vi.mocked(fetch)
fetchSpy.mockResolvedValueOnce({
mockNextFetchResponse({
ok: false,
status: 429,
statusText: 'Too Many Requests',
text: async () => 'Rate limit exceeded',
} as any)
text: 'Rate limit exceeded',
})

await expect(generateSearchEmbedding('test query')).rejects.toThrow('Embedding API failed')

Expand All @@ -356,17 +341,16 @@ describe('Knowledge Search Utils', () => {
KB_OPENAI_MODEL_NAME: 'text-embedding-ada-002',
})

const fetchSpy = vi.mocked(fetch)
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({
mockNextFetchResponse({
json: {
data: [{ embedding: [0.1, 0.2, 0.3] }],
}),
} as any)
usage: { prompt_tokens: 1, total_tokens: 1 },
},
})

await generateSearchEmbedding('test query')

expect(fetchSpy).toHaveBeenCalledWith(
expect(vi.mocked(fetch)).toHaveBeenCalledWith(
expect.any(String),
expect.objectContaining({
body: JSON.stringify({
Expand All @@ -387,17 +371,16 @@ describe('Knowledge Search Utils', () => {
OPENAI_API_KEY: 'test-openai-key',
})

const fetchSpy = vi.mocked(fetch)
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({
mockNextFetchResponse({
json: {
data: [{ embedding: [0.1, 0.2, 0.3] }],
}),
} as any)
usage: { prompt_tokens: 1, total_tokens: 1 },
},
})

await generateSearchEmbedding('test query', 'text-embedding-3-small')

expect(fetchSpy).toHaveBeenCalledWith(
expect(vi.mocked(fetch)).toHaveBeenCalledWith(
expect.any(String),
expect.objectContaining({
body: JSON.stringify({
Expand Down
12 changes: 8 additions & 4 deletions apps/sim/lib/billing/core/subscription.ts
Original file line number Diff line number Diff line change
Expand Up @@ -448,9 +448,11 @@ export async function hasInboxAccess(userId: string): Promise<boolean> {
if (!isProd) {
return true
}
const sub = await getHighestPrioritySubscription(userId)
const [sub, billingStatus] = await Promise.all([
getHighestPrioritySubscription(userId),
getEffectiveBillingStatus(userId),
])
if (!sub) return false
const billingStatus = await getEffectiveBillingStatus(userId)
if (!hasUsableSubscriptionAccess(sub.status, billingStatus.billingBlocked)) return false
return getPlanTierCredits(sub.plan) >= 25000 || checkEnterprisePlan(sub)
} catch (error) {
Expand All @@ -470,9 +472,11 @@ export async function hasLiveSyncAccess(userId: string): Promise<boolean> {
if (!isHosted) {
return true
}
const sub = await getHighestPrioritySubscription(userId)
const [sub, billingStatus] = await Promise.all([
getHighestPrioritySubscription(userId),
getEffectiveBillingStatus(userId),
])
if (!sub) return false
const billingStatus = await getEffectiveBillingStatus(userId)
if (!hasUsableSubscriptionAccess(sub.status, billingStatus.billingBlocked)) return false
return getPlanTierCredits(sub.plan) >= 25000 || checkEnterprisePlan(sub)
} catch (error) {
Expand Down