Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions agent/agentproc/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*p
for k, v := range req.Env {
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", k, v))
}
// Propagate the chat ID so child processes (e.g.
// GIT_ASKPASS) can send it back to the server.
if chatID != "" {
cmd.Env = append(cmd.Env, fmt.Sprintf("CODER_CHAT_ID=%s", chatID))
}

if err := cmd.Start(); err != nil {
cancel()
Expand Down
5 changes: 3 additions & 2 deletions coderd/exp_chats.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,12 @@ const (
maxSystemPromptLenBytes = 131072 // 128 KiB
)

// chatGitRef holds the branch and remote origin reported by the
// workspace agent during a git operation.
// chatGitRef holds the branch, remote origin, and optional chat
// ID reported by the workspace agent during a git operation.
type chatGitRef struct {
Branch string
RemoteOrigin string
ChatID uuid.UUID
}

type chatRepositoryRef struct {
Expand Down
22 changes: 20 additions & 2 deletions coderd/workspaceagents.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import (
"github.com/coder/coder/v2/coderd/telemetry"
maputil "github.com/coder/coder/v2/coderd/util/maps"
"github.com/coder/coder/v2/coderd/wspubsub"
"github.com/coder/coder/v2/coderd/x/gitsync"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
Expand Down Expand Up @@ -1840,6 +1841,11 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
Branch: strings.TrimSpace(query.Get("git_branch")),
RemoteOrigin: strings.TrimSpace(query.Get("git_remote_origin")),
}
if raw := strings.TrimSpace(query.Get("chat_id")); raw != "" {
if parsed, err := uuid.Parse(raw); err == nil {
gitRef.ChatID = parsed
}
}
// Either match or configID must be provided!
match := query.Get("match")
if match == "" {
Expand Down Expand Up @@ -1938,7 +1944,13 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
// context is retained even if the flow requires an out-of-band login.
if gitRef.Branch != "" && gitRef.RemoteOrigin != "" {
//nolint:gocritic // Chat processor context required for cross-user chat lookup
api.gitSyncWorker.MarkStale(dbauthz.AsChatd(ctx), workspace.ID, workspace.OwnerID, gitRef.Branch, gitRef.RemoteOrigin)
api.gitSyncWorker.MarkStale(dbauthz.AsChatd(ctx), gitsync.MarkStaleParams{
WorkspaceID: workspace.ID,
OwnerID: workspace.OwnerID,
Branch: gitRef.Branch,
Origin: gitRef.RemoteOrigin,
ChatID: gitRef.ChatID,
})
}

var previousToken *database.ExternalAuthLink
Expand Down Expand Up @@ -2087,7 +2099,13 @@ func (api *API) workspaceAgentsExternalAuthListen(ctx context.Context, rw http.R
}
// MarkStale will trigger a refresh by coderd/gitsync.
//nolint:gocritic // Chat processor context required for cross-user chat lookup
api.gitSyncWorker.MarkStale(dbauthz.AsChatd(ctx), workspace.ID, workspace.OwnerID, gitRef.Branch, gitRef.RemoteOrigin)
api.gitSyncWorker.MarkStale(dbauthz.AsChatd(ctx), gitsync.MarkStaleParams{
WorkspaceID: workspace.ID,
OwnerID: workspace.OwnerID,
Branch: gitRef.Branch,
Origin: gitRef.RemoteOrigin,
ChatID: gitRef.ChatID,
})
httpapi.Write(ctx, rw, http.StatusOK, resp)
return
}
Expand Down
98 changes: 63 additions & 35 deletions coderd/x/gitsync/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,25 +274,44 @@ func (w *Worker) tick(ctx context.Context) {
}
}

// MarkStale persists the git ref on all chats for a workspace,
// setting stale_at to the past so the next tick picks them up.
// Publishes a diff status event for each affected chat.
// MarkStaleParams holds the arguments for Worker.MarkStale.
type MarkStaleParams struct {
WorkspaceID uuid.UUID
OwnerID uuid.UUID
Branch string
Origin string
// ChatID, when set, targets a single chat instead of
// broadcasting to every chat on the workspace.
ChatID uuid.UUID
}

// MarkStale persists the git ref for a chat (or all chats on a
// workspace when no ChatID is provided), setting stale_at to the
// past so the next tick picks them up. Publishes a diff status
// event for each affected chat.
// Called from workspaceagents handlers. No goroutines spawned.
func (w *Worker) MarkStale(
ctx context.Context,
workspaceID, ownerID uuid.UUID,
branch, origin string,
) {
if branch == "" || origin == "" {
func (w *Worker) MarkStale(ctx context.Context, p MarkStaleParams) {
if p.Branch == "" || p.Origin == "" {
return
}

// When a specific chat is identified, target it directly
// instead of broadcasting to every chat on the workspace.
// Note: this path does not verify that the chat belongs to
// WorkspaceID. This is safe because ChatID originates from
// chatd via the agent (trusted data flow), but differs from
// the broadcast path which filters by workspace.
if p.ChatID != uuid.Nil {
w.markStaleSingle(ctx, p.ChatID, p.Branch, p.Origin)
return
}

chatRows, err := w.store.GetChats(ctx, database.GetChatsParams{
OwnerID: ownerID,
OwnerID: p.OwnerID,
})
if err != nil {
w.logger.Warn(ctx, "list chats for git ref storage",
slog.F("workspace_id", workspaceID),
slog.F("workspace_id", p.WorkspaceID),
slog.Error(err))
return
}
Expand All @@ -302,30 +321,39 @@ func (w *Worker) MarkStale(
chats[i] = row.Chat
}

for _, chat := range filterChatsByWorkspaceID(chats, workspaceID) {
_, err := w.store.UpsertChatDiffStatusReference(ctx,
database.UpsertChatDiffStatusReferenceParams{
ChatID: chat.ID,
GitBranch: branch,
GitRemoteOrigin: origin,
StaleAt: w.clock.Now().Add(-time.Second),
Url: sql.NullString{},
},
)
if err != nil {
w.logger.Warn(ctx, "store git ref on chat diff status",
slog.F("chat_id", chat.ID),
slog.F("workspace_id", workspaceID),
slog.Error(err))
continue
}
// Notify the frontend immediately so the UI shows the
// branch info even before the worker refreshes PR data.
if w.publishDiffStatusChangeFn != nil {
if pubErr := w.publishDiffStatusChangeFn(ctx, chat.ID); pubErr != nil {
w.logger.Debug(ctx, "publish diff status after mark stale",
slog.F("chat_id", chat.ID), slog.Error(pubErr))
}
for _, chat := range filterChatsByWorkspaceID(chats, p.WorkspaceID) {
w.markStaleSingle(ctx, chat.ID, p.Branch, p.Origin)
}
}

// markStaleSingle upserts the git ref for a single chat and
// publishes a diff-status change event.
func (w *Worker) markStaleSingle(
ctx context.Context,
chatID uuid.UUID,
branch, origin string,
) {
_, err := w.store.UpsertChatDiffStatusReference(ctx,
database.UpsertChatDiffStatusReferenceParams{
ChatID: chatID,
GitBranch: branch,
GitRemoteOrigin: origin,
StaleAt: w.clock.Now().Add(-time.Second),
Url: sql.NullString{},
},
)
if err != nil {
w.logger.Warn(ctx, "store git ref on chat diff status",
slog.F("chat_id", chatID),
slog.Error(err))
return
}
// Notify the frontend immediately so the UI shows the
// branch info even before the worker refreshes PR data.
if w.publishDiffStatusChangeFn != nil {
if pubErr := w.publishDiffStatusChangeFn(ctx, chatID); pubErr != nil {
w.logger.Debug(ctx, "publish diff status after mark stale",
slog.F("chat_id", chatID), slog.Error(pubErr))
}
}
}
Expand Down
154 changes: 149 additions & 5 deletions coderd/x/gitsync/worker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,12 @@ func TestWorker_MarkStale_UpsertAndPublish(t *testing.T) {
refresher := newTestRefresher(t, mClock)
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)

worker.MarkStale(ctx, workspaceID, ownerID, "feature", "https://github.com/owner/repo")
worker.MarkStale(ctx, gitsync.MarkStaleParams{
WorkspaceID: workspaceID,
OwnerID: ownerID,
Branch: "feature",
Origin: "https://github.com/owner/repo",
})

mu.Lock()
defer mu.Unlock()
Expand Down Expand Up @@ -683,7 +688,12 @@ func TestWorker_MarkStale_NoMatchingChats(t *testing.T) {
refresher := newTestRefresher(t, mClock)
worker := gitsync.NewWorker(store, refresher, nil, mClock, logger)

worker.MarkStale(ctx, workspaceID, ownerID, "main", "https://github.com/x/y")
worker.MarkStale(ctx, gitsync.MarkStaleParams{
WorkspaceID: workspaceID,
OwnerID: ownerID,
Branch: "main",
Origin: "https://github.com/x/y",
})
}

func TestWorker_MarkStale_UpsertFails_ContinuesNext(t *testing.T) {
Expand Down Expand Up @@ -723,7 +733,12 @@ func TestWorker_MarkStale_UpsertFails_ContinuesNext(t *testing.T) {
refresher := newTestRefresher(t, mClock)
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)

worker.MarkStale(ctx, workspaceID, ownerID, "dev", "https://github.com/a/b")
worker.MarkStale(ctx, gitsync.MarkStaleParams{
WorkspaceID: workspaceID,
OwnerID: ownerID,
Branch: "dev",
Origin: "https://github.com/a/b",
})

assert.Equal(t, int32(1), publishCount.Load())
}
Expand All @@ -743,7 +758,12 @@ func TestWorker_MarkStale_GetChatsFails(t *testing.T) {
refresher := newTestRefresher(t, mClock)
worker := gitsync.NewWorker(store, refresher, nil, mClock, logger)

worker.MarkStale(ctx, uuid.New(), uuid.New(), "main", "https://github.com/x/y")
worker.MarkStale(ctx, gitsync.MarkStaleParams{
WorkspaceID: uuid.New(),
OwnerID: uuid.New(),
Branch: "main",
Origin: "https://github.com/x/y",
})
}

func TestWorker_TickStoreError(t *testing.T) {
Expand Down Expand Up @@ -795,11 +815,135 @@ func TestWorker_MarkStale_EmptyBranchOrOrigin(t *testing.T) {
refresher := newTestRefresher(t, mClock)
worker := gitsync.NewWorker(store, refresher, nil, mClock, logger)

worker.MarkStale(ctx, uuid.New(), uuid.New(), tc.branch, tc.origin)
worker.MarkStale(ctx, gitsync.MarkStaleParams{
WorkspaceID: uuid.New(),
OwnerID: uuid.New(),
Branch: tc.branch,
Origin: tc.origin,
})
})
}
}

func TestWorker_MarkStale_WithChatID(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)

targetChat := uuid.New()

var mu sync.Mutex
var upsertRefCalls []database.UpsertChatDiffStatusReferenceParams
var publishedIDs []uuid.UUID

ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)

// GetChats should NOT be called when a specific chat ID is provided.
store.EXPECT().GetChats(gomock.Any(), gomock.Any()).Times(0)
store.EXPECT().UpsertChatDiffStatusReference(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) {
mu.Lock()
upsertRefCalls = append(upsertRefCalls, arg)
mu.Unlock()
return database.ChatDiffStatus{ChatID: arg.ChatID}, nil
}).Times(1)

pub := func(_ context.Context, chatID uuid.UUID) error {
mu.Lock()
publishedIDs = append(publishedIDs, chatID)
mu.Unlock()
return nil
}

mClock := quartz.NewMock(t)
now := mClock.Now()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
refresher := newTestRefresher(t, mClock)
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)

