diff --git a/enterprise/aibridgedserver/aibridgedserver.go b/enterprise/aibridgedserver/aibridgedserver.go index 156f3aa9d05da..cc096cf3446e4 100644 --- a/enterprise/aibridgedserver/aibridgedserver.go +++ b/enterprise/aibridgedserver/aibridgedserver.go @@ -37,12 +37,13 @@ var ( // matching. // TODO: return these errors to the client in a more structured/comparable // way. - ErrInvalidKey = xerrors.New("invalid key") - ErrUnknownKey = xerrors.New("unknown key") - ErrExpired = xerrors.New("expired") - ErrUnknownUser = xerrors.New("unknown user") - ErrDeletedUser = xerrors.New("deleted user") - ErrSystemUser = xerrors.New("system user") + ErrInvalidKey = xerrors.New("invalid key") + ErrUnknownKey = xerrors.New("unknown key") + ErrExpired = xerrors.New("expired") + ErrUnknownUser = xerrors.New("unknown user") + ErrDeletedUser = xerrors.New("deleted user") + ErrInactiveUser = xerrors.New("inactive user") + ErrSystemUser = xerrors.New("system user") ErrNoExternalAuthLinkFound = xerrors.New("no external auth link found") ) @@ -399,10 +400,13 @@ func (s *Server) IsAuthorized(ctx context.Context, in *proto.IsAuthorizedRequest return nil, ErrUnknownUser } - // User is not deleted or a system user. + // User is active, not deleted, and not a system user. if user.Deleted { return nil, ErrDeletedUser } + if user.Status != database.UserStatusActive { + return nil, ErrInactiveUser + } if user.IsSystem { return nil, ErrSystemUser } diff --git a/enterprise/aibridgedserver/aibridgedserver_test.go b/enterprise/aibridgedserver/aibridgedserver_test.go index b871bfb3f8e54..f14a4995198bd 100644 --- a/enterprise/aibridgedserver/aibridgedserver_test.go +++ b/enterprise/aibridgedserver/aibridgedserver_test.go @@ -94,16 +94,36 @@ func TestAuthorization(t *testing.T) { name: "deleted user", expectedErr: aibridgedserver.ErrDeletedUser, mocksFn: func(db *dbmock.MockStore, apiKey database.APIKey, user database.User) { + user.Deleted = true db.EXPECT().GetAPIKeyByID(gomock.Any(), apiKey.ID).Times(1).Return(apiKey, nil) - db.EXPECT().GetUserByID(gomock.Any(), user.ID).Times(1).Return(database.User{ID: user.ID, Deleted: true}, nil) + db.EXPECT().GetUserByID(gomock.Any(), user.ID).Times(1).Return(user, nil) + }, + }, + { + name: "suspended user", + expectedErr: aibridgedserver.ErrInactiveUser, + mocksFn: func(db *dbmock.MockStore, apiKey database.APIKey, user database.User) { + user.Status = database.UserStatusSuspended + db.EXPECT().GetAPIKeyByID(gomock.Any(), apiKey.ID).Times(1).Return(apiKey, nil) + db.EXPECT().GetUserByID(gomock.Any(), user.ID).Times(1).Return(user, nil) + }, + }, + { + name: "dormant user", + expectedErr: aibridgedserver.ErrInactiveUser, + mocksFn: func(db *dbmock.MockStore, apiKey database.APIKey, user database.User) { + user.Status = database.UserStatusDormant + db.EXPECT().GetAPIKeyByID(gomock.Any(), apiKey.ID).Times(1).Return(apiKey, nil) + db.EXPECT().GetUserByID(gomock.Any(), user.ID).Times(1).Return(user, nil) }, }, { name: "system user", expectedErr: aibridgedserver.ErrSystemUser, mocksFn: func(db *dbmock.MockStore, apiKey database.APIKey, user database.User) { + user.IsSystem = true db.EXPECT().GetAPIKeyByID(gomock.Any(), apiKey.ID).Times(1).Return(apiKey, nil) - db.EXPECT().GetUserByID(gomock.Any(), user.ID).Times(1).Return(database.User{ID: user.ID, IsSystem: true}, nil) + db.EXPECT().GetUserByID(gomock.Any(), user.ID).Times(1).Return(user, nil) }, }, {