diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 03b4aff42fc9e..c3904a427f919 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -44,6 +44,7 @@ import ( "github.com/coder/coder/v2/coderd/wspubsub" "github.com/coder/coder/v2/coderd/x/chatd" "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" + "github.com/coder/coder/v2/coderd/x/chatd/chatstate" "github.com/coder/coder/v2/coderd/x/gitsync" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" @@ -2551,10 +2552,12 @@ func (api *API) workspaceAgentAddChatContext(rw http.ResponseWriter, r *http.Req return } - err = api.Database.InTx(func(tx database.Store) error { - locked, err := tx.GetChatByIDForUpdate(sysCtx, chat.ID) + machine := chatstate.NewChatMachine(api.Database, api.Pubsub, chat.ID, chatstate.Options{}) + err = machine.Update(sysCtx, func(tx *chatstate.Tx) error { + store := tx.Store() + locked, err := store.GetChatByID(sysCtx, chat.ID) if err != nil { - return xerrors.Errorf("lock chat: %w", err) + return xerrors.Errorf("load chat: %w", err) } if !isActiveAgentChat(locked) { return errChatNotActive @@ -2565,26 +2568,30 @@ func (api *API) workspaceAgentAddChatContext(rw http.ResponseWriter, r *http.Req if locked.OwnerID != workspace.OwnerID { return errChatDoesNotBelongToWorkspaceOwner } - apiKeyID, err := resolveAgentChatContextAPIKeyID(sysCtx, tx, locked) + apiKeyID, err := resolveAgentChatContextAPIKeyID(sysCtx, store, locked) if err != nil { return err } - if _, err := tx.InsertChatMessages(sysCtx, chatd.BuildSingleUserChatMessageInsertParams( - chat.ID, - apiKeyID, - content, - database.ChatMessageVisibilityBoth, - locked.LastModelConfigID, - chatprompt.CurrentContentVersion, - uuid.Nil, - )); err != nil { - return xerrors.Errorf("insert context message: %w", err) - } - if err := updateAgentChatLastInjectedContextFromMessages(sysCtx, api.Logger, tx, chat.ID); err != nil { + sendResult, err := tx.SendMessage(chatstate.SendMessageInput{ + Message: agentChatContextStateMessage( + content, + locked.LastModelConfigID, + locked.OwnerID, + apiKeyID, + ), + BusyBehavior: chatstate.BusyBehaviorInterrupt, + }) + if err != nil { + return err + } + if len(sendResult.InsertedMessages) == 0 { + return nil + } + if err := updateAgentChatLastInjectedContextFromMessages(sysCtx, api.Logger, store, chat.ID); err != nil { return xerrors.Errorf("rebuild injected context cache: %w", err) } return nil - }, nil) + }) if err != nil { if errors.Is(err, errChatNotActive) || errors.Is(err, errChatDoesNotBelongToAgent) || errors.Is(err, errChatDoesNotBelongToWorkspaceOwner) { writeAgentChatError(ctx, rw, err) @@ -2596,6 +2603,35 @@ func (api *API) workspaceAgentAddChatContext(rw http.ResponseWriter, r *http.Req }) return } + if errors.Is(err, chatstate.ErrMessageQueueFull) { + var queueFull *chatstate.MessageQueueFullError + detail := "" + if errors.As(err, &queueFull) { + detail = fmt.Sprintf("Maximum %d messages can be queued.", queueFull.Max) + } + httpapi.Write(ctx, rw, http.StatusTooManyRequests, codersdk.Response{ + Message: "Message queue is full.", + Detail: detail, + }) + return + } + if errors.Is(err, chatstate.ErrInvalidState) { + httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{ + Message: "Chat is in an invalid state.", + }) + return + } + if errors.Is(err, chatstate.ErrTransitionNotAllowed) { + httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{ + Message: "Chat is not in a state that accepts new context.", + Detail: err.Error(), + }) + return + } + if errors.Is(err, chatstate.ErrChatNotFound) { + writeAgentChatError(ctx, rw, errChatNotFound) + return + } httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Failed to persist context message.", Detail: err.Error(), @@ -2817,6 +2853,23 @@ func resolveAgentChatContextAPIKeyID(ctx context.Context, db database.Store, cha return newest.ID, nil } +func agentChatContextStateMessage( + content pqtype.NullRawMessage, + modelConfigID uuid.UUID, + ownerID uuid.UUID, + apiKeyID string, +) chatstate.Message { + return chatstate.Message{ + Role: database.ChatMessageRoleUser, + Content: content, + Visibility: database.ChatMessageVisibilityBoth, + ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: modelConfigID != uuid.Nil}, + CreatedBy: uuid.NullUUID{UUID: ownerID, Valid: ownerID != uuid.Nil}, + ContentVersion: chatprompt.CurrentContentVersion, + APIKeyID: sql.NullString{String: apiKeyID, Valid: apiKeyID != ""}, + } +} + func clearAgentChatContext( ctx context.Context, db database.Store, diff --git a/coderd/workspaceagents_chat_context_test.go b/coderd/workspaceagents_chat_context_test.go index 2067fe3ff4e9d..e5ac467d7fa2c 100644 --- a/coderd/workspaceagents_chat_context_test.go +++ b/coderd/workspaceagents_chat_context_test.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "encoding/json" + "fmt" "net/http" "strings" "testing" @@ -19,8 +20,11 @@ import ( "github.com/coder/coder/v2/coderd/database/dbfake" "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbtestutil" + dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub" + coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" "github.com/coder/coder/v2/coderd/x/chatd" "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" + "github.com/coder/coder/v2/coderd/x/chatd/chatstate" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/testutil" @@ -70,16 +74,6 @@ func TestAgentChatContext(t *testing.T) { ContextFilePath: "/workspace/AGENTS.md", ContextFileContent: "context from the agent", } - fileAPart := codersdk.ChatMessagePart{ - Type: codersdk.ChatMessagePartTypeContextFile, - ContextFilePath: "/workspace/file-a.md", - ContextFileContent: "file A context", - } - fileBPart := codersdk.ChatMessagePart{ - Type: codersdk.ChatMessagePartTypeContextFile, - ContextFilePath: "/workspace/file-b.md", - ContextFileContent: "file B context", - } repoHelperSkillPart := codersdk.ChatMessagePart{ Type: codersdk.ChatMessagePartTypeSkill, SkillName: "repo-helper", @@ -96,14 +90,6 @@ func TestAgentChatContext(t *testing.T) { Type: codersdk.ChatMessagePartTypeContextFile, ContextFilePath: agentInstructionsPart.ContextFilePath, } - cachedFileAPart := codersdk.ChatMessagePart{ - Type: codersdk.ChatMessagePartTypeContextFile, - ContextFilePath: fileAPart.ContextFilePath, - } - cachedFileBPart := codersdk.ChatMessagePart{ - Type: codersdk.ChatMessagePartTypeContextFile, - ContextFilePath: fileBPart.ContextFilePath, - } cachedRepoHelperSkillPart := codersdk.ChatMessagePart{ Type: codersdk.ChatMessagePartTypeSkill, SkillName: repoHelperSkillPart.SkillName, @@ -123,14 +109,6 @@ func TestAgentChatContext(t *testing.T) { wantCached: []codersdk.ChatMessagePart{cachedAgentInstructionsPart}, cachedOrdered: true, }, - { - name: "AddSuccessIsAdditive", - steps: []addSuccessStep{{req: agentsdk.AddChatContextRequest{Parts: []codersdk.ChatMessagePart{fileAPart}}, wantCount: 1}, {req: agentsdk.AddChatContextRequest{Parts: []codersdk.ChatMessagePart{fileBPart}}, wantCount: 1}}, - wantStored: [][]codersdk.ChatMessagePart{{fileAPart}, {fileBPart}}, - storedOrdered: false, - wantCached: []codersdk.ChatMessagePart{cachedFileAPart, cachedFileBPart}, - cachedOrdered: false, - }, { name: "AddSuccessWithSkillOnlyPartsGetsSentinel", steps: []addSuccessStep{{req: agentsdk.AddChatContextRequest{Parts: []codersdk.ChatMessagePart{repoHelperSkillPart}}, wantCount: 1}}, @@ -249,6 +227,184 @@ func TestAgentChatContext(t *testing.T) { require.Equal(t, updatedModel.ID, persistedChat.LastModelConfigID) }) + t.Run("AddSuccessUpdatesChatStateVersionsAndPublishes", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + baseDB, pubsub := dbtestutil.NewDB(t) + client := coderdtest.New(t, &coderdtest.Options{ + Database: baseDB, + Pubsub: pubsub, + }) + user := coderdtest.CreateFirstUser(t, client) + workspace := dbfake.WorkspaceBuild(t, baseDB, database.WorkspaceTable{ + OrganizationID: user.OrganizationID, + OwnerID: user.UserID, + }).WithAgent().Do() + agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(workspace.AgentToken)) + model := coderd.InsertAgentChatTestModelConfig(t, baseDB, user.UserID) + chat := createAgentChatContextChat(t, baseDB, user.OrganizationID, user.UserID, model.ID, workspace.Agents[0].ID, t.Name()) + + updateCh := make(chan []byte, 1) + cancelSub, err := pubsub.Subscribe(coderdpubsub.ChatStateUpdateChannel(chat.ID), func(_ context.Context, msg []byte) { + updateCh <- msg + }) + require.NoError(t, err) + defer cancelSub() + + resp, err := agentClient.AddChatContext(ctx, agentsdk.AddChatContextRequest{ + ChatID: chat.ID, + Parts: []codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/workspace/instructions.md", + ContextFileContent: "remember this file", + }}, + }) + require.NoError(t, err) + require.Equal(t, chat.ID, resp.ChatID) + require.Equal(t, 1, resp.Count) + + persisted, err := baseDB.GetChatByID(dbauthz.AsSystemRestricted(ctx), chat.ID) + require.NoError(t, err) + require.Equal(t, chat.SnapshotVersion+1, persisted.SnapshotVersion) + require.Equal(t, persisted.SnapshotVersion, persisted.HistoryVersion) + + messages := requireAgentChatContextMessages(ctx, t, baseDB, chat.ID) + require.Len(t, messages, 1) + require.Equal(t, persisted.SnapshotVersion, messages[0].Revision) + + cached := requireAgentChatContextCachedParts(ctx, t, baseDB, chat.ID) + require.Len(t, cached, 1) + require.Equal(t, "/workspace/instructions.md", cached[0].ContextFilePath) + + select { + case raw := <-updateCh: + var update coderdpubsub.ChatStateUpdateMessage + require.NoError(t, json.Unmarshal(raw, &update)) + require.Equal(t, persisted.SnapshotVersion, update.SnapshotVersion) + require.Equal(t, persisted.HistoryVersion, update.HistoryVersion) + case <-ctx.Done(): + t.Fatal("timed out waiting for chat state update") + } + }) + + func(t *testing.T) { + t.Run("AddInterruptsAndQueuesWhenChatIsRunning", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + setup := newAgentChatContextTestSetup(t) + model := coderd.InsertAgentChatTestModelConfig(t, setup.db, setup.user.UserID) + chat := createAgentChatContextChat(t, setup.db, setup.user.OrganizationID, setup.user.UserID, model.ID, setup.workspace.Agents[0].ID, t.Name()) + chat = setAgentChatContextChatStatus(ctx, t, setup.db, chat.ID, database.ChatStatusRunning) + chat = acquireAgentChatContextChat(ctx, t, setup.db, chat.ID) + apiKeyID := currentAgentChatContextAPIKeyID(t, setup.client) + + resp, err := setup.agentClient.AddChatContext(ctx, agentsdk.AddChatContextRequest{ + ChatID: chat.ID, + Parts: []codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/workspace/queued.md", + ContextFileContent: "queued context", + }}, + }) + require.NoError(t, err) + require.Equal(t, chat.ID, resp.ChatID) + require.Equal(t, 1, resp.Count) + + require.Empty(t, requireAgentChatContextMessages(ctx, t, setup.db, chat.ID)) + + queued, err := setup.db.GetChatQueuedMessages(dbauthz.AsSystemRestricted(ctx), chat.ID) + require.NoError(t, err) + require.Len(t, queued, 1) + require.Equal(t, setup.user.UserID, queued[0].CreatedBy) + require.True(t, queued[0].ModelConfigID.Valid) + require.Equal(t, model.ID, queued[0].ModelConfigID.UUID) + require.True(t, queued[0].APIKeyID.Valid) + require.Equal(t, apiKeyID, queued[0].APIKeyID.String) + + parts := requireAgentChatContextParts(t, queued[0].Content) + require.Len(t, parts, 1) + require.Equal(t, "/workspace/queued.md", parts[0].ContextFilePath) + require.Equal(t, "queued context", parts[0].ContextFileContent) + require.Equal(t, uuid.NullUUID{UUID: setup.workspace.Agents[0].ID, Valid: true}, parts[0].ContextFileAgentID) + + persisted, err := setup.db.GetChatByID(dbauthz.AsSystemRestricted(ctx), chat.ID) + require.NoError(t, err) + require.False(t, persisted.LastInjectedContext.Valid) + require.Equal(t, database.ChatStatusInterrupting, persisted.Status) + require.Equal(t, chat.SnapshotVersion+1, persisted.SnapshotVersion) + require.Equal(t, chat.HistoryVersion, persisted.HistoryVersion) + require.Equal(t, persisted.SnapshotVersion, persisted.QueueVersion) + }) + }(t) + + func(t *testing.T) { + t.Run("AddFailsWhenQueueIsFull", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + setup := newAgentChatContextTestSetup(t) + model := coderd.InsertAgentChatTestModelConfig(t, setup.db, setup.user.UserID) + chat := createAgentChatContextChat(t, setup.db, setup.user.OrganizationID, setup.user.UserID, model.ID, setup.workspace.Agents[0].ID, t.Name()) + chat = setAgentChatContextChatStatus(ctx, t, setup.db, chat.ID, database.ChatStatusRunning) + chat = acquireAgentChatContextChat(ctx, t, setup.db, chat.ID) + apiKeyID := currentAgentChatContextAPIKeyID(t, setup.client) + for i := range int(chatstate.MaxQueueSize) { + content, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ + codersdk.ChatMessageText(fmt.Sprintf("queued %d", i)), + }) + require.NoError(t, err) + _, err = setup.db.InsertChatQueuedMessageWithCreator( + dbauthz.AsSystemRestricted(ctx), + database.InsertChatQueuedMessageWithCreatorParams{ + ChatID: chat.ID, + Content: content.RawMessage, + ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true}, + APIKeyID: sql.NullString{String: apiKeyID, Valid: true}, + CreatedBy: setup.user.UserID, + }, + ) + require.NoError(t, err) + } + + _, err := setup.agentClient.AddChatContext(ctx, agentsdk.AddChatContextRequest{ + ChatID: chat.ID, + Parts: []codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/workspace/overflow.md", + ContextFileContent: "overflow context", + }}, + }) + sdkErr := requireSDKError(t, err, http.StatusTooManyRequests) + require.Equal(t, "Message queue is full.", sdkErr.Message) + require.Contains(t, sdkErr.Detail, "Maximum") + }) + }(t) + + func(t *testing.T) { + t.Run("AddFailsWhenChatStateIsInvalid", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + setup := newAgentChatContextTestSetup(t) + model := coderd.InsertAgentChatTestModelConfig(t, setup.db, setup.user.UserID) + chat := createAgentChatContextChat(t, setup.db, setup.user.OrganizationID, setup.user.UserID, model.ID, setup.workspace.Agents[0].ID, t.Name()) + _ = setAgentChatContextChatStatus(ctx, t, setup.db, chat.ID, database.ChatStatusPending) + + _, err := setup.agentClient.AddChatContext(ctx, agentsdk.AddChatContextRequest{ + ChatID: chat.ID, + Parts: []codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: "/workspace/invalid.md", + ContextFileContent: "invalid state context", + }}, + }) + sdkErr := requireSDKError(t, err, http.StatusConflict) + require.Equal(t, "Chat is in an invalid state.", sdkErr.Message) + }) + }(t) + t.Run("ClearDeletesSkillMessages", func(t *testing.T) { t.Parallel() @@ -986,6 +1142,45 @@ func newAgentChatContextTestSetup(t *testing.T) agentChatContextTestSetup { } } +func currentAgentChatContextAPIKeyID(t testing.TB, client *codersdk.Client) string { + t.Helper() + + apiKeyID, _, ok := strings.Cut(client.SessionToken(), "-") + require.True(t, ok) + require.NotEmpty(t, apiKeyID) + return apiKeyID +} + +func setAgentChatContextChatStatus( + ctx context.Context, + t testing.TB, + db database.Store, + chatID uuid.UUID, + status database.ChatStatus, +) database.Chat { + t.Helper() + + chat, err := db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{ + ID: chatID, + Status: status, + }) + require.NoError(t, err) + return chat +} + +func acquireAgentChatContextChat(ctx context.Context, t testing.TB, db database.Store, chatID uuid.UUID) database.Chat { + t.Helper() + + machine := chatstate.NewChatMachine(db, dbpubsub.NewInMemory(), chatID, chatstate.Options{}) + require.NoError(t, machine.Update(dbauthz.AsSystemRestricted(ctx), func(tx *chatstate.Tx) error { + _, err := tx.Acquire(chatstate.AcquireInput{WorkerID: uuid.New(), RunnerID: uuid.New()}) + return err + })) + chat, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), chatID) + require.NoError(t, err) + return chat +} + func createAgentChatContextChat( t testing.TB, db database.Store,