worker.MarkStale(ctx, gitsync.MarkStaleParams{
WorkspaceID: uuid.New(),
OwnerID: uuid.New(),
Branch: "my-branch",
Origin: "https://github.com/org/repo",
ChatID: targetChat,
})

mu.Lock()
defer mu.Unlock()

require.Len(t, upsertRefCalls, 1)
assert.Equal(t, targetChat, upsertRefCalls[0].ChatID)
assert.Equal(t, "my-branch", upsertRefCalls[0].GitBranch)
assert.Equal(t, "https://github.com/org/repo", upsertRefCalls[0].GitRemoteOrigin)
assert.True(t, upsertRefCalls[0].StaleAt.Before(now),
"stale_at should be in the past, got %v vs now %v", upsertRefCalls[0].StaleAt, now)

require.Len(t, publishedIDs, 1)
assert.Equal(t, targetChat, publishedIDs[0])
}

func TestWorker_MarkStale_NilChatID_Broadcasts(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)

workspaceID := uuid.New()
ownerID := uuid.New()
chat1 := uuid.New()

var mu sync.Mutex
var upsertRefCalls []database.UpsertChatDiffStatusReferenceParams
var publishedIDs []uuid.UUID

ctrl := gomock.NewController(t)
store := dbmock.NewMockStore(ctrl)

