Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 65 additions & 5 deletions coderd/exp_chats.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ import (
"github.com/coder/coder/v2/coderd/workspaceapps"
"github.com/coder/coder/v2/coderd/wsbuilder"
"github.com/coder/coder/v2/coderd/x/chatd"
"github.com/coder/coder/v2/coderd/x/chatd/agentselect"
"github.com/coder/coder/v2/coderd/x/chatd/chaterror"
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
"github.com/coder/coder/v2/coderd/x/chatd/chatstate"
Expand Down Expand Up @@ -471,7 +472,46 @@ func (api *API) listChats(rw http.ResponseWriter, r *http.Request) {
return
}

httpapi.Write(ctx, rw, http.StatusOK, db2sdk.ChatRowsWithChildren(chatRows, childRows, diffStatusesByChatID))
sdkChats := db2sdk.ChatRowsWithChildren(chatRows, childRows, diffStatusesByChatID)
api.enrichChatAgentIDs(ctx, sdkChats)
httpapi.Write(ctx, rw, http.StatusOK, sdkChats)
}

// enrichChatAgentIDs fills AgentID on chats (and their embedded children)
// that have a bound workspace but no persisted agent binding, which chatd
// only persists once a turn dials the workspace. Clients rely on this field
// instead of selecting an agent themselves. Enrichment is response-only and
// best-effort: on error the field stays null.
func (api *API) enrichChatAgentIDs(ctx context.Context, chats []codersdk.Chat) {
agentIDByWorkspace := make(map[uuid.UUID]*uuid.UUID)
resolve := func(workspaceID uuid.UUID) *uuid.UUID {
if agentID, ok := agentIDByWorkspace[workspaceID]; ok {
return agentID
}
var agentID *uuid.UUID
agents, err := api.Database.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, workspaceID)
if err != nil {
api.Logger.Debug(ctx, "failed to fetch workspace agents for chat agent enrichment",
slog.F("workspace_id", workspaceID),
slog.Error(err),
)
} else if agent, err := agentselect.FindChatAgent(agents); err == nil {
agentID = &agent.ID
}
agentIDByWorkspace[workspaceID] = agentID
return agentID
}
for i := range chats {
if chats[i].AgentID == nil && chats[i].WorkspaceID != nil {
chats[i].AgentID = resolve(*chats[i].WorkspaceID)
}
for j := range chats[i].Children {
child := &chats[i].Children[j]
if child.AgentID == nil && child.WorkspaceID != nil {
child.AgentID = resolve(*child.WorkspaceID)
}
}
}
}

func (api *API) getChatDiffStatusesByChatID(
Expand Down Expand Up @@ -2104,6 +2144,10 @@ func (api *API) getChat(rw http.ResponseWriter, r *http.Request) {
sdkChat.Children = db2sdk.ChildChatRows(childRows, childDiffStatuses)
}

enriched := []codersdk.Chat{sdkChat}
api.enrichChatAgentIDs(ctx, enriched)
sdkChat = enriched[0]

httpapi.Write(ctx, rw, http.StatusOK, sdkChat)
}

Expand Down Expand Up @@ -2373,11 +2417,19 @@ func (api *API) watchChatGit(rw http.ResponseWriter, r *http.Request) {
})
return
}
agent, err := agentselect.FindChatAgent(agents)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: codersdk.ChatGitWatchWorkspaceNoAgentsMessage,
Detail: err.Error(),
})
return
}

apiAgent, err := db2sdk.WorkspaceAgent(
api.DERPMap(),
*api.TailnetCoordinator.Load(),
agents[0],
agent,
nil,
nil,
nil,
Expand All @@ -2401,7 +2453,7 @@ func (api *API) watchChatGit(rw http.ResponseWriter, r *http.Request) {
dialCtx, dialCancel := context.WithTimeout(ctx, 30*time.Second)
defer dialCancel()

agentConn, release, err := api.agentProvider.AgentConn(dialCtx, agents[0].ID)
agentConn, release, err := api.agentProvider.AgentConn(dialCtx, agent.ID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error dialing workspace agent.",
Expand Down Expand Up @@ -2528,11 +2580,19 @@ func (api *API) watchChatDesktop(rw http.ResponseWriter, r *http.Request) {
})
return
}
agent, err := agentselect.FindChatAgent(agents)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Chat workspace has no eligible agents.",
Detail: err.Error(),
})
return
}

