Skip to content

Commit 22fe096

Browse files
N2D4fomalhautb
andauthored
Fix source of truth for custom schemas (stack-auth#764)
Co-authored-by: Zai Shi <zaishi00@outlook.com>
1 parent f99c3e9 commit 22fe096

File tree

6 files changed

+79
-36
lines changed

6 files changed

+79
-36
lines changed

.github/workflows/e2e-source-of-truth-api-tests.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jobs:
1717
env:
1818
NODE_ENV: test
1919
STACK_ENABLE_HARDCODED_PASSKEY_CHALLENGE_FOR_TESTING: yes
20-
STACK_OVERRIDE_SOURCE_OF_TRUTH: '{"type": "postgres", "connectionString": "postgres://postgres:PASSWORD-PLACEHOLDER--uqfEC1hmmv@localhost:5432/stackframe2"}'
20+
STACK_OVERRIDE_SOURCE_OF_TRUTH: '{"type": "postgres", "connectionString": "postgres://postgres:PASSWORD-PLACEHOLDER--uqfEC1hmmv@localhost:5432/source-of-truth-db?schema=sot-schema"}'
2121
STACK_TEST_SOURCE_OF_TRUTH: true
2222

2323
strategy:
@@ -98,7 +98,7 @@ jobs:
9898
run: npx wait-on tcp:localhost:8113
9999

100100
- name: Initialize source of truth database
101-
run: "STACK_DIRECT_DATABASE_CONNECTION_STRING='postgres://postgres:PASSWORD-PLACEHOLDER--uqfEC1hmmv@localhost:5432/stackframe2' pnpm run prisma -- migrate reset --force --skip-seed"
101+
run: "STACK_DIRECT_DATABASE_CONNECTION_STRING='postgres://postgres:PASSWORD-PLACEHOLDER--uqfEC1hmmv@localhost:5432/source-of-truth-db?schema=sot-schema' pnpm run prisma -- migrate reset --force --skip-seed"
102102

103103
- name: Initialize database
104104
run: pnpm run prisma -- migrate reset --force
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
-- Drop triggers for project user count
2+
DROP TRIGGER project_user_insert_trigger ON "ProjectUser";
3+
DROP TRIGGER project_user_update_trigger ON "ProjectUser";
4+
DROP TRIGGER project_user_delete_trigger ON "ProjectUser";
5+
6+
-- Drop function for updating project user count
7+
DROP FUNCTION update_project_user_count();

apps/backend/src/app/api/latest/auth/sessions/crud.tsx

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { getPrismaClientForTenancy, globalPrismaClient } from "@/prisma-client";
1+
import { getPrismaClientForTenancy, getPrismaSchemaForTenancy, globalPrismaClient, sqlQuoteIdent } from "@/prisma-client";
22
import { createCrudHandlers } from "@/route-handlers/crud-handler";
33
import { SmartRequestAuth } from "@/route-handlers/smart-request";
44
import { Prisma } from "@prisma/client";
@@ -18,6 +18,7 @@ export const sessionsCrudHandlers = createLazyProxy(() => createCrudHandlers(ses
1818
}).defined(),
1919
onList: async ({ auth, query }) => {
2020
const prisma = getPrismaClientForTenancy(auth.tenancy);
21+
const schema = getPrismaSchemaForTenancy(auth.tenancy);
2122
const listImpersonations = auth.type === 'admin';
2223

2324
if (auth.type === 'client') {
@@ -38,13 +39,12 @@ export const sessionsCrudHandlers = createLazyProxy(() => createCrudHandlers(ses
3839
},
3940
});
4041

41-
4242
// Get the latest event for each session
4343
const events = await prisma.$queryRaw<Array<{ sessionId: string, lastActiveAt: Date, geo: GeoInfo | null, isEndUserIpInfoGuessTrusted: boolean }>>`
4444
WITH latest_events AS (
4545
SELECT data->>'sessionId' as "sessionId",
4646
MAX("eventStartedAt") as "lastActiveAt"
47-
FROM "Event"
47+
FROM ${sqlQuoteIdent(schema)}."Event"
4848
WHERE ${refreshTokenObjs.length > 0
4949
? Prisma.sql`data->>'sessionId' = ANY(${Prisma.sql`ARRAY[${Prisma.join(refreshTokenObjs.map(s => s.id))}]`})`
5050
: Prisma.sql`FALSE`}
@@ -55,9 +55,9 @@ export const sessionsCrudHandlers = createLazyProxy(() => createCrudHandlers(ses
5555
le."lastActiveAt",
5656
row_to_json(geo.*) as "geo",
5757
e.data->>'isEndUserIpInfoGuessTrusted' as "isEndUserIpInfoGuessTrusted"
58-
FROM "Event" e
58+
FROM ${sqlQuoteIdent(schema)}."Event" e
5959
JOIN latest_events le ON e.data->>'sessionId' = le."sessionId" AND e."eventStartedAt" = le."lastActiveAt"
60-
LEFT JOIN "EventIpInfo" geo ON geo.id = e."endUserIpInfoGuessId"
60+
LEFT JOIN ${sqlQuoteIdent(schema)}."EventIpInfo" geo ON geo.id = e."endUserIpInfoGuessId"
6161
WHERE e."systemEventTypeIds" @> '{"$session-activity"}'
6262
`;
6363

apps/backend/src/app/api/latest/internal/metrics/route.tsx

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { Tenancy } from "@/lib/tenancies";
2-
import { getPrismaClientForTenancy, globalPrismaClient } from "@/prisma-client";
2+
import { getPrismaClientForTenancy, getPrismaSchemaForTenancy, globalPrismaClient, sqlQuoteIdent } from "@/prisma-client";
33
import { createSmartRouteHandler } from "@/route-handlers/smart-route-handler";
44
import { KnownErrors } from "@stackframe/stack-shared";
55
import { UsersCrud } from "@stackframe/stack-shared/dist/interface/crud/users";
@@ -45,7 +45,9 @@ async function loadUsersByCountry(tenancy: Tenancy): Promise<Record<string, numb
4545
}
4646

4747
async function loadTotalUsers(tenancy: Tenancy, now: Date): Promise<DataPoints> {
48-
return (await getPrismaClientForTenancy(tenancy).$queryRaw<{date: Date, dailyUsers: bigint, cumUsers: bigint}[]>`
48+
const schema = getPrismaSchemaForTenancy(tenancy);
49+
const prisma = getPrismaClientForTenancy(tenancy);
50+
return (await prisma.$queryRaw<{date: Date, dailyUsers: bigint, cumUsers: bigint}[]>`
4951
WITH date_series AS (
5052
SELECT GENERATE_SERIES(
5153
${now}::date - INTERVAL '30 days',
@@ -59,7 +61,7 @@ async function loadTotalUsers(tenancy: Tenancy, now: Date): Promise<DataPoints>
5961
COALESCE(COUNT(pu."projectUserId"), 0) AS "dailyUsers",
6062
SUM(COALESCE(COUNT(pu."projectUserId"), 0)) OVER (ORDER BY ds.registration_day) AS "cumUsers"
6163
FROM date_series ds
62-
LEFT JOIN "ProjectUser" pu
64+
LEFT JOIN ${sqlQuoteIdent(schema)}."ProjectUser" pu
6365
ON DATE(pu."createdAt") = ds.registration_day AND pu."tenancyId" = ${tenancy.id}::UUID
6466
GROUP BY ds.registration_day
6567
ORDER BY ds.registration_day
@@ -104,7 +106,9 @@ async function loadDailyActiveUsers(tenancy: Tenancy, now: Date) {
104106
}
105107

106108
async function loadLoginMethods(tenancy: Tenancy): Promise<{method: string, count: number }[]> {
107-
return await getPrismaClientForTenancy(tenancy).$queryRaw<{ method: string, count: number }[]>`
109+
const schema = getPrismaSchemaForTenancy(tenancy);
110+
const prisma = getPrismaClientForTenancy(tenancy);
111+
return await prisma.$queryRaw<{ method: string, count: number }[]>`
108112
WITH tab AS (
109113
SELECT
110114
COALESCE(
@@ -116,11 +120,11 @@ async function loadLoginMethods(tenancy: Tenancy): Promise<{method: string, coun
116120
) AS "method",
117121
method.id AS id
118122
FROM
119-
"AuthMethod" method
120-
LEFT JOIN "OAuthAuthMethod" oaam ON method.id = oaam."authMethodId"
121-
LEFT JOIN "PasswordAuthMethod" pam ON method.id = pam."authMethodId"
122-
LEFT JOIN "PasskeyAuthMethod" pkm ON method.id = pkm."authMethodId"
123-
LEFT JOIN "OtpAuthMethod" oam ON method.id = oam."authMethodId"
123+
${sqlQuoteIdent(schema)}."AuthMethod" method
124+
LEFT JOIN ${sqlQuoteIdent(schema)}."OAuthAuthMethod" oaam ON method.id = oaam."authMethodId"
125+
LEFT JOIN ${sqlQuoteIdent(schema)}."PasswordAuthMethod" pam ON method.id = pam."authMethodId"
126+
LEFT JOIN ${sqlQuoteIdent(schema)}."PasskeyAuthMethod" pkm ON method.id = pkm."authMethodId"
127+
LEFT JOIN ${sqlQuoteIdent(schema)}."OtpAuthMethod" oam ON method.id = oam."authMethodId"
124128
WHERE method."tenancyId" = ${tenancy.id}::UUID)
125129
SELECT LOWER("method") AS method, COUNT(id)::int AS "count" FROM tab
126130
GROUP BY "method"

apps/backend/src/app/api/latest/users/crud.tsx

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import { ensureTeamMembershipExists, ensureUserExists } from "@/lib/request-chec
55
import { getSoleTenancyFromProjectBranch, getTenancy } from "@/lib/tenancies";
66
import { PrismaTransaction } from "@/lib/types";
77
import { sendTeamMembershipDeletedWebhook, sendUserCreatedWebhook, sendUserDeletedWebhook, sendUserUpdatedWebhook } from "@/lib/webhooks";
8-
import { RawQuery, getPrismaClientForSourceOfTruth, getPrismaClientForTenancy, globalPrismaClient, rawQuery, retryTransaction } from "@/prisma-client";
8+
import { RawQuery, getPrismaClientForSourceOfTruth, getPrismaClientForTenancy, getPrismaSchemaForSourceOfTruth, getPrismaSchemaForTenancy, globalPrismaClient, rawQuery, retryTransaction, sqlQuoteIdent } from "@/prisma-client";
99
import { createCrudHandlers } from "@/route-handlers/crud-handler";
1010
import { log } from "@/utils/telemetry";
1111
import { runAsynchronouslyAndWaitUntil } from "@/utils/vercel";
@@ -167,9 +167,10 @@ export const getUsersLastActiveAtMillis = async (projectId: string, branchId: st
167167
const tenancy = await getSoleTenancyFromProjectBranch(projectId, branchId);
168168

169169
const prisma = getPrismaClientForTenancy(tenancy);
170+
const schema = getPrismaSchemaForTenancy(tenancy);
170171
const events = await prisma.$queryRaw<Array<{ userId: string, lastActiveAt: Date }>>`
171172
SELECT data->>'userId' as "userId", MAX("eventStartedAt") as "lastActiveAt"
172-
FROM "Event"
173+
FROM ${sqlQuoteIdent(schema)}."Event"
173174
WHERE data->>'userId' = ANY(${Prisma.sql`ARRAY[${Prisma.join(userIds)}]`}) AND data->>'projectId' = ${projectId} AND COALESCE("data"->>'branchId', 'main') = ${branchId} AND "systemEventTypeIds" @> '{"$user-activity"}'
174175
GROUP BY data->>'userId'
175176
`;
@@ -182,7 +183,7 @@ export const getUsersLastActiveAtMillis = async (projectId: string, branchId: st
182183
});
183184
};
184185

185-
export function getUserQuery(projectId: string, branchId: string, userId: string): RawQuery<UsersCrud["Admin"]["Read"] | null> {
186+
export function getUserQuery(projectId: string, branchId: string, userId: string, schema: string): RawQuery<UsersCrud["Admin"]["Read"] | null> {
186187
return {
187188
supportedPrismaClients: ["source-of-truth"],
188189
sql: Prisma.sql`
@@ -193,22 +194,22 @@ export function getUserQuery(projectId: string, branchId: string, userId: string
193194
jsonb_build_object(
194195
'lastActiveAt', (
195196
SELECT MAX("eventStartedAt") as "lastActiveAt"
196-
FROM "Event"
197+
FROM ${sqlQuoteIdent(schema)}."Event"
197198
WHERE data->>'projectId' = ("ProjectUser"."mirroredProjectId") AND COALESCE("data"->>'branchId', 'main') = ("ProjectUser"."mirroredBranchId") AND "data"->>'userId' = ("ProjectUser"."projectUserId")::text AND "systemEventTypeIds" @> '{"$user-activity"}'
198199
),
199200
'ContactChannels', (
200201
SELECT COALESCE(ARRAY_AGG(
201202
to_jsonb("ContactChannel") ||
202203
jsonb_build_object()
203204
), '{}')
204-
FROM "ContactChannel"
205+
FROM ${sqlQuoteIdent(schema)}."ContactChannel"
205206
WHERE "ContactChannel"."tenancyId" = "ProjectUser"."tenancyId" AND "ContactChannel"."projectUserId" = "ProjectUser"."projectUserId" AND "ContactChannel"."isPrimary" = 'TRUE'
206207
),
207208
'ProjectUserOAuthAccounts', (
208209
SELECT COALESCE(ARRAY_AGG(
209210
to_jsonb("ProjectUserOAuthAccount")
210211
), '{}')
211-
FROM "ProjectUserOAuthAccount"
212+
FROM ${sqlQuoteIdent(schema)}."ProjectUserOAuthAccount"
212213
WHERE "ProjectUserOAuthAccount"."tenancyId" = "ProjectUser"."tenancyId" AND "ProjectUserOAuthAccount"."projectUserId" = "ProjectUser"."projectUserId"
213214
),
214215
'AuthMethods', (
@@ -220,36 +221,36 @@ export function getUserQuery(projectId: string, branchId: string, userId: string
220221
to_jsonb("PasswordAuthMethod") ||
221222
jsonb_build_object()
222223
)
223-
FROM "PasswordAuthMethod"
224+
FROM ${sqlQuoteIdent(schema)}."PasswordAuthMethod"
224225
WHERE "PasswordAuthMethod"."tenancyId" = "ProjectUser"."tenancyId" AND "PasswordAuthMethod"."projectUserId" = "ProjectUser"."projectUserId" AND "PasswordAuthMethod"."authMethodId" = "AuthMethod"."id"
225226
),
226227
'OtpAuthMethod', (
227228
SELECT (
228229
to_jsonb("OtpAuthMethod") ||
229230
jsonb_build_object()
230231
)
231-
FROM "OtpAuthMethod"
232+
FROM ${sqlQuoteIdent(schema)}."OtpAuthMethod"
232233
WHERE "OtpAuthMethod"."tenancyId" = "ProjectUser"."tenancyId" AND "OtpAuthMethod"."projectUserId" = "ProjectUser"."projectUserId" AND "OtpAuthMethod"."authMethodId" = "AuthMethod"."id"
233234
),
234235
'PasskeyAuthMethod', (
235236
SELECT (
236237
to_jsonb("PasskeyAuthMethod") ||
237238
jsonb_build_object()
238239
)
239-
FROM "PasskeyAuthMethod"
240+
FROM ${sqlQuoteIdent(schema)}."PasskeyAuthMethod"
240241
WHERE "PasskeyAuthMethod"."tenancyId" = "ProjectUser"."tenancyId" AND "PasskeyAuthMethod"."projectUserId" = "ProjectUser"."projectUserId" AND "PasskeyAuthMethod"."authMethodId" = "AuthMethod"."id"
241242
),
242243
'OAuthAuthMethod', (
243244
SELECT (
244245
to_jsonb("OAuthAuthMethod") ||
245246
jsonb_build_object()
246247
)
247-
FROM "OAuthAuthMethod"
248+
FROM ${sqlQuoteIdent(schema)}."OAuthAuthMethod"
248249
WHERE "OAuthAuthMethod"."tenancyId" = "ProjectUser"."tenancyId" AND "OAuthAuthMethod"."projectUserId" = "ProjectUser"."projectUserId" AND "OAuthAuthMethod"."authMethodId" = "AuthMethod"."id"
249250
)
250251
)
251252
), '{}')
252-
FROM "AuthMethod"
253+
FROM ${sqlQuoteIdent(schema)}."AuthMethod"
253254
WHERE "AuthMethod"."tenancyId" = "ProjectUser"."tenancyId" AND "AuthMethod"."projectUserId" = "ProjectUser"."projectUserId"
254255
),
255256
'SelectedTeamMember', (
@@ -261,17 +262,17 @@ export function getUserQuery(projectId: string, branchId: string, userId: string
261262
to_jsonb("Team") ||
262263
jsonb_build_object()
263264
)
264-
FROM "Team"
265+
FROM ${sqlQuoteIdent(schema)}."Team"
265266
WHERE "Team"."tenancyId" = "ProjectUser"."tenancyId" AND "Team"."teamId" = "TeamMember"."teamId"
266267
)
267268
)
268269
)
269-
FROM "TeamMember"
270+
FROM ${sqlQuoteIdent(schema)}."TeamMember"
270271
WHERE "TeamMember"."tenancyId" = "ProjectUser"."tenancyId" AND "TeamMember"."projectUserId" = "ProjectUser"."projectUserId" AND "TeamMember"."isSelected" = 'TRUE'
271272
)
272273
)
273274
)
274-
FROM "ProjectUser"
275+
FROM ${sqlQuoteIdent(schema)}."ProjectUser"
275276
WHERE "ProjectUser"."mirroredProjectId" = ${projectId} AND "ProjectUser"."mirroredBranchId" = ${branchId} AND "ProjectUser"."projectUserId" = ${userId}::UUID
276277
)
277278
) AS "row_data_json"
@@ -340,7 +341,7 @@ export function getUserQuery(projectId: string, branchId: string, userId: string
340341
*/
341342
export function getUserIfOnGlobalPrismaClientQuery(projectId: string, branchId: string, userId: string): RawQuery<UsersCrud["Admin"]["Read"] | null> {
342343
return {
343-
...getUserQuery(projectId, branchId, userId),
344+
...getUserQuery(projectId, branchId, userId, "public"),
344345
supportedPrismaClients: ["global"],
345346
};
346347
}
@@ -358,7 +359,7 @@ export async function getUser(options: { userId: string } & ({ projectId: string
358359

359360
const environmentConfig = await rawQuery(globalPrismaClient, getRenderedEnvironmentConfigQuery({ projectId, branchId }));
360361
const prisma = getPrismaClientForSourceOfTruth(environmentConfig.sourceOfTruth, branchId);
361-
const result = await rawQuery(prisma, getUserQuery(projectId, branchId, options.userId));
362+
const result = await rawQuery(prisma, getUserQuery(projectId, branchId, options.userId, getPrismaSchemaForSourceOfTruth(environmentConfig.sourceOfTruth, branchId)));
362363
return result;
363364
}
364365

apps/backend/src/prisma-client.tsx

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@ export type PrismaClientTransaction = PrismaClient | Parameters<Parameters<Prism
1818
const prismaClientsStore = (globalVar.__stack_prisma_clients as undefined) || {
1919
global: new PrismaClient(),
2020
neon: new Map<string, PrismaClient>(),
21-
postgres: new Map<string, PrismaClient>(),
21+
postgres: new Map<string, {
22+
client: PrismaClient,
23+
schema: string | null,
24+
}>(),
2225
};
2326
if (getNodeEnvironment().includes('development')) {
2427
globalVar.__stack_prisma_clients = prismaClientsStore; // store globally so fast refresh doesn't recreate too many Prisma clients
@@ -40,11 +43,19 @@ export function getPrismaClientForTenancy(tenancy: Tenancy) {
4043
return getPrismaClientForSourceOfTruth(tenancy.completeConfig.sourceOfTruth, tenancy.branchId);
4144
}
4245

46+
export function getPrismaSchemaForTenancy(tenancy: Tenancy) {
47+
return getPrismaSchemaForSourceOfTruth(tenancy.completeConfig.sourceOfTruth, tenancy.branchId);
48+
}
49+
4350
function getPostgresPrismaClient(connectionString: string) {
4451
let postgresPrismaClient = prismaClientsStore.postgres.get(connectionString);
4552
if (!postgresPrismaClient) {
46-
const adapter = new PrismaPg({ connectionString });
47-
postgresPrismaClient = new PrismaClient({ adapter });
53+
const schema = (new URL(connectionString)).searchParams.get('schema');
54+
const adapter = new PrismaPg({ connectionString }, schema ? { schema } : undefined);
55+
postgresPrismaClient = {
56+
client: new PrismaClient({ adapter }),
57+
schema: schema ?? null,
58+
};
4859
prismaClientsStore.postgres.set(connectionString, postgresPrismaClient);
4960
}
5061
return postgresPrismaClient;
@@ -59,7 +70,7 @@ export function getPrismaClientForSourceOfTruth(sourceOfTruth: OrganizationRende
5970
return getNeonPrismaClient(sourceOfTruth.connectionStrings[branchId]);
6071
}
6172
case 'postgres': {
62-
return getPostgresPrismaClient(sourceOfTruth.connectionString);
73+
return getPostgresPrismaClient(sourceOfTruth.connectionString).client;
6374
}
6475
case 'hosted': {
6576
return globalPrismaClient;
@@ -71,6 +82,17 @@ export function getPrismaClientForSourceOfTruth(sourceOfTruth: OrganizationRende
7182
}
7283
}
7384

85+
export function getPrismaSchemaForSourceOfTruth(sourceOfTruth: OrganizationRenderedConfig["sourceOfTruth"], branchId: string) {
86+
switch (sourceOfTruth.type) {
87+
case 'postgres': {
88+
return getPostgresPrismaClient(sourceOfTruth.connectionString).schema ?? 'public';
89+
}
90+
default: {
91+
return 'public';
92+
}
93+
}
94+
}
95+
7496

7597
class TransactionErrorThatShouldBeRetried extends Error {
7698
constructor(cause: unknown) {
@@ -296,3 +318,12 @@ export function isPrismaUniqueConstraintViolation(error: unknown, modelName: str
296318
if (!error.meta?.target) return false;
297319
return error.meta.modelName === modelName && deepPlainEquals(error.meta.target, target);
298320
}
321+
322+
export function sqlQuoteIdent(id: string) {
323+
// accept letters, numbers, underscore, $, and dash (adjust as needed)
324+
if (!/^[A-Za-z_][A-Za-z0-9_\-$]*$/.test(id)) {
325+
throw new Error(`Invalid identifier: ${id}`);
326+
}
327+
// escape embedded double quotes just in case
328+
return Prisma.raw(`"${id.replace(/"/g, '""')}"`);
329+
}

0 commit comments

Comments
 (0)