Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions coderd/x/chatd/auto_archive_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,9 +293,7 @@ func (f *workerTestFixture) newArchiveWorkerWithOptions(t *testing.T, opts chatW
if opts.NotificationsEnqueuer == nil {
opts.NotificationsEnqueuer = notificationstest.NewFakeEnqueuer()
}
worker, err := newChatWorker(nil, opts)
require.NoError(t, err)
return worker
return newChatWorker(nil, opts.WorkerID, opts.Store, opts.Pubsub, opts.MessagePartBuffer, opts)
}

func mockAuditorPtr(auditor *audit.MockAuditor) *atomic.Pointer[audit.Auditor] {
Expand Down
61 changes: 27 additions & 34 deletions coderd/x/chatd/chatd.go
Original file line number Diff line number Diff line change
Expand Up @@ -2275,31 +2275,27 @@ type manualTitleGenerationError struct {
// read; the title_change pubsub event it publishes remains the source of
// truth for clients.
type generatedChatTitle struct {
mu sync.RWMutex
title string
title atomic.Value // string
}

func (t *generatedChatTitle) Store(title string) {
if t == nil || title == "" {
return
}

t.mu.Lock()
t.title = title
t.mu.Unlock()
t.title.Store(title)
}

func (t *generatedChatTitle) Load() (string, bool) {
if t == nil {
return "", false
}

t.mu.RLock()
defer t.mu.RUnlock()
if t.title == "" {
title, ok := t.title.Load().(string)
if !ok || title == "" {
return "", false
}
return t.title, true
return title, true
}

func (e *manualTitleGenerationError) Error() string {
Expand Down Expand Up @@ -3421,32 +3417,32 @@ func New(ps pubsub.Pubsub, cfg Config) *Server {
p.metrics = chatloop.NopMetrics()
}
p.messagePartBuffer = messagepartbuffer.New(messagepartbuffer.Options{Clock: clk})
localStreamPartsDialer := NewLocalStreamPartsDialer(LocalStreamPartsDialerConfig{
Buffer: p.messagePartBuffer,
Logger: cfg.Logger,
})
p.streamPartsDialer = streamPartsDialerForServer(workerID, localStreamPartsDialer, cfg.StreamPartsDialer)
p.streamSyncPoller = newStreamSyncPoller(ctx, cfg.Database, clk, cfg.Logger.Named("chatstream"))
p.streamSyncPoller.Start()
chatWorker, err := newChatWorker(p, chatWorkerOptions{
WorkerID: workerID,
Store: cfg.Database,
Pubsub: ps,
workerOpts := chatWorkerOptions{
Logger: cfg.Logger.Named("chatworker"),
Clock: clk,
MessagePartBuffer: p.messagePartBuffer,
AcquisitionInterval: pendingChatAcquireInterval,
AcquisitionBatchSize: maxChatsPerAcquire,
HeartbeatInterval: chatHeartbeatInterval,
HeartbeatStaleSeconds: int32(inFlightChatStaleAfter.Seconds()),
NotificationsEnqueuer: notificationsEnqueuer,
Auditor: cfg.Auditor,
AutoArchiveRecords: chatAutoArchiveRecords,
})
if err != nil {
panic("chatd: create chat worker: " + err.Error())
}
p.chatWorker = chatWorker
localStreamPartsDialer := NewLocalStreamPartsDialer(LocalStreamPartsDialerConfig{
Buffer: p.messagePartBuffer,
Logger: cfg.Logger,
})
p.streamPartsDialer = streamPartsDialerForServer(workerID, localStreamPartsDialer, cfg.StreamPartsDialer)
p.streamSyncPoller = newStreamSyncPoller(ctx, cfg.Database, clk, cfg.Logger.Named("chatstream"))
p.streamSyncPoller.Start()
p.chatWorker = newChatWorker(
p,
workerID,
cfg.Database,
ps,
p.messagePartBuffer,
workerOpts,
)

//nolint:gocritic // The chat processor uses a scoped chatd context.
ctx = dbauthz.AsChatd(ctx)
Expand Down Expand Up @@ -4134,7 +4130,7 @@ func (p *Server) loadPersonalSkillBody(
return parsed, nil
}

func (p *Server) appendRootChatTools(
func (p *Server) appendRootChatToolsWithoutWorkspaceContextPersistence(
ctx context.Context,
tools []fantasy.AgentTool,
opts rootChatToolsOptions,
Expand All @@ -4145,14 +4141,11 @@ func (p *Server) appendRootChatTools(
// build logs before the tool completes.
p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindStatusChange, nil)

// Note: we intentionally do not insert AGENTS.md / workspace
// context here. Local tool callbacks must not mutate chat
// history while a local-tool generation task is in flight,
// because that advances history_version before the tool
// result is committed and exits the local-tool commit as
// stale. Workspace context is persisted by the
// persist_workspace_context generation action in a later
// pass.
// Do not persist workspace context from this callback. Local
// tool callbacks run while a generation task is fenced by
// history_version; mutating chat history here would make that
// commit stale. The generation state machine runs
// persist_workspace_context after this tool result commits.

// Prime the workspace MCP tools cache while the create_workspace
// or start_workspace tool is still running. The AgentID guard
Expand Down
2 changes: 1 addition & 1 deletion coderd/x/chatd/generation_preparer.go
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ func (server *Server) prepareGeneration(
tools = append(tools, chattool.NewAskUserQuestionTool())
}
if isRootChat {
tools = server.appendRootChatTools(ctx, tools, rootChatToolsOptions{
tools = server.appendRootChatToolsWithoutWorkspaceContextPersistence(ctx, tools, rootChatToolsOptions{
chat: chat,
modelConfigID: modelConfig.ID,
workspaceCtx: &workspaceCtx,
Expand Down
27 changes: 24 additions & 3 deletions coderd/x/chatd/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,10 +269,31 @@ func testOptions(t *testing.T, f *workerTestFixture, starter chatWorkerTaskStart
}
}

func startWorker(t *testing.T, opts chatWorkerOptions) *chatWorker {
func testWorkerDeps(f *workerTestFixture, opts chatWorkerOptions) chatWorkerDependencies {
workerID := opts.WorkerID
if workerID == uuid.Nil {
workerID = uuid.New()
}
store := opts.Store
if store == nil {
store = f.db
}
pubsub := opts.Pubsub
if pubsub == nil {
pubsub = f.pubsub
}
return chatWorkerDependencies{
WorkerID: workerID,
Store: store,
Pubsub: pubsub,
MessagePartBuffer: opts.MessagePartBuffer,
}
}

func startWorker(t *testing.T, f *workerTestFixture, opts chatWorkerOptions) *chatWorker {
t.Helper()
worker, err := newChatWorker(nil, opts)
require.NoError(t, err)
deps := testWorkerDeps(f, opts)
worker := newChatWorker(nil, deps.WorkerID, deps.Store, deps.Pubsub, deps.MessagePartBuffer, opts)
require.NoError(t, worker.Start(context.Background()))
t.Cleanup(func() { require.NoError(t, worker.Close()) })
return worker
Expand Down
28 changes: 13 additions & 15 deletions coderd/x/chatd/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (

"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/xerrors"

"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/audit"
Expand Down Expand Up @@ -94,19 +93,18 @@ type chatWorkerOptions struct {
TaskRetryMaxBackoff time.Duration
}

func (o chatWorkerOptions) withDefaults() (chatWorkerOptions, error) {
if o.Store == nil {
return chatWorkerOptions{}, xerrors.New("chatworker: store is required")
}
if o.Pubsub == nil {
return chatWorkerOptions{}, xerrors.New("chatworker: pubsub is required")
}
if o.TaskStarter == nil && o.MessagePartBuffer == nil {
return chatWorkerOptions{}, xerrors.New("chatworker: task starter or message part buffer is required")
}
if o.WorkerID == uuid.Nil {
return chatWorkerOptions{}, xerrors.New("chatworker: worker ID is required")
}
type chatWorkerDependencies struct {
WorkerID uuid.UUID
Store database.Store
Pubsub chatWorkerPubsub
MessagePartBuffer *messagepartbuffer.Buffer
}

func (o chatWorkerOptions) withDefaults(deps chatWorkerDependencies) chatWorkerOptions {
o.WorkerID = deps.WorkerID
o.Store = deps.Store
o.Pubsub = deps.Pubsub
o.MessagePartBuffer = deps.MessagePartBuffer
if o.Clock == nil {
o.Clock = quartz.NewReal()
}
Expand Down Expand Up @@ -155,5 +153,5 @@ func (o chatWorkerOptions) withDefaults() (chatWorkerOptions, error) {
if o.TaskRetryMaxBackoff < o.TaskRetryInitialBackoff {
o.TaskRetryMaxBackoff = o.TaskRetryInitialBackoff
}
return o, nil
return o
}
18 changes: 9 additions & 9 deletions coderd/x/chatd/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func TestRunner_IgnoresDuplicateStateNotifications(t *testing.T) {
f := newWorkerTestFixture(t)
chat := f.createRunningChat(t)
starter := newBlockingTaskStarter(false)
startWorker(t, testOptions(t, f, starter))
startWorker(t, f, testOptions(t, f, starter))
starter.waitCall(t, taskKindGeneration, chat.ID)
latest, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID)
require.NoError(t, err)
Expand All @@ -33,7 +33,7 @@ func TestRunner_CancelsActiveTaskWhenHistoryChanges(t *testing.T) {
f := newWorkerTestFixture(t)
chat := f.createRunningChat(t)
starter := newBlockingTaskStarter(false)
startWorker(t, testOptions(t, f, starter))
startWorker(t, f, testOptions(t, f, starter))
first := starter.waitCall(t, taskKindGeneration, chat.ID)

updated := commitAssistantStep(t, f, chat.ID, "first step")
Expand All @@ -49,7 +49,7 @@ func TestRunner_CancelsActiveTaskWhenStatusChanges(t *testing.T) {
f := newWorkerTestFixture(t)
chat := f.createRunningChat(t)
starter := newBlockingTaskStarter(false)
startWorker(t, testOptions(t, f, starter))
startWorker(t, f, testOptions(t, f, starter))
first := starter.waitCall(t, taskKindGeneration, chat.ID)

updated := interruptChat(t, f, chat.ID)
Expand All @@ -64,7 +64,7 @@ func TestRunner_CleansUpOnOwnershipTakeover(t *testing.T) {
f := newWorkerTestFixture(t)
chat := f.createRunningChat(t)
starter := newBlockingTaskStarter(false)
startWorker(t, testOptions(t, f, starter))
startWorker(t, f, testOptions(t, f, starter))
first := starter.waitCall(t, taskKindGeneration, chat.ID)

acquireChat(t, f, chat.ID, uuid.New(), uuid.New())
Expand All @@ -78,7 +78,7 @@ func TestRunner_SerializesReplacementTasksForSameHistoryAndStatus(t *testing.T)
chat := f.createRunningChat(t)
starter := newBlockingTaskStarter(true)
defer starter.releaseAll()
startWorker(t, testOptions(t, f, starter))
startWorker(t, f, testOptions(t, f, starter))
first := starter.waitCall(t, taskKindGeneration, chat.ID)

forceExecutionStateAndPublish(t, f, chat.ID, database.ChatStatusInterrupting, false)
Expand All @@ -97,7 +97,7 @@ func TestRunner_AllowsReplacementForDifferentHistoryOrStatus(t *testing.T) {
chat := f.createRunningChat(t)
starter := newBlockingTaskStarter(true)
defer starter.releaseAll()
startWorker(t, testOptions(t, f, starter))
startWorker(t, f, testOptions(t, f, starter))
first := starter.waitCall(t, taskKindGeneration, chat.ID)

updated := commitAssistantStep(t, f, chat.ID, "different history")
Expand All @@ -117,7 +117,7 @@ func TestRunner_TaskTimeoutRetries(t *testing.T) {
opts.Clock = clock
opts.TaskRetryInitialBackoff = time.Minute
opts.TaskRetryMaxBackoff = time.Minute
startWorker(t, opts)
startWorker(t, f, opts)

timeoutTrap.MustWait(testutil.Context(t, testutil.WaitLong)).MustRelease(testutil.Context(t, testutil.WaitLong))
timeoutTrap.Close()
Expand All @@ -143,7 +143,7 @@ func TestWorker_RoutesDatabaseSyncStateToActiveRunner(t *testing.T) {
opts := testOptions(t, f, starter)
opts.Clock = clock
opts.RunnerSyncInterval = time.Minute
startWorker(t, opts)
startWorker(t, f, opts)
first := starter.waitCall(t, taskKindGeneration, chat.ID)

forceExecutionState(t, f, chat.ID, database.ChatStatusInterrupting, false)
Expand All @@ -157,7 +157,7 @@ func TestWorker_CleanupStopsRoutingAndCancelsTasks(t *testing.T) {
f := newWorkerTestFixture(t)
chat := f.createRunningChat(t)
starter := newBlockingTaskStarter(false)
startWorker(t, testOptions(t, f, starter))
startWorker(t, f, testOptions(t, f, starter))
first := starter.waitCall(t, taskKindGeneration, chat.ID)

latest := acquireChat(t, f, chat.ID, uuid.New(), uuid.New())
Expand Down
40 changes: 21 additions & 19 deletions coderd/x/chatd/tasks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1124,25 +1124,27 @@ func startRealTaskWorker(t *testing.T, f *taskTestFixture) *chatWorker {
t.Helper()
buffer := messagepartbuffer.New(messagepartbuffer.Options{})
t.Cleanup(buffer.Close)
worker, err := newChatWorker(nil, chatWorkerOptions{
WorkerID: uuid.New(),
Store: f.db,
Pubsub: f.pubsub,
Logger: slog.Make(),
MessagePartBuffer: buffer,
AcquisitionInterval: time.Hour,
AcquisitionBatchSize: 10,
RunnerSyncInterval: time.Hour,
HeartbeatInterval: time.Hour,
HeartbeatCleanupInterval: time.Hour,
HeartbeatStaleSeconds: 30,
StateChannelSize: 16,
RunnerManagerChannelSize: 16,
AcquisitionWakeChannelSize: 1,
TaskRetryInitialBackoff: time.Millisecond,
TaskRetryMaxBackoff: time.Millisecond,
})
require.NoError(t, err)
worker := newChatWorker(
nil,
uuid.New(),
f.db,
f.pubsub,
buffer,
chatWorkerOptions{
Logger: slog.Make(),
AcquisitionInterval: time.Hour,
AcquisitionBatchSize: 10,
RunnerSyncInterval: time.Hour,
HeartbeatInterval: time.Hour,
HeartbeatCleanupInterval: time.Hour,
HeartbeatStaleSeconds: 30,
StateChannelSize: 16,
RunnerManagerChannelSize: 16,
AcquisitionWakeChannelSize: 1,
TaskRetryInitialBackoff: time.Millisecond,
TaskRetryMaxBackoff: time.Millisecond,
},
)
require.NoError(t, worker.Start(context.Background()))
t.Cleanup(func() { require.NoError(t, worker.Close()) })
return worker
Expand Down
22 changes: 17 additions & 5 deletions coderd/x/chatd/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/coder/coder/v2/coderd/database"
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
"github.com/coder/coder/v2/coderd/x/chatd/chatstate"
"github.com/coder/coder/v2/coderd/x/chatd/messagepartbuffer"
)

// chatWorker owns chat acquisition and runner lifecycle for one process.
Expand All @@ -32,12 +33,23 @@ type chatWorker struct {

// newChatWorker constructs a chat worker. The worker is idle until Start is
// called.
func newChatWorker(server *Server, opts chatWorkerOptions) (*chatWorker, error) {
withDefaults, err := opts.withDefaults()
if err != nil {
return nil, err
func newChatWorker(
server *Server,
workerID uuid.UUID,
store database.Store,
pubsub chatWorkerPubsub,
messagePartBuffer *messagepartbuffer.Buffer,
opts chatWorkerOptions,
) *chatWorker {
return &chatWorker{
server: server,
opts: opts.withDefaults(chatWorkerDependencies{
WorkerID: workerID,
Store: store,
Pubsub: pubsub,
MessagePartBuffer: messagePartBuffer,
}),
}
return &chatWorker{server: server, opts: withDefaults}, nil
}

// chatWorkerID returns this worker's configured worker ID.
Expand Down
Loading
Loading