diff --git a/agent/agentproc/process.go b/agent/agentproc/process.go index 3a457387dc5b4..c172195b8bdc5 100644 --- a/agent/agentproc/process.go +++ b/agent/agentproc/process.go @@ -148,6 +148,11 @@ func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*p for k, v := range req.Env { cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", k, v)) } + // Propagate the chat ID so child processes (e.g. + // GIT_ASKPASS) can send it back to the server. + if chatID != "" { + cmd.Env = append(cmd.Env, fmt.Sprintf("CODER_CHAT_ID=%s", chatID)) + } if err := cmd.Start(); err != nil { cancel() diff --git a/coderd/exp_chats.go b/coderd/exp_chats.go index 833233db94c5f..14e9039110e08 100644 --- a/coderd/exp_chats.go +++ b/coderd/exp_chats.go @@ -67,11 +67,12 @@ const ( maxSystemPromptLenBytes = 131072 // 128 KiB ) -// chatGitRef holds the branch and remote origin reported by the -// workspace agent during a git operation. +// chatGitRef holds the branch, remote origin, and optional chat +// ID reported by the workspace agent during a git operation. type chatGitRef struct { Branch string RemoteOrigin string + ChatID uuid.UUID } type chatRepositoryRef struct { diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 0ac88255192eb..d17315228ec3c 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -42,6 +42,7 @@ import ( "github.com/coder/coder/v2/coderd/telemetry" maputil "github.com/coder/coder/v2/coderd/util/maps" "github.com/coder/coder/v2/coderd/wspubsub" + "github.com/coder/coder/v2/coderd/x/gitsync" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/codersdk/workspacesdk" @@ -1840,6 +1841,11 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ Branch: strings.TrimSpace(query.Get("git_branch")), RemoteOrigin: strings.TrimSpace(query.Get("git_remote_origin")), } + if raw := strings.TrimSpace(query.Get("chat_id")); raw != "" { + if parsed, err := uuid.Parse(raw); err == nil { + gitRef.ChatID = parsed + } + } // Either match or configID must be provided! match := query.Get("match") if match == "" { @@ -1938,7 +1944,13 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ // context is retained even if the flow requires an out-of-band login. if gitRef.Branch != "" && gitRef.RemoteOrigin != "" { //nolint:gocritic // Chat processor context required for cross-user chat lookup - api.gitSyncWorker.MarkStale(dbauthz.AsChatd(ctx), workspace.ID, workspace.OwnerID, gitRef.Branch, gitRef.RemoteOrigin) + api.gitSyncWorker.MarkStale(dbauthz.AsChatd(ctx), gitsync.MarkStaleParams{ + WorkspaceID: workspace.ID, + OwnerID: workspace.OwnerID, + Branch: gitRef.Branch, + Origin: gitRef.RemoteOrigin, + ChatID: gitRef.ChatID, + }) } var previousToken *database.ExternalAuthLink @@ -2087,7 +2099,13 @@ func (api *API) workspaceAgentsExternalAuthListen(ctx context.Context, rw http.R } // MarkStale will trigger a refresh by coderd/gitsync. //nolint:gocritic // Chat processor context required for cross-user chat lookup - api.gitSyncWorker.MarkStale(dbauthz.AsChatd(ctx), workspace.ID, workspace.OwnerID, gitRef.Branch, gitRef.RemoteOrigin) + api.gitSyncWorker.MarkStale(dbauthz.AsChatd(ctx), gitsync.MarkStaleParams{ + WorkspaceID: workspace.ID, + OwnerID: workspace.OwnerID, + Branch: gitRef.Branch, + Origin: gitRef.RemoteOrigin, + ChatID: gitRef.ChatID, + }) httpapi.Write(ctx, rw, http.StatusOK, resp) return } diff --git a/coderd/x/gitsync/worker.go b/coderd/x/gitsync/worker.go index e082048a4caf9..aafe120827c18 100644 --- a/coderd/x/gitsync/worker.go +++ b/coderd/x/gitsync/worker.go @@ -274,25 +274,44 @@ func (w *Worker) tick(ctx context.Context) { } } -// MarkStale persists the git ref on all chats for a workspace, -// setting stale_at to the past so the next tick picks them up. -// Publishes a diff status event for each affected chat. +// MarkStaleParams holds the arguments for Worker.MarkStale. +type MarkStaleParams struct { + WorkspaceID uuid.UUID + OwnerID uuid.UUID + Branch string + Origin string + // ChatID, when set, targets a single chat instead of + // broadcasting to every chat on the workspace. + ChatID uuid.UUID +} + +// MarkStale persists the git ref for a chat (or all chats on a +// workspace when no ChatID is provided), setting stale_at to the +// past so the next tick picks them up. Publishes a diff status +// event for each affected chat. // Called from workspaceagents handlers. No goroutines spawned. -func (w *Worker) MarkStale( - ctx context.Context, - workspaceID, ownerID uuid.UUID, - branch, origin string, -) { - if branch == "" || origin == "" { +func (w *Worker) MarkStale(ctx context.Context, p MarkStaleParams) { + if p.Branch == "" || p.Origin == "" { + return + } + + // When a specific chat is identified, target it directly + // instead of broadcasting to every chat on the workspace. + // Note: this path does not verify that the chat belongs to + // WorkspaceID. This is safe because ChatID originates from + // chatd via the agent (trusted data flow), but differs from + // the broadcast path which filters by workspace. + if p.ChatID != uuid.Nil { + w.markStaleSingle(ctx, p.ChatID, p.Branch, p.Origin) return } chatRows, err := w.store.GetChats(ctx, database.GetChatsParams{ - OwnerID: ownerID, + OwnerID: p.OwnerID, }) if err != nil { w.logger.Warn(ctx, "list chats for git ref storage", - slog.F("workspace_id", workspaceID), + slog.F("workspace_id", p.WorkspaceID), slog.Error(err)) return } @@ -302,30 +321,39 @@ func (w *Worker) MarkStale( chats[i] = row.Chat } - for _, chat := range filterChatsByWorkspaceID(chats, workspaceID) { - _, err := w.store.UpsertChatDiffStatusReference(ctx, - database.UpsertChatDiffStatusReferenceParams{ - ChatID: chat.ID, - GitBranch: branch, - GitRemoteOrigin: origin, - StaleAt: w.clock.Now().Add(-time.Second), - Url: sql.NullString{}, - }, - ) - if err != nil { - w.logger.Warn(ctx, "store git ref on chat diff status", - slog.F("chat_id", chat.ID), - slog.F("workspace_id", workspaceID), - slog.Error(err)) - continue - } - // Notify the frontend immediately so the UI shows the - // branch info even before the worker refreshes PR data. - if w.publishDiffStatusChangeFn != nil { - if pubErr := w.publishDiffStatusChangeFn(ctx, chat.ID); pubErr != nil { - w.logger.Debug(ctx, "publish diff status after mark stale", - slog.F("chat_id", chat.ID), slog.Error(pubErr)) - } + for _, chat := range filterChatsByWorkspaceID(chats, p.WorkspaceID) { + w.markStaleSingle(ctx, chat.ID, p.Branch, p.Origin) + } +} + +// markStaleSingle upserts the git ref for a single chat and +// publishes a diff-status change event. +func (w *Worker) markStaleSingle( + ctx context.Context, + chatID uuid.UUID, + branch, origin string, +) { + _, err := w.store.UpsertChatDiffStatusReference(ctx, + database.UpsertChatDiffStatusReferenceParams{ + ChatID: chatID, + GitBranch: branch, + GitRemoteOrigin: origin, + StaleAt: w.clock.Now().Add(-time.Second), + Url: sql.NullString{}, + }, + ) + if err != nil { + w.logger.Warn(ctx, "store git ref on chat diff status", + slog.F("chat_id", chatID), + slog.Error(err)) + return + } + // Notify the frontend immediately so the UI shows the + // branch info even before the worker refreshes PR data. + if w.publishDiffStatusChangeFn != nil { + if pubErr := w.publishDiffStatusChangeFn(ctx, chatID); pubErr != nil { + w.logger.Debug(ctx, "publish diff status after mark stale", + slog.F("chat_id", chatID), slog.Error(pubErr)) } } } diff --git a/coderd/x/gitsync/worker_test.go b/coderd/x/gitsync/worker_test.go index d1e6d80036b73..0872a0adcd70f 100644 --- a/coderd/x/gitsync/worker_test.go +++ b/coderd/x/gitsync/worker_test.go @@ -644,7 +644,12 @@ func TestWorker_MarkStale_UpsertAndPublish(t *testing.T) { refresher := newTestRefresher(t, mClock) worker := gitsync.NewWorker(store, refresher, pub, mClock, logger) - worker.MarkStale(ctx, workspaceID, ownerID, "feature", "https://github.com/owner/repo") + worker.MarkStale(ctx, gitsync.MarkStaleParams{ + WorkspaceID: workspaceID, + OwnerID: ownerID, + Branch: "feature", + Origin: "https://github.com/owner/repo", + }) mu.Lock() defer mu.Unlock() @@ -683,7 +688,12 @@ func TestWorker_MarkStale_NoMatchingChats(t *testing.T) { refresher := newTestRefresher(t, mClock) worker := gitsync.NewWorker(store, refresher, nil, mClock, logger) - worker.MarkStale(ctx, workspaceID, ownerID, "main", "https://github.com/x/y") + worker.MarkStale(ctx, gitsync.MarkStaleParams{ + WorkspaceID: workspaceID, + OwnerID: ownerID, + Branch: "main", + Origin: "https://github.com/x/y", + }) } func TestWorker_MarkStale_UpsertFails_ContinuesNext(t *testing.T) { @@ -723,7 +733,12 @@ func TestWorker_MarkStale_UpsertFails_ContinuesNext(t *testing.T) { refresher := newTestRefresher(t, mClock) worker := gitsync.NewWorker(store, refresher, pub, mClock, logger) - worker.MarkStale(ctx, workspaceID, ownerID, "dev", "https://github.com/a/b") + worker.MarkStale(ctx, gitsync.MarkStaleParams{ + WorkspaceID: workspaceID, + OwnerID: ownerID, + Branch: "dev", + Origin: "https://github.com/a/b", + }) assert.Equal(t, int32(1), publishCount.Load()) } @@ -743,7 +758,12 @@ func TestWorker_MarkStale_GetChatsFails(t *testing.T) { refresher := newTestRefresher(t, mClock) worker := gitsync.NewWorker(store, refresher, nil, mClock, logger) - worker.MarkStale(ctx, uuid.New(), uuid.New(), "main", "https://github.com/x/y") + worker.MarkStale(ctx, gitsync.MarkStaleParams{ + WorkspaceID: uuid.New(), + OwnerID: uuid.New(), + Branch: "main", + Origin: "https://github.com/x/y", + }) } func TestWorker_TickStoreError(t *testing.T) { @@ -795,11 +815,135 @@ func TestWorker_MarkStale_EmptyBranchOrOrigin(t *testing.T) { refresher := newTestRefresher(t, mClock) worker := gitsync.NewWorker(store, refresher, nil, mClock, logger) - worker.MarkStale(ctx, uuid.New(), uuid.New(), tc.branch, tc.origin) + worker.MarkStale(ctx, gitsync.MarkStaleParams{ + WorkspaceID: uuid.New(), + OwnerID: uuid.New(), + Branch: tc.branch, + Origin: tc.origin, + }) }) } } +func TestWorker_MarkStale_WithChatID(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + targetChat := uuid.New() + + var mu sync.Mutex + var upsertRefCalls []database.UpsertChatDiffStatusReferenceParams + var publishedIDs []uuid.UUID + + ctrl := gomock.NewController(t) + store := dbmock.NewMockStore(ctrl) + + // GetChats should NOT be called when a specific chat ID is provided. + store.EXPECT().GetChats(gomock.Any(), gomock.Any()).Times(0) + store.EXPECT().UpsertChatDiffStatusReference(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) { + mu.Lock() + upsertRefCalls = append(upsertRefCalls, arg) + mu.Unlock() + return database.ChatDiffStatus{ChatID: arg.ChatID}, nil + }).Times(1) + + pub := func(_ context.Context, chatID uuid.UUID) error { + mu.Lock() + publishedIDs = append(publishedIDs, chatID) + mu.Unlock() + return nil + } + + mClock := quartz.NewMock(t) + now := mClock.Now() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + refresher := newTestRefresher(t, mClock) + worker := gitsync.NewWorker(store, refresher, pub, mClock, logger) + + worker.MarkStale(ctx, gitsync.MarkStaleParams{ + WorkspaceID: uuid.New(), + OwnerID: uuid.New(), + Branch: "my-branch", + Origin: "https://github.com/org/repo", + ChatID: targetChat, + }) + + mu.Lock() + defer mu.Unlock() + + require.Len(t, upsertRefCalls, 1) + assert.Equal(t, targetChat, upsertRefCalls[0].ChatID) + assert.Equal(t, "my-branch", upsertRefCalls[0].GitBranch) + assert.Equal(t, "https://github.com/org/repo", upsertRefCalls[0].GitRemoteOrigin) + assert.True(t, upsertRefCalls[0].StaleAt.Before(now), + "stale_at should be in the past, got %v vs now %v", upsertRefCalls[0].StaleAt, now) + + require.Len(t, publishedIDs, 1) + assert.Equal(t, targetChat, publishedIDs[0]) +} + +func TestWorker_MarkStale_NilChatID_Broadcasts(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + + workspaceID := uuid.New() + ownerID := uuid.New() + chat1 := uuid.New() + + var mu sync.Mutex + var upsertRefCalls []database.UpsertChatDiffStatusReferenceParams + var publishedIDs []uuid.UUID + + ctrl := gomock.NewController(t) + store := dbmock.NewMockStore(ctrl) + + // GetChats IS called because a nil ChatID triggers the + // workspace-wide broadcast path. + store.EXPECT().GetChats(gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, arg database.GetChatsParams) ([]database.GetChatsRow, error) { + require.Equal(t, ownerID, arg.OwnerID) + return []database.GetChatsRow{ + {Chat: database.Chat{ID: chat1, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}}}, + }, nil + }) + store.EXPECT().UpsertChatDiffStatusReference(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) { + mu.Lock() + upsertRefCalls = append(upsertRefCalls, arg) + mu.Unlock() + return database.ChatDiffStatus{ChatID: arg.ChatID}, nil + }).Times(1) + + pub := func(_ context.Context, chatID uuid.UUID) error { + mu.Lock() + publishedIDs = append(publishedIDs, chatID) + mu.Unlock() + return nil + } + + mClock := quartz.NewMock(t) + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + refresher := newTestRefresher(t, mClock) + worker := gitsync.NewWorker(store, refresher, pub, mClock, logger) + + // Zero-value ChatID (uuid.Nil) triggers broadcast. + worker.MarkStale(ctx, gitsync.MarkStaleParams{ + WorkspaceID: workspaceID, + OwnerID: ownerID, + Branch: "main", + Origin: "https://github.com/org/repo", + }) + + mu.Lock() + defer mu.Unlock() + + require.Len(t, upsertRefCalls, 1) + assert.Equal(t, chat1, upsertRefCalls[0].ChatID) + assert.Equal(t, "main", upsertRefCalls[0].GitBranch) + + require.Len(t, publishedIDs, 1) + assert.Equal(t, chat1, publishedIDs[0]) +} + // TestWorker exercises the worker tick against a // real PostgreSQL database to verify that the SQL queries, foreign key // constraints, and upsert logic work end-to-end.