// GetChats IS called because a nil ChatID triggers the
// workspace-wide broadcast path.
store.EXPECT().GetChats(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, arg database.GetChatsParams) ([]database.GetChatsRow, error) {
require.Equal(t, ownerID, arg.OwnerID)
return []database.GetChatsRow{
{Chat: database.Chat{ID: chat1, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}}},
}, nil
})
store.EXPECT().UpsertChatDiffStatusReference(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) {
mu.Lock()
upsertRefCalls = append(upsertRefCalls, arg)
mu.Unlock()
return database.ChatDiffStatus{ChatID: arg.ChatID}, nil
}).Times(1)

pub := func(_ context.Context, chatID uuid.UUID) error {
mu.Lock()
publishedIDs = append(publishedIDs, chatID)
mu.Unlock()
return nil
}

mClock := quartz.NewMock(t)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
refresher := newTestRefresher(t, mClock)
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)

// Zero-value ChatID (uuid.Nil) triggers broadcast.
worker.MarkStale(ctx, gitsync.MarkStaleParams{
WorkspaceID: workspaceID,
OwnerID: ownerID,
Branch: "main",
Origin: "https://github.com/org/repo",
})

mu.Lock()
defer mu.Unlock()

require.Len(t, upsertRefCalls, 1)
assert.Equal(t, chat1, upsertRefCalls[0].ChatID)
assert.Equal(t, "main", upsertRefCalls[0].GitBranch)

require.Len(t, publishedIDs, 1)
assert.Equal(t, chat1, publishedIDs[0])
}

// TestWorker exercises the worker tick against a
// real PostgreSQL database to verify that the SQL queries, foreign key
// constraints, and upsert logic work end-to-end.
Expand Down
Loading