Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions coderd/agentapi/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,15 @@ type Options struct {
OrganizationID uuid.UUID
TemplateVersionID uuid.UUID

AuthenticatedCtx context.Context
Log slog.Logger
Clock quartz.Clock
Database database.Store
NotificationsEnqueuer notifications.Enqueuer
Pubsub pubsub.Pubsub
AuthenticatedCtx context.Context
Log slog.Logger
Clock quartz.Clock
Database database.Store
NotificationsEnqueuer notifications.Enqueuer
Pubsub pubsub.Pubsub
// ContextDirtyMarker is the chatd-backed hydrate/dirty fan-out invoked
// from PushContextState. Nil when chatd is disabled.
ContextDirtyMarker ContextDirtyMarker
ConnectionLogger *atomic.Pointer[connectionlog.ConnectionLogger]
DerpMapFn func() *tailcfg.DERPMap
TailnetCoordinator *atomic.Pointer[tailnet.Coordinator]
Expand Down Expand Up @@ -248,11 +251,12 @@ func New(opts Options, workspace database.Workspace, agent database.WorkspaceAge
}

api.ContextAPI = &ContextAPI{
AgentID: agent.ID,
Workspace: api.cachedWorkspaceFields,
Log: opts.Log,
Clock: opts.Clock,
Database: opts.Database,
AgentID: agent.ID,
Workspace: api.cachedWorkspaceFields,
Log: opts.Log,
Clock: opts.Clock,
Database: opts.Database,
DirtyMarker: opts.ContextDirtyMarker,
}

// Start background cache refresh loop to handle workspace changes
Expand Down
41 changes: 41 additions & 0 deletions coderd/agentapi/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"errors"
"math"
"sort"
"time"

"github.com/google/uuid"
"golang.org/x/xerrors"
Expand Down Expand Up @@ -62,6 +63,25 @@ type ContextAPI struct {
Log slog.Logger
Clock quartz.Clock
Database database.Store
// DirtyMarker hydrates chats from, and marks chats dirty against, the
// snapshot persisted by a push. It is nil when chatd is not running,
// in which case PushContextState stays a pure write path.
DirtyMarker ContextDirtyMarker
}

// ContextDirtyMarker hydrates chats from, and marks chats dirty against, a
// freshly persisted agent context snapshot. It is implemented by chatd and
// injected at coderd construction so this package neither imports the chat
// domain nor performs chat-authorized writes directly.
type ContextDirtyMarker interface {
// HydrateAndMarkChatsDirty runs inside the PushContextState
// transaction using the supplied store. It hydrates chats for the
// agent that have no pinned hash yet (no dirty event) and flips
// already-pinned chats whose hash differs from aggregateHash. It
// returns a callback that publishes the resulting dirty watch events;
// the caller invokes it only after the transaction commits. The
// callback is nil when nothing transitioned to dirty.
HydrateAndMarkChatsDirty(ctx context.Context, tx database.Store, agentID uuid.UUID, aggregateHash []byte, snapshotError string, now time.Time) (publishDirty func(), err error)
}

// PushContextState persists a snapshot pushed by the workspace
Expand Down Expand Up @@ -120,10 +140,15 @@ func (a *ContextAPI) PushContextState(ctx context.Context, req *agentproto.PushC
sort.Strings(activeSources)

var accepted bool
// publishDirty is captured from the final (committed) attempt and
// invoked after the transaction commits; ReadModifyUpdate may re-run
// the closure on serialization conflicts.
var publishDirty func()
err = database.ReadModifyUpdate(a.Database, func(tx database.Store) error {
// The closure re-runs on serialization conflicts; reset any
// state carried over from a rolled-back attempt.
accepted = false
publishDirty = nil

existing, err := tx.GetLatestWorkspaceAgentContextSnapshot(ctx, a.AgentID)
switch {
Expand Down Expand Up @@ -171,6 +196,16 @@ func (a *ContextAPI) PushContextState(ctx context.Context, req *agentproto.PushC
return xerrors.Errorf("delete stale resources: %w", err)
}

// Hydrate and dirty chats against the snapshot just written, in the
// same transaction so a concurrent refresh cannot interleave with
// the version gate. Events are published only after commit.
if a.DirtyMarker != nil {
publishDirty, err = a.DirtyMarker.HydrateAndMarkChatsDirty(ctx, tx, a.AgentID, req.AggregateHash, req.SnapshotError, now)
if err != nil {
return xerrors.Errorf("hydrate and mark chats dirty: %w", err)
}
}

accepted = true
return nil
})
Expand All @@ -187,6 +222,12 @@ func (a *ContextAPI) PushContextState(ctx context.Context, req *agentproto.PushC
return &agentproto.PushContextStateResponse{Accepted: false}, nil
}

// The snapshot committed; fan out dirty watch events to chats whose
// pinned context drifted from this push.
if publishDirty != nil {
publishDirty()
}

a.Log.Debug(ctx, "PushContextState accepted",
slog.F("agent_id", a.AgentID),
slog.F("version", req.Version),
Expand Down
83 changes: 83 additions & 0 deletions coderd/agentapi/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,70 @@ func TestPushContextState(t *testing.T) {
require.True(t, resp.GetAccepted())
})

t.Run("DirtyMarkerInvokedAfterCommit", func(t *testing.T) {
t.Parallel()

api, dbm := makeAPI(t)
marker := &fakeDirtyMarker{}
api.DirtyMarker = marker
expectInTx(dbm)

dbm.EXPECT().GetLatestWorkspaceAgentContextSnapshot(gomock.Any(), agentID).
Return(database.WorkspaceAgentContextSnapshot{}, errNoRows())
dbm.EXPECT().UpsertWorkspaceAgentContextSnapshot(gomock.Any(), gomock.Any()).
Return(database.WorkspaceAgentContextSnapshot{}, nil)
dbm.EXPECT().UpsertWorkspaceAgentContextResource(gomock.Any(), gomock.Any()).
Return(database.WorkspaceAgentContextResource{}, nil).Times(1)
dbm.EXPECT().DeleteStaleWorkspaceAgentContextResources(gomock.Any(), gomock.Any()).
Return(nil)

resp, err := api.PushContextState(context.Background(), &agentproto.PushContextStateRequest{
Version: 1,
AggregateHash: []byte{0xaa, 0xbb},
SnapshotError: "watcher degraded",
Initial: true,
Resources: []*agentproto.ContextResource{
instructionResource("/home/coder/AGENTS.md", "hello"),
},
})
require.NoError(t, err)
require.True(t, resp.GetAccepted())
// The marker runs inside the push transaction and its returned
// callback publishes only after the transaction commits.
require.Equal(t, 1, marker.called)
require.Equal(t, 1, marker.published)
require.Equal(t, agentID, marker.gotAgent)
require.Equal(t, []byte{0xaa, 0xbb}, marker.gotHash)
require.Equal(t, "watcher degraded", marker.gotErr)
})

t.Run("DirtyMarkerSkippedOnDrop", func(t *testing.T) {
t.Parallel()

api, dbm := makeAPI(t)
marker := &fakeDirtyMarker{}
api.DirtyMarker = marker
expectInTx(dbm)

// A non-initial push at a version not strictly greater than the
// stored one is dropped before any write; hydration and the
// dirty fan-out must not run.
dbm.EXPECT().GetLatestWorkspaceAgentContextSnapshot(gomock.Any(), agentID).
Return(database.WorkspaceAgentContextSnapshot{Version: 5}, nil)

resp, err := api.PushContextState(context.Background(), &agentproto.PushContextStateRequest{
Version: 2,
AggregateHash: []byte{0x01},
Resources: []*agentproto.ContextResource{
instructionResource("/home/coder/AGENTS.md", "hello"),
},
})
require.NoError(t, err)
require.False(t, resp.GetAccepted())
require.Equal(t, 0, marker.called)
require.Equal(t, 0, marker.published)
})

t.Run("RejectsEmptyAndDuplicateSources", func(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -598,3 +662,22 @@ func mcpServerResource(source, serverName, description string) *agentproto.Conte
},
}
}

// fakeDirtyMarker is a test double for agentapi.ContextDirtyMarker. It records
// the in-transaction call and counts callback invocations so tests can assert
// the marker runs inside the push transaction and publishes only after commit.
type fakeDirtyMarker struct {
called int
published int
gotAgent uuid.UUID
gotHash []byte
gotErr string
}

func (f *fakeDirtyMarker) HydrateAndMarkChatsDirty(_ context.Context, _ database.Store, agentID uuid.UUID, aggregateHash []byte, snapshotError string, _ time.Time) (func(), error) {
f.called++
f.gotAgent = agentID
f.gotHash = aggregateHash
f.gotErr = snapshotError
return func() { f.published++ }, nil
}
30 changes: 30 additions & 0 deletions coderd/database/dbauthz/dbauthz.go
Original file line number Diff line number Diff line change
Expand Up @@ -5660,6 +5660,16 @@ func (q *querier) GetWorkspacesForWorkspaceMetrics(ctx context.Context) ([]datab
return q.db.GetWorkspacesForWorkspaceMetrics(ctx)
}

func (q *querier) HydrateAgentChatsContext(ctx context.Context, arg database.HydrateAgentChatsContextParams) error {
// System-level operation: an agent context push fans hydration out
// across every not-yet-pinned chat for the agent, so it authorizes at
// the resource level rather than per-chat.
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat); err != nil {
return err
}
return q.db.HydrateAgentChatsContext(ctx, arg)
}

func (q *querier) IncrementChatGenerationAttempt(ctx context.Context, id uuid.UUID) (int64, error) {
chat, err := q.db.GetChatByID(ctx, id)
if err != nil {
Expand Down Expand Up @@ -6642,6 +6652,15 @@ func (q *querier) MarkAllInboxNotificationsAsRead(ctx context.Context, arg datab
return q.db.MarkAllInboxNotificationsAsRead(ctx, arg)
}

func (q *querier) MarkChatsContextDirtyByAgent(ctx context.Context, arg database.MarkChatsContextDirtyByAgentParams) ([]database.MarkChatsContextDirtyByAgentRow, error) {
// System-level operation: the dirty fan-out runs across every active
// chat for the agent in response to a context push.
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat); err != nil {
return nil, err
}
return q.db.MarkChatsContextDirtyByAgent(ctx, arg)
}

func (q *querier) OIDCClaimFieldValues(ctx context.Context, args database.OIDCClaimFieldValuesParams) ([]string, error) {
resource := rbac.ResourceIdpsyncSettings
if args.OrganizationID != uuid.Nil {
Expand Down Expand Up @@ -6772,6 +6791,17 @@ func (q *querier) SelectUsageEventsForPublishing(ctx context.Context, arg time.T
return q.db.SelectUsageEventsForPublishing(ctx, arg)
}

func (q *querier) SetChatContextSnapshot(ctx context.Context, arg database.SetChatContextSnapshotParams) error {
chat, err := q.db.GetChatByID(ctx, arg.ID)
if err != nil {
return err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return err
}
return q.db.SetChatContextSnapshot(ctx, arg)
}

func (q *querier) SoftDeleteChatMessageByID(ctx context.Context, id int64) error {
msg, err := q.db.GetChatMessageByID(ctx, id)
if err != nil {
Expand Down
18 changes: 18 additions & 0 deletions coderd/database/dbauthz/dbauthz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,24 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().AcquireChats(gomock.Any(), arg).Return([]database.Chat{chat}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceChat, policy.ActionUpdate).Returns([]database.Chat{chat})
}))
s.Run("HydrateAgentChatsContext", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
arg := database.HydrateAgentChatsContextParams{AgentID: uuid.New()}
dbm.EXPECT().HydrateAgentChatsContext(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceChat, policy.ActionUpdate)
}))
s.Run("MarkChatsContextDirtyByAgent", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
arg := database.MarkChatsContextDirtyByAgentParams{AgentID: uuid.New()}
rows := []database.MarkChatsContextDirtyByAgentRow{{ID: uuid.New(), OwnerID: uuid.New()}}
dbm.EXPECT().MarkChatsContextDirtyByAgent(gomock.Any(), arg).Return(rows, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceChat, policy.ActionUpdate).Returns(rows)
}))
s.Run("SetChatContextSnapshot", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.SetChatContextSnapshotParams{ID: chat.ID}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().SetChatContextSnapshot(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate)
}))
s.Run("GetChatWorkerAcquisitionCandidates", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
arg := database.GetChatWorkerAcquisitionCandidatesParams{
StaleSeconds: 30,
Expand Down
24 changes: 24 additions & 0 deletions coderd/database/dbmetrics/querymetrics.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

43 changes: 43 additions & 0 deletions coderd/database/dbmock/dbmock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading