diff --git a/coderd/x/chatd/auto_archive.go b/coderd/x/chatd/auto_archive.go index e045447632d60..bbc791a0b77e8 100644 --- a/coderd/x/chatd/auto_archive.go +++ b/coderd/x/chatd/auto_archive.go @@ -163,8 +163,8 @@ func isExpectedAutoArchiveError(err error) bool { } func (w *chatWorker) publishArchiveWatchEvents(familyChats []database.Chat) { - if w.server != nil { - w.server.publishChatPubsubEvents(familyChats, codersdk.ChatWatchEventKindDeleted) + if w.opts.Server != nil { + w.opts.Server.publishChatPubsubEvents(familyChats, codersdk.ChatWatchEventKindDeleted) return } for _, chat := range familyChats { @@ -178,10 +178,10 @@ func (w *chatWorker) publishArchiveWatchEvents(familyChats []database.Chat) { } func (w *chatWorker) scheduleArchiveDebugCleanup(ctx context.Context, familyChats []database.Chat) { - if w.server == nil || len(familyChats) == 0 { + if w.opts.Server == nil || len(familyChats) == 0 { return } - w.server.scheduleArchiveDebugCleanup(ctx, familyChats) + w.opts.Server.scheduleArchiveDebugCleanup(ctx, familyChats) } func (p *Server) scheduleArchiveDebugCleanup(ctx context.Context, familyChats []database.Chat) { diff --git a/coderd/x/chatd/auto_archive_internal_test.go b/coderd/x/chatd/auto_archive_internal_test.go index 8c2e68b924400..8e5041cfb7d54 100644 --- a/coderd/x/chatd/auto_archive_internal_test.go +++ b/coderd/x/chatd/auto_archive_internal_test.go @@ -293,7 +293,7 @@ func (f *workerTestFixture) newArchiveWorkerWithOptions(t *testing.T, opts chatW if opts.NotificationsEnqueuer == nil { opts.NotificationsEnqueuer = notificationstest.NewFakeEnqueuer() } - worker, err := newChatWorker(nil, opts) + worker, err := newChatWorker(opts) require.NoError(t, err) return worker } diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index 0e18f600b60cf..ab993850069f6 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -3428,8 +3428,9 @@ func New(ps pubsub.Pubsub, cfg Config) *Server { p.streamPartsDialer = streamPartsDialerForServer(workerID, localStreamPartsDialer, cfg.StreamPartsDialer) p.streamSyncPoller = newStreamSyncPoller(ctx, cfg.Database, clk, cfg.Logger.Named("chatstream")) p.streamSyncPoller.Start() - chatWorker, err := newChatWorker(p, chatWorkerOptions{ + chatWorker, err := newChatWorker(chatWorkerOptions{ WorkerID: workerID, + Server: p, Store: cfg.Database, Pubsub: ps, Logger: cfg.Logger.Named("chatworker"), diff --git a/coderd/x/chatd/generation.go b/coderd/x/chatd/generation.go index c3afe31faa179..a6d2737168593 100644 --- a/coderd/x/chatd/generation.go +++ b/coderd/x/chatd/generation.go @@ -83,6 +83,38 @@ type generationCompaction struct { Options chatloop.GenerateCompactionOptions } +type generationTaskDeps struct { + prepareGeneration func(context.Context, generationPrepareInput) (generationPrepared, error) + buildWorkspaceContext func(context.Context, workspaceContextBuildInput) (workspaceContextBuildResult, error) + afterOutcome func(context.Context, generationOutcome) error + metrics *chatloop.Metrics +} + +func (d generationTaskDeps) withDefaults() (generationTaskDeps, error) { + if d.prepareGeneration == nil { + return generationTaskDeps{}, xerrors.New("chatworker: generation prepare callback is required") + } + if d.buildWorkspaceContext == nil { + return generationTaskDeps{}, xerrors.New("chatworker: workspace context callback is required") + } + if d.afterOutcome == nil { + d.afterOutcome = func(context.Context, generationOutcome) error { return nil } + } + if d.metrics == nil { + d.metrics = chatloop.NopMetrics() + } + return d, nil +} + +func (server *Server) generationTaskDeps() generationTaskDeps { + return generationTaskDeps{ + prepareGeneration: server.prepareGeneration, + buildWorkspaceContext: server.buildWorkspaceContext, + afterOutcome: server.afterGenerationOutcome, + metrics: server.metrics, + } +} + type generationDebug struct { Enabled bool Service *chatdebug.Service @@ -345,23 +377,20 @@ func hasExclusiveToolCall(toolCalls []fantasy.ToolCallContent, exclusiveToolName } func (s *taskStarter) StartGeneration(ctx context.Context, input chatWorkerTaskStartInput) error { - if s.server == nil { - return xerrors.New("chatworker: server is required") - } machine := chatstate.NewChatMachine(s.opts.Store, s.opts.Pubsub, input.ChatID) chainModeDisabled := false for { - locked, messages, err := loadGenerationState(ctx, machine, input) + chat, messages, err := loadGenerationState(ctx, machine, input) if err != nil { return err } prepareInput := generationPrepareInput{ - Chat: locked, + Chat: chat, Messages: messages, ChainModeDisabled: chainModeDisabled, } prepared, err := retryGenerationPhase(ctx, s, "prepare", func() (generationPrepared, error) { - return s.server.prepareGeneration(ctx, prepareInput) + return s.generation.prepareGeneration(ctx, prepareInput) }) if err != nil { if errors.Is(err, errTaskExpectedExit) || errors.Is(err, errTaskRetryable) { @@ -391,7 +420,7 @@ func (s *taskStarter) StartGeneration(ctx context.Context, input chatWorkerTaskS return xerrors.Errorf("decide generation: %w", err) } if errors.Is(err, errCompactionStillOverLimit) && prepared.Compaction != nil { - s.server.metrics.RecordCompaction( + s.generation.metrics.RecordCompaction( compactionProvider(prepared.Compaction.Options), compactionModel(prepared.Compaction.Options), false, @@ -469,35 +498,32 @@ func loadGenerationState( ctx context.Context, machine *chatstate.ChatMachine, input chatWorkerTaskStartInput, -) (database.Chat, []database.ChatMessage, error) { - var locked database.Chat - var messages []database.ChatMessage - err := machine.ReadLock(ctx, func(store database.Store) error { - chat, err := store.GetChatByID(ctx, input.ChatID) - if errors.Is(err, sql.ErrNoRows) { - return errTaskExpectedExit - } - if err != nil { - return xerrors.Errorf("load locked chat: %w", err) +) (chat database.Chat, messages []database.ChatMessage, err error) { + err = machine.ReadLock(ctx, func(store database.Store) error { + var loadErr error + chat, loadErr = store.GetChatByID(ctx, input.ChatID) + if loadErr != nil { + if errors.Is(loadErr, sql.ErrNoRows) { + return errTaskExpectedExit + } + return xerrors.Errorf("load locked chat: %w", loadErr) } if err := verifyTaskFence(chat, input, database.ChatStatusRunning, taskFenceOptions{requireHistory: true}); err != nil { return err } - loaded, err := store.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + messages, loadErr = store.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ ChatID: input.ChatID, AfterID: 0, }) - if err != nil { - return xerrors.Errorf("load chat messages: %w", err) + if loadErr != nil { + return xerrors.Errorf("load chat messages: %w", loadErr) } - locked = chat - messages = loaded return nil }) if err != nil { return database.Chat{}, nil, normalizeTaskInfrastructureError(err, "lock chat for generation") } - return locked, messages, nil + return chat, messages, nil } func (*taskStarter) recordGenerationRetry( @@ -641,11 +667,12 @@ func (s *taskStarter) generateAssistant( input chatWorkerTaskStartInput, prepared generationPrepared, ) error { - attempt, _, publish, closeEpisode, err := s.beginGenerationAttempt(ctx, machine, input) + episode, err := s.beginGenerationAttempt(ctx, machine, input) if err != nil { return err } - defer closeEpisode() + defer episode.Close() + publish := episode.Publish runCtx := input.DebugTurn.Ensure(ctx, prepared.Chat, prepared.Debug) outcome, err := chatloop.GenerateAssistant(runCtx, chatloop.GenerateAssistantOptions{ Model: prepared.Model, @@ -660,13 +687,13 @@ func (s *taskStarter) generateAssistant( PublishMessagePart: publish, Logger: s.opts.Logger, Clock: s.opts.Clock, - Metrics: s.server.metrics, + Metrics: s.generation.metrics, }) if err != nil { return err } if len(outcome.Step.Content) == 0 { - return s.finishGenerationTurn(ctx, machine, input, attempt, generationDecision{kind: generationActionFinishTurn, finishReason: generationFinishReasonComplete}, generationAttemptRequired) + return s.finishGenerationTurn(ctx, machine, input, episode.Attempt, generationDecision{kind: generationActionFinishTurn, finishReason: generationFinishReasonComplete}, generationAttemptRequired) } messages, err := buildCommitStepMessages(buildCommitStepMessagesInput{ modelConfigID: prepared.ModelConfigID, @@ -677,9 +704,9 @@ func (s *taskStarter) generateAssistant( contentVersion: chatprompt.CurrentContentVersion, }) if err != nil { - return s.finishGenerationError(ctx, machine, input, attempt, err, generationAttemptRequired) + return s.finishGenerationError(ctx, machine, input, episode.Attempt, err, generationAttemptRequired) } - return s.commitGenerationStep(ctx, machine, input, attempt, generationActionGenerateAssistant, messages) + return s.commitGenerationStep(ctx, machine, input, episode.Attempt, generationActionGenerateAssistant, messages) } func (s *taskStarter) executeLocalTools( @@ -689,11 +716,12 @@ func (s *taskStarter) executeLocalTools( prepared generationPrepared, decision generationDecision, ) error { - attempt, _, publish, closeEpisode, err := s.beginGenerationAttempt(ctx, machine, input) + episode, err := s.beginGenerationAttempt(ctx, machine, input) if err != nil { return err } - defer closeEpisode() + defer episode.Close() + publish := episode.Publish provider := "" modelName := "" if prepared.Model != nil { @@ -716,7 +744,7 @@ func (s *taskStarter) executeLocalTools( ModelName: modelName, PublishMessagePart: publish, Logger: s.opts.Logger, - Metrics: s.server.metrics, + Metrics: s.generation.metrics, Clock: s.opts.Clock, }) if err != nil { @@ -731,9 +759,9 @@ func (s *taskStarter) executeLocalTools( contentVersion: chatprompt.CurrentContentVersion, }) if err != nil { - return s.finishGenerationError(ctx, machine, input, attempt, err, generationAttemptRequired) + return s.finishGenerationError(ctx, machine, input, episode.Attempt, err, generationAttemptRequired) } - return s.commitGenerationStep(ctx, machine, input, attempt, generationActionExecuteLocalTools, messages) + return s.commitGenerationStep(ctx, machine, input, episode.Attempt, generationActionExecuteLocalTools, messages) } func (s *taskStarter) generateCompaction( @@ -742,25 +770,26 @@ func (s *taskStarter) generateCompaction( input chatWorkerTaskStartInput, prepared generationPrepared, ) error { - attempt, _, publish, closeEpisode, err := s.beginGenerationAttempt(ctx, machine, input) + episode, err := s.beginGenerationAttempt(ctx, machine, input) if err != nil { return err } - defer closeEpisode() + defer episode.Close() + publish := episode.Publish if prepared.Compaction == nil { - return s.finishGenerationError(ctx, machine, input, attempt, xerrors.New("compaction action missing options"), generationAttemptRequired) + return s.finishGenerationError(ctx, machine, input, episode.Attempt, xerrors.New("compaction action missing options"), generationAttemptRequired) } compactionOpts := prepared.Compaction.Options compactionOpts.PublishMessagePart = publish outcome, err := chatloop.GenerateCompaction(ctx, compactionOpts) if err != nil { - s.server.metrics.RecordCompaction(compactionProvider(compactionOpts), compactionModel(compactionOpts), false, err) + s.generation.metrics.RecordCompaction(compactionProvider(compactionOpts), compactionModel(compactionOpts), false, err) return err } if strings.TrimSpace(outcome.SystemSummary) == "" || strings.TrimSpace(outcome.SummaryReport) == "" { err := xerrors.New("compaction produced no summary") - s.server.metrics.RecordCompaction(compactionProvider(compactionOpts), compactionModel(compactionOpts), false, err) - return s.finishGenerationError(ctx, machine, input, attempt, err, generationAttemptRequired) + s.generation.metrics.RecordCompaction(compactionProvider(compactionOpts), compactionModel(compactionOpts), false, err) + return s.finishGenerationError(ctx, machine, input, episode.Attempt, err, generationAttemptRequired) } messages, err := buildCompactionMessages(buildCompactionMessagesInput{ modelConfigID: prepared.ModelConfigID, @@ -771,14 +800,14 @@ func (s *taskStarter) generateCompaction( contentVersion: chatprompt.CurrentContentVersion, }) if err != nil { - s.server.metrics.RecordCompaction(compactionProvider(compactionOpts), compactionModel(compactionOpts), false, err) - return s.finishGenerationError(ctx, machine, input, attempt, err, generationAttemptRequired) + s.generation.metrics.RecordCompaction(compactionProvider(compactionOpts), compactionModel(compactionOpts), false, err) + return s.finishGenerationError(ctx, machine, input, episode.Attempt, err, generationAttemptRequired) } - err = s.commitGenerationStep(ctx, machine, input, attempt, generationActionCompact, stepMessagesForCommit{ + err = s.commitGenerationStep(ctx, machine, input, episode.Attempt, generationActionCompact, stepMessagesForCommit{ Messages: messages.Messages, VisibleIndexes: visibleMessageIndexes(messages.Messages), }) - s.server.metrics.RecordCompaction(compactionProvider(compactionOpts), compactionModel(compactionOpts), err == nil, err) + s.generation.metrics.RecordCompaction(compactionProvider(compactionOpts), compactionModel(compactionOpts), err == nil, err) return err } @@ -807,11 +836,8 @@ func (s *taskStarter) persistWorkspaceContext( ctx context.Context, machine *chatstate.ChatMachine, input chatWorkerTaskStartInput, - locked database.Chat, + chat database.Chat, ) error { - if s.server == nil { - return errTaskExpectedExit - } messages, err := s.opts.Store.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ ChatID: input.ChatID, AfterID: 0, @@ -819,14 +845,14 @@ func (s *taskStarter) persistWorkspaceContext( if err != nil { return taskRetryableError{err: xerrors.Errorf("load chat messages for workspace context: %w", err)} } - attempt, _, _, closeEpisode, err := s.beginGenerationAttempt(ctx, machine, input) + episode, err := s.beginGenerationAttempt(ctx, machine, input) if err != nil { return err } - defer closeEpisode() + defer episode.Close() modelOpts := modelBuildOptionsFromMessages(messages) - result, err := s.server.buildWorkspaceContext(ctx, workspaceContextBuildInput{ - Chat: locked, + result, err := s.generation.buildWorkspaceContext(ctx, workspaceContextBuildInput{ + Chat: chat, Messages: messages, ActiveAPIKeyID: modelOpts.ActiveAPIKeyID, }) @@ -839,28 +865,28 @@ func (s *taskStarter) persistWorkspaceContext( } return err } - return s.commitGenerationStep(ctx, machine, input, attempt, generationActionPersistWorkspaceContext, stepMessagesForCommit{ + return s.commitGenerationStep(ctx, machine, input, episode.Attempt, generationActionPersistWorkspaceContext, stepMessagesForCommit{ Messages: result.Messages, VisibleIndexes: visibleMessageIndexes(result.Messages), }) } +type generationAttemptEpisode struct { + Attempt int64 + Key messagepartbuffer.Key + Publish func(codersdk.ChatMessageRole, codersdk.ChatMessagePart) + Close func() +} + func (s *taskStarter) beginGenerationAttempt( ctx context.Context, machine *chatstate.ChatMachine, input chatWorkerTaskStartInput, -) (int64, messagepartbuffer.Key, func(codersdk.ChatMessageRole, codersdk.ChatMessagePart), func(), error) { +) (generationAttemptEpisode, error) { var attempt int64 var committed database.Chat err := machine.Update(ctx, func(tx *chatstate.Tx, store database.Store) error { - locked, err := store.GetChatByID(ctx, input.ChatID) - if errors.Is(err, sql.ErrNoRows) { - return errTaskExpectedExit - } - if err != nil { - return xerrors.Errorf("load chat: %w", err) - } - if err := verifyTaskFence(locked, input, database.ChatStatusRunning, taskFenceOptions{requireHistory: true}); err != nil { + if _, err := getChatForTask(ctx, store, input, database.ChatStatusRunning); err != nil { return err } result, err := tx.RecordGenerationAttempt(chatstate.RecordGenerationAttemptInput{}) @@ -868,14 +894,15 @@ func (s *taskStarter) beginGenerationAttempt( return err } attempt = result.GenerationAttempt - committed, err = store.GetChatByID(ctx, input.ChatID) + chat, err := store.GetChatByID(ctx, input.ChatID) if err != nil { return xerrors.Errorf("load committed chat: %w", err) } + committed = chat return nil }) if err != nil { - return 0, messagepartbuffer.Key{}, nil, nil, normalizeTaskTransitionError(err, "record generation attempt") + return generationAttemptEpisode{}, normalizeTaskTransitionError(err, "record generation attempt") } key := messagepartbuffer.Key{ ChatID: input.ChatID, @@ -883,15 +910,19 @@ func (s *taskStarter) beginGenerationAttempt( GenerationAttempt: attempt, } if err := s.opts.MessagePartBuffer.CreateEpisode(key); err != nil && ctx.Err() == nil { - return 0, messagepartbuffer.Key{}, nil, nil, taskRetryableError{err: xerrors.Errorf("create message part episode: %w", err)} + return generationAttemptEpisode{}, taskRetryableError{err: xerrors.Errorf("create message part episode: %w", err)} } - publish := func(role codersdk.ChatMessageRole, part codersdk.ChatMessagePart) { + episode := generationAttemptEpisode{ + Attempt: attempt, + Key: key, + } + episode.Publish = func(role codersdk.ChatMessageRole, part codersdk.ChatMessagePart) { _ = s.opts.MessagePartBuffer.AddPart(key, role, part) } - closeEpisode := func() { + episode.Close = func() { _ = s.opts.MessagePartBuffer.CloseEpisode(key) } - return attempt, key, publish, closeEpisode, nil + return episode, nil } func (s *taskStarter) commitGenerationStep( @@ -905,42 +936,39 @@ func (s *taskStarter) commitGenerationStep( if len(messages.Messages) == 0 { return s.finishGenerationTurn(ctx, machine, input, attempt, generationDecision{kind: generationActionFinishTurn, finishReason: generationFinishReasonComplete}, generationAttemptRequired) } - var committed database.Chat - insertedMessages := []runnerActionMessage{} + var outcome generationOutcome err := machine.Update(ctx, func(tx *chatstate.Tx, store database.Store) error { - locked, err := store.GetChatByID(ctx, input.ChatID) - if errors.Is(err, sql.ErrNoRows) { - return errTaskExpectedExit - } - if err != nil { - return xerrors.Errorf("load chat: %w", err) - } - if err := verifyGenerationFence(locked, input, attempt); err != nil { + if _, err := getChatForGenerationAttempt(ctx, store, input, attempt); err != nil { return err } commitResult, err := tx.CommitStep(chatstate.CommitStepInput{Messages: messages.Messages}) if err != nil { return err } - insertedMessages = make([]runnerActionMessage, 0, len(commitResult.InsertedMessages)) - for _, msg := range commitResult.InsertedMessages { - insertedMessages = append(insertedMessages, runnerActionMessage{ID: msg.ID, Role: codersdk.ChatMessageRole(msg.Role)}) - } - committed, err = store.GetChatByID(ctx, input.ChatID) + chat, err := store.GetChatByID(ctx, input.ChatID) if err != nil { return xerrors.Errorf("load committed chat: %w", err) } + outcome = generationOutcome{ + Chat: chat, + Kind: runnerActionKind(kind), + InsertedMessages: runnerActionMessages(commitResult.InsertedMessages), + } return nil }) if err != nil { return normalizeTaskTransitionError(err, "commit generation step") } - s.routeStateHint(ctx, stateUpdateFromChat(committed)) - return s.afterGenerationOutcome(ctx, generationOutcome{ - Chat: committed, - Kind: runnerActionKind(kind), - InsertedMessages: insertedMessages, - }) + s.routeStateHint(ctx, stateUpdateFromChat(outcome.Chat)) + return s.afterGenerationOutcome(ctx, outcome) +} + +func runnerActionMessages(messages []database.ChatMessage) []runnerActionMessage { + out := make([]runnerActionMessage, 0, len(messages)) + for _, msg := range messages { + out = append(out, runnerActionMessage{ID: msg.ID, Role: codersdk.ChatMessageRole(msg.Role)}) + } + return out } func (s *taskStarter) enterRequiresAction( @@ -950,23 +978,17 @@ func (s *taskStarter) enterRequiresAction( ) error { var committed database.Chat err := machine.Update(ctx, func(tx *chatstate.Tx, store database.Store) error { - locked, err := store.GetChatByID(ctx, input.ChatID) - if errors.Is(err, sql.ErrNoRows) { - return errTaskExpectedExit - } - if err != nil { - return xerrors.Errorf("load chat: %w", err) - } - if err := verifyTaskFence(locked, input, database.ChatStatusRunning, taskFenceOptions{requireHistory: true}); err != nil { + if _, err := getChatForTask(ctx, store, input, database.ChatStatusRunning); err != nil { return err } if _, err := tx.EnterRequiresAction(chatstate.EnterRequiresActionInput{}); err != nil { return err } - committed, err = store.GetChatByID(ctx, input.ChatID) + chat, err := store.GetChatByID(ctx, input.ChatID) if err != nil { return xerrors.Errorf("load committed chat: %w", err) } + committed = chat return nil }) if err != nil { @@ -999,18 +1021,7 @@ func (s *taskStarter) finishGenerationTurn( ) error { var committed database.Chat err := machine.Update(ctx, func(tx *chatstate.Tx, store database.Store) error { - locked, err := store.GetChatByID(ctx, input.ChatID) - if errors.Is(err, sql.ErrNoRows) { - return errTaskExpectedExit - } - if err != nil { - return xerrors.Errorf("load chat: %w", err) - } - if attemptFence == generationAttemptRequired { - if err := verifyGenerationFence(locked, input, attempt); err != nil { - return err - } - } else if err := verifyTaskFence(locked, input, database.ChatStatusRunning, taskFenceOptions{requireHistory: true}); err != nil { + if err := verifyGenerationTask(ctx, store, input, attempt, attemptFence); err != nil { return err } finishResult, err := tx.FinishTurn(chatstate.FinishTurnInput{}) @@ -1067,44 +1078,37 @@ func (s *taskStarter) finishGenerationError( slog.Error(cause), ) lastError, message := generationLastError(cause) - var committed database.Chat + outcome := generationOutcome{ + Kind: runnerActionKindFinishError, + WatchEventKind: codersdk.ChatWatchEventKindStatusChange, + LastError: message, + } err := machine.Update(ctx, func(tx *chatstate.Tx, store database.Store) error { - locked, err := store.GetChatByID(ctx, input.ChatID) - if errors.Is(err, sql.ErrNoRows) { - return errTaskExpectedExit - } - if err != nil { - return xerrors.Errorf("load chat: %w", err) - } - if attemptFence == generationAttemptRequired { - if err := verifyGenerationFence(locked, input, attempt); err != nil { - return err - } - } else if err := verifyTaskFence(locked, input, database.ChatStatusRunning, taskFenceOptions{requireHistory: true}); err != nil { + if err := verifyGenerationTask(ctx, store, input, attempt, attemptFence); err != nil { return err } if _, err := tx.FinishError(chatstate.FinishErrorInput{LastError: lastError}); err != nil { return err } - committed, err = store.GetChatByID(ctx, input.ChatID) + chat, err := store.GetChatByID(ctx, input.ChatID) if err != nil { return xerrors.Errorf("load committed chat: %w", err) } + outcome.Chat = chat return nil }) if err != nil { - return normalizeTaskTransitionError(err, "finish generation error") + current, ok := s.committedStateAfterUpdateError(ctx, outcome.Chat) + if !ok { + return normalizeTaskTransitionError(err, "finish generation error") + } + outcome.Chat = current } input.DebugTurn.RecordOutcome(chatdebug.StatusError) - if err := s.publishWatchAndRoute(ctx, committed, codersdk.ChatWatchEventKindStatusChange); err != nil { + if err := s.publishWatchAndRoute(ctx, outcome.Chat, outcome.WatchEventKind); err != nil { return err } - return s.afterGenerationOutcome(ctx, generationOutcome{ - Chat: committed, - Kind: runnerActionKindFinishError, - WatchEventKind: codersdk.ChatWatchEventKindStatusChange, - LastError: message, - }) + return s.runAfterGenerationOutcome(ctx, outcome) } func generationLastError(err error) (pqtype.NullRawMessage, string) { @@ -1124,23 +1128,64 @@ func generationLastError(err error) (pqtype.NullRawMessage, string) { } func (s *taskStarter) afterGenerationOutcome(ctx context.Context, outcome generationOutcome) error { - if s.server == nil { - return nil - } - if err := s.server.afterGenerationOutcome(ctx, outcome); err != nil { + return s.runAfterGenerationOutcome(ctx, outcome) +} + +func (s *taskStarter) runAfterGenerationOutcome(ctx context.Context, outcome generationOutcome) error { + if err := s.generation.afterOutcome(ctx, outcome); err != nil { return taskRetryableError{err: xerrors.Errorf("generation post-outcome side effects: %w", err)} } return nil } -func verifyGenerationFence(chat database.Chat, input chatWorkerTaskStartInput, attempt int64) error { - if err := verifyTaskFence(chat, input, database.ChatStatusRunning, taskFenceOptions{requireHistory: true}); err != nil { - return err +func getChatForTask( + ctx context.Context, + store database.Store, + input chatWorkerTaskStartInput, + status database.ChatStatus, +) (database.Chat, error) { + chat, err := store.GetChatByID(ctx, input.ChatID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return database.Chat{}, errTaskExpectedExit + } + return database.Chat{}, xerrors.Errorf("load chat: %w", err) + } + if err := verifyTaskFence(chat, input, status, taskFenceOptions{requireHistory: true}); err != nil { + return database.Chat{}, err + } + return chat, nil +} + +func getChatForGenerationAttempt( + ctx context.Context, + store database.Store, + input chatWorkerTaskStartInput, + attempt int64, +) (database.Chat, error) { + chat, err := getChatForTask(ctx, store, input, database.ChatStatusRunning) + if err != nil { + return database.Chat{}, err } if chat.GenerationAttempt != attempt { - return errTaskExpectedExit + return database.Chat{}, errTaskExpectedExit } - return nil + return chat, nil +} + +func verifyGenerationTask( + ctx context.Context, + store database.Store, + input chatWorkerTaskStartInput, + attempt int64, + attemptFence generationAttemptFence, +) error { + if attemptFence == generationAttemptRequired { + _, err := getChatForGenerationAttempt(ctx, store, input, attempt) + return err + } + _, err := getChatForTask(ctx, store, input, database.ChatStatusRunning) + return err } func stepDataFromPersisted(step chatloop.PersistedStep) stepData { diff --git a/coderd/x/chatd/helpers_test.go b/coderd/x/chatd/helpers_test.go index 352392f26c48c..f7f7b5321ba86 100644 --- a/coderd/x/chatd/helpers_test.go +++ b/coderd/x/chatd/helpers_test.go @@ -271,7 +271,7 @@ func testOptions(t *testing.T, f *workerTestFixture, starter chatWorkerTaskStart func startWorker(t *testing.T, opts chatWorkerOptions) *chatWorker { t.Helper() - worker, err := newChatWorker(nil, opts) + worker, err := newChatWorker(opts) require.NoError(t, err) require.NoError(t, worker.Start(context.Background())) t.Cleanup(func() { require.NoError(t, worker.Close()) }) diff --git a/coderd/x/chatd/options.go b/coderd/x/chatd/options.go index ff3dbdd3d9a30..cac8e8f3109d1 100644 --- a/coderd/x/chatd/options.go +++ b/coderd/x/chatd/options.go @@ -67,12 +67,14 @@ type chatWorkerTaskStartInput struct { // chatWorkerOptions configures a chatWorker. type chatWorkerOptions struct { WorkerID uuid.UUID + Server *Server Store database.Store Pubsub chatWorkerPubsub Logger slog.Logger Clock quartz.Clock TaskStarter chatWorkerTaskStarter + Generation generationTaskDeps MessagePartBuffer *messagepartbuffer.Buffer NotificationsEnqueuer notifications.Enqueuer @@ -101,8 +103,17 @@ func (o chatWorkerOptions) withDefaults() (chatWorkerOptions, error) { if o.Pubsub == nil { return chatWorkerOptions{}, xerrors.New("chatworker: pubsub is required") } - if o.TaskStarter == nil && o.MessagePartBuffer == nil { - return chatWorkerOptions{}, xerrors.New("chatworker: task starter or message part buffer is required") + if o.TaskStarter == nil { + if o.MessagePartBuffer == nil { + return chatWorkerOptions{}, xerrors.New("chatworker: task starter or message part buffer is required") + } + if o.Server == nil { + withGeneration, err := o.Generation.withDefaults() + if err != nil { + return chatWorkerOptions{}, err + } + o.Generation = withGeneration + } } if o.WorkerID == uuid.Nil { return chatWorkerOptions{}, xerrors.New("chatworker: worker ID is required") diff --git a/coderd/x/chatd/runner_manager.go b/coderd/x/chatd/runner_manager.go index dc9737c8d6d43..56ec391956f4d 100644 --- a/coderd/x/chatd/runner_manager.go +++ b/coderd/x/chatd/runner_manager.go @@ -84,9 +84,8 @@ func (r *runnerRecord) startCleanup() { } type runnerManager struct { - server *Server - opts chatWorkerOptions - ctx context.Context + opts chatWorkerOptions + ctx context.Context closed bool spawnMu sync.Mutex @@ -102,9 +101,8 @@ type runnerManager struct { wg sync.WaitGroup } -func newRunnerManager(ctx context.Context, server *Server, opts chatWorkerOptions) *runnerManager { +func newRunnerManager(ctx context.Context, opts chatWorkerOptions) *runnerManager { return &runnerManager{ - server: server, opts: opts, ctx: ctx, spawnCh: make(chan spawnRunnerRequest, opts.RunnerManagerChannelSize), diff --git a/coderd/x/chatd/tasks.go b/coderd/x/chatd/tasks.go index f370ea852f7e9..7875cd9d7d080 100644 --- a/coderd/x/chatd/tasks.go +++ b/coderd/x/chatd/tasks.go @@ -167,16 +167,16 @@ type interruptionOutcome struct { } type taskStarter struct { - server *Server opts chatWorkerOptions + generation generationTaskDeps routeStateHint func(context.Context, runnerStateUpdate) requestCleanup func(context.Context, runnerKey) afterInterruptionOutcome func(context.Context, interruptionOutcome) error } func newTaskStarter( - server *Server, opts chatWorkerOptions, + generation generationTaskDeps, routeStateHint func(context.Context, runnerStateUpdate), requestCleanup func(context.Context, runnerKey), ) (*taskStarter, error) { @@ -201,6 +201,10 @@ func newTaskStarter( if opts.TaskRetryMaxBackoff < opts.TaskRetryInitialBackoff { opts.TaskRetryMaxBackoff = opts.TaskRetryInitialBackoff } + withGeneration, err := generation.withDefaults() + if err != nil { + return nil, err + } if routeStateHint == nil { return nil, xerrors.New("chatworker: route state hint callback is required") } @@ -208,8 +212,8 @@ func newTaskStarter( return nil, xerrors.New("chatworker: cleanup callback is required") } return &taskStarter{ - server: server, opts: opts, + generation: withGeneration, routeStateHint: routeStateHint, requestCleanup: requestCleanup, }, nil @@ -325,14 +329,10 @@ func (s *taskStarter) StartInterrupt(ctx context.Context, input chatWorkerTaskSt } func (s *taskStarter) runAfterInterruptionOutcome(ctx context.Context, outcome interruptionOutcome) error { - afterOutcome := s.afterInterruptionOutcome - if afterOutcome == nil && s.server != nil { - afterOutcome = s.server.afterInterruptionOutcome - } - if afterOutcome == nil { + if s.afterInterruptionOutcome == nil { return nil } - if err := afterOutcome(ctx, outcome); err != nil { + if err := s.afterInterruptionOutcome(ctx, outcome); err != nil { return taskRetryableError{err: xerrors.Errorf("interruption post-outcome side effects: %w", err)} } return nil diff --git a/coderd/x/chatd/tasks_test.go b/coderd/x/chatd/tasks_test.go index 8fcc588c35a6f..82dd528de08b3 100644 --- a/coderd/x/chatd/tasks_test.go +++ b/coderd/x/chatd/tasks_test.go @@ -578,7 +578,7 @@ func TestGenerationTask_RecordRetryState(t *testing.T) { recorder := newTaskSideEffectRecorder() starter := newTestTaskStarter(t, f, recorder) - attempt, _, _, closeEpisode, err := starter.beginGenerationAttempt( + episode, err := starter.beginGenerationAttempt( testutil.Context(t, testutil.WaitLong), chatstate.NewChatMachine(f.db, f.pubsub, chat.ID), chatWorkerTaskStartInput{ @@ -590,8 +590,8 @@ func TestGenerationTask_RecordRetryState(t *testing.T) { }, ) require.NoError(t, err) - closeEpisode() - require.Equal(t, int64(1), attempt) + episode.Close() + require.Equal(t, int64(1), episode.Attempt) before, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID) require.NoError(t, err) require.False(t, before.RetryState.Valid) @@ -650,7 +650,7 @@ func TestGenerationTask_RecordRetryStateUsesDurableGenerationAttempt(t *testing. machine := chatstate.NewChatMachine(f.db, f.pubsub, chat.ID) for range 3 { - attempt, _, _, closeEpisode, err := starter.beginGenerationAttempt( + episode, err := starter.beginGenerationAttempt( testutil.Context(t, testutil.WaitLong), machine, chatWorkerTaskStartInput{ @@ -662,8 +662,8 @@ func TestGenerationTask_RecordRetryStateUsesDurableGenerationAttempt(t *testing. }, ) require.NoError(t, err) - closeEpisode() - require.Positive(t, attempt) + episode.Close() + require.Positive(t, episode.Attempt) } decision, err := starter.recordGenerationRetry( @@ -714,10 +714,10 @@ func TestGenerationTask_RecordRetryStateClearedByNextAttempt(t *testing.T) { Status: database.ChatStatusRunning, } - attempt, _, _, closeEpisode, err := starter.beginGenerationAttempt(testutil.Context(t, testutil.WaitLong), machine, input) + episode, err := starter.beginGenerationAttempt(testutil.Context(t, testutil.WaitLong), machine, input) require.NoError(t, err) - closeEpisode() - require.Equal(t, int64(1), attempt) + episode.Close() + require.Equal(t, int64(1), episode.Attempt) _, err = starter.recordGenerationRetry( testutil.Context(t, testutil.WaitLong), machine, @@ -734,10 +734,10 @@ func TestGenerationTask_RecordRetryStateClearedByNextAttempt(t *testing.T) { require.NoError(t, err) require.True(t, withRetry.RetryState.Valid) - attempt, _, _, closeEpisode, err = starter.beginGenerationAttempt(testutil.Context(t, testutil.WaitLong), machine, input) + episode, err = starter.beginGenerationAttempt(testutil.Context(t, testutil.WaitLong), machine, input) require.NoError(t, err) - closeEpisode() - require.Equal(t, int64(2), attempt) + episode.Close() + require.Equal(t, int64(2), episode.Attempt) after, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID) require.NoError(t, err) require.False(t, after.RetryState.Valid) @@ -755,7 +755,7 @@ func TestGenerationTask_RecordRetryStateStaleFenceExits(t *testing.T) { acquired := f.acquireChat(t, chat.ID, workerID, runnerID) starter := newTestTaskStarter(t, f, newTaskSideEffectRecorder()) machine := chatstate.NewChatMachine(f.db, f.pubsub, chat.ID) - attempt, _, _, closeEpisode, err := starter.beginGenerationAttempt( + episode, err := starter.beginGenerationAttempt( testutil.Context(t, testutil.WaitLong), machine, chatWorkerTaskStartInput{ @@ -767,8 +767,8 @@ func TestGenerationTask_RecordRetryStateStaleFenceExits(t *testing.T) { }, ) require.NoError(t, err) - closeEpisode() - require.Equal(t, int64(1), attempt) + episode.Close() + require.Equal(t, int64(1), episode.Attempt) otherWorkerID := uuid.New() otherRunnerID := uuid.New() @@ -1124,8 +1124,9 @@ func startRealTaskWorker(t *testing.T, f *taskTestFixture) *chatWorker { t.Helper() buffer := messagepartbuffer.New(messagepartbuffer.Options{}) t.Cleanup(buffer.Close) - worker, err := newChatWorker(nil, chatWorkerOptions{ + worker, err := newChatWorker(chatWorkerOptions{ WorkerID: uuid.New(), + Generation: testGenerationTaskDeps(), Store: f.db, Pubsub: f.pubsub, Logger: slog.Make(), @@ -1242,11 +1243,22 @@ func (r *taskSideEffectRecorder) requireInterruptionOutcome(t *testing.T, chatID t.Fatalf("missing interruption outcome chat_id=%s status=%s outcomes=%v", chatID, status, r.interrupts) } +func testGenerationTaskDeps() generationTaskDeps { + return generationTaskDeps{ + prepareGeneration: func(context.Context, generationPrepareInput) (generationPrepared, error) { + return generationPrepared{}, errTaskExpectedExit + }, + buildWorkspaceContext: func(context.Context, workspaceContextBuildInput) (workspaceContextBuildResult, error) { + return workspaceContextBuildResult{}, xerrors.New("unexpected workspace context call") + }, + } +} + func newTestTaskStarter(t *testing.T, f *taskTestFixture, recorder *taskSideEffectRecorder) *taskStarter { t.Helper() buffer := messagepartbuffer.New(messagepartbuffer.Options{}) t.Cleanup(buffer.Close) - starter, err := newTaskStarter(nil, chatWorkerOptions{ + starter, err := newTaskStarter(chatWorkerOptions{ Store: f.db, Pubsub: f.pubsub, Logger: slog.Make(), @@ -1254,7 +1266,7 @@ func newTestTaskStarter(t *testing.T, f *taskTestFixture, recorder *taskSideEffe MessagePartBuffer: buffer, TaskRetryInitialBackoff: time.Millisecond, TaskRetryMaxBackoff: time.Millisecond, - }, recorder.routeStateHint, recorder.requestCleanup) + }, testGenerationTaskDeps(), recorder.routeStateHint, recorder.requestCleanup) require.NoError(t, err) starter.afterInterruptionOutcome = recorder.afterInterruptionOutcome return starter diff --git a/coderd/x/chatd/worker.go b/coderd/x/chatd/worker.go index 7b3e8d5666fc1..37cce3b9becce 100644 --- a/coderd/x/chatd/worker.go +++ b/coderd/x/chatd/worker.go @@ -17,8 +17,7 @@ import ( // chatWorker owns chat acquisition and runner lifecycle for one process. type chatWorker struct { - server *Server - opts chatWorkerOptions + opts chatWorkerOptions mu sync.Mutex started bool @@ -32,12 +31,12 @@ type chatWorker struct { // newChatWorker constructs a chat worker. The worker is idle until Start is // called. -func newChatWorker(server *Server, opts chatWorkerOptions) (*chatWorker, error) { +func newChatWorker(opts chatWorkerOptions) (*chatWorker, error) { withDefaults, err := opts.withDefaults() if err != nil { return nil, err } - return &chatWorker{server: server, opts: withDefaults}, nil + return &chatWorker{opts: withDefaults}, nil } // chatWorkerID returns this worker's configured worker ID. @@ -54,13 +53,20 @@ func (w *chatWorker) Start(ctx context.Context) error { } workerID := w.opts.WorkerID workerCtx, cancel := context.WithCancel(ctx) - manager := newRunnerManager(workerCtx, w.server, w.opts) + manager := newRunnerManager(workerCtx, w.opts) if manager.opts.TaskStarter == nil { - starter, err := newTaskStarter(manager.server, manager.opts, manager.RouteStateHint, manager.requestCleanup) + generation := w.opts.Generation + if w.opts.Server != nil { + generation = w.opts.Server.generationTaskDeps() + } + starter, err := newTaskStarter(manager.opts, generation, manager.RouteStateHint, manager.requestCleanup) if err != nil { cancel() return err } + if w.opts.Server != nil { + starter.afterInterruptionOutcome = w.opts.Server.afterInterruptionOutcome + } manager.opts.TaskStarter = starter } wakeCh := make(chan struct{}, w.opts.AcquisitionWakeChannelSize) diff --git a/coderd/x/chatd/worker_internal_test.go b/coderd/x/chatd/worker_internal_test.go index f01bb0d69cd71..835e983a474b6 100644 --- a/coderd/x/chatd/worker_internal_test.go +++ b/coderd/x/chatd/worker_internal_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/require" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/x/chatd/messagepartbuffer" "github.com/coder/coder/v2/testutil" "github.com/coder/quartz" ) @@ -18,16 +19,30 @@ import ( func TestWorker_NewRequiresTaskStarterOrMessagePartBuffer(t *testing.T) { t.Parallel() f := newWorkerTestFixture(t) - _, err := newChatWorker(nil, chatWorkerOptions{WorkerID: uuid.New(), Store: f.db, Pubsub: f.pubsub}) + _, err := newChatWorker(chatWorkerOptions{WorkerID: uuid.New(), Store: f.db, Pubsub: f.pubsub}) require.ErrorContains(t, err, "task starter or message part buffer is required") } +func TestWorker_NewRequiresGenerationDepsWhenUsingDefaultTaskStarter(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + buffer := messagepartbuffer.New(messagepartbuffer.Options{}) + t.Cleanup(buffer.Close) + _, err := newChatWorker(chatWorkerOptions{ + WorkerID: uuid.New(), + Store: f.db, + Pubsub: f.pubsub, + MessagePartBuffer: buffer, + }) + require.ErrorContains(t, err, "generation prepare callback is required") +} + func TestWorker_NewRequiresWorkerID(t *testing.T) { t.Parallel() f := newWorkerTestFixture(t) opts := testOptions(t, f, newRecordingTaskStarter()) opts.WorkerID = uuid.Nil - _, err := newChatWorker(nil, opts) + _, err := newChatWorker(opts) require.ErrorContains(t, err, "worker ID is required") } @@ -37,7 +52,7 @@ func TestWorker_UsesConfiguredWorkerID(t *testing.T) { starter := newRecordingTaskStarter() opts := testOptions(t, f, starter) workerID := opts.WorkerID - worker, err := newChatWorker(nil, opts) + worker, err := newChatWorker(opts) require.NoError(t, err) require.Equal(t, workerID, worker.chatWorkerID()) require.NoError(t, worker.Start(context.Background()))