apiAgent, err := db2sdk.WorkspaceAgent(
api.DERPMap(),
*api.TailnetCoordinator.Load(),
agents[0],
agent,
nil,
nil,
nil,
Expand All @@ -2556,7 +2616,7 @@ func (api *API) watchChatDesktop(rw http.ResponseWriter, r *http.Request) {
dialCtx, dialCancel := context.WithTimeout(ctx, 30*time.Second)
defer dialCancel()

agentConn, release, err := api.agentProvider.AgentConn(dialCtx, agents[0].ID)
agentConn, release, err := api.agentProvider.AgentConn(dialCtx, agent.ID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to dial workspace agent.",
Expand Down
123 changes: 123 additions & 0 deletions coderd/exp_chats_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,135 @@ package coderd
import (
"testing"

"github.com/google/uuid"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"golang.org/x/xerrors"

"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)

func TestEnrichChatAgentIDs(t *testing.T) {
t.Parallel()

newAPI := func(t *testing.T) (*API, *dbmock.MockStore) {
t.Helper()
mDB := dbmock.NewMockStore(gomock.NewController(t))
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
return &API{
Options: &Options{
Database: mDB,
Logger: logger,
},
}, mDB
}

t.Run("ResolvesRootAgentSkippingSubAgent", func(t *testing.T) {
t.Parallel()

var (
ctx = testutil.Context(t, testutil.WaitShort)
workspaceID = uuid.New()
rootAgentID = uuid.New()
)
api, mDB := newAPI(t)

// The sub-agent is returned first to prove selection is not
// positional.
mDB.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
Return([]database.WorkspaceAgent{
{
ID: uuid.New(),
ParentID: uuid.NullUUID{UUID: rootAgentID, Valid: true},
Name: "dev-container",
},
{
ID: rootAgentID,
Name: "main",
},
}, nil)

chats := []codersdk.Chat{{WorkspaceID: &workspaceID}}
api.enrichChatAgentIDs(ctx, chats)

require.NotNil(t, chats[0].AgentID)
require.Equal(t, rootAgentID, *chats[0].AgentID)
})

t.Run("DeduplicatesLookupsAndEnrichesChildren", func(t *testing.T) {
t.Parallel()

var (
ctx = testutil.Context(t, testutil.WaitShort)
workspaceID = uuid.New()
agentID = uuid.New()
)
api, mDB := newAPI(t)

// A single lookup serves the root chat and its child; gomock
// fails the test on a second call.
mDB.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
Return([]database.WorkspaceAgent{{ID: agentID, Name: "main"}}, nil).
Times(1)

chats := []codersdk.Chat{{
WorkspaceID: &workspaceID,
Children: []codersdk.Chat{{WorkspaceID: &workspaceID}},
}}
api.enrichChatAgentIDs(ctx, chats)

require.NotNil(t, chats[0].AgentID)
require.Equal(t, agentID, *chats[0].AgentID)
require.NotNil(t, chats[0].Children[0].AgentID)
require.Equal(t, agentID, *chats[0].Children[0].AgentID)
})

t.Run("LeavesNullOnError", func(t *testing.T) {
t.Parallel()

var (
ctx = testutil.Context(t, testutil.WaitShort)
workspaceID = uuid.New()
)
api, mDB := newAPI(t)

mDB.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
Return(nil, xerrors.New("boom"))

chats := []codersdk.Chat{{WorkspaceID: &workspaceID}}
api.enrichChatAgentIDs(ctx, chats)

require.Nil(t, chats[0].AgentID)
})

t.Run("SkipsChatsWithoutWorkspaceOrWithAgent", func(t *testing.T) {
t.Parallel()

var (
ctx = testutil.Context(t, testutil.WaitShort)
workspaceID = uuid.New()
existing = uuid.New()
)
// No database expectations: neither chat should trigger a
// lookup.
api, _ := newAPI(t)

chats := []codersdk.Chat{
{},
{WorkspaceID: &workspaceID, AgentID: &existing},
}
api.enrichChatAgentIDs(ctx, chats)

require.Nil(t, chats[0].AgentID)
require.Equal(t, existing, *chats[1].AgentID)
})
}

func TestValidateChatModelProviderOptions_AnthropicThinkingDisplay(t *testing.T) {
t.Parallel()

Expand Down
101 changes: 101 additions & 0 deletions coderd/workspaceagents_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,107 @@ func TestWatchChatGit(t *testing.T) {
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
})

t.Run("RootAgentPreferredOverSubAgent", func(t *testing.T) {
t.Parallel()

// This test ensures that the handler selects a root agent
// even when the database returns a dev container sub-agent
// (parent_id set) first.

var (
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug).Named("coderd")

mCtrl = gomock.NewController(t)
mDB = dbmock.NewMockStore(mCtrl)
mCoordinator = tailnettest.NewMockCoordinator(mCtrl)

chatID = uuid.New()
workspaceID = uuid.New()
subAgentID = uuid.New()
rootAgentID = uuid.New()
resourceID = uuid.New()

r = chi.NewMux()

api = API{
ctx: ctx,
Options: &Options{
AgentInactiveDisconnectTimeout: testutil.WaitShort,
Database: mDB,
Logger: logger,
DeploymentValues: &codersdk.DeploymentValues{},
TailnetCoordinator: tailnettest.NewFakeCoordinator(),
},
HTTPAuth: &HTTPAuthorizer{
Authorizer: &mockAuthorizer{},
Logger: logger,
},
}
)

var tailnetCoordinator tailnet.Coordinator = mCoordinator
api.TailnetCoordinator.Store(&tailnetCoordinator)

// Setup: Return a chat with a valid workspace ID.
mDB.EXPECT().GetChatByID(gomock.Any(), chatID).Return(database.Chat{
ID: chatID,
OwnerID: uuid.New(),
WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true},
}, nil)

// And: Return the workspace so the handler's
// workspace-level authz check can run.
mDB.EXPECT().GetWorkspaceByID(gomock.Any(), workspaceID).Return(database.Workspace{
ID: workspaceID,
}, nil)

// And: Return a sub-agent first, then the root agent. Both
// are disconnected so the handler stops after the agent
// state check.
mDB.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(gomock.Any(), workspaceID).
Return([]database.WorkspaceAgent{
{
ID: subAgentID,
ParentID: uuid.NullUUID{UUID: rootAgentID, Valid: true},
Name: "dev-container",
ResourceID: resourceID,
LifecycleState: database.WorkspaceAgentLifecycleStateCreated,
},
{
ID: rootAgentID,
Name: "main",
ResourceID: resourceID,
LifecycleState: database.WorkspaceAgentLifecycleStateCreated,
},
}, nil)

// Then: Node receives the root agent's ID, proving it was
// selected over the sub-agent.
mCoordinator.EXPECT().Node(rootAgentID).Return(nil)

// And: We mount the HTTP handler.
r.With(injectSystemActor, httpmw.ExtractChatParam(mDB)).
Get("/chats/{chat}/stream/git", api.watchChatGit)

// Given: We create the HTTP server.
srv := httptest.NewServer(r)
defer srv.Close()

// When: We make a request.
req, err := http.NewRequestWithContext(ctx, http.MethodGet,
fmt.Sprintf("%s/chats/%s/stream/git", srv.URL, chatID), nil)
require.NoError(t, err)

resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

// Then: We expect a 400 response since the selected root
// agent is not connected.
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
})

t.Run("BidirectionalProxyWorks", func(t *testing.T) {
t.Parallel()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"github.com/stretchr/testify/require"

"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/x/chatd/internal/agentselect"
"github.com/coder/coder/v2/coderd/x/chatd/agentselect"
)

func TestFindChatAgent(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion coderd/x/chatd/chatd.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import (
"github.com/coder/coder/v2/coderd/util/xjson"
"github.com/coder/coder/v2/coderd/webpush"
"github.com/coder/coder/v2/coderd/workspacestats"
"github.com/coder/coder/v2/coderd/x/chatd/agentselect"
"github.com/coder/coder/v2/coderd/x/chatd/chatadvisor"
"github.com/coder/coder/v2/coderd/x/chatd/chatcost"
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
Expand All @@ -49,7 +50,6 @@ import (
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
"github.com/coder/coder/v2/coderd/x/chatd/chatstate"
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
"github.com/coder/coder/v2/coderd/x/chatd/internal/agentselect"
"github.com/coder/coder/v2/coderd/x/chatd/mcpclient"
"github.com/coder/coder/v2/coderd/x/chatd/messagepartbuffer"
skillspkg "github.com/coder/coder/v2/coderd/x/skills"
Expand Down
Loading
Loading