@@ -31,6 +31,11 @@ import (
3131
3232var validProxyByHostnameRegex = regexp .MustCompile (`^[a-zA-Z0-9._-]+$` )
3333
34+ var errForeignKeyConstraint = & pq.Error {
35+ Code : "23503" ,
36+ Message : "update or delete on table violates foreign key constraint" ,
37+ }
38+
3439var errDuplicateKey = & pq.Error {
3540 Code : "23505" ,
3641 Message : "duplicate key value violates unique constraint" ,
@@ -45,6 +50,7 @@ func New() database.Store {
4550 organizationMembers : make ([]database.OrganizationMember , 0 ),
4651 organizations : make ([]database.Organization , 0 ),
4752 users : make ([]database.User , 0 ),
53+ dbcryptKeys : make ([]database.DBCryptKey , 0 ),
4854 gitAuthLinks : make ([]database.GitAuthLink , 0 ),
4955 groups : make ([]database.Group , 0 ),
5056 groupMembers : make ([]database.GroupMember , 0 ),
@@ -117,6 +123,7 @@ type data struct {
117123 // New tables
118124 workspaceAgentStats []database.WorkspaceAgentStat
119125 auditLogs []database.AuditLog
126+ dbcryptKeys []database.DBCryptKey
120127 files []database.File
121128 gitAuthLinks []database.GitAuthLink
122129 gitSSHKey []database.GitSSHKey
@@ -665,6 +672,19 @@ func (q *FakeQuerier) isEveryoneGroup(id uuid.UUID) bool {
665672 return false
666673}
667674
675+ func (q * FakeQuerier ) GetActiveDBCryptKeys (_ context.Context ) ([]database.DBCryptKey , error ) {
676+ q .mutex .RLock ()
677+ defer q .mutex .RUnlock ()
678+ ks := make ([]database.DBCryptKey , 0 , len (q .dbcryptKeys ))
679+ for _ , k := range q .dbcryptKeys {
680+ if ! k .ActiveKeyDigest .Valid {
681+ continue
682+ }
683+ ks = append ([]database.DBCryptKey {}, k )
684+ }
685+ return ks , nil
686+ }
687+
668688func (* FakeQuerier ) AcquireLock (_ context.Context , _ int64 ) error {
669689 return xerrors .New ("AcquireLock must only be called within a transaction" )
670690}
@@ -1151,6 +1171,14 @@ func (q *FakeQuerier) GetAuthorizationUserRoles(_ context.Context, userID uuid.U
11511171 }, nil
11521172}
11531173
1174+ func (q * FakeQuerier ) GetDBCryptKeys (_ context.Context ) ([]database.DBCryptKey , error ) {
1175+ q .mutex .RLock ()
1176+ defer q .mutex .RUnlock ()
1177+ ks := make ([]database.DBCryptKey , 0 )
1178+ ks = append (ks , q .dbcryptKeys ... )
1179+ return ks , nil
1180+ }
1181+
11541182func (q * FakeQuerier ) GetDERPMeshKey (_ context.Context ) (string , error ) {
11551183 q .mutex .RLock ()
11561184 defer q .mutex .RUnlock ()
@@ -1393,6 +1421,18 @@ func (q *FakeQuerier) GetGitAuthLink(_ context.Context, arg database.GetGitAuthL
13931421 return database.GitAuthLink {}, sql .ErrNoRows
13941422}
13951423
1424+ func (q * FakeQuerier ) GetGitAuthLinksByUserID (_ context.Context , userID uuid.UUID ) ([]database.GitAuthLink , error ) {
1425+ q .mutex .RLock ()
1426+ defer q .mutex .RUnlock ()
1427+ gals := make ([]database.GitAuthLink , 0 )
1428+ for _ , gal := range q .gitAuthLinks {
1429+ if gal .UserID == userID {
1430+ gals = append (gals , gal )
1431+ }
1432+ }
1433+ return gals , nil
1434+ }
1435+
13961436func (q * FakeQuerier ) GetGitSSHKey (_ context.Context , userID uuid.UUID ) (database.GitSSHKey , error ) {
13971437 q .mutex .RLock ()
13981438 defer q .mutex .RUnlock ()
@@ -2833,6 +2873,18 @@ func (q *FakeQuerier) GetUserLinkByUserIDLoginType(_ context.Context, params dat
28332873 return database.UserLink {}, sql .ErrNoRows
28342874}
28352875
2876+ func (q * FakeQuerier ) GetUserLinksByUserID (_ context.Context , userID uuid.UUID ) ([]database.UserLink , error ) {
2877+ q .mutex .RLock ()
2878+ defer q .mutex .RUnlock ()
2879+ uls := make ([]database.UserLink , 0 )
2880+ for _ , ul := range q .userLinks {
2881+ if ul .UserID == userID {
2882+ uls = append (uls , ul )
2883+ }
2884+ }
2885+ return uls , nil
2886+ }
2887+
28362888func (q * FakeQuerier ) GetUsers (_ context.Context , params database.GetUsersParams ) ([]database.GetUsersRow , error ) {
28372889 if err := validateDatabaseType (params ); err != nil {
28382890 return nil , err
@@ -3846,6 +3898,26 @@ func (q *FakeQuerier) InsertAuditLog(_ context.Context, arg database.InsertAudit
38463898 return alog , nil
38473899}
38483900
3901+ func (q * FakeQuerier ) InsertDBCryptKey (_ context.Context , arg database.InsertDBCryptKeyParams ) error {
3902+ err := validateDatabaseType (arg )
3903+ if err != nil {
3904+ return err
3905+ }
3906+
3907+ for _ , key := range q .dbcryptKeys {
3908+ if key .Number == arg .Number {
3909+ return errDuplicateKey
3910+ }
3911+ }
3912+
3913+ q .dbcryptKeys = append (q .dbcryptKeys , database.DBCryptKey {
3914+ Number : arg .Number ,
3915+ ActiveKeyDigest : sql.NullString {String : arg .ActiveKeyDigest , Valid : true },
3916+ Test : arg .Test ,
3917+ })
3918+ return nil
3919+ }
3920+
38493921func (q * FakeQuerier ) InsertDERPMeshKey (_ context.Context , id string ) error {
38503922 q .mutex .Lock ()
38513923 defer q .mutex .Unlock ()
@@ -3892,13 +3964,15 @@ func (q *FakeQuerier) InsertGitAuthLink(_ context.Context, arg database.InsertGi
38923964 defer q .mutex .Unlock ()
38933965 // nolint:gosimple
38943966 gitAuthLink := database.GitAuthLink {
3895- ProviderID : arg .ProviderID ,
3896- UserID : arg .UserID ,
3897- CreatedAt : arg .CreatedAt ,
3898- UpdatedAt : arg .UpdatedAt ,
3899- OAuthAccessToken : arg .OAuthAccessToken ,
3900- OAuthRefreshToken : arg .OAuthRefreshToken ,
3901- OAuthExpiry : arg .OAuthExpiry ,
3967+ ProviderID : arg .ProviderID ,
3968+ UserID : arg .UserID ,
3969+ CreatedAt : arg .CreatedAt ,
3970+ UpdatedAt : arg .UpdatedAt ,
3971+ OAuthAccessToken : arg .OAuthAccessToken ,
3972+ OAuthAccessTokenKeyID : arg .OAuthAccessTokenKeyID ,
3973+ OAuthRefreshToken : arg .OAuthRefreshToken ,
3974+ OAuthRefreshTokenKeyID : arg .OAuthRefreshTokenKeyID ,
3975+ OAuthExpiry : arg .OAuthExpiry ,
39023976 }
39033977 q .gitAuthLinks = append (q .gitAuthLinks , gitAuthLink )
39043978 return gitAuthLink , nil
@@ -4362,12 +4436,14 @@ func (q *FakeQuerier) InsertUserLink(_ context.Context, args database.InsertUser
43624436
43634437 //nolint:gosimple
43644438 link := database.UserLink {
4365- UserID : args .UserID ,
4366- LoginType : args .LoginType ,
4367- LinkedID : args .LinkedID ,
4368- OAuthAccessToken : args .OAuthAccessToken ,
4369- OAuthRefreshToken : args .OAuthRefreshToken ,
4370- OAuthExpiry : args .OAuthExpiry ,
4439+ UserID : args .UserID ,
4440+ LoginType : args .LoginType ,
4441+ LinkedID : args .LinkedID ,
4442+ OAuthAccessToken : args .OAuthAccessToken ,
4443+ OAuthAccessTokenKeyID : args .OAuthAccessTokenKeyID ,
4444+ OAuthRefreshToken : args .OAuthRefreshToken ,
4445+ OAuthRefreshTokenKeyID : args .OAuthRefreshTokenKeyID ,
4446+ OAuthExpiry : args .OAuthExpiry ,
43714447 }
43724448
43734449 q .userLinks = append (q .userLinks , link )
@@ -4793,6 +4869,46 @@ func (q *FakeQuerier) RegisterWorkspaceProxy(_ context.Context, arg database.Reg
47934869 return database.WorkspaceProxy {}, sql .ErrNoRows
47944870}
47954871
4872+ func (q * FakeQuerier ) RevokeDBCryptKey (_ context.Context , activeKeyDigest string ) error {
4873+ q .mutex .Lock ()
4874+ defer q .mutex .Unlock ()
4875+
4876+ for i := range q .dbcryptKeys {
4877+ key := q .dbcryptKeys [i ]
4878+
4879+ // Is the key already revoked?
4880+ if ! key .ActiveKeyDigest .Valid {
4881+ continue
4882+ }
4883+
4884+ if key .ActiveKeyDigest .String != activeKeyDigest {
4885+ continue
4886+ }
4887+
4888+ // Check for foreign key constraints.
4889+ for _ , ul := range q .userLinks {
4890+ if (ul .OAuthAccessTokenKeyID .Valid && ul .OAuthAccessTokenKeyID .String == activeKeyDigest ) ||
4891+ (ul .OAuthRefreshTokenKeyID .Valid && ul .OAuthRefreshTokenKeyID .String == activeKeyDigest ) {
4892+ return errForeignKeyConstraint
4893+ }
4894+ }
4895+ for _ , gal := range q .gitAuthLinks {
4896+ if (gal .OAuthAccessTokenKeyID .Valid && gal .OAuthAccessTokenKeyID .String == activeKeyDigest ) ||
4897+ (gal .OAuthRefreshTokenKeyID .Valid && gal .OAuthRefreshTokenKeyID .String == activeKeyDigest ) {
4898+ return errForeignKeyConstraint
4899+ }
4900+ }
4901+
4902+ // Revoke the key.
4903+ q .dbcryptKeys [i ].RevokedAt = sql.NullTime {Time : dbtime .Now (), Valid : true }
4904+ q .dbcryptKeys [i ].RevokedKeyDigest = sql.NullString {String : key .ActiveKeyDigest .String , Valid : true }
4905+ q .dbcryptKeys [i ].ActiveKeyDigest = sql.NullString {}
4906+ return nil
4907+ }
4908+
4909+ return sql .ErrNoRows
4910+ }
4911+
47964912func (* FakeQuerier ) TryAcquireLock (_ context.Context , _ int64 ) (bool , error ) {
47974913 return false , xerrors .New ("TryAcquireLock must only be called within a transaction" )
47984914}
@@ -4834,7 +4950,9 @@ func (q *FakeQuerier) UpdateGitAuthLink(_ context.Context, arg database.UpdateGi
48344950 }
48354951 gitAuthLink .UpdatedAt = arg .UpdatedAt
48364952 gitAuthLink .OAuthAccessToken = arg .OAuthAccessToken
4953+ gitAuthLink .OAuthAccessTokenKeyID = arg .OAuthAccessTokenKeyID
48374954 gitAuthLink .OAuthRefreshToken = arg .OAuthRefreshToken
4955+ gitAuthLink .OAuthRefreshTokenKeyID = arg .OAuthRefreshTokenKeyID
48384956 gitAuthLink .OAuthExpiry = arg .OAuthExpiry
48394957 q .gitAuthLinks [index ] = gitAuthLink
48404958
@@ -5306,7 +5424,9 @@ func (q *FakeQuerier) UpdateUserLink(_ context.Context, params database.UpdateUs
53065424 for i , link := range q .userLinks {
53075425 if link .UserID == params .UserID && link .LoginType == params .LoginType {
53085426 link .OAuthAccessToken = params .OAuthAccessToken
5427+ link .OAuthAccessTokenKeyID = params .OAuthAccessTokenKeyID
53095428 link .OAuthRefreshToken = params .OAuthRefreshToken
5429+ link .OAuthRefreshTokenKeyID = params .OAuthRefreshTokenKeyID
53105430 link .OAuthExpiry = params .OAuthExpiry
53115431
53125432 q .userLinks [i ] = link
0 commit comments