diff --git a/coderd/apidoc/docs.go b/coderd/apidoc/docs.go index edd7f23560971..244e4314ff3e7 100644 --- a/coderd/apidoc/docs.go +++ b/coderd/apidoc/docs.go @@ -959,6 +959,45 @@ const docTemplate = `{ ] } }, + "/api/experimental/chats/{chat}/stream/parts": { + "get": { + "description": "Experimental: this endpoint is subject to change.", + "produces": [ + "application/json" + ], + "tags": [ + "Chats" + ], + "summary": "Stream chat parts via WebSockets", + "operationId": "stream-chat-parts-via-websockets", + "parameters": [ + { + "type": "string", + "format": "uuid", + "description": "Chat ID", + "name": "chat", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/codersdk.ChatStreamEvent" + } + } + }, + "security": [ + { + "CoderSessionToken": [] + } + ], + "x-apidocgen": { + "skip": true + } + } + }, "/api/experimental/chats/{chat}/title/regenerate": { "post": { "description": "Experimental: this endpoint is subject to change.", @@ -17414,7 +17453,9 @@ const docTemplate = `{ "error", "queue_update", "retry", - "action_required" + "action_required", + "preview_reset", + "history_reset" ], "x-enum-varnames": [ "ChatStreamEventTypeMessagePart", @@ -17423,17 +17464,28 @@ const docTemplate = `{ "ChatStreamEventTypeError", "ChatStreamEventTypeQueueUpdate", "ChatStreamEventTypeRetry", - "ChatStreamEventTypeActionRequired" + "ChatStreamEventTypeActionRequired", + "ChatStreamEventTypePreviewReset", + "ChatStreamEventTypeHistoryReset" ] }, "codersdk.ChatStreamMessagePart": { "type": "object", "properties": { + "generation_attempt": { + "type": "integer" + }, + "history_version": { + "type": "integer" + }, "part": { "$ref": "#/definitions/codersdk.ChatMessagePart" }, "role": { "$ref": "#/definitions/codersdk.ChatMessageRole" + }, + "seq": { + "type": "integer" } } }, diff --git a/coderd/apidoc/swagger.json b/coderd/apidoc/swagger.json index 63dada070c524..8c2009cb039c1 100644 --- a/coderd/apidoc/swagger.json +++ b/coderd/apidoc/swagger.json @@ -848,6 +848,41 @@ ] } }, + "/api/experimental/chats/{chat}/stream/parts": { + "get": { + "description": "Experimental: this endpoint is subject to change.", + "produces": ["application/json"], + "tags": ["Chats"], + "summary": "Stream chat parts via WebSockets", + "operationId": "stream-chat-parts-via-websockets", + "parameters": [ + { + "type": "string", + "format": "uuid", + "description": "Chat ID", + "name": "chat", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/codersdk.ChatStreamEvent" + } + } + }, + "security": [ + { + "CoderSessionToken": [] + } + ], + "x-apidocgen": { + "skip": true + } + } + }, "/api/experimental/chats/{chat}/title/regenerate": { "post": { "description": "Experimental: this endpoint is subject to change.", @@ -15711,7 +15746,9 @@ "error", "queue_update", "retry", - "action_required" + "action_required", + "preview_reset", + "history_reset" ], "x-enum-varnames": [ "ChatStreamEventTypeMessagePart", @@ -15720,17 +15757,28 @@ "ChatStreamEventTypeError", "ChatStreamEventTypeQueueUpdate", "ChatStreamEventTypeRetry", - "ChatStreamEventTypeActionRequired" + "ChatStreamEventTypeActionRequired", + "ChatStreamEventTypePreviewReset", + "ChatStreamEventTypeHistoryReset" ] }, "codersdk.ChatStreamMessagePart": { "type": "object", "properties": { + "generation_attempt": { + "type": "integer" + }, + "history_version": { + "type": "integer" + }, "part": { "$ref": "#/definitions/codersdk.ChatMessagePart" }, "role": { "$ref": "#/definitions/codersdk.ChatMessageRole" + }, + "seq": { + "type": "integer" } } }, diff --git a/coderd/coderd.go b/coderd/coderd.go index a43bedcd02c9f..3751cdcdb5caf 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -251,9 +251,10 @@ type Options struct { SSHConfig codersdk.SSHConfigResponse HTTPClient *http.Client - // ChatSubscribeFn provides cross-replica subscription merging. - // Set by enterprise for HA deployments. Nil in AGPL single-replica. - ChatSubscribeFn chatd.SubscribeFn + // ChatStreamPartsDialer dials remote chat stream parts. + // Set by enterprise for HA deployments. Nil uses chatd's local + // in-process channel dialer. + ChatStreamPartsDialer chatd.StreamPartsDialer // ChatProviderAPIKeys overrides deployment-derived provider keys. // Test harnesses use this to route chat models to local providers. ChatProviderAPIKeys *chatprovider.ProviderAPIKeys @@ -816,7 +817,7 @@ func New(options *Options) *API { Logger: options.Logger.Named("chatd"), Database: options.Database, ReplicaID: api.ID, - SubscribeFn: options.ChatSubscribeFn, + StreamPartsDialer: options.ChatStreamPartsDialer, MaxChatsPerAcquire: int32(maxChatsPerAcquire), //nolint:gosec // maxChatsPerAcquire is clamped to int32 range above. ProviderAPIKeys: providerAPIKeys, AllowBYOK: options.DeploymentValues.AI.BridgeConfig.AllowBYOK.Value(), @@ -1340,6 +1341,7 @@ func New(options *Options) *API { r.Get("/prompts", api.getChatUserPrompts) r.Route("/stream", func(r chi.Router) { r.Get("/", api.streamChat) + r.Get("/parts", api.streamChatParts) r.Get("/desktop", api.watchChatDesktop) r.Get("/git", api.watchChatGit) }) diff --git a/coderd/exp_chats.go b/coderd/exp_chats.go index c164c68718ea6..a134be9f1bb3b 100644 --- a/coderd/exp_chats.go +++ b/coderd/exp_chats.go @@ -7958,3 +7958,23 @@ func (api *API) getChatDebugRun(rw http.ResponseWriter, r *http.Request) { httpapi.Write(ctx, rw, http.StatusOK, db2sdk.ChatDebugRunDetail(run, steps)) } + +// EXPERIMENTAL: this endpoint is experimental and is subject to change. +// +// @Summary Stream chat parts via WebSockets +// @ID stream-chat-parts-via-websockets +// @Security CoderSessionToken +// @Tags Chats +// @Produce json +// @Param chat path string true "Chat ID" format(uuid) +// @Success 200 {object} codersdk.ChatStreamEvent +// @Router /api/experimental/chats/{chat}/stream/parts [get] +// @x-apidocgen {"skip": true} +// @Description Experimental: this endpoint is subject to change. +func (api *API) streamChatParts(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + chat := httpmw.ChatParam(r) + if err := api.chatDaemon.ServeStreamPartsAuthorized(rw, r, chat); err != nil { + api.Logger.Named("chat_stream_parts").Debug(ctx, "chat stream parts closed", slog.Error(err)) + } +} diff --git a/coderd/pubsub/chatstreamnotify.go b/coderd/pubsub/chatstreamnotify.go deleted file mode 100644 index d53605d29c07b..0000000000000 --- a/coderd/pubsub/chatstreamnotify.go +++ /dev/null @@ -1,56 +0,0 @@ -package pubsub - -import ( - "fmt" - - "github.com/google/uuid" - - "github.com/coder/coder/v2/codersdk" -) - -// ChatStreamNotifyChannel returns the pubsub channel for per-chat -// stream notifications. Subscribers receive lightweight notifications -// and read actual content from the database. -func ChatStreamNotifyChannel(chatID uuid.UUID) string { - return fmt.Sprintf("chat:stream:%s", chatID) -} - -// ChatStreamNotifyMessage is the payload published on the per-chat -// stream notification channel. Durable message content is still read -// from the database, while transient control events can be carried -// inline for cross-replica delivery. -type ChatStreamNotifyMessage struct { - // AfterMessageID tells subscribers to query messages after this - // ID. Set when a new message is persisted. - AfterMessageID int64 `json:"after_message_id,omitempty"` - - // Status is set when the chat status changes. Subscribers use - // this to update clients and to manage relay lifecycle. - Status string `json:"status,omitempty"` - - // WorkerID identifies which replica is running the chat. Used - // by enterprise relay to know where to connect. - WorkerID string `json:"worker_id,omitempty"` - - // Retry carries a structured retry event for cross-replica live - // delivery. This is transient stream state and is not read back - // from the database. - Retry *codersdk.ChatStreamRetry `json:"retry,omitempty"` - - // ErrorPayload carries a structured error event for cross-replica - // live delivery. Keep Error for backward compatibility with older - // replicas during rolling deploys. - ErrorPayload *codersdk.ChatError `json:"error_payload,omitempty"` - - // Error is the legacy string-only error payload kept for mixed- - // version compatibility during rollout. - Error string `json:"error,omitempty"` - - // QueueUpdate is set when the queued messages change. - QueueUpdate bool `json:"queue_update,omitempty"` - - // FullRefresh signals that subscribers should re-fetch all - // messages from the beginning (e.g. after an edit that - // truncates message history). - FullRefresh bool `json:"full_refresh,omitempty"` -} diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index ffeb6c986d8ac..831ea20c8b1dc 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -8,7 +8,6 @@ import ( "encoding/json" "errors" "fmt" - "math" "net/http" "slices" "strconv" @@ -99,14 +98,6 @@ const ( DefaultChatHeartbeatInterval = 30 * time.Second maxChatSteps = 1200 - // RelaySentinelAfterID is the after_id sentinel used by cross-replica - // relay subscribers. It instructs the peer to skip the durable DB - // snapshot and only deliver buffered message_part events. The - // buffer itself filters committed parts out (see snapshotBufferLocked), - // so the sentinel resolves to "send me any in-progress streaming - // parts you have; I will receive durable messages through pubsub." - RelaySentinelAfterID = math.MaxInt64 - // maxConcurrentRecordingUploads caps the number of recording // stop-and-store operations that can run concurrently. Each // slot buffers up to MaxRecordingSize + MaxThumbnailSize @@ -114,25 +105,6 @@ const ( // to roughly maxConcurrentRecordingUploads * 110 MB. maxConcurrentRecordingUploads = 25 - // bufferRetainGracePeriod is how long the per-chat stream - // state is kept after processing completes. The retained - // state lets late-connecting cross-replica relay subscribers - // register against the live stream before the next worker - // run starts, preventing a race between cleanupStreamIfIdle - // and subscriber registration. The buffer itself is no - // longer useful at this point: every part has been claimed - // by its durable assistant message and is filtered out of - // the subscriber snapshot. - bufferRetainGracePeriod = 5 * time.Second - // chatStreamControlFetchTimeout bounds subscriber-owned - // control-path DB reads when the caller has no deadline. - chatStreamControlFetchTimeout = 5 * time.Second - - // streamJanitorInterval is how often sweepIdleStreams runs. - // Worst-case retention is bufferRetainGracePeriod + - // streamJanitorInterval. - streamJanitorInterval = 30 * time.Second - // agentDisconnectedRecoveryThreshold is how long the latest // workspace agent must be disconnected before chatd suggests // destructive stop/start recovery. This is intentionally longer @@ -200,7 +172,7 @@ type Server struct { workerID uuid.UUID logger slog.Logger - subscribeFn SubscribeFn + streamPartsDialer StreamPartsDialer agentConnFn AgentConnFunc agentInactiveDisconnectTimeout time.Duration @@ -221,11 +193,6 @@ type Server struct { configCache *chatConfigCache configCacheUnsubscribe func() - // chatStreams stores per-chat stream state. Using sync.Map - // gives each chat independent locking — concurrent chats - // never contend with each other. - chatStreams sync.Map // uuid.UUID -> *chatStreamState - // workspaceMCPToolsCache caches workspace MCP tool definitions // per chat to avoid re-fetching on every turn. The cache is // keyed by chat ID and invalidated when the agent changes. @@ -236,6 +203,7 @@ type Server struct { metrics *chatloop.Metrics chatWorker *chatWorker messagePartBuffer *messagepartbuffer.Buffer + streamSyncPoller *streamSyncPoller recordingSem chan struct{} aibridgeTransportFactory *atomic.Pointer[aibridge.TransportFactory] @@ -1226,150 +1194,6 @@ func (c *turnWorkspaceContext) getWorkspaceConn(ctx context.Context) (workspaces // AgentConnFunc provides access to workspace agent connections. type AgentConnFunc func(ctx context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) -// SubscribeFn replaces the default local-only subscription with a -// multi-replica-aware implementation that merges pubsub notifications, -// remote relay streams, and local parts into a single event channel. -// When set, Subscribe delegates the event-merge goroutine to this -// function instead of using simple local forwarding. -// -// Parameters: -// - ctx: subscription lifetime context (canceled on unsubscribe). -// - params: all state needed to build the merged stream. -// -// Returns the merged event channel. Cleanup is driven by ctx -// cancellation — the merge goroutine tears down all relay state -// in its defer when ctx is done. -// Set by enterprise for HA deployments. Nil in AGPL single-replica. -type SubscribeFn func( - ctx context.Context, - params SubscribeFnParams, -) <-chan codersdk.ChatStreamEvent - -// StatusNotification informs the enterprise relay manager of chat -// status changes so it can open or close relay connections. -type StatusNotification struct { - Status database.ChatStatus - WorkerID uuid.UUID -} - -// SubscribeFnParams carries the state that the enterprise -// SubscribeFn implementation needs from the OSS Subscribe preamble. -type SubscribeFnParams struct { - ChatID uuid.UUID - Chat database.Chat - WorkerID uuid.UUID - StatusNotifications <-chan StatusNotification - RequestHeader http.Header - DB database.Store - Logger slog.Logger -} - -// bufferedStreamPart is a buffered message_part event with its -// committed-message linkage. Parts that have not yet been claimed by -// a durable assistant message carry committedMessageID == 0 and are -// considered "in progress"; when an assistant message is published -// every still-in-progress part is claimed by that durable message -// ID, marking the part as redundant for any subscriber that will -// receive the durable message via REST or pubsub. -type bufferedStreamPart struct { - event codersdk.ChatStreamEvent - // committedMessageID is the durable assistant message ID that - // claimed this part, or 0 while the part belongs to the - // in-progress turn. snapshotBufferLocked drops parts with - // committedMessageID != 0 because the subscriber will receive - // the durable message through a different channel (REST snapshot, - // initial DB query in SubscribeAuthorized, or pubsub). - committedMessageID int64 -} - -type chatStreamState struct { - mu sync.Mutex - buffer []bufferedStreamPart - buffering bool - durableMessages []codersdk.ChatStreamEvent - durableEvictedBefore int64 // highest message ID evicted from durable cache - subscribers map[uuid.UUID]chan codersdk.ChatStreamEvent - bufferDropCount int64 - bufferLastWarnAt time.Time - subscriberDropCount int64 - subscriberLastWarnAt time.Time - // currentRetry records the current retry phase for late-joining - // same-replica subscribers. Nil when the stream is not waiting - // to retry. - currentRetry *codersdk.ChatStreamRetry - // bufferRetainedAt records when processing completed and - // the per-chat stream state entered the post-completion - // grace window. Zero while buffering is active. When - // non-zero, cleanupStreamIfIdle skips GC until the grace - // period expires so cross-replica relay subscribers can - // register without racing state deletion. The buffer - // itself does not deliver content here: every part is - // claimed by a durable assistant message before - // bufferRetainedAt is set, so snapshotBufferLocked - // returns no parts during the grace window. - bufferRetainedAt time.Time -} - -// streamStateCollector exposes scrape-time gauges derived from -// p.chatStreams. Scrape cost is O(n) with a brief per-state mutex -// held for two len() reads; acceptable at typical scrape cadences. -type streamStateCollector struct { - server *Server -} - -var ( - streamsActiveDesc = prometheus.NewDesc( - "coderd_chatd_streams_active", - "Current number of chat stream state entries (in-flight plus retained).", - nil, nil, - ) - streamBufferSizeMaxDesc = prometheus.NewDesc( - "coderd_chatd_stream_buffer_size_max", - "Maximum current buffer length across all chat streams.", - nil, nil, - ) - streamBufferEventsDesc = prometheus.NewDesc( - "coderd_chatd_stream_buffer_events", - "Sum of current buffer lengths across all chat streams.", - nil, nil, - ) - streamSubscribersDesc = prometheus.NewDesc( - "coderd_chatd_stream_subscribers", - "Current number of chat stream subscribers across all chat streams.", - nil, nil, - ) -) - -func (*streamStateCollector) Describe(ch chan<- *prometheus.Desc) { - ch <- streamsActiveDesc - ch <- streamBufferSizeMaxDesc - ch <- streamBufferEventsDesc - ch <- streamSubscribersDesc -} - -func (c *streamStateCollector) Collect(ch chan<- prometheus.Metric) { - var active, totalEvents, maxBufLen, totalSubs int - c.server.chatStreams.Range(func(_, v any) bool { - state, ok := v.(*chatStreamState) - if !ok { - return true - } - active++ - state.mu.Lock() - bufLen := len(state.buffer) - subs := len(state.subscribers) - state.mu.Unlock() - totalEvents += bufLen - totalSubs += subs - maxBufLen = max(maxBufLen, bufLen) - return true - }) - ch <- prometheus.MustNewConstMetric(streamsActiveDesc, prometheus.GaugeValue, float64(active)) - ch <- prometheus.MustNewConstMetric(streamBufferSizeMaxDesc, prometheus.GaugeValue, float64(maxBufLen)) - ch <- prometheus.MustNewConstMetric(streamBufferEventsDesc, prometheus.GaugeValue, float64(totalEvents)) - ch <- prometheus.MustNewConstMetric(streamSubscribersDesc, prometheus.GaugeValue, float64(totalSubs)) -} - var ( // ErrInvalidModelConfigID indicates the requested model config does not exist. ErrInvalidModelConfigID = xerrors.New("invalid model config ID") @@ -3533,10 +3357,12 @@ func BuildSingleUserChatMessageInsertParams( // Config configures a chat processor. type Config struct { - Logger slog.Logger - Database database.Store - ReplicaID uuid.UUID - SubscribeFn SubscribeFn + Logger slog.Logger + Database database.Store + ReplicaID uuid.UUID + // StreamPartsDialer dials remote stream parts. Nil uses the local + // in-process channel dialer for every stream. + StreamPartsDialer StreamPartsDialer PendingChatAcquireInterval time.Duration MaxChatsPerAcquire int32 InFlightChatStaleAfter time.Duration @@ -3625,13 +3451,11 @@ func New(cfg Config) *Server { if cfg.AllowBYOKSet { allowBYOK = cfg.AllowBYOK } - p := &Server{ cancel: cancel, db: cfg.Database, workerID: workerID, logger: cfg.Logger.Named("processor"), - subscribeFn: cfg.SubscribeFn, agentConnFn: cfg.AgentConn, agentInactiveDisconnectTimeout: cfg.AgentInactiveDisconnectTimeout, dialTimeout: defaultDialTimeout, @@ -3671,7 +3495,6 @@ func New(cfg Config) *Server { var chatAutoArchiveRecords prometheus.Counter if cfg.PrometheusRegistry != nil { p.metrics = chatloop.NewMetrics(cfg.PrometheusRegistry) - cfg.PrometheusRegistry.MustRegister(&streamStateCollector{server: p}) chatAutoArchiveRecords = prometheus.NewCounter(prometheus.CounterOpts{ Namespace: "coderd", Subsystem: "chat_auto_archive", @@ -3683,6 +3506,13 @@ func New(cfg Config) *Server { p.metrics = chatloop.NopMetrics() } p.messagePartBuffer = messagepartbuffer.New(messagepartbuffer.Options{Clock: clk}) + localStreamPartsDialer := NewLocalStreamPartsDialer(LocalStreamPartsDialerConfig{ + Buffer: p.messagePartBuffer, + Logger: cfg.Logger, + }) + p.streamPartsDialer = streamPartsDialerForServer(workerID, localStreamPartsDialer, cfg.StreamPartsDialer) + p.streamSyncPoller = newStreamSyncPoller(ctx, cfg.Database, clk, cfg.Logger.Named("chatstream")) + p.streamSyncPoller.Start() chatWorker, err := newChatWorker(p, chatWorkerOptions{ WorkerID: workerID, Store: cfg.Database, @@ -3735,7 +3565,6 @@ func New(cfg Config) *Server { p.ctx = ctx // Spawn background goroutines that all servers need. - p.wg.Go(func() { p.streamJanitorLoop(ctx) }) return p } @@ -3753,225 +3582,6 @@ func (p *Server) Start() *Server { return p } -// getCachedDurableMessages returns cached durable messages with IDs -// greater than afterID. Returns nil when the cache has no relevant -// entries. -func (p *Server) getCachedDurableMessages( - chatID uuid.UUID, - afterID int64, -) []codersdk.ChatStreamEvent { - state := p.getOrCreateStreamState(chatID) - state.mu.Lock() - defer state.mu.Unlock() - - if afterID < state.durableEvictedBefore { - return nil - } - - var result []codersdk.ChatStreamEvent - for _, event := range state.durableMessages { - if event.Message != nil && event.Message.ID > afterID { - result = append(result, event) - } - } - return result -} - -// snapshotBufferLocked returns the buffered message_part events that -// the caller should receive in their initial snapshot. -// -// Parts whose committedMessageID != 0 are dropped: those parts were -// claimed by a durable assistant message that the subscriber will -// receive through a different channel (REST snapshot, the initial DB -// query in SubscribeAuthorized, or pubsub catch-up). Delivering them -// here would render the same content twice on the client, once in the -// streaming UI and once as a durable message. -// -// Every caller receives the same view: in-progress parts are always -// delivered and committed parts are always dropped, regardless of -// cursor or relay sentinel. This keeps the buffer free of duplicate -// work for every subscriber, including cross-replica relay -// subscribers whose user-facing peers receive the durable message -// via pubsub. -// -// The caller must hold the per-chat stream state lock. -func snapshotBufferLocked(buffer []bufferedStreamPart) []codersdk.ChatStreamEvent { - if len(buffer) == 0 { - return nil - } - snapshot := make([]codersdk.ChatStreamEvent, 0, len(buffer)) - for _, part := range buffer { - if part.committedMessageID != 0 { - continue - } - snapshot = append(snapshot, part.event) - } - return snapshot -} - -// subscribeToStream registers a subscriber to the per-chat in-memory -// stream and returns a snapshot of currently in-progress message_part -// events plus the current retry phase, the live subscriber channel, -// and a cancel func. -// -// Parts that were claimed by a committed durable assistant message -// (committedMessageID != 0) are excluded from the snapshot. The -// subscriber will receive those durable messages through the REST -// snapshot, the initial DB query in SubscribeAuthorized, or pubsub, -// so re-delivering their constituent parts here would render the -// same content twice. -func (p *Server) subscribeToStream(chatID uuid.UUID) ( - []codersdk.ChatStreamEvent, - *codersdk.ChatStreamRetry, - <-chan codersdk.ChatStreamEvent, - func(), -) { - state := p.getOrCreateStreamState(chatID) - state.mu.Lock() - snapshot := snapshotBufferLocked(state.buffer) - var currentRetry *codersdk.ChatStreamRetry - if state.currentRetry != nil { - retryCopy := *state.currentRetry - currentRetry = &retryCopy - } - id := uuid.New() - ch := make(chan codersdk.ChatStreamEvent, 128) - state.subscribers[id] = ch - state.mu.Unlock() - - cancel := func() { - state.mu.Lock() - // Remove the subscriber but do not close the channel. - // publishToStream copies subscriber references under - // the per-chat lock then sends outside; closing here - // races with that send and can panic. The channel - // becomes unreachable once removed and will be GC'd. - delete(state.subscribers, id) - p.cleanupStreamIfIdle(chatID, state) - state.mu.Unlock() - } - - return snapshot, currentRetry, ch, cancel -} - -// getOrCreateStreamState returns the per-chat stream state, -// creating one atomically if it doesn't exist. The returned -// state has its own mutex — callers must lock state.mu for -// access. -func (p *Server) getOrCreateStreamState(chatID uuid.UUID) *chatStreamState { - if val, ok := p.chatStreams.Load(chatID); ok { - state, _ := val.(*chatStreamState) - return state - } - val, _ := p.chatStreams.LoadOrStore(chatID, &chatStreamState{ - subscribers: make(map[uuid.UUID]chan codersdk.ChatStreamEvent), - }) - state, _ := val.(*chatStreamState) - return state -} - -// cleanupStreamIfIdle removes the chat entry from the sync.Map when -// there are no subscribers, the stream is not buffering, and any -// grace period for late-connecting relay subscribers has elapsed. If -// the grace window is still open it returns without rescheduling. -// streamJanitorLoop is the backstop that re-checks on a timer. -// -// The caller must hold state.mu. The state pointer may have been -// captured outside this lock (sync.Map.Load or Range); we use -// CompareAndDelete so a stale pointer cannot evict a fresh entry -// installed by a racing getOrCreateStreamState. Returns true -// if the state was deleted, false otherwise. -func (p *Server) cleanupStreamIfIdle(chatID uuid.UUID, state *chatStreamState) bool { - if state.buffering || len(state.subscribers) > 0 { - return false - } - // Keep stream state alive during the grace period so - // late-connecting cross-replica relay subscribers can - // register against this chat before GC. - if !state.bufferRetainedAt.IsZero() && - p.clock.Now().Before(state.bufferRetainedAt.Add(bufferRetainGracePeriod)) { - return false - } - if !p.chatStreams.CompareAndDelete(chatID, state) { - return false - } - p.workspaceMCPToolsCache.Delete(chatID) - return true -} - -// streamJanitorLoop periodically reaps idle chat stream states whose -// grace period has expired. It is the backstop for the grace-window -// early-return in cleanupStreamIfIdle; without it, a subscriber that -// detaches inside grace (the common enterprise relay-drain case, -// relayDrainTimeout = 200ms vs. 5s grace) pins the state forever. -func (p *Server) streamJanitorLoop(ctx context.Context) { - ticker := p.clock.NewTicker(streamJanitorInterval, "chatd", "stream-janitor") - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - p.safeSweepIdleStreams(ctx) - } - } -} - -// safeSweepIdleStreams runs sweepIdleStreams under a panic recovery -// so an unexpected panic in the sweep cannot kill the janitor -// goroutine and silently reintroduce the very leak it exists to -// prevent. The next tick retries. -func (p *Server) safeSweepIdleStreams(ctx context.Context) { - defer func() { - if r := recover(); r != nil { - p.logger.Error(ctx, "stream janitor sweep panicked, will retry next tick", - slog.F("panic", r)) - } - }() - p.sweepIdleStreams() -} - -// sweepIdleStreams iterates chatStreams once and delegates each entry -// to cleanupStreamIfIdle. Range may skip entries that become reapable -// concurrently. Any such entry is reaped on the next tick. -func (p *Server) sweepIdleStreams() { - var reaped atomic.Int64 - defer func() { - if count := reaped.Load(); count > 0 { - p.logger.Info(context.Background(), "reaped idle chat streams", slog.F("count", count)) - } - }() - p.chatStreams.Range(func(key, value any) bool { - chatID, ok := key.(uuid.UUID) - if !ok { - return true - } - state, ok := value.(*chatStreamState) - if !ok { - return true - } - // guard against any panic from cleanupStreamIfIdle locking state.mu for all time - func() { - state.mu.Lock() - defer state.mu.Unlock() - if p.cleanupStreamIfIdle(chatID, state) { - reaped.Add(1) - } - }() - return true - }) -} - -// streamSubscriberControlFetchContext keeps a control-path lookup tied to the -// requesting subscriber while applying a fallback timeout when the caller has -// no deadline. -func streamSubscriberControlFetchContext(ctx context.Context) (context.Context, context.CancelFunc) { - if _, ok := ctx.Deadline(); ok { - return ctx, func() {} - } - return context.WithTimeout(ctx, chatStreamControlFetchTimeout) -} - func subscribeWithInitialError(chatID uuid.UUID, message string) ( []codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, @@ -3987,535 +3597,6 @@ func subscribeWithInitialError(chatID uuid.UUID, message string) ( }}, events, func() {}, true } -func (p *Server) Subscribe( - ctx context.Context, - chatID uuid.UUID, - requestHeader http.Header, - afterMessageID int64, -) ( - []codersdk.ChatStreamEvent, - <-chan codersdk.ChatStreamEvent, - func(), - bool, -) { - if p == nil { - return nil, nil, nil, false - } - - chat, err := p.db.GetChatByID(ctx, chatID) - if err != nil { - if dbauthz.IsNotAuthorizedError(err) { - return nil, nil, nil, false - } - p.logger.Warn(ctx, "failed to load chat for stream subscription", - slog.F("chat_id", chatID), - slog.Error(err), - ) - return subscribeWithInitialError(chatID, "failed to load initial snapshot") - } - return p.SubscribeAuthorized(ctx, chat, requestHeader, afterMessageID) -} - -// SubscribeAuthorized subscribes an already-authorized chat to merged stream -// updates. The passed chat row proves authorization, but SubscribeAuthorized -// still reloads the chat after the stream subscriptions are armed so the -// initial status and relay setup use fresh state. -func (p *Server) SubscribeAuthorized( - ctx context.Context, - chat database.Chat, - requestHeader http.Header, - afterMessageID int64, -) ( - []codersdk.ChatStreamEvent, - <-chan codersdk.ChatStreamEvent, - func(), - bool, -) { - if p == nil { - return nil, nil, nil, false - } - chatID := chat.ID - - // Subscribe to the local stream for message_parts and same-replica - // persisted messages. Capture the current retry phase under the same - // lock so the transient snapshot and subscriber registration reflect - // a single moment in time. - localSnapshot, localRetry, localParts, localCancel := p.subscribeToStream(chatID) - - // Merge all event sources. - mergedCtx, mergedCancel := context.WithCancel(ctx) - mergedEvents := make(chan codersdk.ChatStreamEvent, 128) - - var allCancels []func() - allCancels = append(allCancels, localCancel) - - // Subscribe to pubsub for durable and structured control - // events (status, messages, queue updates, retry, errors). - // If the subscription cannot be established, deliver all local - // events. - // - // This MUST happen before the DB queries below so that any - // notification published between the query and the subscription - // is not lost (subscribe-first-then-query pattern). - notifyCh := make(chan coderdpubsub.ChatStreamNotifyMessage, 10) - errCh := make(chan error, 1) - listener := func(_ context.Context, message []byte, listenErr error) { - if listenErr != nil { - select { - case <-mergedCtx.Done(): - case errCh <- listenErr: - } - return - } - var notify coderdpubsub.ChatStreamNotifyMessage - if unmarshalErr := json.Unmarshal(message, ¬ify); unmarshalErr != nil { - select { - case <-mergedCtx.Done(): - case errCh <- xerrors.Errorf("unmarshal chat stream notify: %w", unmarshalErr): - } - return - } - select { - case <-mergedCtx.Done(): - case notifyCh <- notify: - } - } - - if pubsubCancel, pubsubErr := p.pubsub.SubscribeWithErr( - coderdpubsub.ChatStreamNotifyChannel(chatID), - listener, - ); pubsubErr == nil { - allCancels = append(allCancels, pubsubCancel) - } else { - p.logger.Warn(ctx, "failed to subscribe to chat stream notifications", - slog.F("chat_id", chatID), - slog.Error(pubsubErr), - ) - } - - cancel := func() { - mergedCancel() - for _, cancelFn := range allCancels { - if cancelFn != nil { - cancelFn() - } - } - } - - // Re-read the chat after the local/pubsub subscriptions are active so - // the initial status event and any enterprise relay setup use fresh - // state instead of the middleware-loaded row. - refreshCtx, refreshCancel := streamSubscriberControlFetchContext(ctx) - snapshotChat, err := func() (database.Chat, error) { - defer refreshCancel() - //nolint:gocritic // SubscribeAuthorized already validated the - // caller; this refresh only loads the latest status/worker for - // the already-authorized stream subscription. - return p.db.GetChatByID(dbauthz.AsChatd(refreshCtx), chatID) - }() - if err != nil { - p.logger.Warn(ctx, "failed to refresh chat for stream subscription; using stale state", - slog.F("chat_id", chatID), - slog.Error(err), - ) - snapshotChat = chat - } - - // Build initial snapshot synchronously. The pubsub subscription - // is already active so no notifications can be lost during this - // window. - initialSnapshot := make([]codersdk.ChatStreamEvent, 0) - delivered := map[int64]struct{}{} - // Add local same-replica message_parts to the snapshot. Retry comes - // from state.currentRetry, not the event buffer, so late joiners see - // only the latest phase rather than a stale buffered retry event. - for _, event := range localSnapshot { - if event.Type == codersdk.ChatStreamEventTypeMessagePart { - initialSnapshot = append(initialSnapshot, event) - } - } - - var retryEvent *codersdk.ChatStreamEvent - if localRetry != nil { - retryEvent = &codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeRetry, - ChatID: chatID, - Retry: localRetry, - } - } - - // Load initial messages from DB. When afterMessageID > 0 the - // caller already has messages up to that ID (e.g. from the REST - // endpoint), so we only fetch newer ones to avoid sending - // duplicate data. - messages, err := p.db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ - ChatID: chatID, - AfterID: afterMessageID, - }) - if err != nil { - p.logger.Error(ctx, "failed to load initial chat messages", - slog.Error(err), - slog.F("chat_id", chatID), - ) - initialSnapshot = append(initialSnapshot, codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeError, - ChatID: chatID, - Error: &codersdk.ChatError{Message: "failed to load initial snapshot"}, - }) - } else { - for _, msg := range messages { - sdkMsg := db2sdk.ChatMessage(msg) - initialSnapshot = append(initialSnapshot, codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessage, - ChatID: chatID, - Message: &sdkMsg, - }) - delivered[msg.ID] = struct{}{} - } - } - - // Load initial queue. Queue snapshots are intentionally not - // singleflighted because a chat-scoped key cannot distinguish the - // pre- and post-notification queue state. - queueCtx, queueCancel := streamSubscriberControlFetchContext(ctx) - queued, err := p.db.GetChatQueuedMessages(queueCtx, chatID) - queueCancel() - if err != nil { - p.logger.Error(ctx, "failed to load initial queued messages", - slog.Error(err), - slog.F("chat_id", chatID), - ) - initialSnapshot = append(initialSnapshot, codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeError, - ChatID: chatID, - Error: &codersdk.ChatError{Message: "failed to load initial snapshot"}, - }) - } else if len(queued) > 0 { - initialSnapshot = append(initialSnapshot, codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeQueueUpdate, - ChatID: chatID, - QueuedMessages: db2sdk.ChatQueuedMessages(queued), - }) - } - - // Include the current chat status in the snapshot so the - // frontend can gate message_part processing correctly from - // the very first batch, without waiting for a separate REST - // query. - statusEvent := codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeStatus, - ChatID: chatID, - Status: &codersdk.ChatStreamStatus{ - Status: codersdk.ChatStatus(snapshotChat.Status), - }, - } - // Prepend so the frontend sees the current stream phases - // before any message_part events. - prefix := []codersdk.ChatStreamEvent{statusEvent} - if retryEvent != nil { - prefix = append(prefix, *retryEvent) - } - initialSnapshot = append(prefix, initialSnapshot...) - - // Track the highest durable message ID delivered to this subscriber, - // whether it came from the initial DB snapshot, the same-replica local - // stream, or a later DB/cache catch-up. - lastMessageID := afterMessageID - if len(messages) > 0 { - lastMessageID = messages[len(messages)-1].ID - } - - // When an enterprise SubscribeFn is provided, call it to get relay events - // (message_parts from remote replicas). OSS owns pubsub subscription, - // message catch-up, queue updates, and status forwarding; enterprise only - // manages relay dialing. - var relayEvents <-chan codersdk.ChatStreamEvent - var statusNotifications chan StatusNotification - if p.subscribeFn != nil { - statusNotifications = make(chan StatusNotification, 10) - relayEvents = p.subscribeFn(mergedCtx, SubscribeFnParams{ - ChatID: chatID, - Chat: snapshotChat, - WorkerID: p.workerID, - StatusNotifications: statusNotifications, - RequestHeader: requestHeader, - DB: p.db, - Logger: p.logger, - }) - } - // hasPubsubSubscription is only true when we actually subscribed - // successfully above (allCancels will contain the pubsub - // cancel func in that case). - hasPubsubSubscription := len(allCancels) > 1 - - //nolint:nestif - go func() { - defer close(mergedEvents) - if statusNotifications != nil { - defer close(statusNotifications) - } - for { - select { - case <-mergedCtx.Done(): - return - case psErr := <-errCh: - p.logger.Error(mergedCtx, "chat stream pubsub error", - slog.F("chat_id", chatID), - slog.Error(psErr), - ) - select { - case mergedEvents <- codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeError, - ChatID: chatID, - Error: &codersdk.ChatError{ - Message: psErr.Error(), - }, - }: - case <-mergedCtx.Done(): - } - return - case notify := <-notifyCh: - // Marker for ENG-2645: subscriber received pubsub notify. - p.logger.Debug(mergedCtx, "stream subscriber received notify", - slog.F("chat_id", chatID), - slog.F("after_message_id", notify.AfterMessageID), - slog.F("status", notify.Status), - slog.F("queue_update", notify.QueueUpdate), - slog.F("last_message_id", lastMessageID), - ) - if notify.AfterMessageID > 0 || notify.FullRefresh { - if notify.FullRefresh { - lastMessageID = 0 - clear(delivered) - } - var ( - deliveredCount int - source string - ) - // Notifies can arrive out of order. Rescan from - // min(AfterMessageID, lastMessageID) to cover the gap, - // floored at afterMessageID to respect the subscription - // boundary. The delivered set deduplicates. - lookupAfter := lastMessageID - if !notify.FullRefresh { - lookupAfter = max(afterMessageID, min(notify.AfterMessageID, lastMessageID)) - } - cached := p.getCachedDurableMessages(chatID, lookupAfter) - if !notify.FullRefresh && len(cached) > 0 { - for _, event := range cached { - if event.Message == nil { - continue - } - if _, ok := delivered[event.Message.ID]; ok { - continue - } - select { - case <-mergedCtx.Done(): - return - case mergedEvents <- event: - } - delivered[event.Message.ID] = struct{}{} - if event.Message.ID > lastMessageID { - lastMessageID = event.Message.ID - } - deliveredCount++ - source = "cache" - } - } - // DB pass picks up cross-replica messages the local cache - // cannot have. Delivered set dedupes against the cache pass. - newMessages, msgErr := p.db.GetChatMessagesByChatID(mergedCtx, database.GetChatMessagesByChatIDParams{ - ChatID: chatID, - AfterID: lookupAfter, - }) - if msgErr != nil { - p.logger.Warn(mergedCtx, "failed to get chat messages after pubsub notification", - slog.F("chat_id", chatID), - slog.Error(msgErr), - ) - } else { - for _, msg := range newMessages { - if msg.ID <= lookupAfter { - continue - } - if _, ok := delivered[msg.ID]; ok { - continue - } - sdkMsg := db2sdk.ChatMessage(msg) - select { - case <-mergedCtx.Done(): - return - case mergedEvents <- codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessage, - ChatID: chatID, - Message: &sdkMsg, - }: - } - delivered[msg.ID] = struct{}{} - if msg.ID > lastMessageID { - lastMessageID = msg.ID - } - deliveredCount++ - switch source { - case "": - source = "db" - case "cache": - source = "cache+db" - } - } - } - // Marker for ENG-2645: subscriber delivered durable messages. - p.logger.Debug(mergedCtx, "stream subscriber delivered messages", - slog.F("chat_id", chatID), - slog.F("after_message_id", notify.AfterMessageID), - slog.F("lookup_after", lookupAfter), - slog.F("source", source), - slog.F("delivered_count", deliveredCount), - slog.F("last_message_id", lastMessageID), - ) - } - if notify.Status != "" { - status := database.ChatStatus(notify.Status) - select { - case <-mergedCtx.Done(): - return - case mergedEvents <- codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeStatus, - ChatID: chatID, - Status: &codersdk.ChatStreamStatus{Status: codersdk.ChatStatus(status)}, - }: - } - // Notify enterprise relay manager if present. - if statusNotifications != nil { - workerID := uuid.Nil - if notify.WorkerID != "" { - if parsed, parseErr := uuid.Parse(notify.WorkerID); parseErr == nil { - workerID = parsed - } - } - select { - case statusNotifications <- StatusNotification{Status: status, WorkerID: workerID}: - case <-mergedCtx.Done(): - return - } - } - } - if notify.Retry != nil { - select { - case <-mergedCtx.Done(): - return - case mergedEvents <- codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeRetry, - ChatID: chatID, - Retry: notify.Retry, - }: - } - } - if notify.ErrorPayload != nil { - select { - case <-mergedCtx.Done(): - return - case mergedEvents <- codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeError, - ChatID: chatID, - Error: notify.ErrorPayload, - }: - } - } else if notify.Error != "" { - select { - case <-mergedCtx.Done(): - return - case mergedEvents <- codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeError, - ChatID: chatID, - Error: &codersdk.ChatError{ - Message: notify.Error, - }, - }: - } - } - if notify.QueueUpdate { - queueCtx, queueCancel := streamSubscriberControlFetchContext(mergedCtx) - queuedMsgs, queueErr := p.db.GetChatQueuedMessages(queueCtx, chatID) - queueCancel() - if queueErr != nil { - p.logger.Warn(mergedCtx, "failed to get queued messages after pubsub notification", - slog.F("chat_id", chatID), - slog.Error(queueErr), - ) - } else { - select { - case <-mergedCtx.Done(): - return - case mergedEvents <- codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeQueueUpdate, - ChatID: chatID, - QueuedMessages: db2sdk.ChatQueuedMessages(queuedMsgs), - }: - } - } - } - case event, ok := <-localParts: - if !ok { - localParts = nil - // Local parts channel closed. If pubsub is - // active we continue with pubsub-driven events. - // Otherwise terminate. - if !hasPubsubSubscription { - return - } - continue - } - if hasPubsubSubscription { - // Forward transient events from local. - // Durable events (messages, queue updates) - // come via pubsub + cache. Status is - // included alongside message_part because - // both travel through the same ordered - // channel: publishStatus is called before - // the first message_part, so FIFO delivery - // guarantees the frontend sees - // status=running before any content. - // Pubsub will deliver a duplicate status - // later; the frontend deduplicates it - // (setChatStatus is idempotent). - // action_required is also transient and - // only published on the local stream, so - // it must be forwarded here. - if event.Type == codersdk.ChatStreamEventTypeMessagePart || - event.Type == codersdk.ChatStreamEventTypeStatus || - event.Type == codersdk.ChatStreamEventTypeActionRequired { - select { - case <-mergedCtx.Done(): - return - case mergedEvents <- event: - } - } - } else { - // No pubsub subscription: forward all event types. - select { - case <-mergedCtx.Done(): - return - case mergedEvents <- event: - } - } - case event, ok := <-relayEvents: - if !ok { - relayEvents = nil - continue - } - select { - case <-mergedCtx.Done(): - return - case mergedEvents <- event: - } - } - } - }() - - return initialSnapshot, mergedEvents, cancel, true -} - // publishChatPubsubEvents broadcasts a lifecycle event for each affected chat. func (p *Server) publishChatPubsubEvents(chats []database.Chat, kind codersdk.ChatWatchEventKind) { for _, chat := range chats { @@ -4666,7 +3747,7 @@ func (p *Server) trackWorkspaceUsage( // so no prebuild guard is needed (unlike reporter.go). // // This fires every heartbeat (~30s) but the SQL only - // writes when 5% of the deadline has elapsed — most calls + // writes when 5% of the deadline has elapsed, most calls // perform a read-only CTE lookup with no UPDATE. // // Scaling note: for 10,000 active chats, this could lead to @@ -5843,7 +4924,7 @@ func (p *Server) fetchWorkspaceContext( // Stamp server-side fields and sanitize content. The // agent cannot know its own UUID, OS metadata, or - // directory — those are added here at the trust boundary. + // directory, those are added here at the trust boundary. agentID := uuid.NullUUID{UUID: loadedAgent.ID, Valid: true} for i := range agentParts { @@ -5973,7 +5054,7 @@ func (p *Server) persistInstructionFiles( // updateLastInjectedContext persists the injected context // parts (AGENTS.md files and skills) on the chat row so they // are directly queryable without scanning messages. This is -// best-effort — a failure here is logged but does not block +// best-effort, a failure here is logged but does not block // the turn. func (p *Server) updateLastInjectedContext(ctx context.Context, chatID uuid.UUID, parts []codersdk.ChatMessagePart) { param := pqtype.NullRawMessage{Valid: false} @@ -6442,6 +5523,9 @@ func (p *Server) Close() error { p.logger.Warn(context.Background(), "failed to close chat worker", slog.Error(err)) } } + if p.streamSyncPoller != nil { + p.streamSyncPoller.Close() + } if p.messagePartBuffer != nil { p.messagePartBuffer.Close() } diff --git a/coderd/x/chatd/chatd_internal_test.go b/coderd/x/chatd/chatd_internal_test.go index 86ccc09a9f776..a87621adde14d 100644 --- a/coderd/x/chatd/chatd_internal_test.go +++ b/coderd/x/chatd/chatd_internal_test.go @@ -2125,54 +2125,6 @@ func TestTurnWorkspaceContext_EnsureWorkspaceAgentIgnoresCachedAgentForDifferent require.Equal(t, updatedChat, currentChat) } -func TestSubscribeAuthorizedFallsBackToStaleRowWhenRefreshFails(t *testing.T) { - t.Parallel() - - ctx := testutil.Context(t, testutil.WaitShort) - ctrl := gomock.NewController(t) - db := dbmock.NewMockStore(ctrl) - server := newSubscribeTestServer(t, db) - - chatID := uuid.New() - staleChat := database.Chat{ID: chatID, Status: database.ChatStatusPending} - - state := server.getOrCreateStreamState(chatID) - state.mu.Lock() - state.buffer = []bufferedStreamPart{{ - event: codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessagePart, - ChatID: chatID, - MessagePart: &codersdk.ChatStreamMessagePart{ - Role: "assistant", - Part: codersdk.ChatMessageText("thinking"), - }, - }, - }} - state.mu.Unlock() - - gomock.InOrder( - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(database.Chat{}, xerrors.New("refresh failed")), - db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ - ChatID: chatID, - AfterID: 0, - }).Return(nil, nil), - db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil), - ) - - initialSnapshot, events, cancel, ok := server.SubscribeAuthorized(ctx, staleChat, nil, 0) - require.True(t, ok) - defer cancel() - - require.Len(t, initialSnapshot, 2) - require.Equal(t, codersdk.ChatStreamEventTypeStatus, initialSnapshot[0].Type) - require.NotNil(t, initialSnapshot[0].Status) - require.Equal(t, codersdk.ChatStatusPending, initialSnapshot[0].Status.Status) - require.Equal(t, codersdk.ChatStreamEventTypeMessagePart, initialSnapshot[1].Type) - require.NotNil(t, initialSnapshot[1].MessagePart) - require.Equal(t, "thinking", initialSnapshot[1].MessagePart.Part.Text) - requireNoStreamEvent(t, events, 200*time.Millisecond) -} - func TestSubscribeRejectsUnauthorizedCallerBeforeSharedFetches(t *testing.T) { t.Parallel() @@ -2190,9 +2142,6 @@ func TestSubscribeRejectsUnauthorizedCallerBeforeSharedFetches(t *testing.T) { require.Nil(t, snapshot) require.Nil(t, events) require.Nil(t, cancel) - - _, exists := server.chatStreams.Load(chatID) - require.False(t, exists) } func TestSubscribeSurfacesTransientLookupFailureAsInitialError(t *testing.T) { @@ -2217,9 +2166,6 @@ func TestSubscribeSurfacesTransientLookupFailureAsInitialError(t *testing.T) { _, open := <-events require.False(t, open) - - _, exists := server.chatStreams.Load(chatID) - require.False(t, exists) } func newSubscribeTestServer(t *testing.T, db database.Store) *Server { @@ -2232,19 +2178,6 @@ func newSubscribeTestServer(t *testing.T, db database.Store) *Server { } } -func requireNoStreamEvent(t *testing.T, events <-chan codersdk.ChatStreamEvent, wait time.Duration) { - t.Helper() - - select { - case event, ok := <-events: - if !ok { - t.Fatal("chat stream closed unexpectedly") - } - t.Fatalf("unexpected chat stream event: %+v", event) - case <-time.After(wait): - } -} - func TestResolveUserCompactionThreshold(t *testing.T) { t.Parallel() @@ -3083,266 +3016,6 @@ func TestSkillsFromPartsUsesLatestContextAgent(t *testing.T) { }}, got) } -// TestSubscribeCancelDuringGrace_ReapedBySweep verifies that a -// subscriber detach inside bufferRetainGracePeriod (the OSS trigger -// for the retained-buffer leak) leaves the state mapped, and the -// next sweep past the grace window reaps it. -func TestSubscribeCancelDuringGrace_ReapedBySweep(t *testing.T) { - t.Parallel() - - logger := slogtest.Make(t, nil) - mClock := quartz.NewMock(t) - - server := &Server{ - logger: logger, - clock: mClock, - } - - chatID := uuid.New() - start := mClock.Now() - - // Just-finished chat: processing done, buffer retained for - // late-connecting relay subscribers. - state := &chatStreamState{ - buffering: false, - bufferRetainedAt: start, - subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{}, - buffer: []bufferedStreamPart{{ - event: codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessagePart, - MessagePart: &codersdk.ChatStreamMessagePart{ - Role: codersdk.ChatMessageRoleAssistant, - }, - }, - }}, - } - server.chatStreams.Store(chatID, state) - - // Real subscribeToStream cancel path: the WS subscriber detach - // that leaks in prod. - snapshot, currentRetry, events, cancelSub := server.subscribeToStream(chatID) - require.Len(t, snapshot, 1) - require.Nil(t, currentRetry) - require.NotNil(t, events) - - mClock.Advance(bufferRetainGracePeriod / 2) - cancelSub() - - _, ok := server.chatStreams.Load(chatID) - require.True(t, ok, - "entry should remain during grace window after subscriber detach") - - mClock.Advance(bufferRetainGracePeriod) - server.sweepIdleStreams() - - _, ok = server.chatStreams.Load(chatID) - require.False(t, ok, - "entry should be reaped after grace period expires and sweep runs") -} - -// TestSweepIdleStreams_ReapsStaleRetainedBuffer: grace expired, no -// subscribers, not buffering -> reaped. -func TestSweepIdleStreams_ReapsStaleRetainedBuffer(t *testing.T) { - t.Parallel() - - mClock := quartz.NewMock(t) - server := &Server{ - logger: slogtest.Make(t, nil), - clock: mClock, - } - - chatID := uuid.New() - state := &chatStreamState{ - buffering: false, - bufferRetainedAt: mClock.Now(), - subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{}, - buffer: []bufferedStreamPart{{ - event: codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessagePart, - MessagePart: &codersdk.ChatStreamMessagePart{}, - }, - }}, - } - server.chatStreams.Store(chatID, state) - - mClock.Advance(bufferRetainGracePeriod + time.Second) - server.sweepIdleStreams() - - _, ok := server.chatStreams.Load(chatID) - require.False(t, ok, "stale retained state should be reaped") -} - -// TestSweepIdleStreams_DoesNotReapActiveBuffering: buffering=true -// blocks reap even long after any grace would have expired. -func TestSweepIdleStreams_DoesNotReapActiveBuffering(t *testing.T) { - t.Parallel() - - mClock := quartz.NewMock(t) - server := &Server{ - logger: slogtest.Make(t, nil), - clock: mClock, - } - - chatID := uuid.New() - state := &chatStreamState{ - buffering: true, - subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{}, - buffer: []bufferedStreamPart{{ - event: codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessagePart, - MessagePart: &codersdk.ChatStreamMessagePart{}, - }, - }}, - } - server.chatStreams.Store(chatID, state) - - mClock.Advance(time.Hour) - server.sweepIdleStreams() - - _, ok := server.chatStreams.Load(chatID) - require.True(t, ok, "actively-buffering state must not be reaped") -} - -// TestSweepIdleStreams_DoesNotReapWithSubscribers: attached -// subscribers block reap even when grace has expired. -func TestSweepIdleStreams_DoesNotReapWithSubscribers(t *testing.T) { - t.Parallel() - - mClock := quartz.NewMock(t) - server := &Server{ - logger: slogtest.Make(t, nil), - clock: mClock, - } - - chatID := uuid.New() - state := &chatStreamState{ - buffering: false, - bufferRetainedAt: mClock.Now(), - subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{ - uuid.New(): make(chan codersdk.ChatStreamEvent, 1), - }, - buffer: []bufferedStreamPart{{ - event: codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessagePart, - MessagePart: &codersdk.ChatStreamMessagePart{}, - }, - }}, - } - server.chatStreams.Store(chatID, state) - - mClock.Advance(bufferRetainGracePeriod + time.Second) - server.sweepIdleStreams() - - _, ok := server.chatStreams.Load(chatID) - require.True(t, ok, "state with subscribers must not be reaped") -} - -// TestSweepIdleStreams_DefersDuringGracePeriod: sweep inside grace -// is a no-op; the next sweep past grace reaps. -func TestSweepIdleStreams_DefersDuringGracePeriod(t *testing.T) { - t.Parallel() - - mClock := quartz.NewMock(t) - server := &Server{ - logger: slogtest.Make(t, nil), - clock: mClock, - } - - chatID := uuid.New() - start := mClock.Now() - state := &chatStreamState{ - buffering: false, - bufferRetainedAt: start, - subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{}, - buffer: []bufferedStreamPart{{ - event: codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessagePart, - MessagePart: &codersdk.ChatStreamMessagePart{}, - }, - }}, - } - server.chatStreams.Store(chatID, state) - - mClock.Advance(bufferRetainGracePeriod / 2) - server.sweepIdleStreams() - - _, ok := server.chatStreams.Load(chatID) - require.True(t, ok, "sweep inside grace window must not reap") - - mClock.Advance(bufferRetainGracePeriod) - server.sweepIdleStreams() - - _, ok = server.chatStreams.Load(chatID) - require.False(t, ok, "sweep after grace window must reap") -} - -func TestCleanupStreamIfIdle_StalePointerDoesNotDeleteFreshEntry(t *testing.T) { - t.Parallel() - - mClock := quartz.NewMock(t) - server := &Server{ - logger: slogtest.Make(t, nil), - clock: mClock, - } - - chatID := uuid.New() - - // Stale pointer: reapable (not buffering, no subscribers, grace - // expired) but no longer the map's live entry. - stale := &chatStreamState{ - buffering: false, - bufferRetainedAt: mClock.Now(), - subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{}, - } - - // Fresh entry: the state getOrCreateStreamState would install - // after a racing processChat run. Actively buffering, so not - // reapable. Only this state is in the map. - fresh := &chatStreamState{ - buffering: true, - subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{}, - } - server.chatStreams.Store(chatID, fresh) - - mClock.Advance(bufferRetainGracePeriod + time.Second) - - // Stale caller mirrors the janitor Range callback after the map - // entry has already been replaced. - stale.mu.Lock() - server.cleanupStreamIfIdle(chatID, stale) - stale.mu.Unlock() - - got, ok := server.chatStreams.Load(chatID) - require.True(t, ok, - "fresh entry must remain mapped when cleanup is called with a stale pointer") - require.Same(t, fresh, got, - "cleanup must not replace the fresh entry with the stale one") -} - -// TestSafeSweepIdleStreams_RecoversFromPanic verifies that an -// unexpected panic inside sweepIdleStreams is recovered rather than -// killing the janitor goroutine. Without this guard, a panic would -// silently reintroduce the very leak the janitor exists to prevent. -func TestSafeSweepIdleStreams_RecoversFromPanic(t *testing.T) { - t.Parallel() - - server := &Server{ - logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}), - clock: quartz.NewMock(t), - } - - chatID := uuid.New() - // A nil *chatStreamState passes the type assertion in sweepIdleStreams - // but panics on state.mu.Lock with a nil-pointer deref. Any future - // panic source in the sweep would trigger the same recovery path. - var nilState *chatStreamState - server.chatStreams.Store(chatID, nilState) - - require.NotPanics(t, func() { - server.safeSweepIdleStreams(context.Background()) - }, "safeSweepIdleStreams must recover panics so the janitor loop keeps running") -} - func TestGetWorkspaceConn_StaleAgentRecovery(t *testing.T) { // Regression test: when a workspace is rebuilt, the chat's stored // agent ID points to a disconnected agent from the old build. The @@ -4422,111 +4095,6 @@ func TestGetWorkspaceConn_DialErrorNotMisclassifiedAsTimeout(t *testing.T) { require.ErrorContains(t, err, "authentication failed") } -// makeInProgressPart is a small constructor for buffered message_part -// fixtures used by snapshotBufferLocked / subscribeToStream tests. It -// builds an in-progress part (committedMessageID == 0) with a -// recognizable text body so failing assertions can identify which -// part survived the filter. -func makeInProgressPart(text string) bufferedStreamPart { - return bufferedStreamPart{ - event: codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessagePart, - MessagePart: &codersdk.ChatStreamMessagePart{ - Role: codersdk.ChatMessageRoleAssistant, - Part: codersdk.ChatMessageText(text), - }, - }, - } -} - -// makeCommittedPart builds a part already claimed by the given -// durable assistant message ID. -func makeCommittedPart(committedID int64, text string) bufferedStreamPart { - p := makeInProgressPart(text) - p.committedMessageID = committedID - return p -} - -func partText(event codersdk.ChatStreamEvent) string { - if event.MessagePart == nil { - return "" - } - return event.MessagePart.Part.Text -} - -// TestSnapshotBufferLocked_DropsCommittedParts asserts the core -// dedup contract: parts that were claimed by a durable assistant -// message (committedMessageID != 0) are dropped from the snapshot -// because the subscriber will receive that durable message through -// the REST snapshot, the initial DB query, or pubsub. -func TestSnapshotBufferLocked_DropsCommittedParts(t *testing.T) { - t.Parallel() - - buffer := []bufferedStreamPart{ - makeCommittedPart(100, "turnA-1"), - makeCommittedPart(100, "turnA-2"), - makeCommittedPart(200, "turnB-1"), - makeInProgressPart("in-progress-1"), - makeInProgressPart("in-progress-2"), - } - - snapshot := snapshotBufferLocked(buffer) - - require.Len(t, snapshot, 2, - "only in-progress (committedMessageID == 0) parts should be kept") - require.Equal(t, "in-progress-1", partText(snapshot[0])) - require.Equal(t, "in-progress-2", partText(snapshot[1])) -} - -// TestSnapshotBufferLocked_AllInProgressReturnsAll covers the -// fresh-load convention: when no assistant message has committed -// yet, every buffered part is in-progress and must be delivered. -func TestSnapshotBufferLocked_AllInProgressReturnsAll(t *testing.T) { - t.Parallel() - - buffer := []bufferedStreamPart{ - makeInProgressPart("a"), - makeInProgressPart("b"), - makeInProgressPart("c"), - } - - snapshot := snapshotBufferLocked(buffer) - - require.Len(t, snapshot, 3, - "all in-progress parts must be delivered to the subscriber") - require.Equal(t, "a", partText(snapshot[0])) - require.Equal(t, "b", partText(snapshot[1])) - require.Equal(t, "c", partText(snapshot[2])) -} - -// TestSnapshotBufferLocked_EmptyBufferReturnsNil documents that -// snapshotBufferLocked returns nil (not an empty slice) for an -// empty buffer, matching the prior append-from-nil behavior. -func TestSnapshotBufferLocked_EmptyBufferReturnsNil(t *testing.T) { - t.Parallel() - - require.Nil(t, snapshotBufferLocked(nil)) - require.Nil(t, snapshotBufferLocked([]bufferedStreamPart{})) -} - -// TestSnapshotBufferLocked_AllCommittedReturnsEmpty covers the -// natural resting point after an assistant turn commits and before -// the next turn starts streaming: every buffered part has been -// claimed and must be filtered out. The snapshot must be empty so -// reconnecting subscribers do not re-render content that is already -// available as a durable message. -func TestSnapshotBufferLocked_AllCommittedReturnsEmpty(t *testing.T) { - t.Parallel() - - buffer := []bufferedStreamPart{ - makeCommittedPart(100, "a"), - makeCommittedPart(100, "b"), - makeCommittedPart(200, "c"), - } - - require.Empty(t, snapshotBufferLocked(buffer)) -} - func TestPrimeWorkspaceMCPCache_SuccessOnFirstAttempt(t *testing.T) { t.Parallel() diff --git a/coderd/x/chatd/chatd_test.go b/coderd/x/chatd/chatd_test.go index ea539b6440cd6..bc3b65d0f8d45 100644 --- a/coderd/x/chatd/chatd_test.go +++ b/coderd/x/chatd/chatd_test.go @@ -29,7 +29,6 @@ import ( "github.com/prometheus/client_golang/prometheus" io_prometheus_client "github.com/prometheus/client_model/go" "github.com/sqlc-dev/pqtype" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" "golang.org/x/xerrors" @@ -3077,8 +3076,15 @@ func TestSubscribeSnapshotIncludesStatusEvent(t *testing.T) { // Passive server: status is always Pending. require.NotEmpty(t, snapshot) - require.Equal(t, codersdk.ChatStreamEventTypeStatus, snapshot[0].Type) - require.NotNil(t, snapshot[0].Status) + statusIdx := -1 + for i, event := range snapshot { + if event.Type == codersdk.ChatStreamEventTypeStatus { + statusIdx = i + break + } + } + require.NotEqual(t, -1, statusIdx) + require.NotNil(t, snapshot[statusIdx].Status) } func TestPersistToolResultWithBinaryData(t *testing.T) { @@ -11600,7 +11606,6 @@ func TestAdvisorGating_RootChat(t *testing.T) { // covers the glue from chatd wiring -> chatadvisor.Tool -> Runtime.Run -> // nested model call -> structured result back to the outer model. func TestAdvisorHappyPath_RootChat(t *testing.T) { - t.Skip("todo: re-enable this test after pr 4 from the chatd refactor is implemented. it depends on subscribe being implemented.") t.Parallel() db, ps := dbtestutil.NewDB(t) @@ -11625,10 +11630,12 @@ func TestAdvisorHappyPath_RootChat(t *testing.T) { switch streamedCallCount.Add(1) { case 1: // Parent turn 1: call advisor solo. - return chattest.OpenAIStreamingResponse(chattest.OpenAIToolCallChunk( + chunk := chattest.OpenAIToolCallChunk( "advisor", `{"question":"how should I approach this refactor?"}`, - )) + ) + chunk.Choices[0].ToolCalls[0].ID = "advisor-happy-path-call" + return chattest.OpenAIStreamingResponse(chunk) case 2: // Nested advisor turn. The nested call has no tools because // chatadvisor.RunAdvisor runs with MaxSteps=1 and no tool @@ -11700,6 +11707,7 @@ func TestAdvisorHappyPath_RootChat(t *testing.T) { if event.MessagePart.Role != codersdk.ChatMessageRoleTool || part.Type != codersdk.ChatMessagePartTypeToolResult || part.ToolName != chatadvisor.ToolName || + part.ToolCallID != "advisor-happy-path-call" || part.ResultDelta == "" { continue } @@ -11766,15 +11774,23 @@ func TestAdvisorHappyPath_RootChat(t *testing.T) { require.True(t, parentSawAdvisorResult, "parent must see the advisor reply in its continuation call") - require.EventuallyWithT(t, func(c *assert.CollectT) { + // Stop the live collector and assert it captured the streaming + // advisor deltas during processing. Late subscribers no longer + // see committed parts because publishMessage claims them out of + // new snapshots, so the assertion must use the live collector. + require.Eventually(t, func() bool { livePartsMu.Lock() defer livePartsMu.Unlock() - assert.Equal(c, advisorDeltas, liveAdvisorDeltas, - "advisor nested text deltas must stream into the parent tool card") - }, testutil.WaitLong, testutil.IntervalFast) - + return slices.Equal(advisorDeltas, liveAdvisorDeltas) + }, testutil.WaitLong, testutil.IntervalFast, + "advisor nested text deltas must stream into the parent tool card") cancelLive() <-liveCollectorDone + livePartsMu.Lock() + collectedAdvisorDeltas := append([]string(nil), liveAdvisorDeltas...) + livePartsMu.Unlock() + require.Equal(t, advisorDeltas, collectedAdvisorDeltas, + "advisor nested text deltas must stream into the parent tool card") persisted, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ ChatID: chat.ID, diff --git a/coderd/x/chatd/options.go b/coderd/x/chatd/options.go index 016eb3a7ffb13..53bb6e075297c 100644 --- a/coderd/x/chatd/options.go +++ b/coderd/x/chatd/options.go @@ -7,9 +7,8 @@ import ( "time" "github.com/google/uuid" - "golang.org/x/xerrors" - "github.com/prometheus/client_golang/prometheus" + "golang.org/x/xerrors" "cdr.dev/slog/v3" "github.com/coder/coder/v2/coderd/audit" diff --git a/coderd/x/chatd/stream_loop.go b/coderd/x/chatd/stream_loop.go new file mode 100644 index 0000000000000..8f3f7b77f7bfc --- /dev/null +++ b/coderd/x/chatd/stream_loop.go @@ -0,0 +1,457 @@ +package chatd + +import ( + "context" + "encoding/json" + "net/http" + + "github.com/google/uuid" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/db2sdk" + coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" + "github.com/coder/coder/v2/coderd/x/chatd/chatstate" + "github.com/coder/coder/v2/codersdk" +) + +type streamLoop struct { + chatID uuid.UUID + db database.Store + logger slog.Logger + state streamLocalState +} + +type streamLocalState struct { + snapshotVersion int64 + historyVersion int64 + queueVersion int64 + retryVersion int64 + + knownMessages map[int64]int64 + + status database.ChatStatus + + errorHistoryVersion int64 + actionRequiredHistoryVersion int64 + + workerID uuid.NullUUID + generationAttempt int64 + lastPartSeq int64 + + afterMessageID int64 + initialMessageSyncDone bool +} + +type streamSyncHint struct { + snapshotVersion int64 + historyVersion int64 + queueVersion int64 + retryVersion int64 + status database.ChatStatus + workerID uuid.NullUUID + generationAttempt int64 +} + +type streamDBSnapshot struct { + chat database.Chat + + historyChanged bool + changedMessages []database.ChatMessage + historyReset bool + fullHistory []database.ChatMessage + + queueChanged bool + queue []database.ChatQueuedMessage + + actionRequired *codersdk.ChatStreamActionRequired +} + +func newStreamLoop(chat database.Chat, db database.Store, logger slog.Logger, afterMessageID int64) *streamLoop { + return &streamLoop{ + chatID: chat.ID, + db: db, + logger: logger, + state: streamLocalState{ + knownMessages: make(map[int64]int64), + afterMessageID: afterMessageID, + }, + } +} + +func streamSyncHintFromUpdate(update coderdpubsub.ChatStateUpdateMessage) streamSyncHint { + hint := streamSyncHint{ + snapshotVersion: update.SnapshotVersion, + historyVersion: update.HistoryVersion, + queueVersion: update.QueueVersion, + retryVersion: update.RetryStateVersion, + status: database.ChatStatus(update.Status), + generationAttempt: update.GenerationAttempt, + } + if update.WorkerID != nil { + hint.workerID = uuid.NullUUID{UUID: *update.WorkerID, Valid: true} + } + return hint +} + +func (l *streamLoop) sync(ctx context.Context, hint streamSyncHint) ([]codersdk.ChatStreamEvent, streamRelayTarget, bool, error) { + if !l.shouldFetch(hint) { + return nil, l.currentRelayTarget(), false, nil + } + return l.syncDB(ctx) +} + +func (l *streamLoop) syncDB(ctx context.Context) ([]codersdk.ChatStreamEvent, streamRelayTarget, bool, error) { + snapshot, err := l.loadDBSnapshot(ctx) + if err != nil { + return nil, l.currentRelayTarget(), false, err + } + if snapshot.chat.SnapshotVersion <= l.state.snapshotVersion { + return nil, l.currentRelayTarget(), false, nil + } + return l.applyDBSnapshot(snapshot), l.currentRelayTarget(), true, nil +} + +func (l *streamLoop) shouldFetch(hint streamSyncHint) bool { + if hint.snapshotVersion <= l.state.snapshotVersion { + return false + } + if hint.historyVersion > l.state.historyVersion { + return true + } + if hint.queueVersion > l.state.queueVersion { + return true + } + if hint.retryVersion > l.state.retryVersion { + return true + } + if hint.status != l.state.status { + return true + } + if !sameNullUUID(hint.workerID, l.state.workerID) { + return true + } + if hint.generationAttempt != l.state.generationAttempt { + return true + } + return false +} + +func (l *streamLoop) loadDBSnapshot(ctx context.Context) (streamDBSnapshot, error) { + var snapshot streamDBSnapshot + machine := chatstate.NewChatMachine(l.db, nil, l.chatID, chatstate.Options{}) + err := machine.ReadLock(ctx, func(tx database.Store) error { + chat, err := tx.GetChatByID(ctx, l.chatID) + if err != nil { + return xerrors.Errorf("get chat for stream: %w", err) + } + snapshot.chat = chat + + if chat.HistoryVersion > l.state.historyVersion { + snapshot.historyChanged = true + snapshot.changedMessages, err = tx.GetChatMessagesByRevisionForStream(ctx, database.GetChatMessagesByRevisionForStreamParams{ + ChatID: l.chatID, + AfterRevision: l.state.historyVersion, + }) + if err != nil { + return xerrors.Errorf("get changed chat messages: %w", err) + } + for _, msg := range snapshot.changedMessages { + if msg.Deleted { + snapshot.historyReset = true + break + } + } + if snapshot.historyReset { + snapshot.fullHistory, err = tx.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: l.chatID, + AfterID: 0, + }) + if err != nil { + return xerrors.Errorf("get full chat history: %w", err) + } + } + } + + if chat.QueueVersion > l.state.queueVersion { + snapshot.queueChanged = true + snapshot.queue, err = tx.GetChatQueuedMessages(ctx, l.chatID) + if err != nil { + return xerrors.Errorf("get chat queue: %w", err) + } + } + + if chat.Status == database.ChatStatusRequiresAction { + history := snapshot.fullHistory + if len(history) == 0 { + history, err = tx.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: l.chatID, + AfterID: 0, + }) + if err != nil { + return xerrors.Errorf("get requires_action history: %w", err) + } + } + actionRequired, err := l.actionRequiredFromHistory(chat, history) + if err != nil { + return err + } + snapshot.actionRequired = actionRequired + } + return nil + }) + if err != nil { + return streamDBSnapshot{}, err + } + return snapshot, nil +} + +func (*streamLoop) actionRequiredFromHistory(chat database.Chat, messages []database.ChatMessage) (*codersdk.ChatStreamActionRequired, error) { + dynamicToolNames, err := parseDynamicToolNames(chat.DynamicTools) + if err != nil { + return nil, xerrors.Errorf("parse dynamic tools for stream: %w", err) + } + _, pending, err := unresolvedToolCallsFromHistory(messages, dynamicToolNames) + if err != nil { + return nil, xerrors.Errorf("derive pending dynamic tool calls: %w", err) + } + toolCalls := make([]codersdk.ChatStreamToolCall, 0, len(pending)) + for _, call := range pending { + toolCalls = append(toolCalls, codersdk.ChatStreamToolCall{ + ToolCallID: call.ToolCallID, + ToolName: call.ToolName, + Args: call.Args, + }) + } + return &codersdk.ChatStreamActionRequired{ToolCalls: toolCalls}, nil +} + +func (l *streamLoop) applyDBSnapshot(snapshot streamDBSnapshot) []codersdk.ChatStreamEvent { + chat := snapshot.chat + events := make([]codersdk.ChatStreamEvent, 0) + historyChanged := chat.HistoryVersion > l.state.historyVersion + generationChanged := chat.GenerationAttempt != l.state.generationAttempt + + if historyChanged { + events = append(events, l.messageEvents(snapshot)...) + } + if !l.state.initialMessageSyncDone { + l.state.initialMessageSyncDone = true + } + + if chat.QueueVersion > l.state.queueVersion { + events = append(events, codersdk.ChatStreamEvent{ + Type: codersdk.ChatStreamEventTypeQueueUpdate, + ChatID: l.chatID, + QueuedMessages: db2sdk.ChatQueuedMessages(snapshot.queue), + }) + } + + if chat.Status != l.state.status { + events = append(events, codersdk.ChatStreamEvent{ + Type: codersdk.ChatStreamEventTypeStatus, + ChatID: l.chatID, + Status: &codersdk.ChatStreamStatus{Status: codersdk.ChatStatus(chat.Status)}, + }) + } + + if chat.Status == database.ChatStatusError && chat.HistoryVersion > l.state.errorHistoryVersion { + events = append(events, codersdk.ChatStreamEvent{ + Type: codersdk.ChatStreamEventTypeError, + ChatID: l.chatID, + Error: l.chatError(chat), + }) + l.state.errorHistoryVersion = chat.HistoryVersion + } + + if chat.Status == database.ChatStatusRequiresAction && chat.HistoryVersion > l.state.actionRequiredHistoryVersion { + actionRequired := snapshot.actionRequired + if actionRequired == nil { + actionRequired = &codersdk.ChatStreamActionRequired{} + } + events = append(events, codersdk.ChatStreamEvent{ + Type: codersdk.ChatStreamEventTypeActionRequired, + ChatID: l.chatID, + ActionRequired: actionRequired, + }) + l.state.actionRequiredHistoryVersion = chat.HistoryVersion + } + + if chat.RetryStateVersion > l.state.retryVersion { + if retry := l.retryEvent(chat); retry != nil { + events = append(events, *retry) + } + } + + if historyChanged || (generationChanged && chat.GenerationAttempt != 0) { + l.state.lastPartSeq = 0 + events = append(events, codersdk.ChatStreamEvent{ + Type: codersdk.ChatStreamEventTypePreviewReset, + ChatID: l.chatID, + }) + } + + l.state.snapshotVersion = chat.SnapshotVersion + l.state.historyVersion = chat.HistoryVersion + l.state.queueVersion = chat.QueueVersion + l.state.retryVersion = chat.RetryStateVersion + l.state.status = chat.Status + l.state.workerID = chat.WorkerID + l.state.generationAttempt = chat.GenerationAttempt + return events +} + +func (l *streamLoop) messageEvents(snapshot streamDBSnapshot) []codersdk.ChatStreamEvent { + if snapshot.historyReset { + events := []codersdk.ChatStreamEvent{{ + Type: codersdk.ChatStreamEventTypeHistoryReset, + ChatID: l.chatID, + }} + clear(l.state.knownMessages) + for _, msg := range snapshot.fullHistory { + l.state.knownMessages[msg.ID] = msg.Revision + sdkMsg := db2sdk.ChatMessage(msg) + events = append(events, codersdk.ChatStreamEvent{ + Type: codersdk.ChatStreamEventTypeMessage, + ChatID: l.chatID, + Message: &sdkMsg, + }) + } + return events + } + + events := make([]codersdk.ChatStreamEvent, 0, len(snapshot.changedMessages)) + for _, msg := range snapshot.changedMessages { + knownRevision := l.state.knownMessages[msg.ID] + if knownRevision >= msg.Revision { + continue + } + l.state.knownMessages[msg.ID] = msg.Revision + if !l.state.initialMessageSyncDone && msg.ID <= l.state.afterMessageID { + continue + } + sdkMsg := db2sdk.ChatMessage(msg) + events = append(events, codersdk.ChatStreamEvent{ + Type: codersdk.ChatStreamEventTypeMessage, + ChatID: l.chatID, + Message: &sdkMsg, + }) + } + return events +} + +func (l *streamLoop) chatError(chat database.Chat) *codersdk.ChatError { + if !chat.LastError.Valid || len(chat.LastError.RawMessage) == 0 { + return &codersdk.ChatError{ + Message: "The chat request failed unexpectedly.", + Kind: codersdk.ChatErrorKindGeneric, + } + } + var payload codersdk.ChatError + if err := json.Unmarshal(chat.LastError.RawMessage, &payload); err != nil { + l.logger.Warn(context.Background(), "failed to parse chat stream last_error", + slog.F("chat_id", l.chatID), + slog.Error(err), + ) + return &codersdk.ChatError{ + Message: "The chat request failed unexpectedly.", + Kind: codersdk.ChatErrorKindGeneric, + } + } + if payload.Message == "" { + payload.Message = "The chat request failed unexpectedly." + } + if payload.Kind == "" { + payload.Kind = codersdk.ChatErrorKindGeneric + } + return &payload +} + +func (l *streamLoop) retryEvent(chat database.Chat) *codersdk.ChatStreamEvent { + if !chat.RetryState.Valid || len(chat.RetryState.RawMessage) == 0 { + return nil + } + var retry codersdk.ChatStreamRetry + if err := json.Unmarshal(chat.RetryState.RawMessage, &retry); err != nil { + l.logger.Warn(context.Background(), "failed to parse chat stream retry_state", + slog.F("chat_id", l.chatID), + slog.Error(err), + ) + return nil + } + return &codersdk.ChatStreamEvent{ + Type: codersdk.ChatStreamEventTypeRetry, + ChatID: l.chatID, + Retry: &retry, + } +} + +func (l *streamLoop) part(part streamPart) (event codersdk.ChatStreamEvent, accepted bool, err error) { + if part.HistoryVersion != l.state.historyVersion || part.GenerationAttempt != l.state.generationAttempt { + return codersdk.ChatStreamEvent{}, false, nil + } + if part.Seq <= l.state.lastPartSeq { + return codersdk.ChatStreamEvent{}, false, nil + } + if part.Seq != l.state.lastPartSeq+1 { + err := xerrors.Errorf( + "chat stream message part sequence gap: got %d after %d", + part.Seq, + l.state.lastPartSeq, + ) + l.logger.Error(context.Background(), "chat stream message part sequence gap", + slog.F("chat_id", l.chatID), + slog.F("history_version", part.HistoryVersion), + slog.F("generation_attempt", part.GenerationAttempt), + slog.F("last_seq", l.state.lastPartSeq), + slog.F("seq", part.Seq), + slog.Error(err), + ) + return codersdk.ChatStreamEvent{}, false, err + } + l.state.lastPartSeq = part.Seq + return codersdk.ChatStreamEvent{ + Type: codersdk.ChatStreamEventTypeMessagePart, + ChatID: l.chatID, + MessagePart: &codersdk.ChatStreamMessagePart{ + Role: part.Role, + Part: part.Part, + HistoryVersion: part.HistoryVersion, + GenerationAttempt: part.GenerationAttempt, + Seq: part.Seq, + }, + }, true, nil +} + +func (l *streamLoop) currentRelayTarget() streamRelayTarget { + return streamRelayTarget{ + workerID: l.state.workerID, + historyVersion: l.state.historyVersion, + generationAttempt: l.state.generationAttempt, + } +} + +func sameNullUUID(a, b uuid.NullUUID) bool { + if a.Valid != b.Valid { + return false + } + if !a.Valid { + return true + } + return a.UUID == b.UUID +} + +func cloneHeader(header http.Header) http.Header { + if header == nil { + return nil + } + return header.Clone() +} + +func ctxDone(ctx context.Context) <-chan struct{} { + if ctx == nil { + return nil + } + return ctx.Done() +} diff --git a/coderd/x/chatd/stream_loop_internal_test.go b/coderd/x/chatd/stream_loop_internal_test.go new file mode 100644 index 0000000000000..eebd6d0c9781f --- /dev/null +++ b/coderd/x/chatd/stream_loop_internal_test.go @@ -0,0 +1,351 @@ +package chatd + +import ( + "encoding/json" + "testing" + "time" + + "github.com/google/uuid" + "github.com/sqlc-dev/pqtype" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbmock" + "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/testutil" +) + +func TestStreamLoopSyncHintDecision(t *testing.T) { + t.Parallel() + + workerA := uuid.New() + workerB := uuid.New() + loop := &streamLoop{ + state: streamLocalState{ + snapshotVersion: 5, + historyVersion: 2, + queueVersion: 3, + retryVersion: 4, + status: database.ChatStatusRunning, + workerID: uuid.NullUUID{UUID: workerA, Valid: true}, + generationAttempt: 1, + }, + } + + for _, tt := range []struct { + name string + hint streamSyncHint + want bool + }{ + { + name: "stale snapshot ignored even with higher history", + hint: streamSyncHint{snapshotVersion: 5, historyVersion: 3}, + }, + { + name: "duplicate snapshot ignored", + hint: streamSyncHint{snapshotVersion: 5}, + }, + { + name: "new snapshot with no changed fields is ignored", + hint: streamSyncHint{snapshotVersion: 6, historyVersion: 2, queueVersion: 3, retryVersion: 4, status: database.ChatStatusRunning, workerID: uuid.NullUUID{UUID: workerA, Valid: true}, generationAttempt: 1}, + }, + { + name: "new history fetches", + hint: streamSyncHint{snapshotVersion: 6, historyVersion: 3}, + want: true, + }, + { + name: "new queue fetches", + hint: streamSyncHint{snapshotVersion: 6, historyVersion: 2, queueVersion: 4}, + want: true, + }, + { + name: "new retry fetches", + hint: streamSyncHint{snapshotVersion: 6, historyVersion: 2, queueVersion: 3, retryVersion: 5}, + want: true, + }, + { + name: "new status fetches", + hint: streamSyncHint{snapshotVersion: 6, historyVersion: 2, queueVersion: 3, retryVersion: 4, status: database.ChatStatusWaiting, workerID: uuid.NullUUID{UUID: workerA, Valid: true}, generationAttempt: 1}, + want: true, + }, + { + name: "new worker fetches", + hint: streamSyncHint{snapshotVersion: 6, historyVersion: 2, queueVersion: 3, retryVersion: 4, status: database.ChatStatusRunning, workerID: uuid.NullUUID{UUID: workerB, Valid: true}, generationAttempt: 1}, + want: true, + }, + { + name: "new generation attempt fetches", + hint: streamSyncHint{snapshotVersion: 6, historyVersion: 2, queueVersion: 3, retryVersion: 4, status: database.ChatStatusRunning, workerID: uuid.NullUUID{UUID: workerA, Valid: true}, generationAttempt: 2}, + want: true, + }, + } { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tt.want, loop.shouldFetch(tt.hint)) + }) + } +} + +func TestStreamLoopMessageSyncAfterIDAndEdits(t *testing.T) { + t.Parallel() + + chatID := uuid.New() + loop := newStreamLoop(database.Chat{ID: chatID}, nil, slogtest.Make(t, nil), 1) + initial := streamDBSnapshot{ + chat: database.Chat{ + ID: chatID, + Status: database.ChatStatusRunning, + SnapshotVersion: 1, + HistoryVersion: 1, + }, + changedMessages: []database.ChatMessage{ + streamMessage(t, chatID, 1, 1, database.ChatMessageRoleUser, "already seen", false), + streamMessage(t, chatID, 2, 1, database.ChatMessageRoleAssistant, "new", false), + }, + } + + events := loop.applyDBSnapshot(initial) + requireEventTypes(t, events, + codersdk.ChatStreamEventTypeMessage, + codersdk.ChatStreamEventTypeStatus, + codersdk.ChatStreamEventTypePreviewReset, + ) + require.Equal(t, int64(2), events[0].Message.ID) + + edited := streamDBSnapshot{ + chat: database.Chat{ + ID: chatID, + Status: database.ChatStatusRunning, + SnapshotVersion: 2, + HistoryVersion: 2, + }, + changedMessages: []database.ChatMessage{ + streamMessage(t, chatID, 1, 2, database.ChatMessageRoleUser, "edited", false), + }, + } + events = loop.applyDBSnapshot(edited) + requireEventTypes(t, events, + codersdk.ChatStreamEventTypeMessage, + codersdk.ChatStreamEventTypePreviewReset, + ) + require.Equal(t, int64(1), events[0].Message.ID) +} + +func TestStreamLoopHistoryReset(t *testing.T) { + t.Parallel() + + chatID := uuid.New() + loop := newStreamLoop(database.Chat{ID: chatID}, nil, slogtest.Make(t, nil), 0) + loop.state.snapshotVersion = 1 + loop.state.historyVersion = 1 + loop.state.status = database.ChatStatusRunning + loop.state.initialMessageSyncDone = true + loop.state.knownMessages[1] = 1 + loop.state.knownMessages[2] = 1 + + events := loop.applyDBSnapshot(streamDBSnapshot{ + chat: database.Chat{ + ID: chatID, + Status: database.ChatStatusRunning, + SnapshotVersion: 2, + HistoryVersion: 2, + }, + changedMessages: []database.ChatMessage{ + streamMessage(t, chatID, 1, 2, database.ChatMessageRoleUser, "deleted", true), + }, + historyReset: true, + fullHistory: []database.ChatMessage{ + streamMessage(t, chatID, 3, 2, database.ChatMessageRoleUser, "replacement", false), + }, + }) + + requireEventTypes(t, events, + codersdk.ChatStreamEventTypeHistoryReset, + codersdk.ChatStreamEventTypeMessage, + codersdk.ChatStreamEventTypePreviewReset, + ) + require.Equal(t, int64(3), events[1].Message.ID) + require.Equal(t, map[int64]int64{3: 2}, loop.state.knownMessages) +} + +func TestStreamLoopQueueStatusRetryErrorActionRequiredAndPreviewReset(t *testing.T) { + t.Parallel() + + chatID := uuid.New() + retry := codersdk.ChatStreamRetry{Attempt: 2, DelayMs: 100, Error: "retrying", RetryingAt: time.Now()} + retryRaw, err := json.Marshal(retry) + require.NoError(t, err) + chatError := codersdk.ChatError{Message: "provider failed", Kind: codersdk.ChatErrorKindConfig} + errorRaw, err := json.Marshal(chatError) + require.NoError(t, err) + + loop := newStreamLoop(database.Chat{ID: chatID}, nil, slogtest.Make(t, nil), 0) + loop.state.snapshotVersion = 1 + loop.state.historyVersion = 1 + loop.state.queueVersion = 1 + loop.state.retryVersion = 1 + loop.state.generationAttempt = 1 + loop.state.status = database.ChatStatusRunning + + events := loop.applyDBSnapshot(streamDBSnapshot{ + chat: database.Chat{ + ID: chatID, + Status: database.ChatStatusError, + SnapshotVersion: 2, + HistoryVersion: 2, + QueueVersion: 2, + RetryStateVersion: 2, + GenerationAttempt: 2, + LastError: pqtype.NullRawMessage{RawMessage: errorRaw, Valid: true}, + RetryState: pqtype.NullRawMessage{RawMessage: retryRaw, Valid: true}, + }, + queue: []database.ChatQueuedMessage{}, + }) + + requireEventTypes(t, events, + codersdk.ChatStreamEventTypeQueueUpdate, + codersdk.ChatStreamEventTypeStatus, + codersdk.ChatStreamEventTypeError, + codersdk.ChatStreamEventTypeRetry, + codersdk.ChatStreamEventTypePreviewReset, + ) + require.Equal(t, chatError.Message, events[2].Error.Message) + require.Equal(t, retry.Attempt, events[3].Retry.Attempt) + + actionLoop := newStreamLoop(database.Chat{ID: chatID}, nil, slogtest.Make(t, nil), 0) + actionEvents := actionLoop.applyDBSnapshot(streamDBSnapshot{ + chat: database.Chat{ + ID: chatID, + Status: database.ChatStatusRequiresAction, + SnapshotVersion: 1, + HistoryVersion: 1, + }, + actionRequired: &codersdk.ChatStreamActionRequired{ToolCalls: []codersdk.ChatStreamToolCall{{ToolCallID: "call-1", ToolName: "browser"}}}, + }) + requireEventTypes(t, actionEvents, + codersdk.ChatStreamEventTypeStatus, + codersdk.ChatStreamEventTypeActionRequired, + codersdk.ChatStreamEventTypePreviewReset, + ) + require.Equal(t, "call-1", actionEvents[1].ActionRequired.ToolCalls[0].ToolCallID) +} + +func TestStreamLoopActionRequiredFromHistory(t *testing.T) { + t.Parallel() + + chatID := uuid.New() + toolDefs, err := json.Marshal([]codersdk.DynamicTool{{Name: "browser"}}) + require.NoError(t, err) + assistant := streamMessageParts(t, chatID, 1, 1, database.ChatMessageRoleAssistant, []codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeToolCall, + ToolCallID: "call-1", + ToolName: "browser", + Args: json.RawMessage(`{"url":"https://example.com"}`), + }}, false) + loop := newStreamLoop(database.Chat{ID: chatID}, nil, slogtest.Make(t, nil), 0) + action, err := loop.actionRequiredFromHistory(database.Chat{ + ID: chatID, + DynamicTools: pqtype.NullRawMessage{RawMessage: toolDefs, Valid: true}, + }, []database.ChatMessage{assistant}) + require.NoError(t, err) + require.Len(t, action.ToolCalls, 1) + require.Equal(t, "call-1", action.ToolCalls[0].ToolCallID) + require.Equal(t, "browser", action.ToolCalls[0].ToolName) +} + +func TestStreamLoopPartValidation(t *testing.T) { + t.Parallel() + + chatID := uuid.New() + loop := newStreamLoop(database.Chat{ID: chatID}, nil, slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}), 0) + loop.state.historyVersion = 7 + loop.state.generationAttempt = 3 + + event, accepted, err := loop.part(StreamPart{HistoryVersion: 7, GenerationAttempt: 3, Seq: 1, Role: codersdk.ChatMessageRoleAssistant, Part: codersdk.ChatMessageText("a")}) + require.NoError(t, err) + require.True(t, accepted) + require.Equal(t, codersdk.ChatStreamEventTypeMessagePart, event.Type) + require.Equal(t, int64(7), event.MessagePart.HistoryVersion) + require.Equal(t, int64(3), event.MessagePart.GenerationAttempt) + require.Equal(t, int64(1), event.MessagePart.Seq) + + _, accepted, err = loop.part(StreamPart{HistoryVersion: 6, GenerationAttempt: 3, Seq: 2, Part: codersdk.ChatMessageText("old history")}) + require.NoError(t, err) + require.False(t, accepted) + _, accepted, err = loop.part(StreamPart{HistoryVersion: 7, GenerationAttempt: 2, Seq: 2, Part: codersdk.ChatMessageText("old attempt")}) + require.NoError(t, err) + require.False(t, accepted) + _, accepted, err = loop.part(StreamPart{HistoryVersion: 7, GenerationAttempt: 3, Seq: 1, Part: codersdk.ChatMessageText("dup")}) + require.NoError(t, err) + require.False(t, accepted) + _, accepted, err = loop.part(StreamPart{HistoryVersion: 7, GenerationAttempt: 3, Seq: 3, Part: codersdk.ChatMessageText("gap")}) + require.Error(t, err) + require.False(t, accepted) +} + +func TestStreamLoopInitialSyncRecoversWithoutHint(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + tx := dbmock.NewMockStore(ctrl) + chatID := uuid.New() + loop := newStreamLoop(database.Chat{ID: chatID}, db, slogtest.Make(t, nil), 0) + loop.state.snapshotVersion = 1 + loop.state.status = database.ChatStatusRunning + + db.EXPECT().InTx(gomock.Any(), nil).DoAndReturn( + func(fn func(database.Store) error, _ *database.TxOptions) error { return fn(tx) }, + ) + tx.EXPECT().GetChatByIDForShare(gomock.Any(), chatID).Return(database.Chat{ + ID: chatID, + Status: database.ChatStatusWaiting, + SnapshotVersion: 2, + }, nil) + tx.EXPECT().GetChatByID(gomock.Any(), chatID).Return(database.Chat{ + ID: chatID, + Status: database.ChatStatusWaiting, + SnapshotVersion: 2, + }, nil) + + events, _, changed, err := loop.syncDB(ctx) + require.NoError(t, err) + require.True(t, changed) + requireEventTypes(t, events, codersdk.ChatStreamEventTypeStatus) + require.Equal(t, codersdk.ChatStatusWaiting, events[0].Status.Status) +} + +func requireEventTypes(t *testing.T, events []codersdk.ChatStreamEvent, types ...codersdk.ChatStreamEventType) { + t.Helper() + require.Len(t, events, len(types)) + for i, typ := range types { + require.Equal(t, typ, events[i].Type, "event %d", i) + } +} + +func streamMessage(t *testing.T, chatID uuid.UUID, id int64, revision int64, role database.ChatMessageRole, text string, deleted bool) database.ChatMessage { + t.Helper() + return streamMessageParts(t, chatID, id, revision, role, []codersdk.ChatMessagePart{codersdk.ChatMessageText(text)}, deleted) +} + +func streamMessageParts(t *testing.T, chatID uuid.UUID, id int64, revision int64, role database.ChatMessageRole, parts []codersdk.ChatMessagePart, deleted bool) database.ChatMessage { + t.Helper() + content, err := chatprompt.MarshalParts(parts) + require.NoError(t, err) + return database.ChatMessage{ + ID: id, + ChatID: chatID, + CreatedAt: time.Unix(id, 0), + Role: role, + Content: content, + Visibility: database.ChatMessageVisibilityBoth, + ContentVersion: chatprompt.CurrentContentVersion, + Deleted: deleted, + Revision: revision, + } +} diff --git a/coderd/x/chatd/stream_parts.go b/coderd/x/chatd/stream_parts.go new file mode 100644 index 0000000000000..cd1879ae28007 --- /dev/null +++ b/coderd/x/chatd/stream_parts.go @@ -0,0 +1,190 @@ +package chatd + +import ( + "context" + "net/http" + + "github.com/google/uuid" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/httpapi" + "github.com/coder/coder/v2/coderd/x/chatd/messagepartbuffer" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/websocket" +) + +type streamPartsControl struct { + HistoryVersion int64 `json:"history_version"` + GenerationAttempt int64 `json:"generation_attempt"` +} + +type streamPartsEndpoint struct { + chatID uuid.UUID + buffer *messagepartbuffer.Buffer + logger slog.Logger +} + +// ServeStreamPartsAuthorized serves the internal episode-selected parts stream +// for an already authorized chat route. +func (p *Server) ServeStreamPartsAuthorized(rw http.ResponseWriter, r *http.Request, chat database.Chat) error { + if p == nil || p.messagePartBuffer == nil { + return xerrors.New("message part buffer is not configured") + } + endpoint := streamPartsEndpoint{ + chatID: chat.ID, + buffer: p.messagePartBuffer, + logger: p.logger.Named("chat_stream_parts").With(slog.F("chat_id", chat.ID)), + } + return endpoint.serveWebSocket(rw, r) +} + +func (e streamPartsEndpoint) serveWebSocket(rw http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + conn, err := websocket.Accept(rw, r, nil) + if err != nil { + return xerrors.Errorf("accept parts websocket: %w", err) + } + transport := streamPartsWebSocketServerTransport{conn: conn} + defer func() { + _ = transport.Close() + }() + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + go httpapi.HeartbeatClose(ctx, e.logger, cancel, conn) + + return e.serve(ctx, transport) +} + +func (e streamPartsEndpoint) serve(ctx context.Context, transport streamPartsServerTransport) error { + if e.buffer == nil { + return xerrors.New("message part buffer is not configured") + } + if transport == nil { + return xerrors.New("stream parts transport is not configured") + } + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + controlCh := make(chan streamPartsControl, 1) + errCh := make(chan error, 1) + go func() { + for { + control, err := transport.ReadControl(ctx) + if err != nil { + select { + case errCh <- err: + case <-ctx.Done(): + } + return + } + select { + case controlCh <- control: + case <-ctx.Done(): + return + } + } + }() + + var ( + parts <-chan messagepartbuffer.Part + partCancel func() + partCancelFn context.CancelFunc + selected streamPartsControl + lastSeq int64 + ) + defer func() { + if partCancel != nil { + partCancel() + } + if partCancelFn != nil { + partCancelFn() + } + }() + + selectEpisode := func(control streamPartsControl) error { + if partCancel != nil { + partCancel() + partCancel = nil + } + if partCancelFn != nil { + partCancelFn() + partCancelFn = nil + } + parts = nil + selected = control + lastSeq = 0 + partCtx, cancel := context.WithCancel(ctx) + ch, cancelSub, err := e.buffer.SubscribeToEpisode(partCtx, messagepartbuffer.Key{ + ChatID: e.chatID, + HistoryVersion: control.HistoryVersion, + GenerationAttempt: control.GenerationAttempt, + }) + if err != nil { + cancel() + return err + } + partCancelFn = cancel + partCancel = cancelSub + parts = ch + return nil + } + + for { + select { + case <-ctx.Done(): + return nil + case err := <-errCh: + if ctx.Err() != nil || streamPartsExpectedTransportClose(err) { + return nil + } + return err + case control := <-controlCh: + if err := selectEpisode(control); err != nil { + return err + } + case part, ok := <-parts: + if !ok { + parts = nil + continue + } + if part.Seq != lastSeq+1 { + return xerrors.Errorf("message part sequence gap: got %d after %d", part.Seq, lastSeq) + } + lastSeq = part.Seq + event := codersdk.ChatStreamEvent{ + Type: codersdk.ChatStreamEventTypeMessagePart, + ChatID: e.chatID, + MessagePart: &codersdk.ChatStreamMessagePart{ + Role: part.Role, + Part: part.MessagePart, + HistoryVersion: selected.HistoryVersion, + GenerationAttempt: selected.GenerationAttempt, + Seq: part.Seq, + }, + } + if err := transport.WriteEvents(ctx, []codersdk.ChatStreamEvent{event}); err != nil { + if ctx.Err() != nil || streamPartsExpectedTransportClose(err) { + return nil + } + return err + } + } + } +} + +func StreamPartFromEvent(event codersdk.ChatStreamEvent) (StreamPart, bool) { + if event.Type != codersdk.ChatStreamEventTypeMessagePart || event.MessagePart == nil { + return StreamPart{}, false + } + return StreamPart{ + HistoryVersion: event.MessagePart.HistoryVersion, + GenerationAttempt: event.MessagePart.GenerationAttempt, + Seq: event.MessagePart.Seq, + Role: event.MessagePart.Role, + Part: event.MessagePart.Part, + }, true +} diff --git a/coderd/x/chatd/stream_parts_dialer.go b/coderd/x/chatd/stream_parts_dialer.go new file mode 100644 index 0000000000000..eaada8220db62 --- /dev/null +++ b/coderd/x/chatd/stream_parts_dialer.go @@ -0,0 +1,60 @@ +package chatd + +import ( + "context" + + "github.com/google/uuid" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/x/chatd/messagepartbuffer" +) + +// LocalStreamPartsDialerConfig configures an in-process stream parts dialer. +type LocalStreamPartsDialerConfig struct { + Buffer *messagepartbuffer.Buffer + Logger slog.Logger +} + +// NewLocalStreamPartsDialer returns a dialer that streams message parts through +// in-process channels while using the same stream serving loop as WebSockets. +func NewLocalStreamPartsDialer(cfg LocalStreamPartsDialerConfig) StreamPartsDialer { + return func(ctx context.Context, input StreamPartsDialInput) (StreamPartsSession, error) { + if cfg.Buffer == nil { + return nil, xerrors.New("message part buffer is not configured") + } + serverTransport, clientTransport := newStreamPartsChannelTransportPair() + logger := cfg.Logger.Named("chat_stream_parts").With(slog.F("chat_id", input.ChatID)) + endpoint := streamPartsEndpoint{ + chatID: input.ChatID, + buffer: cfg.Buffer, + logger: logger, + } + serveCtx, cancel := context.WithCancel(ctx) + go func() { + defer cancel() + defer func() { + _ = serverTransport.Close() + }() + if err := endpoint.serve(serveCtx, serverTransport); err != nil && !streamPartsExpectedTransportClose(err) { + logger.Debug(serveCtx, "chat stream parts closed", slog.Error(err)) + } + }() + return newStreamPartsTransportSession(serveCtx, clientTransport), nil + } +} + +func streamPartsDialerForServer(workerID uuid.UUID, local StreamPartsDialer, remote StreamPartsDialer) StreamPartsDialer { + return func(ctx context.Context, input StreamPartsDialInput) (StreamPartsSession, error) { + if local == nil && remote == nil { + return nil, xerrors.New("stream parts dialer is not configured") + } + if remote == nil || input.WorkerID == uuid.Nil || input.WorkerID == workerID { + if local == nil { + return nil, xerrors.New("local stream parts dialer is not configured") + } + return local(ctx, input) + } + return remote(ctx, input) + } +} diff --git a/coderd/x/chatd/stream_parts_internal_test.go b/coderd/x/chatd/stream_parts_internal_test.go new file mode 100644 index 0000000000000..cc855daddd519 --- /dev/null +++ b/coderd/x/chatd/stream_parts_internal_test.go @@ -0,0 +1,354 @@ +package chatd + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/x/chatd/messagepartbuffer" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/testutil" + "github.com/coder/websocket" + "github.com/coder/websocket/wsjson" +) + +func TestStreamPartsEndpointReplayLiveAndReselect(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + chatID := uuid.New() + buffer := messagepartbuffer.New(messagepartbuffer.Options{}) + endpoint := streamPartsEndpoint{ + chatID: chatID, + buffer: buffer, + logger: slogtest.Make(t, nil), + } + serverTransport, clientTransport := newStreamPartsChannelTransportPair() + serveDone := serveStreamPartsEndpoint(ctx, t, endpoint, serverTransport) + defer func() { + require.NoError(t, clientTransport.Close()) + <-serveDone + }() + + firstKey := messagepartbuffer.Key{ChatID: chatID, HistoryVersion: 1, GenerationAttempt: 1} + secondKey := messagepartbuffer.Key{ChatID: chatID, HistoryVersion: 2, GenerationAttempt: 1} + require.NoError(t, buffer.CreateEpisode(firstKey)) + require.NoError(t, buffer.AddPart(firstKey, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("replayed"))) + require.NoError(t, buffer.CreateEpisode(secondKey)) + require.NoError(t, buffer.AddPart(secondKey, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("second"))) + + require.NoError(t, clientTransport.WriteControl(ctx, streamPartsControl{HistoryVersion: 1, GenerationAttempt: 1})) + got := readStreamPartsTransportBatch(ctx, t, clientTransport) + require.Len(t, got, 1) + require.Equal(t, "replayed", got[0].MessagePart.Part.Text) + require.Equal(t, int64(1), got[0].MessagePart.Seq) + require.Equal(t, int64(1), got[0].MessagePart.HistoryVersion) + require.Equal(t, int64(1), got[0].MessagePart.GenerationAttempt) + + require.NoError(t, buffer.AddPart(firstKey, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("live"))) + got = readStreamPartsTransportBatch(ctx, t, clientTransport) + require.Len(t, got, 1) + require.Equal(t, "live", got[0].MessagePart.Part.Text) + require.Equal(t, int64(2), got[0].MessagePart.Seq) + + require.NoError(t, clientTransport.WriteControl(ctx, streamPartsControl{HistoryVersion: 2, GenerationAttempt: 1})) + got = readStreamPartsTransportBatch(ctx, t, clientTransport) + require.Len(t, got, 1) + require.Equal(t, "second", got[0].MessagePart.Part.Text) + require.Equal(t, int64(2), got[0].MessagePart.HistoryVersion) + + require.NoError(t, buffer.AddPart(firstKey, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("ignored"))) + select { + case <-ctx.Done(): + t.Fatal("timed out waiting to verify previous episode was canceled") + default: + } + require.NoError(t, buffer.AddPart(secondKey, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("second-live"))) + got = readStreamPartsTransportBatch(ctx, t, clientTransport) + require.Equal(t, "second-live", got[0].MessagePart.Part.Text) +} + +func TestStreamPartsEndpointWaitsForMissingEpisode(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + chatID := uuid.New() + buffer := messagepartbuffer.New(messagepartbuffer.Options{}) + endpoint := streamPartsEndpoint{ + chatID: chatID, + buffer: buffer, + logger: slogtest.Make(t, nil), + } + serverTransport, clientTransport := newStreamPartsChannelTransportPair() + serveDone := serveStreamPartsEndpoint(ctx, t, endpoint, serverTransport) + defer func() { + require.NoError(t, clientTransport.Close()) + <-serveDone + }() + + key := messagepartbuffer.Key{ChatID: chatID, HistoryVersion: 9, GenerationAttempt: 2} + require.NoError(t, clientTransport.WriteControl(ctx, streamPartsControl{HistoryVersion: key.HistoryVersion, GenerationAttempt: key.GenerationAttempt})) + require.NoError(t, buffer.CreateEpisode(key)) + require.NoError(t, buffer.AddPart(key, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("eventual"))) + + got := readStreamPartsTransportBatch(ctx, t, clientTransport) + require.Len(t, got, 1) + require.Equal(t, "eventual", got[0].MessagePart.Part.Text) + require.Equal(t, int64(9), got[0].MessagePart.HistoryVersion) + require.Equal(t, int64(2), got[0].MessagePart.GenerationAttempt) +} + +func TestStreamPartsEndpointReselectsWhileEpisodeMissing(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + chatID := uuid.New() + buffer := messagepartbuffer.New(messagepartbuffer.Options{}) + endpoint := streamPartsEndpoint{ + chatID: chatID, + buffer: buffer, + logger: slogtest.Make(t, nil), + } + serverTransport, clientTransport := newStreamPartsChannelTransportPair() + serveDone := serveStreamPartsEndpoint(ctx, t, endpoint, serverTransport) + defer func() { + require.NoError(t, clientTransport.Close()) + <-serveDone + }() + + missingKey := messagepartbuffer.Key{ChatID: chatID, HistoryVersion: 10, GenerationAttempt: 1} + selectedKey := messagepartbuffer.Key{ChatID: chatID, HistoryVersion: 11, GenerationAttempt: 1} + require.NoError(t, clientTransport.WriteControl(ctx, streamPartsControl{HistoryVersion: missingKey.HistoryVersion, GenerationAttempt: missingKey.GenerationAttempt})) + require.NoError(t, clientTransport.WriteControl(ctx, streamPartsControl{HistoryVersion: selectedKey.HistoryVersion, GenerationAttempt: selectedKey.GenerationAttempt})) + require.NoError(t, buffer.CreateEpisode(selectedKey)) + require.NoError(t, buffer.AddPart(selectedKey, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("selected"))) + + got := readStreamPartsTransportBatch(ctx, t, clientTransport) + require.Len(t, got, 1) + require.Equal(t, "selected", got[0].MessagePart.Part.Text) + require.Equal(t, selectedKey.HistoryVersion, got[0].MessagePart.HistoryVersion) +} + +func TestStreamPartsEndpointClientDisconnectCancels(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + chatID := uuid.New() + buffer := messagepartbuffer.New(messagepartbuffer.Options{}) + endpoint := streamPartsEndpoint{ + chatID: chatID, + buffer: buffer, + logger: slogtest.Make(t, nil), + } + serverTransport, clientTransport := newStreamPartsChannelTransportPair() + serveDone := serveStreamPartsEndpoint(ctx, t, endpoint, serverTransport) + require.NoError(t, clientTransport.Close()) + + select { + case <-serveDone: + case <-ctx.Done(): + t.Fatal("stream parts endpoint did not exit after client disconnect") + } +} + +func TestStreamPartsEndpointWebSocket(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + chatID := uuid.New() + buffer := messagepartbuffer.New(messagepartbuffer.Options{}) + endpoint := streamPartsEndpoint{ + chatID: chatID, + buffer: buffer, + logger: slogtest.Make(t, nil), + } + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + _ = endpoint.serveWebSocket(rw, r) + })) + t.Cleanup(server.Close) + + key := messagepartbuffer.Key{ChatID: chatID, HistoryVersion: 1, GenerationAttempt: 1} + require.NoError(t, buffer.CreateEpisode(key)) + require.NoError(t, buffer.AddPart(key, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("websocket"))) + + conn, resp, err := websocket.Dial(ctx, server.URL, nil) + require.NoError(t, err) + if resp != nil && resp.Body != nil { + require.NoError(t, resp.Body.Close()) + } + defer conn.Close(websocket.StatusNormalClosure, "") + + require.NoError(t, wsjson.Write(ctx, conn, streamPartsControl{HistoryVersion: 1, GenerationAttempt: 1})) + got := readStreamPartsWebSocketBatch(ctx, t, conn) + require.Len(t, got, 1) + require.Equal(t, "websocket", got[0].MessagePart.Part.Text) +} + +func TestStreamPartsWebSocketSession(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + chatID := uuid.New() + buffer := messagepartbuffer.New(messagepartbuffer.Options{}) + endpoint := streamPartsEndpoint{ + chatID: chatID, + buffer: buffer, + logger: slogtest.Make(t, nil), + } + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + _ = endpoint.serveWebSocket(rw, r) + })) + t.Cleanup(server.Close) + + key := messagepartbuffer.Key{ChatID: chatID, HistoryVersion: 4, GenerationAttempt: 2} + require.NoError(t, buffer.CreateEpisode(key)) + require.NoError(t, buffer.AddPart(key, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("session"))) + + conn, resp, err := websocket.Dial(ctx, server.URL, nil) + require.NoError(t, err) + if resp != nil && resp.Body != nil { + require.NoError(t, resp.Body.Close()) + } + session := NewStreamPartsJSONSession(ctx, conn) + defer session.Close() + + require.NoError(t, session.SelectEpisode(ctx, key.HistoryVersion, key.GenerationAttempt)) + part := readStreamPart(ctx, t, session.Parts()) + require.Equal(t, key.HistoryVersion, part.HistoryVersion) + require.Equal(t, key.GenerationAttempt, part.GenerationAttempt) + require.Equal(t, int64(1), part.Seq) + require.Equal(t, "session", part.Part.Text) +} + +func TestLocalStreamPartsDialerReplayLiveAndClose(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + chatID := uuid.New() + buffer := messagepartbuffer.New(messagepartbuffer.Options{}) + dialer := NewLocalStreamPartsDialer(LocalStreamPartsDialerConfig{ + Buffer: buffer, + Logger: slogtest.Make(t, nil), + }) + key := messagepartbuffer.Key{ChatID: chatID, HistoryVersion: 3, GenerationAttempt: 1} + require.NoError(t, buffer.CreateEpisode(key)) + require.NoError(t, buffer.AddPart(key, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("replayed"))) + + session, err := dialer(ctx, StreamPartsDialInput{ChatID: chatID, WorkerID: uuid.New()}) + require.NoError(t, err) + require.NoError(t, session.SelectEpisode(ctx, key.HistoryVersion, key.GenerationAttempt)) + + part := readStreamPart(ctx, t, session.Parts()) + require.Equal(t, int64(1), part.Seq) + require.Equal(t, "replayed", part.Part.Text) + + require.NoError(t, buffer.AddPart(key, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("live"))) + part = readStreamPart(ctx, t, session.Parts()) + require.Equal(t, int64(2), part.Seq) + require.Equal(t, "live", part.Part.Text) + + require.NoError(t, session.Close()) + select { + case _, ok := <-session.Parts(): + require.False(t, ok) + case <-ctx.Done(): + t.Fatal("stream parts session did not close") + } +} + +func TestStreamPartsDialerForServer(t *testing.T) { + t.Parallel() + + serverWorkerID := uuid.New() + remoteWorkerID := uuid.New() + + cases := []struct { + name string + remote bool + workerID uuid.UUID + want string + }{ + {name: "no remote uses local", workerID: remoteWorkerID, want: "local"}, + {name: "same worker uses local", remote: true, workerID: serverWorkerID, want: "local"}, + {name: "different worker uses remote", remote: true, workerID: remoteWorkerID, want: "remote"}, + {name: "nil worker uses local", remote: true, workerID: uuid.Nil, want: "local"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + called := make(chan string, 1) + local := func(context.Context, StreamPartsDialInput) (StreamPartsSession, error) { + called <- "local" + return nil, xerrors.New("local") + } + var remote StreamPartsDialer + if tc.remote { + remote = func(context.Context, StreamPartsDialInput) (StreamPartsSession, error) { + called <- "remote" + return nil, xerrors.New("remote") + } + } + dialer := streamPartsDialerForServer(serverWorkerID, local, remote) + _, _ = dialer(ctx, StreamPartsDialInput{WorkerID: tc.workerID}) + require.Equal(t, tc.want, <-called) + }) + } +} + +func serveStreamPartsEndpoint(ctx context.Context, t *testing.T, endpoint streamPartsEndpoint, transport streamPartsServerTransport) <-chan struct{} { + t.Helper() + done := make(chan struct{}) + go func() { + defer close(done) + err := endpoint.serve(ctx, transport) + if err != nil && !streamPartsExpectedTransportClose(err) { + require.NoError(t, err) + } + }() + return done +} + +func readStreamPartsTransportBatch(ctx context.Context, t *testing.T, transport streamPartsClientTransport) []codersdk.ChatStreamEvent { + t.Helper() + got, err := transport.ReadEvents(ctx) + require.NoError(t, err) + assertStreamPartsBatch(t, got) + return got +} + +func readStreamPartsWebSocketBatch(ctx context.Context, t *testing.T, conn *websocket.Conn) []codersdk.ChatStreamEvent { + t.Helper() + var got []codersdk.ChatStreamEvent + require.NoError(t, wsjson.Read(ctx, conn, &got)) + assertStreamPartsBatch(t, got) + return got +} + +func assertStreamPartsBatch(t *testing.T, got []codersdk.ChatStreamEvent) { + t.Helper() + for _, event := range got { + require.Equal(t, codersdk.ChatStreamEventTypeMessagePart, event.Type) + require.NotNil(t, event.MessagePart) + } +} + +func readStreamPart(ctx context.Context, t *testing.T, parts <-chan StreamPart) StreamPart { + t.Helper() + select { + case part, ok := <-parts: + require.True(t, ok) + return part + case <-ctx.Done(): + t.Fatal("timed out waiting for stream part") + return StreamPart{} + } +} diff --git a/coderd/x/chatd/stream_parts_transport.go b/coderd/x/chatd/stream_parts_transport.go new file mode 100644 index 0000000000000..c43a57dc38f7e --- /dev/null +++ b/coderd/x/chatd/stream_parts_transport.go @@ -0,0 +1,267 @@ +package chatd + +import ( + "context" + "errors" + "net" + "sync" + + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/codersdk" + "github.com/coder/websocket" + "github.com/coder/websocket/wsjson" +) + +var errStreamPartsTransportClosed = xerrors.New("stream parts transport closed") + +type streamPartsServerTransport interface { + ReadControl(context.Context) (streamPartsControl, error) + WriteEvents(context.Context, []codersdk.ChatStreamEvent) error + Close() error +} + +type streamPartsClientTransport interface { + WriteControl(context.Context, streamPartsControl) error + ReadEvents(context.Context) ([]codersdk.ChatStreamEvent, error) + Close() error +} + +type streamPartsWebSocketServerTransport struct { + conn *websocket.Conn +} + +func (t streamPartsWebSocketServerTransport) ReadControl(ctx context.Context) (streamPartsControl, error) { + var control streamPartsControl + if err := wsjson.Read(ctx, t.conn, &control); err != nil { + return streamPartsControl{}, err + } + return control, nil +} + +func (t streamPartsWebSocketServerTransport) WriteEvents(ctx context.Context, events []codersdk.ChatStreamEvent) error { + return wsjson.Write(ctx, t.conn, events) +} + +func (t streamPartsWebSocketServerTransport) Close() error { + return t.conn.Close(websocket.StatusNormalClosure, "") +} + +type streamPartsWebSocketClientTransport struct { + conn *websocket.Conn +} + +func (t streamPartsWebSocketClientTransport) WriteControl(ctx context.Context, control streamPartsControl) error { + return wsjson.Write(ctx, t.conn, control) +} + +func (t streamPartsWebSocketClientTransport) ReadEvents(ctx context.Context) ([]codersdk.ChatStreamEvent, error) { + var batch []codersdk.ChatStreamEvent + if err := wsjson.Read(ctx, t.conn, &batch); err != nil { + return nil, err + } + return batch, nil +} + +func (t streamPartsWebSocketClientTransport) Close() error { + return t.conn.Close(websocket.StatusNormalClosure, "") +} + +type streamPartsChannelPipe struct { + controlCh chan streamPartsControl + eventsCh chan []codersdk.ChatStreamEvent + done chan struct{} + closeOnce sync.Once +} + +type streamPartsChannelServerTransport struct { + pipe *streamPartsChannelPipe +} + +type streamPartsChannelClientTransport struct { + pipe *streamPartsChannelPipe +} + +func newStreamPartsChannelTransportPair() (streamPartsServerTransport, streamPartsClientTransport) { + pipe := &streamPartsChannelPipe{ + controlCh: make(chan streamPartsControl, 1), + eventsCh: make(chan []codersdk.ChatStreamEvent, 128), + done: make(chan struct{}), + } + return streamPartsChannelServerTransport{pipe: pipe}, streamPartsChannelClientTransport{pipe: pipe} +} + +func (t streamPartsChannelServerTransport) ReadControl(ctx context.Context) (streamPartsControl, error) { + select { + case <-ctxDone(ctx): + return streamPartsControl{}, ctx.Err() + case <-t.pipe.done: + return streamPartsControl{}, errStreamPartsTransportClosed + default: + } + select { + case control := <-t.pipe.controlCh: + return control, nil + case <-ctxDone(ctx): + return streamPartsControl{}, ctx.Err() + case <-t.pipe.done: + return streamPartsControl{}, errStreamPartsTransportClosed + } +} + +func (t streamPartsChannelServerTransport) WriteEvents(ctx context.Context, events []codersdk.ChatStreamEvent) error { + select { + case <-ctxDone(ctx): + return ctx.Err() + case <-t.pipe.done: + return errStreamPartsTransportClosed + default: + } + select { + case t.pipe.eventsCh <- events: + return nil + case <-ctxDone(ctx): + return ctx.Err() + case <-t.pipe.done: + return errStreamPartsTransportClosed + } +} + +func (t streamPartsChannelServerTransport) Close() error { + return t.pipe.close() +} + +func (t streamPartsChannelClientTransport) WriteControl(ctx context.Context, control streamPartsControl) error { + select { + case <-ctxDone(ctx): + return ctx.Err() + case <-t.pipe.done: + return errStreamPartsTransportClosed + default: + } + select { + case t.pipe.controlCh <- control: + return nil + case <-ctxDone(ctx): + return ctx.Err() + case <-t.pipe.done: + return errStreamPartsTransportClosed + } +} + +func (t streamPartsChannelClientTransport) ReadEvents(ctx context.Context) ([]codersdk.ChatStreamEvent, error) { + select { + case <-ctxDone(ctx): + return nil, ctx.Err() + case <-t.pipe.done: + return nil, errStreamPartsTransportClosed + default: + } + select { + case events := <-t.pipe.eventsCh: + return events, nil + case <-ctxDone(ctx): + return nil, ctx.Err() + case <-t.pipe.done: + return nil, errStreamPartsTransportClosed + } +} + +func (t streamPartsChannelClientTransport) Close() error { + return t.pipe.close() +} + +func (p *streamPartsChannelPipe) close() error { + p.closeOnce.Do(func() { + close(p.done) + }) + return nil +} + +type streamPartsTransportSession struct { + ctx context.Context + cancel context.CancelFunc + transport streamPartsClientTransport + parts chan StreamPart + closeOnce sync.Once + closeErr error +} + +func newStreamPartsTransportSession(ctx context.Context, transport streamPartsClientTransport) *streamPartsTransportSession { + sessionCtx, cancel := context.WithCancel(ctx) + session := &streamPartsTransportSession{ + ctx: sessionCtx, + cancel: cancel, + transport: transport, + parts: make(chan StreamPart, 128), + } + go session.readLoop() + return session +} + +func (s *streamPartsTransportSession) SelectEpisode(ctx context.Context, historyVersion, generationAttempt int64) error { + return s.transport.WriteControl(ctx, streamPartsControl{ + HistoryVersion: historyVersion, + GenerationAttempt: generationAttempt, + }) +} + +func (s *streamPartsTransportSession) Parts() <-chan StreamPart { + return s.parts +} + +func (s *streamPartsTransportSession) Close() error { + s.closeOnce.Do(func() { + s.cancel() + s.closeErr = s.transport.Close() + }) + return s.closeErr +} + +func (s *streamPartsTransportSession) readLoop() { + defer close(s.parts) + for { + batch, err := s.transport.ReadEvents(s.ctx) + if err != nil { + return + } + for _, event := range batch { + part, ok := StreamPartFromEvent(event) + if !ok { + continue + } + select { + case s.parts <- part: + case <-s.ctx.Done(): + return + } + } + } +} + +type StreamPartsJSONSession struct { + *streamPartsTransportSession +} + +func NewStreamPartsJSONSession(ctx context.Context, conn *websocket.Conn) *StreamPartsJSONSession { + return &StreamPartsJSONSession{ + streamPartsTransportSession: newStreamPartsTransportSession(ctx, streamPartsWebSocketClientTransport{conn: conn}), + } +} + +func streamPartsExpectedTransportClose(err error) bool { + if err == nil { + return true + } + if errors.Is(err, errStreamPartsTransportClosed) || + errors.Is(err, context.Canceled) || + errors.Is(err, net.ErrClosed) { + return true + } + switch websocket.CloseStatus(err) { + case websocket.StatusNormalClosure, websocket.StatusGoingAway: + return true + default: + return false + } +} diff --git a/coderd/x/chatd/stream_relay.go b/coderd/x/chatd/stream_relay.go new file mode 100644 index 0000000000000..da9a43104be92 --- /dev/null +++ b/coderd/x/chatd/stream_relay.go @@ -0,0 +1,246 @@ +package chatd + +import ( + "context" + "net/http" + "sync" + "time" + + "github.com/google/uuid" + + "cdr.dev/slog/v3" + "github.com/coder/quartz" +) + +const ( + streamRelayRetryInitialBackoff = 100 * time.Millisecond + streamRelayRetryMaxBackoff = 5 * time.Second +) + +type streamRelayForwarder struct { + chatID uuid.UUID + requestHeader http.Header + dialer StreamPartsDialer + clock quartz.Clock + logger slog.Logger + + parts chan StreamPart + + ctx context.Context + cancel context.CancelFunc + done chan struct{} + + configure chan streamRelayTarget + closeOnce sync.Once +} + +func newStreamRelayForwarder( + chatID uuid.UUID, + requestHeader http.Header, + dialer StreamPartsDialer, + clock quartz.Clock, + logger slog.Logger, +) *streamRelayForwarder { + if clock == nil { + clock = quartz.NewReal() + } + ctx, cancel := context.WithCancel(context.Background()) + f := &streamRelayForwarder{ + chatID: chatID, + requestHeader: cloneHeader(requestHeader), + dialer: dialer, + clock: clock, + logger: logger, + parts: make(chan StreamPart, 128), + ctx: ctx, + cancel: cancel, + done: make(chan struct{}), + configure: make(chan streamRelayTarget, 1), + } + go f.loop() + return f +} + +func (f *streamRelayForwarder) Parts() <-chan StreamPart { + return f.parts +} + +func (f *streamRelayForwarder) Configure(ctx context.Context, target streamRelayTarget) { + if f == nil { + return + } + done := ctxDone(ctx) + select { + case f.configure <- target: + case <-f.ctx.Done(): + case <-done: + default: + select { + case <-f.configure: + default: + } + select { + case f.configure <- target: + case <-f.ctx.Done(): + case <-done: + } + } +} + +func (f *streamRelayForwarder) Close() { + if f == nil { + return + } + f.closeOnce.Do(func() { + f.cancel() + <-f.done + }) +} + +func (f *streamRelayForwarder) loop() { + defer close(f.done) + defer close(f.parts) + var ( + target streamRelayTarget + connected streamRelayTarget + session StreamPartsSession + sessionParts <-chan StreamPart + retryTimer *quartz.Timer + retryC <-chan time.Time + retryBackoff = streamRelayRetryInitialBackoff + ) + stopRetry := func() { + if retryTimer != nil { + retryTimer.Stop() + retryTimer = nil + retryC = nil + } + } + closeSession := func() { + if session != nil { + _ = session.Close() + } + session = nil + sessionParts = nil + connected = streamRelayTarget{} + } + scheduleRetry := func() { + if !target.needsRelay() || f.dialer == nil || retryTimer != nil { + return + } + retryTimer = f.clock.NewTimer(retryBackoff, "chatd", "stream-relay-retry") + retryC = retryTimer.C + if retryBackoff < streamRelayRetryMaxBackoff { + retryBackoff *= 2 + if retryBackoff > streamRelayRetryMaxBackoff { + retryBackoff = streamRelayRetryMaxBackoff + } + } + } + connect := func(ctx context.Context) { + stopRetry() + if !target.needsRelay() { + closeSession() + return + } + if f.dialer == nil { + return + } + if session != nil && connected.workerID.Valid && sameNullUUID(connected.workerID, target.workerID) { + if err := session.SelectEpisode(ctx, target.historyVersion, target.generationAttempt); err != nil { + f.logger.Warn(ctx, "failed to select stream parts episode", + slog.F("chat_id", f.chatID), + slog.F("history_version", target.historyVersion), + slog.F("generation_attempt", target.generationAttempt), + slog.Error(err), + ) + closeSession() + scheduleRetry() + return + } + connected = target + retryBackoff = streamRelayRetryInitialBackoff + return + } + closeSession() + newSession, err := f.dialer(ctx, StreamPartsDialInput{ + ChatID: f.chatID, + WorkerID: target.workerID.UUID, + RequestHeader: cloneHeader(f.requestHeader), + }) + if err != nil { + f.logger.Warn(ctx, "failed to dial stream parts relay", + slog.F("chat_id", f.chatID), + slog.F("worker_id", target.workerID.UUID), + slog.Error(err), + ) + scheduleRetry() + return + } + session = newSession + sessionParts = newSession.Parts() + connected = streamRelayTarget{workerID: target.workerID} + if err := session.SelectEpisode(ctx, target.historyVersion, target.generationAttempt); err != nil { + f.logger.Warn(ctx, "failed to select stream parts episode", + slog.F("chat_id", f.chatID), + slog.F("history_version", target.historyVersion), + slog.F("generation_attempt", target.generationAttempt), + slog.Error(err), + ) + closeSession() + scheduleRetry() + return + } + connected = target + retryBackoff = streamRelayRetryInitialBackoff + } + + for { + select { + case <-f.ctx.Done(): + stopRetry() + closeSession() + return + case nextTarget := <-f.configure: + target = nextTarget + stopRetry() + if !target.needsRelay() { + closeSession() + continue + } + connect(f.ctx) + case <-retryC: + retryTimer = nil + retryC = nil + connect(f.ctx) + case part, ok := <-sessionParts: + if !ok { + closeSession() + scheduleRetry() + continue + } + if !connected.sameEpisode(target) || + part.HistoryVersion != target.historyVersion || + part.GenerationAttempt != target.generationAttempt { + continue + } + select { + case f.parts <- part: + case <-f.ctx.Done(): + stopRetry() + closeSession() + return + } + } + } +} + +func (t streamRelayTarget) needsRelay() bool { + return t.workerID.Valid && t.generationAttempt > 0 +} + +func (t streamRelayTarget) sameEpisode(other streamRelayTarget) bool { + return sameNullUUID(t.workerID, other.workerID) && + t.historyVersion == other.historyVersion && + t.generationAttempt == other.generationAttempt +} diff --git a/coderd/x/chatd/stream_subscribe.go b/coderd/x/chatd/stream_subscribe.go new file mode 100644 index 0000000000000..4be9ecb2de339 --- /dev/null +++ b/coderd/x/chatd/stream_subscribe.go @@ -0,0 +1,265 @@ +package chatd + +import ( + "context" + "net/http" + "time" + + "github.com/google/uuid" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" + coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" + "github.com/coder/coder/v2/coderd/x/chatd/messagepartbuffer" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/quartz" +) + +const ( + streamSyncRetryInitialBackoff = 100 * time.Millisecond + streamSyncRetryMaxBackoff = time.Second + streamSyncRetryMaxAttempts = 5 +) + +func (p *Server) subscribeStreamLoop( + ctx context.Context, + chat database.Chat, + requestHeader http.Header, + afterMessageID int64, +) ([]codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), bool) { + if p == nil || p.db == nil || p.pubsub == nil { + return nil, nil, nil, false + } + if p.messagePartBuffer == nil { + p.messagePartBuffer = messagepartbuffer.New(messagepartbuffer.Options{Clock: p.clock}) + } + chatID := chat.ID + streamCtx, streamCancel := context.WithCancel(ctx) + events := make(chan codersdk.ChatStreamEvent, 128) + logger := p.logger.With(slog.F("chat_id", chatID)) + + updateCh := make(chan streamSyncHint, 32) + pubsubCancel, err := p.pubsub.SubscribeWithErr( + coderdpubsub.ChatStateUpdateChannel(chatID), + coderdpubsub.HandleChatStateUpdate(func(_ context.Context, payload coderdpubsub.ChatStateUpdateMessage, err error) { + if err != nil { + logger.Warn(streamCtx, "chat stream pubsub error", slog.Error(err)) + return + } + select { + case updateCh <- streamSyncHintFromUpdate(payload): + case <-streamCtx.Done(): + } + }), + ) + if err != nil { + logger.Warn(ctx, "failed to subscribe to chat state updates", slog.Error(err)) + streamCancel() + return subscribeWithInitialError(chatID, "failed to subscribe to chat updates") + } + + pollerCh, unregisterPoller := p.registerStreamSyncPoller(chatID) + loop := newStreamLoop(chat, p.db, logger, afterMessageID) + //nolint:gocritic // The HTTP route authorizes the chat before subscribing; the stream loop needs chatd-scoped reads for one consistent snapshot. + initial, target, _, err := loop.syncDB(dbauthz.AsChatd(ctx)) + if err != nil { + logger.Error(ctx, "failed to load initial chat stream snapshot", slog.Error(err)) + unregisterPoller() + pubsubCancel() + streamCancel() + return subscribeWithInitialError(chatID, "failed to load initial snapshot") + } + + relay := newStreamRelayForwarder( + chatID, + requestHeader, + p.streamPartsDialer, + p.clock, + logger, + ) + relay.Configure(streamCtx, target) + + done := make(chan struct{}) + go func() { + defer close(done) + defer close(events) + defer relay.Close() + defer unregisterPoller() + for { + select { + case <-streamCtx.Done(): + return + case hint := <-updateCh: + if !p.runStreamSync(streamCtx, loop, relay, events, hint) { + return + } + case hint, ok := <-pollerCh: + if !ok { + return + } + if !p.runStreamSync(streamCtx, loop, relay, events, hint) { + return + } + case part, ok := <-relay.Parts(): + if !ok { + return + } + event, accepted, err := loop.part(part) + if err != nil { + logger.Error(streamCtx, "chat stream invariant violation", slog.Error(err)) + return + } + if accepted { + sendStreamEvent(streamCtx, events, event) + } + } + } + }() + + cancel := func() { + streamCancel() + pubsubCancel() + <-done + } + return initial, events, cancel, true +} + +func (p *Server) registerStreamSyncPoller(chatID uuid.UUID) (<-chan streamSyncHint, func()) { + if p.streamSyncPoller != nil { + return p.streamSyncPoller.Register(chatID) + } + ch := make(chan streamSyncHint) + close(ch) + return ch, func() {} +} + +func (p *Server) runStreamSync( + ctx context.Context, + loop *streamLoop, + relay *streamRelayForwarder, + events chan<- codersdk.ChatStreamEvent, + hint streamSyncHint, +) bool { + syncEvents, target, changed, err := p.syncStreamWithRetry(ctx, loop, hint) + if err != nil { + p.logger.Error(ctx, "failed to sync chat stream after retries", slog.Error(err)) + return false + } + for _, event := range syncEvents { + if !sendStreamEvent(ctx, events, event) { + return false + } + } + if changed { + relay.Configure(ctx, target) + } + return true +} + +func (p *Server) syncStreamWithRetry( + ctx context.Context, + loop *streamLoop, + hint streamSyncHint, +) ([]codersdk.ChatStreamEvent, streamRelayTarget, bool, error) { + var ( + syncEvents []codersdk.ChatStreamEvent + target streamRelayTarget + changed bool + err error + ) + for attempt := 1; attempt <= streamSyncRetryMaxAttempts; attempt++ { + //nolint:gocritic // The subscriber was authorized before the loop started; follow-up syncs need chatd-scoped reads for consistency. + syncEvents, target, changed, err = loop.sync(dbauthz.AsChatd(ctx), hint) + if err == nil || ctx.Err() != nil { + return syncEvents, target, changed, err + } + p.logger.Warn(ctx, "failed to sync chat stream", + slog.F("attempt", attempt), + slog.Error(err), + ) + if attempt == streamSyncRetryMaxAttempts { + break + } + if !p.waitBeforeStreamSyncRetry(ctx, attempt) { + return nil, loop.currentRelayTarget(), false, ctx.Err() + } + } + return nil, loop.currentRelayTarget(), false, err +} + +func (p *Server) waitBeforeStreamSyncRetry(ctx context.Context, attempt int) bool { + clock := p.clock + if clock == nil { + clock = quartz.NewReal() + } + delay := streamSyncRetryInitialBackoff + for range attempt - 1 { + delay *= 2 + if delay >= streamSyncRetryMaxBackoff { + delay = streamSyncRetryMaxBackoff + break + } + } + timer := clock.NewTimer(delay, "chatd", "stream-sync-retry") + defer timer.Stop() + select { + case <-timer.C: + return true + case <-ctx.Done(): + return false + } +} + +func sendStreamEvent(ctx context.Context, ch chan<- codersdk.ChatStreamEvent, event codersdk.ChatStreamEvent) bool { + select { + case ch <- event: + return true + case <-ctx.Done(): + return false + } +} + +func (p *Server) Subscribe( + ctx context.Context, + chatID uuid.UUID, + requestHeader http.Header, + afterMessageID int64, +) ( + []codersdk.ChatStreamEvent, + <-chan codersdk.ChatStreamEvent, + func(), + bool, +) { + if p == nil { + return nil, nil, nil, false + } + + chat, err := p.db.GetChatByID(ctx, chatID) + if err != nil { + if dbauthz.IsNotAuthorizedError(err) { + return nil, nil, nil, false + } + p.logger.Warn(ctx, "failed to load chat for stream subscription", + slog.F("chat_id", chatID), + slog.Error(err), + ) + return subscribeWithInitialError(chatID, "failed to load initial snapshot") + } + return p.SubscribeAuthorized(ctx, chat, requestHeader, afterMessageID) +} + +// SubscribeAuthorized subscribes an already-authorized chat to stream updates. +func (p *Server) SubscribeAuthorized( + ctx context.Context, + chat database.Chat, + requestHeader http.Header, + afterMessageID int64, +) ( + []codersdk.ChatStreamEvent, + <-chan codersdk.ChatStreamEvent, + func(), + bool, +) { + return p.subscribeStreamLoop(ctx, chat, requestHeader, afterMessageID) +} diff --git a/coderd/x/chatd/stream_sync_poller.go b/coderd/x/chatd/stream_sync_poller.go new file mode 100644 index 0000000000000..11e9171687e64 --- /dev/null +++ b/coderd/x/chatd/stream_sync_poller.go @@ -0,0 +1,167 @@ +package chatd + +import ( + "context" + "sync" + "time" + + "github.com/google/uuid" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/quartz" +) + +const streamSyncInterval = 10 * time.Second + +type streamSyncPoller struct { + ctx context.Context + cancel context.CancelFunc + db database.Store + clock quartz.Clock + logger slog.Logger + + mu sync.Mutex + subscribers map[uuid.UUID]map[*streamSyncPollerSubscriber]struct{} +} + +type streamSyncPollerSubscriber struct { + chatID uuid.UUID + hints chan streamSyncHint +} + +func newStreamSyncPoller( + ctx context.Context, + db database.Store, + clock quartz.Clock, + logger slog.Logger, +) *streamSyncPoller { + if clock == nil { + clock = quartz.NewReal() + } + //nolint:gocritic // The poller is internal chatd infrastructure. Each + // registered stream was already authorized before subscription, and this + // batch query only fetches synchronization metadata for subscribed chats. + pollerCtx, cancel := context.WithCancel(dbauthz.AsChatd(ctx)) + return &streamSyncPoller{ + ctx: pollerCtx, + cancel: cancel, + db: db, + clock: clock, + logger: logger, + subscribers: make(map[uuid.UUID]map[*streamSyncPollerSubscriber]struct{}), + } +} + +func (p *streamSyncPoller) Start() { + if p == nil { + return + } + go p.loop() +} + +func (p *streamSyncPoller) Close() { + if p == nil { + return + } + p.cancel() +} + +func (p *streamSyncPoller) Register(chatID uuid.UUID) (<-chan streamSyncHint, func()) { + if p == nil { + ch := make(chan streamSyncHint) + close(ch) + return ch, func() {} + } + subscriber := &streamSyncPollerSubscriber{ + chatID: chatID, + hints: make(chan streamSyncHint, 1), + } + p.mu.Lock() + if p.subscribers[chatID] == nil { + p.subscribers[chatID] = make(map[*streamSyncPollerSubscriber]struct{}) + } + p.subscribers[chatID][subscriber] = struct{}{} + p.mu.Unlock() + + return subscriber.hints, func() { + p.unregister(subscriber) + } +} + +func (p *streamSyncPoller) unregister(subscriber *streamSyncPollerSubscriber) { + p.mu.Lock() + defer p.mu.Unlock() + chatSubscribers := p.subscribers[subscriber.chatID] + if chatSubscribers == nil { + return + } + delete(chatSubscribers, subscriber) + if len(chatSubscribers) == 0 { + delete(p.subscribers, subscriber.chatID) + } + close(subscriber.hints) +} + +func (p *streamSyncPoller) loop() { + ticker := p.clock.NewTicker(streamSyncInterval, "chatd", "stream-sync-poller") + defer ticker.Stop() + for { + select { + case <-p.ctx.Done(): + return + case <-ticker.C: + p.pollOnce() + } + } +} + +func (p *streamSyncPoller) pollOnce() { + chatIDs, subscribers := p.snapshotSubscribers() + if len(chatIDs) == 0 { + return + } + rows, err := p.db.GetChatStreamSyncRows(p.ctx, chatIDs) + if err != nil { + if p.ctx.Err() == nil { + p.logger.Warn(p.ctx, "failed to poll chat streams", slog.Error(err)) + } + return + } + for _, row := range rows { + hint := streamSyncHintFromPollRow(row) + for _, subscriber := range subscribers[row.ID] { + select { + case subscriber.hints <- hint: + default: + } + } + } +} + +func (p *streamSyncPoller) snapshotSubscribers() ([]uuid.UUID, map[uuid.UUID][]*streamSyncPollerSubscriber) { + p.mu.Lock() + defer p.mu.Unlock() + chatIDs := make([]uuid.UUID, 0, len(p.subscribers)) + subscribers := make(map[uuid.UUID][]*streamSyncPollerSubscriber, len(p.subscribers)) + for chatID, chatSubscribers := range p.subscribers { + chatIDs = append(chatIDs, chatID) + for subscriber := range chatSubscribers { + subscribers[chatID] = append(subscribers[chatID], subscriber) + } + } + return chatIDs, subscribers +} + +func streamSyncHintFromPollRow(row database.GetChatStreamSyncRowsRow) streamSyncHint { + return streamSyncHint{ + snapshotVersion: row.SnapshotVersion, + historyVersion: row.HistoryVersion, + queueVersion: row.QueueVersion, + retryVersion: row.RetryStateVersion, + status: row.Status, + workerID: row.WorkerID, + generationAttempt: row.GenerationAttempt, + } +} diff --git a/coderd/x/chatd/stream_types.go b/coderd/x/chatd/stream_types.go new file mode 100644 index 0000000000000..d413079ff92f3 --- /dev/null +++ b/coderd/x/chatd/stream_types.go @@ -0,0 +1,44 @@ +package chatd + +import ( + "context" + "net/http" + + "github.com/google/uuid" + + "github.com/coder/coder/v2/codersdk" +) + +// StreamPartsDialer dials an episode-aware source of message parts. +type StreamPartsDialer func(ctx context.Context, input StreamPartsDialInput) (StreamPartsSession, error) + +// StreamPartsDialInput carries the metadata needed to dial a parts source. +type StreamPartsDialInput struct { + ChatID uuid.UUID + WorkerID uuid.UUID + RequestHeader http.Header +} + +// StreamPartsSession streams message parts for selected episodes. +type StreamPartsSession interface { + SelectEpisode(ctx context.Context, historyVersion, generationAttempt int64) error + Parts() <-chan StreamPart + Close() error +} + +// StreamPart is a live preview part scoped to one chat history episode. +type StreamPart struct { + HistoryVersion int64 + GenerationAttempt int64 + Seq int64 + Role codersdk.ChatMessageRole + Part codersdk.ChatMessagePart +} + +type streamPart = StreamPart + +type streamRelayTarget struct { + workerID uuid.NullUUID + historyVersion int64 + generationAttempt int64 +} diff --git a/coderd/x/chatd/streamcollector_internal_test.go b/coderd/x/chatd/streamcollector_internal_test.go deleted file mode 100644 index 089ad26290759..0000000000000 --- a/coderd/x/chatd/streamcollector_internal_test.go +++ /dev/null @@ -1,167 +0,0 @@ -package chatd - -import ( - "sync" - "testing" - "time" - - "github.com/google/uuid" - "github.com/prometheus/client_golang/prometheus" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/testutil" -) - -// TestStreamStateCollector exercises the four gauges emitted by -// streamStateCollector against representative map states. -func TestStreamStateCollector(t *testing.T) { - t.Parallel() - - t.Run("EmptyMap", func(t *testing.T) { - t.Parallel() - - reg := prometheus.NewRegistry() - server := &Server{} - reg.MustRegister(&streamStateCollector{server: server}) - - assertGauges(t, reg, gaugeExpectations{ - active: 0, - bufferMax: 0, - bufferTotal: 0, - subscribers: 0, - }) - }) - - t.Run("PopulatedMap", func(t *testing.T) { - t.Parallel() - - reg := prometheus.NewRegistry() - server := &Server{} - - server.chatStreams.Store(uuid.New(), &chatStreamState{ - buffer: make([]bufferedStreamPart, 10), - subscribers: newSubscribers(t, 2), - }) - server.chatStreams.Store(uuid.New(), &chatStreamState{ - buffer: make([]bufferedStreamPart, 25), - subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{}, - }) - server.chatStreams.Store(uuid.New(), &chatStreamState{ - buffer: nil, - subscribers: newSubscribers(t, 1), - }) - - reg.MustRegister(&streamStateCollector{server: server}) - - assertGauges(t, reg, gaugeExpectations{ - active: 3, - bufferMax: 25, - bufferTotal: 35, - subscribers: 3, - }) - }) - - t.Run("SkipsWrongType", func(t *testing.T) { - t.Parallel() - - reg := prometheus.NewRegistry() - server := &Server{} - - server.chatStreams.Store(uuid.New(), "garbage") - server.chatStreams.Store(uuid.New(), &chatStreamState{ - buffer: make([]bufferedStreamPart, 5), - subscribers: newSubscribers(t, 1), - }) - - reg.MustRegister(&streamStateCollector{server: server}) - - // The non-matching entry is silently skipped. Only the - // valid chatStreamState counts. - assertGauges(t, reg, gaugeExpectations{ - active: 1, - bufferMax: 5, - bufferTotal: 5, - subscribers: 1, - }) - }) - - // Runs Collect concurrently with state.mu mutations; catches - // missing lock acquisition under `go test -race`. - t.Run("LockContentionSmoke", func(t *testing.T) { - t.Parallel() - - server := &Server{} - state := &chatStreamState{ - buffer: make([]bufferedStreamPart, 0, 100), - subscribers: newSubscribers(t, 1), - } - server.chatStreams.Store(uuid.New(), state) - collector := &streamStateCollector{server: server} - - const iterations = 100 - var wg sync.WaitGroup - - // Mutator: grows and shrinks the buffer under state.mu. - wg.Go(func() { - for range iterations { - state.mu.Lock() - state.buffer = append(state.buffer, bufferedStreamPart{}) - if len(state.buffer) > 50 { - state.buffer = state.buffer[10:] - } - state.mu.Unlock() - } - }) - - // Scraper: repeatedly invokes Collect into a discard - // channel. A panic or race here fails the test. - wg.Go(func() { - ctx := testutil.Context(t, 10*time.Second) - for range iterations { - ch := make(chan prometheus.Metric, 4) - collector.Collect(ch) - // Drain all metrics the collector wrote. - for range 4 { - testutil.SoftTryReceive(ctx, t, ch) - } - } - }) - - wg.Wait() - }) -} - -type gaugeExpectations struct { - active float64 - bufferMax float64 - bufferTotal float64 - subscribers float64 -} - -func assertGauges(t *testing.T, reg *prometheus.Registry, want gaugeExpectations) { - t.Helper() - families, err := reg.Gather() - require.NoError(t, err) - - got := map[string]float64{} - for _, f := range families { - require.Len(t, f.GetMetric(), 1, "metric %q should have exactly one sample", f.GetName()) - got[f.GetName()] = f.GetMetric()[0].GetGauge().GetValue() - } - - assert.Equal(t, want.active, got["coderd_chatd_streams_active"], "streams_active") - assert.Equal(t, want.bufferMax, got["coderd_chatd_stream_buffer_size_max"], "buffer_size_max") - assert.Equal(t, want.bufferTotal, got["coderd_chatd_stream_buffer_events"], "buffer_events") - assert.Equal(t, want.subscribers, got["coderd_chatd_stream_subscribers"], "subscribers") -} - -func newSubscribers(t *testing.T, n int) map[uuid.UUID]chan codersdk.ChatStreamEvent { - t.Helper() - subs := make(map[uuid.UUID]chan codersdk.ChatStreamEvent, n) - for range n { - subs[uuid.New()] = make(chan codersdk.ChatStreamEvent, 1) - } - return subs -} diff --git a/coderd/x/chatd/tasks.go b/coderd/x/chatd/tasks.go index 7563bd5dec946..f93f9e22c3d80 100644 --- a/coderd/x/chatd/tasks.go +++ b/coderd/x/chatd/tasks.go @@ -8,10 +8,10 @@ import ( "strings" "time" - "cdr.dev/slog/v3" "github.com/google/uuid" "golang.org/x/xerrors" + "cdr.dev/slog/v3" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" diff --git a/codersdk/chats.go b/codersdk/chats.go index b493e83f39fc8..519960ed97d1b 100644 --- a/codersdk/chats.go +++ b/codersdk/chats.go @@ -1500,6 +1500,8 @@ const ( ChatStreamEventTypeQueueUpdate ChatStreamEventType = "queue_update" ChatStreamEventTypeRetry ChatStreamEventType = "retry" ChatStreamEventTypeActionRequired ChatStreamEventType = "action_required" + ChatStreamEventTypePreviewReset ChatStreamEventType = "preview_reset" + ChatStreamEventTypeHistoryReset ChatStreamEventType = "history_reset" ) // ChatQueuedMessage represents a queued message waiting to be processed. @@ -1513,8 +1515,11 @@ type ChatQueuedMessage struct { // ChatStreamMessagePart is a streamed message part update. type ChatStreamMessagePart struct { - Role ChatMessageRole `json:"role,omitempty"` - Part ChatMessagePart `json:"part"` + Role ChatMessageRole `json:"role,omitempty"` + Part ChatMessagePart `json:"part"` + HistoryVersion int64 `json:"history_version,omitempty"` + GenerationAttempt int64 `json:"generation_attempt,omitempty"` + Seq int64 `json:"seq,omitempty"` } // ChatStreamStatus represents an updated chat status. diff --git a/codersdk/chats_test.go b/codersdk/chats_test.go index f169590050791..4fa9f3ec3b91a 100644 --- a/codersdk/chats_test.go +++ b/codersdk/chats_test.go @@ -166,6 +166,42 @@ func TestChatErrorKind_JSONRoundTrip(t *testing.T) { require.Equal(t, codersdk.ChatErrorKindUsageLimit, decodedRetry.Kind) } +func TestChatStreamEvent_JSONRoundTripIncludesResetTypesAndPartMetadata(t *testing.T) { + t.Parallel() + + chatID := uuid.New() + events := []codersdk.ChatStreamEvent{ + {Type: codersdk.ChatStreamEventTypePreviewReset, ChatID: chatID}, + {Type: codersdk.ChatStreamEventTypeHistoryReset, ChatID: chatID}, + { + Type: codersdk.ChatStreamEventTypeMessagePart, + ChatID: chatID, + MessagePart: &codersdk.ChatStreamMessagePart{ + Role: codersdk.ChatMessageRoleAssistant, + Part: codersdk.ChatMessageText("partial"), + HistoryVersion: 12, + GenerationAttempt: 3, + Seq: 4, + }, + }, + } + data, err := json.Marshal(events) + require.NoError(t, err) + require.Contains(t, string(data), `"type":"preview_reset"`) + require.Contains(t, string(data), `"type":"history_reset"`) + require.Contains(t, string(data), `"history_version":12`) + require.Contains(t, string(data), `"generation_attempt":3`) + require.Contains(t, string(data), `"seq":4`) + + var decoded []codersdk.ChatStreamEvent + require.NoError(t, json.Unmarshal(data, &decoded)) + require.Equal(t, codersdk.ChatStreamEventTypePreviewReset, decoded[0].Type) + require.Equal(t, codersdk.ChatStreamEventTypeHistoryReset, decoded[1].Type) + require.Equal(t, int64(12), decoded[2].MessagePart.HistoryVersion) + require.Equal(t, int64(3), decoded[2].MessagePart.GenerationAttempt) + require.Equal(t, int64(4), decoded[2].MessagePart.Seq) +} + func TestChatMessagePart_StripInternal(t *testing.T) { t.Parallel() diff --git a/docs/admin/integrations/prometheus.md b/docs/admin/integrations/prometheus.md index 479c670bfd9ec..92fbc1d812d8a 100644 --- a/docs/admin/integrations/prometheus.md +++ b/docs/admin/integrations/prometheus.md @@ -212,11 +212,7 @@ deployment. They will always be available from the agent. | `coderd_chatd_prompt_size_bytes` | histogram | Estimated byte size of the prompt per LLM request. | `model` `provider` | | `coderd_chatd_steps_total` | counter | Total agentic loop steps across all chats. | `model` `provider` | | `coderd_chatd_stream_buffer_dropped_total` | counter | Number of chat stream buffer events dropped due to the per-chat buffer cap. | | -| `coderd_chatd_stream_buffer_events` | gauge | Sum of current buffer lengths across all chat streams. | | -| `coderd_chatd_stream_buffer_size_max` | gauge | Maximum current buffer length across all chat streams. | | | `coderd_chatd_stream_retries_total` | counter | Total LLM stream retries. | `chain_broken` `kind` `model` `provider` | -| `coderd_chatd_stream_subscribers` | gauge | Current number of chat stream subscribers across all chat streams. | | -| `coderd_chatd_streams_active` | gauge | Current number of chat stream state entries (in-flight plus retained). | | | `coderd_chatd_tool_errors_total` | counter | Total tool calls that returned an error result. | `model` `provider` `tool_name` | | `coderd_chatd_tool_result_size_bytes` | histogram | Size in bytes of each tool execution result. | `model` `provider` `tool_name` | | `coderd_chatd_ttft_seconds` | histogram | Time-to-first-token: wall time from LLM request to first streamed chunk. | `model` `provider` | diff --git a/docs/reference/api/chats.md b/docs/reference/api/chats.md index 02c43e6ed6abe..86ff25b839d80 100644 --- a/docs/reference/api/chats.md +++ b/docs/reference/api/chats.md @@ -2708,6 +2708,8 @@ Experimental: this endpoint is subject to change. } }, "message_part": { + "generation_attempt": 0, + "history_version": 0, "part": { "args": [ 0 @@ -2770,7 +2772,8 @@ Experimental: this endpoint is subject to change. "type": "text", "url": "string" }, - "role": "system" + "role": "system", + "seq": 0 }, "queued_messages": [ { diff --git a/docs/reference/api/schemas.md b/docs/reference/api/schemas.md index 56535c866a3b1..4c23074eb4cf6 100644 --- a/docs/reference/api/schemas.md +++ b/docs/reference/api/schemas.md @@ -3652,6 +3652,8 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in } }, "message_part": { + "generation_attempt": 0, + "history_version": 0, "part": { "args": [ 0 @@ -3714,7 +3716,8 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in "type": "text", "url": "string" }, - "role": "system" + "role": "system", + "seq": 0 }, "queued_messages": [ { @@ -3828,14 +3831,16 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in #### Enumerated Values -| Value(s) | -|------------------------------------------------------------------------------------------| -| `action_required`, `error`, `message`, `message_part`, `queue_update`, `retry`, `status` | +| Value(s) | +|----------------------------------------------------------------------------------------------------------------------------| +| `action_required`, `error`, `history_reset`, `message`, `message_part`, `preview_reset`, `queue_update`, `retry`, `status` | ## codersdk.ChatStreamMessagePart ```json { + "generation_attempt": 0, + "history_version": 0, "part": { "args": [ 0 @@ -3898,16 +3903,20 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in "type": "text", "url": "string" }, - "role": "system" + "role": "system", + "seq": 0 } ``` ### Properties -| Name | Type | Required | Restrictions | Description | -|--------|------------------------------------------------------|----------|--------------|-------------| -| `part` | [codersdk.ChatMessagePart](#codersdkchatmessagepart) | false | | | -| `role` | [codersdk.ChatMessageRole](#codersdkchatmessagerole) | false | | | +| Name | Type | Required | Restrictions | Description | +|----------------------|------------------------------------------------------|----------|--------------|-------------| +| `generation_attempt` | integer | false | | | +| `history_version` | integer | false | | | +| `part` | [codersdk.ChatMessagePart](#codersdkchatmessagepart) | false | | | +| `role` | [codersdk.ChatMessageRole](#codersdkchatmessagerole) | false | | | +| `seq` | integer | false | | | ## codersdk.ChatStreamRetry diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 2df327f674aed..12fd7c5f99f85 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -159,10 +159,14 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { } var replicaManagerPtr atomic.Pointer[replicasync.Manager] + var api *API resolveReplicaAddress := func( _ context.Context, replicaID uuid.UUID, ) (string, bool) { + if api != nil && api.AGPL != nil && replicaID == api.AGPL.ID && api.AGPL.AccessURL != nil { + return api.AGPL.AccessURL.String(), true + } manager := replicaManagerPtr.Load() if manager == nil { return "", false @@ -180,7 +184,7 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { return "", false } - api := &API{ + api = &API{ ctx: ctx, cancel: cancelFunc, Options: options, @@ -207,17 +211,13 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { replicaHTTPClient = http.DefaultClient } // Use a closure that captures api by reference so it can access - // api.AGPL.ID after coderd.New is called. The SubscribeFn is - // only invoked from Subscribe, which happens after init. - options.Options.ChatSubscribeFn = entchatd.NewMultiReplicaSubscribeFn(entchatd.MultiReplicaSubscribeConfig{ + // api.AGPL.ID after coderd.New is called. The parts dialer is + // only invoked from stream subscriptions, which happen after init. + options.Options.ChatStreamPartsDialer = entchatd.NewStreamPartsDialer(entchatd.StreamPartsDialerConfig{ ResolveReplicaAddress: resolveReplicaAddress, ReplicaHTTPClient: replicaHTTPClient, ReplicaIDFn: func() uuid.UUID { - id := api.AGPL.ID - if id == uuid.Nil { - return uuid.New() - } - return id + return api.AGPL.ID }, }) diff --git a/enterprise/coderd/exp_chats_test.go b/enterprise/coderd/exp_chats_test.go index 83ede0531f86a..d29240dd2ef4a 100644 --- a/enterprise/coderd/exp_chats_test.go +++ b/enterprise/coderd/exp_chats_test.go @@ -67,7 +67,6 @@ func createOpenAIModelConfigForTest( func TestChatStreamRelay(t *testing.T) { t.Parallel() - t.Skip("chatd refactor: remove in PR 4") t.Run("RelayMessagePartsAcrossReplicas", func(t *testing.T) { t.Parallel() diff --git a/enterprise/coderd/x/chatd/chatd.go b/enterprise/coderd/x/chatd/chatd.go index 8301e1d191286..d3b2d91fa398f 100644 --- a/enterprise/coderd/x/chatd/chatd.go +++ b/enterprise/coderd/x/chatd/chatd.go @@ -2,25 +2,16 @@ package chatd import ( "context" - "errors" - "fmt" "net/http" "net/url" - "strconv" "strings" - "time" "github.com/google/uuid" "golang.org/x/xerrors" - "cdr.dev/slog/v3" - "github.com/coder/coder/v2/coderd/database" osschatd "github.com/coder/coder/v2/coderd/x/chatd" "github.com/coder/coder/v2/codersdk" - "github.com/coder/quartz" - "github.com/coder/retry" "github.com/coder/websocket" - "github.com/coder/websocket/wsjson" ) // RelaySourceHeader marks replica-relayed stream requests. @@ -29,28 +20,10 @@ const RelaySourceHeader = "X-Coder-Relay-Source-Replica" const ( authorizationHeader = "Authorization" cookieHeader = "Cookie" - - // relayDrainTimeout is how long an established relay is - // kept open after the chat leaves running state, giving - // buffered snapshot events time to be forwarded before - // the relay is torn down. - relayDrainTimeout = 200 * time.Millisecond - - // Retry knobs for the cross-replica relay handshake. Uses the - // github.com/coder/retry defaults (φ-growth, no jitter) but drives - // the delay manually because retry.Retrier.Wait uses time.After, - // which isn't compatible with quartz.Clock determinism in tests. - relayRetryFloor = 500 * time.Millisecond // first retry matches old fixed delay - relayRetryCeil = 15 * time.Second // cap stall before tear-down - // After this many reconnect retries the relay leg is torn down. - // Total dial attempts = 1 initial dial + relayMaxRetries. - relayMaxRetries = 6 ) // RelayDialError wraps a failed relay handshake. HTTPStatus is 0 -// when the failure happened before a response (DNS, TCP, TLS, -// timeout, context cancel); otherwise it carries the peer's status -// code for the reconnect loop to classify. +// when the failure happened before a response. type RelayDialError struct { HTTPStatus int Err error @@ -60,661 +33,59 @@ func (e *RelayDialError) Error() string { return e.Err.Error() } func (e *RelayDialError) Unwrap() error { return e.Err } // IsUnrecoverable reports whether retrying with the same captured -// session token is futile. Only 401/403 qualify - the token is dead -// or the peer won't authorize it. 5xx, 429, network, and context -// errors fall through to backoff. +// session token is futile. func (e *RelayDialError) IsUnrecoverable() bool { return e.HTTPStatus == http.StatusUnauthorized || e.HTTPStatus == http.StatusForbidden } -// MultiReplicaSubscribeConfig holds the dependencies for multi-replica chat -// subscription. ReplicaIDFn is called lazily because the -// replica ID may not be known at construction time. -// -// DialerFn, when set, overrides the default WebSocket relay -// dialer. This is used in tests to inject mock relay behavior -// without requiring real HTTP servers. -type MultiReplicaSubscribeConfig struct { +// StreamPartsDialerConfig holds dependencies for multi-replica stream parts. +type StreamPartsDialerConfig struct { ResolveReplicaAddress func(context.Context, uuid.UUID) (string, bool) ReplicaHTTPClient *http.Client ReplicaIDFn func() uuid.UUID - DialerFn func( - ctx context.Context, - chatID uuid.UUID, - workerID uuid.UUID, - requestHeader http.Header, - ) ( - snapshot []codersdk.ChatStreamEvent, - parts <-chan codersdk.ChatStreamEvent, - cancel func(), - err error, - ) - // Clock is used for creating timers. In production use - // quartz.NewReal(); in tests use quartz.NewMock(t) to - // control reconnect timing deterministically. - Clock quartz.Clock -} - -// dial returns the configured dialer, preferring DialerFn (tests) -// over the real dialRelay. Returns nil when relay is not configured. -func (c MultiReplicaSubscribeConfig) dial() func( - ctx context.Context, - chatID uuid.UUID, - workerID uuid.UUID, - requestHeader http.Header, -) ( - []codersdk.ChatStreamEvent, - <-chan codersdk.ChatStreamEvent, - func(), - error, -) { - if c.DialerFn != nil { - return c.DialerFn - } - if c.ResolveReplicaAddress == nil { - return nil - } - return func( - ctx context.Context, - chatID uuid.UUID, - workerID uuid.UUID, - requestHeader http.Header, - ) ( - []codersdk.ChatStreamEvent, - <-chan codersdk.ChatStreamEvent, - func(), - error, - ) { - return dialRelay(ctx, chatID, workerID, requestHeader, c, c.clock()) - } + DialerFn func(context.Context, osschatd.StreamPartsDialInput) (osschatd.StreamPartsSession, error) } -// clock returns the quartz.Clock to use. Defaults to a real clock -// when not set. -func (c MultiReplicaSubscribeConfig) clock() quartz.Clock { - if c.Clock != nil { - return c.Clock - } - return quartz.NewReal() -} - -// NewMultiReplicaSubscribeFn returns a SubscribeFn that manages -// relay connections to remote replicas and returns relay -// message_part events only. OSS handles pubsub subscription, -// message catch-up, queue updates, status forwarding, and local -// parts merging. -// -//nolint:gocognit // Complexity is inherent to the multi-source merge loop. -func NewMultiReplicaSubscribeFn( - cfg MultiReplicaSubscribeConfig, -) osschatd.SubscribeFn { - return func(ctx context.Context, params osschatd.SubscribeFnParams) <-chan codersdk.ChatStreamEvent { - chatID := params.ChatID - requestHeader := params.RequestHeader - logger := params.Logger - - var relayCancel func() - var relayParts <-chan codersdk.ChatStreamEvent - - // If the chat is currently running on a different worker - // and we have a remote parts provider, open an initial - // relay synchronously so the caller gets in-flight - // message_part events right away. - var initialRelaySnapshot []codersdk.ChatStreamEvent - if params.Chat.Status == database.ChatStatusRunning && - params.Chat.WorkerID.Valid && - params.Chat.WorkerID.UUID != params.WorkerID && - cfg.dial() != nil { - snapshot, parts, cancel, err := cfg.dial()(ctx, chatID, params.Chat.WorkerID.UUID, requestHeader) - if err == nil { - relayCancel = cancel - relayParts = parts - // Collect relay message_parts to forward at the - // start of the merge goroutine. - for _, event := range snapshot { - if event.Type == codersdk.ChatStreamEventTypeMessagePart { - initialRelaySnapshot = append(initialRelaySnapshot, event) - } - } - } else { - logger.Warn(ctx, "failed to open initial relay for chat stream", - slog.F("chat_id", chatID), - slog.Error(err), - ) - } - } - - // Merge all event sources. - mergedEvents := make(chan codersdk.ChatStreamEvent, 128) - // Channel for async relay establishment. - type relayResult struct { - parts <-chan codersdk.ChatStreamEvent - cancel func() - workerID uuid.UUID // the worker this dial targeted - // err and parts are mutually exclusive: success sets - // parts; failure sets err (unwrap to *RelayDialError - // for classification). - err error - } - relayReadyCh := make(chan relayResult, 4) - - // Reset on successful dial or when the relay target - // changes, so a fresh target starts at the floor delay. - retryState := newRelayRetryState() - // Per-dial context so in-flight dials can be canceled when - // a new dial is initiated or the relay is closed. - var dialCancel context.CancelFunc - - // expectedWorkerID tracks which replica we expect the next - // relay result to target. Stale results are discarded. - var expectedWorkerID uuid.UUID - - // Reconnect timer state. - var reconnectTimer *quartz.Timer - var reconnectCh <-chan time.Time - - // drainAndClose is set when the chat transitions away - // from running while a relay dial is still in progress. - // Instead of canceling the dial immediately, we let it - // complete so the snapshot of buffered message_parts - // can be forwarded to the subscriber. - var drainAndClose bool - - // Drain timer state. When the relay connects in - // drain-and-close mode, a short timer is started. - // During this window the normal relayPartsCh case - // forwards buffered snapshot events. When the timer - // fires the relay is torn down. - var drainTimer *quartz.Timer - var drainTimerCh <-chan time.Time - - // Helper to close relay and stop any pending reconnect - // timer. - closeRelay := func() { - // Cancel any in-flight dial goroutine first. - if dialCancel != nil { - dialCancel() - dialCancel = nil - } - // Drain all buffered relay results from canceled dials. - for { - select { - case result := <-relayReadyCh: - if result.cancel != nil { - result.cancel() - } - default: - goto drained - } - } - drained: - expectedWorkerID = uuid.Nil - if relayCancel != nil { - relayCancel() - relayCancel = nil - } - relayParts = nil - if reconnectTimer != nil { - reconnectTimer.Stop() - reconnectTimer = nil - reconnectCh = nil - } - if drainTimer != nil { - drainTimer.Stop() - drainTimer = nil - drainTimerCh = nil - } - drainAndClose = false - } - - // openRelayAsync dials the remote replica in a background - // goroutine and delivers the result on relayReadyCh so the - // main select loop is never blocked by network I/O. - openRelayAsync := func(workerID uuid.UUID) { - if cfg.dial() == nil { - return - } - // Scoped here (not in closeRelay) so repeated dials - // against the same worker keep the attempt counter and - // correctly trip the cap. - if workerID != expectedWorkerID { - retryState.reset() - } - closeRelay() - // Create a per-dial context so this goroutine is - // canceled if closeRelay() or openRelayAsync() is - // called again before the dial completes. - var dialCtx context.Context - dialCtx, dialCancel = context.WithCancel(ctx) - expectedWorkerID = workerID - go func() { - snapshot, parts, cancel, err := cfg.dial()(dialCtx, chatID, workerID, requestHeader) - if err != nil { - // Don't log context-canceled errors - // since they are expected when a dial is - // superseded by a newer one. - if dialCtx.Err() == nil { - fields := []slog.Field{ - slog.F("chat_id", chatID), - slog.F("worker_id", workerID), - slog.Error(err), - } - // Surface the peer's HTTP status (when we - // got one) as a structured field so - // operators can filter 401/403 spam - // separately from 5xx/network warnings. - var dialErr *RelayDialError - if errors.As(err, &dialErr) && dialErr.HTTPStatus != 0 { - fields = append(fields, slog.F("http_status", dialErr.HTTPStatus)) - } - logger.Warn(ctx, "failed to open relay for message parts", fields...) - } - // Hand the error to the merge loop, which will - // classify it and either back off or tear down. - select { - case relayReadyCh <- relayResult{workerID: workerID, err: err}: - case <-dialCtx.Done(): - } - return - } - // Discard stale dials so we don't start a - // wrappedParts goroutine on a canceled connection. - if dialCtx.Err() != nil { - cancel() - return - } - // Wrap the relay channel so snapshot parts - // are delivered through the same channel as - // live parts. This goroutine only forwards - // events - it does not own the relay - // lifecycle. When dialCtx is canceled it - // simply returns, closing wrappedParts via - // its defer. The cancel() is called by - // whoever canceled dialCtx (closeRelay or - // the send-fallback select below). - wrappedParts := make(chan codersdk.ChatStreamEvent, 128) - go func() { - defer close(wrappedParts) - for _, event := range snapshot { - if event.Type == codersdk.ChatStreamEventTypeMessagePart { - select { - case wrappedParts <- event: - case <-dialCtx.Done(): - return - } - } - } - for { - select { - case event, ok := <-parts: - if !ok { - return - } - select { - case wrappedParts <- event: - case <-dialCtx.Done(): - return - } - case <-dialCtx.Done(): - return - } - } - }() - select { - case relayReadyCh <- relayResult{parts: wrappedParts, cancel: cancel, workerID: workerID}: - case <-dialCtx.Done(): - cancel() - } - }() +// NewStreamPartsDialer returns a dialer for the owning replica's parts endpoint. +func NewStreamPartsDialer(cfg StreamPartsDialerConfig) osschatd.StreamPartsDialer { + return func(ctx context.Context, input osschatd.StreamPartsDialInput) (osschatd.StreamPartsSession, error) { + if cfg.DialerFn != nil { + return cfg.DialerFn(ctx, input) } - - // scheduleRelayReconnect arms a timer so the select loop - // can re-check chat status and reopen the relay. Callers - // pass the delay from retryState so the failed-dial branch - // gets backoff while transient branches stay at the floor. - scheduleRelayReconnect := func(delay time.Duration) { - if cfg.dial() == nil { - return - } - if reconnectTimer != nil { - reconnectTimer.Stop() - } - reconnectTimer = cfg.clock().NewTimer(delay, "reconnect") - reconnectCh = reconnectTimer.C - } - - // sendRelayTerminalError enqueues one error event for the - // subscriber; callers return afterwards so the deferred - // close(mergedEvents) fires and the OSS merge loop tears - // the relay leg down while pubsub/local sources keep going. - sendRelayTerminalError := func(msg string) { - select { - case mergedEvents <- codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeError, - ChatID: chatID, - Error: &codersdk.ChatError{Message: msg}, - }: - case <-ctx.Done(): - } - } - statusNotifications := params.StatusNotifications - go func() { - defer close(mergedEvents) - defer closeRelay() - - // Forward any initial relay snapshot parts - // collected synchronously above. - for _, event := range initialRelaySnapshot { - select { - case <-ctx.Done(): - return - case mergedEvents <- event: - } - } - - for { - relayPartsCh := relayParts - select { - case <-ctx.Done(): - return - case result := <-relayReadyCh: - // Discard stale relay results from a - // previous dial that was superseded. - if result.workerID != expectedWorkerID { - if result.cancel != nil { - result.cancel() - } - continue - } - // A nil parts channel signals the dial - // failed - classify the error to decide - // whether to schedule a backoff retry, emit a - // terminal error and tear the relay leg down - // (unrecoverable / cap reached), or simply - // drop the stale drain. - if result.parts == nil { - if drainAndClose { - // Dial failed and we were only - // waiting to drain - nothing to do. - drainAndClose = false - continue - } - var dialErr *RelayDialError - if errors.As(result.err, &dialErr) && dialErr.IsUnrecoverable() { - logger.Warn(ctx, "relay dial unrecoverable; tearing down relay leg", - slog.F("chat_id", chatID), - slog.F("worker_id", result.workerID), - slog.F("http_status", dialErr.HTTPStatus), - ) - sendRelayTerminalError(fmt.Sprintf( - "relay authentication failed (status %d)", - dialErr.HTTPStatus, - )) - return - } - delay, giveUp := retryState.next() - if giveUp { - logger.Warn(ctx, "relay dial retry cap reached; tearing down relay leg", - slog.F("chat_id", chatID), - slog.F("worker_id", result.workerID), - slog.F("max_retries", relayMaxRetries), - ) - sendRelayTerminalError(fmt.Sprintf( - "relay connection failed after %d retries", - relayMaxRetries, - )) - return - } - scheduleRelayReconnect(delay) - continue - } - // An async relay dial completed. Swap in the - // new relay channel. We deliberately do NOT - // reset the retry counter here: a peer that - // accepts the handshake and immediately drops - // the stream would otherwise keep reconnecting - // forever, since each success would zero the - // counter before the next drop re-incremented - // it. The counter only resets when the target - // worker changes (see openRelayAsync). - if relayCancel != nil { - relayCancel() - relayCancel = nil - } - relayParts = result.parts - relayCancel = result.cancel - if drainAndClose { - // The chat is no longer running on - // the remote worker, but the dial - // completed. Verify no new worker - // has claimed the chat before we - // drain stale parts. - currentChat, dbErr := params.DB.GetChatByID(ctx, chatID) - if dbErr != nil { - logger.Warn(ctx, "failed to check chat status for relay drain", - slog.F("chat_id", chatID), - slog.Error(dbErr), - ) - } - if dbErr == nil && currentChat.Status == database.ChatStatusRunning && - currentChat.WorkerID.Valid && - currentChat.WorkerID.UUID != params.WorkerID { - // A new worker picked up the chat; - // discard the stale relay and let - // openRelayAsync handle the new one. - closeRelay() - } else { - // Chat is still idle - drain the - // buffered snapshot before closing. - if drainTimer != nil { - drainTimer.Stop() - } - drainTimer = cfg.clock().NewTimer(relayDrainTimeout, "drain") - drainTimerCh = drainTimer.C - drainAndClose = false - } - } - case <-reconnectCh: - reconnectCh = nil - // Re-check whether the chat is still - // running on a remote worker before - // reconnecting. - currentChat, chatErr := params.DB.GetChatByID(ctx, chatID) - if chatErr != nil { - logger.Warn(ctx, "failed to get chat for relay reconnect", - slog.F("chat_id", chatID), - slog.Error(chatErr), - ) - // Retry on transient DB errors to - // avoid permanently stalling the - // stream. The same retry state - // bounds the DB-error loop too so a - // persistently broken DB eventually - // tears the relay down instead of - // spinning forever. - delay, giveUp := retryState.next() - if giveUp { - logger.Warn(ctx, "relay reconnect retry cap reached; tearing down relay leg", - slog.F("chat_id", chatID), - slog.F("max_retries", relayMaxRetries), - ) - sendRelayTerminalError(fmt.Sprintf( - "relay connection failed after %d retries", - relayMaxRetries, - )) - return - } - scheduleRelayReconnect(delay) - continue - } - if currentChat.Status == database.ChatStatusRunning && - currentChat.WorkerID.Valid && currentChat.WorkerID.UUID != params.WorkerID { - openRelayAsync(currentChat.WorkerID.UUID) - } - case sn, ok := <-statusNotifications: - if !ok { - statusNotifications = nil - continue - } - if sn.Status == database.ChatStatusRunning && sn.WorkerID != uuid.Nil && sn.WorkerID != params.WorkerID { - openRelayAsync(sn.WorkerID) - } else { - switch { - case dialCancel != nil && relayParts == nil: - // In-progress dial: let it complete - // so its snapshot can be forwarded. - drainAndClose = true - case relayParts != nil: - // Active relay: give it a short - // window to deliver any remaining - // buffered parts before closing. - if drainTimer != nil { - drainTimer.Stop() - } - drainTimer = cfg.clock().NewTimer(relayDrainTimeout, "drain") - drainTimerCh = drainTimer.C - default: - closeRelay() - } - } - case <-drainTimerCh: - drainTimerCh = nil - drainTimer = nil - closeRelay() - case event, ok := <-relayPartsCh: - if !ok { - if relayCancel != nil { - relayCancel() - relayCancel = nil - } - relayParts = nil - // Reuse the retry state so a relay that - // repeatedly drops eventually tears down. - delay, giveUp := retryState.next() - if giveUp { - logger.Warn(ctx, "relay drop retry cap reached; tearing down relay leg", - slog.F("chat_id", chatID), - slog.F("max_retries", relayMaxRetries), - ) - sendRelayTerminalError(fmt.Sprintf( - "relay connection failed after %d retries", - relayMaxRetries, - )) - return - } - scheduleRelayReconnect(delay) - continue - } - // Only forward message_part events from - // relay. - if event.Type == codersdk.ChatStreamEventTypeMessagePart { - select { - case <-ctx.Done(): - return - case mergedEvents <- event: - } - } - } - } - }() - - // Cleanup is driven by ctx cancellation: the merge - // goroutine owns all relay state (reconnectTimer, - // relayCancel, dialCancel, etc.) and tears it down - // via defer closeRelay() when ctx is done. - return mergedEvents + return dialRelayParts(ctx, input, cfg) } } -// relayRetryState drives the retry policy for the relay reconnect -// loop. Wraps github.com/coder/retry to reuse its φ-growth defaults -// but computes the delay without blocking so the merge loop can -// schedule its own quartz.Clock timer. -// -// Not safe for concurrent use. -type relayRetryState struct { - retrier *retry.Retrier - attempts int -} - -func newRelayRetryState() *relayRetryState { - return &relayRetryState{ - retrier: retry.New(relayRetryFloor, relayRetryCeil), - } -} - -// next returns the delay before the next dial and sets giveUp once -// attempts exceed relayMaxRetries. Adapts the math from -// retry.Retrier.Wait (github.com/coder/retry/retrier.go) without -// blocking: the library's Wait returns 0 on the first call and sets -// Delay to Floor only after the sleep, so we clamp to Floor up -// front. -func (s *relayRetryState) next() (delay time.Duration, giveUp bool) { - s.attempts++ - if s.attempts > relayMaxRetries { - return 0, true - } - r := s.retrier - d := time.Duration(float64(r.Delay) * r.Rate) - if d > r.Ceil { - d = r.Ceil - } - if d < r.Floor { - d = r.Floor - } - r.Delay = d - return d, false -} - -// reset returns the state to the floor delay and zero attempts. -// Called after a successful dial or a relay target change. -func (s *relayRetryState) reset() { - s.retrier.Reset() - s.attempts = 0 -} - -// dialRelay opens a WebSocket to the replica owning chatID and -// returns any buffered message_part snapshot plus a live channel of -// subsequent events. Handshake failures return an error unwrapping -// to *RelayDialError so callers can classify via IsUnrecoverable. -// -// websocket.Dial is called directly (not via the SDK wrapper) so we -// can read *http.Response.StatusCode for classification. -func dialRelay( +func dialRelayParts( ctx context.Context, - chatID uuid.UUID, - workerID uuid.UUID, - requestHeader http.Header, - cfg MultiReplicaSubscribeConfig, - clk quartz.Clock, -) ( - snapshot []codersdk.ChatStreamEvent, - parts <-chan codersdk.ChatStreamEvent, - cancel func(), - err error, -) { - address, ok := cfg.ResolveReplicaAddress(ctx, workerID) + input osschatd.StreamPartsDialInput, + cfg StreamPartsDialerConfig, +) (osschatd.StreamPartsSession, error) { + if cfg.ResolveReplicaAddress == nil { + return nil, &RelayDialError{Err: xerrors.New("dial relay stream parts: resolver not configured")} + } + address, ok := cfg.ResolveReplicaAddress(ctx, input.WorkerID) if !ok { - return nil, nil, nil, &RelayDialError{ - Err: xerrors.New("dial relay stream: worker replica not found"), - } + return nil, &RelayDialError{Err: xerrors.New("dial relay stream parts: worker replica not found")} } - - wsURL, err := buildRelayURL(address, chatID) + wsURL, err := buildRelayURL(address, input.ChatID) if err != nil { - return nil, nil, nil, &RelayDialError{ - Err: xerrors.Errorf("dial relay stream: %w", err), - } + return nil, &RelayDialError{Err: xerrors.Errorf("dial relay stream parts: %w", err)} } + if cfg.ReplicaIDFn == nil { + return nil, &RelayDialError{Err: xerrors.New("dial relay stream parts: replica ID function not configured")} + } replicaID := cfg.ReplicaIDFn() + if replicaID == uuid.Nil { + return nil, &RelayDialError{Err: xerrors.New("dial relay stream parts: replica ID is nil")} + } headers := make(http.Header, 2) - headers.Set(codersdk.SessionTokenHeader, extractSessionToken(requestHeader)) + headers.Set(codersdk.SessionTokenHeader, extractSessionToken(input.RequestHeader)) headers.Set(RelaySourceHeader, replicaID.String()) - relayCtx, relayCancel := context.WithCancel(ctx) - conn, resp, dialErr := websocket.Dial(relayCtx, wsURL, &websocket.DialOptions{ + conn, resp, dialErr := websocket.Dial(ctx, wsURL, &websocket.DialOptions{ HTTPClient: cfg.ReplicaHTTPClient, HTTPHeader: headers, CompressionMode: websocket.CompressionDisabled, @@ -722,118 +93,22 @@ func dialRelay( status := 0 if resp != nil { status = resp.StatusCode - // The websocket library closes resp.Body on success; on - // failure we close it ourselves so we don't leak the TCP - // connection. if dialErr != nil && resp.Body != nil { _ = resp.Body.Close() } } if dialErr != nil { - relayCancel() - return nil, nil, nil, &RelayDialError{ + return nil, &RelayDialError{ HTTPStatus: status, - Err: xerrors.Errorf("dial relay stream: %w", dialErr), + Err: xerrors.Errorf("dial relay stream parts: %w", dialErr), } } - // Match the server's 4 MiB read limit in codersdk.StreamChat so - // large message_part batches don't trip the default 32 KiB cap. conn.SetReadLimit(1 << 22) - - snapshot = make([]codersdk.ChatStreamEvent, 0, 100) - - // sourceEvents is the flattened batch→event channel. A small - // goroutine reads batches off the websocket and fans them out; - // callers see a single event stream identical to the shape the - // old SDK call produced. - sourceEvents := make(chan codersdk.ChatStreamEvent, 128) - go func() { - defer close(sourceEvents) - for { - var batch []codersdk.ChatStreamEvent - if readErr := wsjson.Read(relayCtx, conn, &batch); readErr != nil { - return - } - for _, event := range batch { - select { - case sourceEvents <- event: - case <-relayCtx.Done(): - return - } - } - } - }() - - closeSource := func() { - relayCancel() - _ = conn.Close(websocket.StatusNormalClosure, "") - } - - // Wait briefly for the first event to handle the common - // case where the remote side has buffered parts but hasn't - // flushed them to the WebSocket yet. - const drainTimeout = time.Second - drainTimer := clk.NewTimer(drainTimeout, "drain") - defer drainTimer.Stop() - -drainInitial: - for len(snapshot) < cap(snapshot) { - select { - case <-relayCtx.Done(): - closeSource() - return nil, nil, nil, &RelayDialError{ - Err: xerrors.Errorf("dial relay stream: %w", relayCtx.Err()), - } - case event, ok := <-sourceEvents: - if !ok { - break drainInitial - } - if event.Type != codersdk.ChatStreamEventTypeMessagePart { - continue - } - snapshot = append(snapshot, event) - // After getting the first event, switch to - // non-blocking drain for remaining buffered events. - drainTimer.Stop() - drainTimer.Reset(0) - case <-drainTimer.C: - break drainInitial - } - } - - events := make(chan codersdk.ChatStreamEvent, 128) - - go func() { - defer close(events) - defer closeSource() - - // No need to re-send snapshot events - they're - // returned to the caller directly. - for { - select { - case <-relayCtx.Done(): - return - case event, ok := <-sourceEvents: - if !ok { - return - } - if event.Type != codersdk.ChatStreamEventTypeMessagePart { - continue - } - select { - case events <- event: - case <-relayCtx.Done(): - return - } - } - } - }() - - return snapshot, events, closeSource, nil + return osschatd.NewStreamPartsJSONSession(ctx, conn), nil } -// buildRelayURL builds the websocket URL for the chat stream -// endpoint on a peer replica. It maps http(s) schemes to ws(s). +// buildRelayURL builds the websocket URL for the chat stream parts endpoint on +// a peer replica. It maps http(s) schemes to ws(s). func buildRelayURL(address string, chatID uuid.UUID) (string, error) { u, err := url.Parse(address) if err != nil { @@ -845,40 +120,30 @@ func buildRelayURL(address string, chatID uuid.UUID) (string, error) { case "https": u.Scheme = "wss" case "ws", "wss": - // already a websocket URL, leave as-is. default: return "", xerrors.Errorf("unsupported relay address scheme %q", u.Scheme) } - u.Path = fmt.Sprintf("/api/experimental/chats/%s/stream", chatID) - q := u.Query() - // Relays only need live message_part events, not the full - // history; pass the relay sentinel so the peer skips its - // durable DB snapshot and delivers in-flight parts only. - q.Set("after_id", strconv.FormatInt(osschatd.RelaySentinelAfterID, 10)) - u.RawQuery = q.Encode() + u.Path = "/api/experimental/chats/" + chatID.String() + "/stream/parts" + u.RawQuery = "" return u.String(), nil } -// extractSessionToken returns the session token carried by the -// given request headers. It mirrors the priority order used by -// apiKeyMiddleware: cookie, then Coder-Session-Token header, then -// Authorization: Bearer header. +// extractSessionToken returns the session token carried by the given request +// headers. It mirrors the priority order used by apiKeyMiddleware: cookie, +// then Coder-Session-Token header, then Authorization: Bearer header. func extractSessionToken(header http.Header) string { if header == nil { return "" } - // Cookie (browser WebSocket upgrade - most common relay case). if raw := header.Get(cookieHeader); raw != "" { r := &http.Request{Header: http.Header{cookieHeader: {raw}}} if c, err := r.Cookie(codersdk.SessionTokenCookie); err == nil && c.Value != "" { return c.Value } } - // Coder-Session-Token header (SDK / CLI callers). if v := header.Get(codersdk.SessionTokenHeader); v != "" { return v } - // Authorization: Bearer . if v := header.Get(authorizationHeader); len(v) > 7 && strings.EqualFold(v[:7], "bearer ") { return strings.TrimSpace(v[7:]) } diff --git a/enterprise/coderd/x/chatd/chatd_retry_test.go b/enterprise/coderd/x/chatd/chatd_retry_test.go deleted file mode 100644 index d21a15b9ba0de..0000000000000 --- a/enterprise/coderd/x/chatd/chatd_retry_test.go +++ /dev/null @@ -1,796 +0,0 @@ -package chatd_test - -import ( - "context" - "database/sql" - "encoding/json" - "io" - "math" - "net/http" - "net/http/httptest" - "regexp" - "sync/atomic" - "testing" - "time" - - "github.com/google/uuid" - "github.com/stretchr/testify/require" - "golang.org/x/xerrors" - - "cdr.dev/slog/v3/sloggers/slogtest" - "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/dbtestutil" - dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub" - coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" - osschatd "github.com/coder/coder/v2/coderd/x/chatd" - "github.com/coder/coder/v2/codersdk" - entchatd "github.com/coder/coder/v2/enterprise/coderd/x/chatd" - "github.com/coder/coder/v2/testutil" - "github.com/coder/quartz" -) - -// mulPhi multiplies a duration by math.Phi to compute the next -// step in retry.Retrier's φ-growth backoff sequence. If -// TestRelayReconnectUsesExponentialBackoff starts failing after a -// retry library bump, check whether the growth factor has changed. -func mulPhi(d time.Duration) time.Duration { - return time.Duration(float64(d) * math.Phi) -} - -// setChatRunningAndPublish marks the chat row as running on workerID -// and publishes a matching status notification. It keeps the DB row -// and pubsub notification in sync so the async reconnect loop -// re-dials on each timer fire (the reconnect branch re-checks DB -// status before calling openRelayAsync). -func setChatRunningAndPublish( - ctx context.Context, - t *testing.T, - db database.Store, - ps dbpubsub.Pubsub, - chatID, workerID uuid.UUID, -) { - t.Helper() - now := time.Now() - _, err := db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ - ID: chatID, - Status: database.ChatStatusRunning, - WorkerID: uuid.NullUUID{UUID: workerID, Valid: true}, - StartedAt: sql.NullTime{Time: now, Valid: true}, - HeartbeatAt: sql.NullTime{Time: now, Valid: true}, - }) - require.NoError(t, err) - payload, err := json.Marshal(coderdpubsub.ChatStreamNotifyMessage{ - Status: string(database.ChatStatusRunning), - WorkerID: workerID.String(), - }) - require.NoError(t, err) - require.NoError(t, ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chatID), payload)) -} - -// TestRelayDialErrorIsUnrecoverable locks the classification policy. -// Adding a new HTTP status to the unrecoverable set should force a -// test edit too. -func TestRelayDialErrorIsUnrecoverable(t *testing.T) { - t.Parallel() - - cases := []struct { - name string - status int - want bool - }{ - {"unauthorized", http.StatusUnauthorized, true}, - {"forbidden", http.StatusForbidden, true}, - {"internal_server", http.StatusInternalServerError, false}, - {"bad_gateway", http.StatusBadGateway, false}, - {"service_unavailable", http.StatusServiceUnavailable, false}, - {"too_many_requests", http.StatusTooManyRequests, false}, - {"pre_response", 0, false}, - {"bad_request", http.StatusBadRequest, false}, - {"not_found", http.StatusNotFound, false}, - } - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - e := &entchatd.RelayDialError{HTTPStatus: tc.status, Err: io.EOF} - require.Equal(t, tc.want, e.IsUnrecoverable(), - "status=%d", tc.status) - }) - } -} - -// TestRelayReconnectUsesExponentialBackoff asserts that the reconnect -// timer follows the φ-growth sequence produced by -// github.com/coder/retry's defaults, floored at relayRetryFloor. -func TestRelayReconnectUsesExponentialBackoff(t *testing.T) { - t.Parallel() - - db, ps := dbtestutil.NewDB(t) - workerID := uuid.New() - subscriberID := uuid.New() - - var failCount atomic.Int32 - dialer := func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) ( - []codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error, - ) { - failCount.Add(1) - return nil, nil, nil, &entchatd.RelayDialError{ - HTTPStatus: http.StatusBadGateway, - Err: io.EOF, - } - } - - mclk := quartz.NewMock(t) - trapReconnect := mclk.Trap().NewTimer("reconnect") - defer trapReconnect.Close() - - subscriber := newTestServer(t, db, ps, subscriberID, dialer, mclk) - - ctx := testutil.Context(t, testutil.WaitLong) - user, org, model := seedChatDependencies(t, db) - chat := seedWaitingChat(t, db, org.ID, user, model, "relay-backoff") - - _, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0) - require.True(t, ok) - t.Cleanup(cancel) - - // Kick the async relay loop and keep the DB row in sync so - // each reconnect timer fire triggers another dial. - setChatRunningAndPublish(ctx, t, db, ps, chat.ID, workerID) - // Expected sequence from retry.Retrier math: - // attempt 1 → floor (500ms) - // attempt n → prev × φ (capped at ceil) - floor := 500 * time.Millisecond - expected := []time.Duration{ - floor, - mulPhi(floor), - mulPhi(mulPhi(floor)), - mulPhi(mulPhi(mulPhi(floor))), - mulPhi(mulPhi(mulPhi(mulPhi(floor)))), - } - - for i, want := range expected { - call := trapReconnect.MustWait(ctx) - require.Equal(t, want, call.Duration, - "attempt %d: want %v got %v", i+1, want, call.Duration) - call.MustRelease(ctx) - mclk.Advance(want).MustWait(ctx) - } - - // We expect 1 initial attempt + 5 reconnects fired by the - // trapped timer = 6 dials before the cap-check runs. Use - // Eventually so we don't race the final dial goroutine that - // the last Advance kicked off. - require.Eventually(t, func() bool { - return failCount.Load() >= 6 - }, testutil.WaitShort, testutil.IntervalFast, - "expected 6 dials, got %d", failCount.Load()) - - // The events channel must remain open - we're still under the - // cap. - select { - case ev, open := <-events: - if !open { - t.Fatalf("events channel closed prematurely; retries should continue below cap") - } - // Allow through events that might have been queued; just - // confirm it's not a terminal error. - if ev.Type == codersdk.ChatStreamEventTypeError { - t.Fatalf("unexpected terminal error: %v", ev.Error) - } - default: - } -} - -// TestRelayReconnectResetsOnSuccess exercises the path where a -// successful dial resets the retry state so the next failure starts -// over at the floor delay. -// TestRelayRepeatedDropsHitCap verifies the cap covers a peer that -// accepts the handshake and immediately drops it. Without a proper -// cap, such a peer would produce one reconnect per floor delay -// forever. The retry counter must accumulate across dial-success / -// parts-close cycles so the cap trips. -func TestRelayRepeatedDropsHitCap(t *testing.T) { - t.Parallel() - - db, ps := dbtestutil.NewDB(t) - workerID := uuid.New() - subscriberID := uuid.New() - - opened := make(chan chan codersdk.ChatStreamEvent, 32) - var call atomic.Int32 - dialer := func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) ( - []codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error, - ) { - call.Add(1) - ch := make(chan codersdk.ChatStreamEvent, 1) - opened <- ch - return nil, ch, func() {}, nil - } - - mclk := quartz.NewMock(t) - trapReconnect := mclk.Trap().NewTimer("reconnect") - defer trapReconnect.Close() - - subscriber := newTestServer(t, db, ps, subscriberID, dialer, mclk) - - ctx := testutil.Context(t, testutil.WaitLong) - user, org, model := seedChatDependencies(t, db) - chat := seedWaitingChat(t, db, org.ID, user, model, "relay-drops") - - _, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0) - require.True(t, ok) - t.Cleanup(cancel) - - // Kick off the first async dial. - setChatRunningAndPublish(ctx, t, db, ps, chat.ID, workerID) - - // Close the first dial's parts channel so the merge loop - // schedules a reconnect. Then advance 6 reconnect timers, - // closing the parts channel each time so the cycle is: - // dial -> success -> parts-close -> next() -> reconnect. - // 1 initial dial + 6 timer-driven dials = 7 total; the 7th - // parts-close trips the cap. - for i := 0; i < 7; i++ { - var ch chan codersdk.ChatStreamEvent - select { - case ch = <-opened: - case <-ctx.Done(): - t.Fatalf("timed out waiting for dial %d", i+1) - } - // Closing the parts channel triggers the relayPartsCh - // close branch, which calls retryState.next() and - // schedules the next reconnect. - close(ch) - if i == 6 { - // 7th parts-close should trip the cap; no more - // reconnect timers. - break - } - call := trapReconnect.MustWait(ctx) - call.MustRelease(ctx) - mclk.Advance(call.Duration).MustWait(ctx) - } - - // A terminal error event must arrive on the events channel. - var errEvent *codersdk.ChatStreamEvent - require.Eventually(t, func() bool { - select { - case ev, open := <-events: - if !open { - return errEvent != nil - } - if ev.Type == codersdk.ChatStreamEventTypeError { - errEvent = &ev - return true - } - return false - default: - return false - } - }, testutil.WaitShort, testutil.IntervalFast, - "expected a terminal error event after repeated drops hit cap") - require.NotNil(t, errEvent.Error) - require.Contains(t, errEvent.Error.Message, "relay connection failed") - - // We should have observed exactly 7 dials before tear-down. - require.Equal(t, int32(7), call.Load(), - "expected 7 dials (1 initial + 6 reconnect retries) before cap") -} - -// TestRelayStopsAfterIntermittentCap verifies the cap-reached -// tear-down path: after N intermittent failures the merge loop emits -// one error event, closes the events channel, and stops dialing. -func TestRelayStopsAfterIntermittentCap(t *testing.T) { - t.Parallel() - - db, ps := dbtestutil.NewDB(t) - workerID := uuid.New() - subscriberID := uuid.New() - - var callCount atomic.Int32 - dialer := func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) ( - []codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error, - ) { - callCount.Add(1) - return nil, nil, nil, &entchatd.RelayDialError{ - HTTPStatus: http.StatusBadGateway, - Err: io.EOF, - } - } - - mclk := quartz.NewMock(t) - trapReconnect := mclk.Trap().NewTimer("reconnect") - defer trapReconnect.Close() - - subscriber := newTestServer(t, db, ps, subscriberID, dialer, mclk) - - ctx := testutil.Context(t, testutil.WaitLong) - user, org, model := seedChatDependencies(t, db) - chat := seedWaitingChat(t, db, org.ID, user, model, "relay-cap") - - _, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0) - require.True(t, ok) - t.Cleanup(cancel) - - setChatRunningAndPublish(ctx, t, db, ps, chat.ID, workerID) - // Advance through N consecutive reconnect timers. Each one - // triggers a dial, which fails and schedules the next timer. - // After the Nth failure the retry state says giveUp=true on - // the next .next() call, so the merge loop tears down. - for i := 0; i < 6; i++ { - call := trapReconnect.MustWait(ctx) - call.MustRelease(ctx) - mclk.Advance(call.Duration).MustWait(ctx) - } - - // Wait for the terminal error event to arrive. mergedEvents - // closes inside the enterprise merge goroutine, but OSS only - // nil-outs relayEvents on close - the outer events channel - // stays open for pubsub/local, so we wait for the error event - // itself rather than channel closure. - var errEvent *codersdk.ChatStreamEvent - require.Eventually(t, func() bool { - select { - case ev, open := <-events: - if !open { - return errEvent != nil - } - if ev.Type == codersdk.ChatStreamEventTypeError { - errEvent = &ev - return true - } - return false - default: - return false - } - }, testutil.WaitShort, testutil.IntervalFast, - "expected a terminal error event") - require.NotNil(t, errEvent, "expected a terminal error event") - require.NotNil(t, errEvent.Error) - require.Contains(t, errEvent.Error.Message, "relay connection failed") - require.Contains(t, errEvent.Error.Message, "6") - - // Ensure the cap fires at attempt N+1 - the retry state allows - // relayMaxRetries successful next() calls before flipping - // giveUp. With one initial dial + 6 reconnect-timer fires the - // 7th .next() trips the cap and tears down, so we see 7 dials - // total and nothing further. - totalDials := callCount.Load() - require.Equal(t, int32(7), totalDials, - "expected exactly relayMaxRetries+1 dials before cap; got %d", totalDials) -} - -// chatByIDErrorStore wraps a database.Store and forces GetChatByID -// to return a caller-supplied error once after N successful calls. -// This lets the initial Subscribe call succeed (OSS's initial state -// load needs a real Chat to wire up the relay) while subsequent -// reconnect-branch calls exercise the DB-error retry path. -type chatByIDErrorStore struct { - database.Store - err error - okRemain atomic.Int32 // number of calls allowed to delegate before erroring. -} - -func (s *chatByIDErrorStore) GetChatByID(ctx context.Context, id uuid.UUID) (database.Chat, error) { - if s.okRemain.Add(-1) >= 0 { - return s.Store.GetChatByID(ctx, id) - } - return database.Chat{}, s.err -} - -// TestRelayReconnectStopsAfterDBErrorCap verifies the reconnect-timer -// branch's DB-error path shares the same retry budget as dial -// failures and trips the cap after enough consecutive DB errors. -func TestRelayReconnectStopsAfterDBErrorCap(t *testing.T) { - t.Parallel() - - realDB, ps := dbtestutil.NewDB(t) - workerID := uuid.New() - subscriberID := uuid.New() - - var callCount atomic.Int32 - dialer := func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) ( - []codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error, - ) { - callCount.Add(1) - return nil, nil, nil, &entchatd.RelayDialError{ - HTTPStatus: http.StatusBadGateway, - Err: io.EOF, - } - } - - mclk := quartz.NewMock(t) - trapReconnect := mclk.Trap().NewTimer("reconnect") - defer trapReconnect.Close() - - // The server sees a DB whose GetChatByID always errors after - // the initial Subscribe snapshot load. Other methods delegate - // to the real DB, so seeding below still works. - failingDB := &chatByIDErrorStore{ - Store: realDB, - err: xerrors.New("mock: GetChatByID always fails"), - } - // Allow one successful GetChatByID (the Subscribe preamble's - // initial state load). All subsequent calls return the mock - // error, exercising the reconnect-branch DB-error path. - failingDB.okRemain.Store(1) - - ctx := testutil.Context(t, testutil.WaitLong) - user, org, model := seedChatDependencies(t, realDB) - chat := seedWaitingChat(t, realDB, org.ID, user, model, "relay-db-error") - - subscriber := newTestServer(t, failingDB, ps, subscriberID, dialer, mclk) - _, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0) - require.True(t, ok) - t.Cleanup(cancel) - - // Flip to running so the merge loop starts an async dial. The - // dial fails (attempts=1, reconnect scheduled). From there each - // reconnect timer fires, the merge loop calls GetChatByID, the - // failing DB returns an error, and retryState.next() increments. - // - // Budget: 1 dial-failure + 6 DB-failures = 7 next() calls; the - // 7th trips the cap. - setChatRunningAndPublish(ctx, t, realDB, ps, chat.ID, workerID) - for i := 0; i < 6; i++ { - call := trapReconnect.MustWait(ctx) - call.MustRelease(ctx) - mclk.Advance(call.Duration).MustWait(ctx) - } - - var errEvent *codersdk.ChatStreamEvent - require.Eventually(t, func() bool { - select { - case ev, open := <-events: - if !open { - return errEvent != nil - } - if ev.Type == codersdk.ChatStreamEventTypeError { - errEvent = &ev - return true - } - return false - default: - return false - } - }, testutil.WaitShort, testutil.IntervalFast, - "expected terminal error event after DB-error cap") - require.NotNil(t, errEvent.Error) - require.Contains(t, errEvent.Error.Message, "relay connection failed") - require.Contains(t, errEvent.Error.Message, "6") - - // Exactly 1 dial fired: the one that triggered the initial - // reconnect schedule. All subsequent next() calls come from the - // DB-error branch without calling the dialer. - require.Equal(t, int32(1), callCount.Load(), - "expected exactly 1 dial; reconnects should short-circuit on DB error") -} - -// TestRelayStopsImmediatelyOnUnauthorized tests the unrecoverable -// branch and its table of status codes. -func TestRelayStopsImmediatelyOnUnauthorized(t *testing.T) { - t.Parallel() - - cases := []struct { - name string - status int - wantUnrecoverable bool - wantMsgContains string - }{ - {"401", http.StatusUnauthorized, true, "401"}, - {"403", http.StatusForbidden, true, "403"}, - {"500_intermittent", http.StatusInternalServerError, false, ""}, - {"zero_intermittent", 0, false, ""}, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - db, ps := dbtestutil.NewDB(t) - workerID := uuid.New() - subscriberID := uuid.New() - - var callCount atomic.Int32 - dialer := func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) ( - []codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error, - ) { - callCount.Add(1) - return nil, nil, nil, &entchatd.RelayDialError{ - HTTPStatus: tc.status, - Err: io.EOF, - } - } - - mclk := quartz.NewMock(t) - trapReconnect := mclk.Trap().NewTimer("reconnect") - defer trapReconnect.Close() - - subscriber := newTestServer(t, db, ps, subscriberID, dialer, mclk) - - ctx := testutil.Context(t, testutil.WaitLong) - user, org, model := seedChatDependencies(t, db) - chat := seedWaitingChat(t, db, org.ID, user, model, - "relay-unrec-"+tc.name) - - _, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0) - require.True(t, ok) - t.Cleanup(cancel) - - setChatRunningAndPublish(ctx, t, db, ps, chat.ID, workerID) - if tc.wantUnrecoverable { - // First dial should tear the relay down. - var errEvent *codersdk.ChatStreamEvent - require.Eventually(t, func() bool { - select { - case ev, open := <-events: - if !open { - return errEvent != nil - } - if ev.Type == codersdk.ChatStreamEventTypeError { - errEvent = &ev - return true - } - return false - default: - return false - } - }, testutil.WaitShort, testutil.IntervalFast, - "expected terminal error event") - require.NotNil(t, errEvent) - require.Contains(t, errEvent.Error.Message, "relay authentication failed") - require.Contains(t, errEvent.Error.Message, tc.wantMsgContains) - require.Equal(t, int32(1), callCount.Load(), - "unrecoverable errors must not retry; got %d dials", callCount.Load()) - } else { - // Intermittent: fire one reconnect timer - // and confirm the dialer is called again. - call := trapReconnect.MustWait(ctx) - call.MustRelease(ctx) - mclk.Advance(call.Duration).MustWait(ctx) - require.Eventually(t, func() bool { - return callCount.Load() >= 2 - }, testutil.WaitShort, testutil.IntervalFast, - "intermittent should retry at least once") - } - }) - } -} - -// TestRelayBackoffResetsOnStatusChange checks that closeRelay (driven -// by a status notification) resets the retry counter so subsequent -// dials against a new target start at the floor delay. -func TestRelayBackoffResetsOnStatusChange(t *testing.T) { - t.Parallel() - - db, ps := dbtestutil.NewDB(t) - workerID1 := uuid.New() - workerID2 := uuid.New() - subscriberID := uuid.New() - - dialer := func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) ( - []codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error, - ) { - return nil, nil, nil, &entchatd.RelayDialError{ - HTTPStatus: http.StatusBadGateway, - Err: io.EOF, - } - } - - mclk := quartz.NewMock(t) - trapReconnect := mclk.Trap().NewTimer("reconnect") - defer trapReconnect.Close() - - subscriber := newTestServer(t, db, ps, subscriberID, dialer, mclk) - - ctx := testutil.Context(t, testutil.WaitLong) - user, org, model := seedChatDependencies(t, db) - chat := seedWaitingChat(t, db, org.ID, user, model, "relay-reset-on-status") - - _, _, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0) - require.True(t, ok) - t.Cleanup(cancel) - - // Drive the async openRelayAsync path with workerID1. - setChatRunningAndPublish(ctx, t, db, ps, chat.ID, workerID1) - - // Drive 3 intermittent failures so attempts=3 and the delay - // has grown past the floor. After each loop iteration the 4th - // reconnect timer is queued - consume it too so our later - // assertion sees the reset's timer, not a stale one. - for i := 0; i < 3; i++ { - call := trapReconnect.MustWait(ctx) - call.MustRelease(ctx) - mclk.Advance(call.Duration).MustWait(ctx) - } - // Grab the next trapped timer (the grown one scheduled after - // the 3rd dial fails) but don't advance it - we want to see it - // replaced by a fresh floor-delay timer after the reset. - grown := trapReconnect.MustWait(ctx) - require.Greater(t, grown.Duration, 500*time.Millisecond, - "sanity: pre-reset delay should have grown past the floor") - grown.MustRelease(ctx) - - // Flip the chat to waiting; closeRelay runs (because the - // status notification no longer points at a running peer) and - // should reset the retry state. - _, err := db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ - ID: chat.ID, - Status: database.ChatStatusWaiting, - }) - require.NoError(t, err) - waitingPayload, err := json.Marshal(coderdpubsub.ChatStreamNotifyMessage{ - Status: string(database.ChatStatusWaiting), - }) - require.NoError(t, err) - require.NoError(t, ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chat.ID), waitingPayload)) - - // Flip back to running on a different worker. This triggers a - // fresh openRelayAsync which fails, arming a reconnect timer. - // That timer's delay must be the floor, proving the reset. - setChatRunningAndPublish(ctx, t, db, ps, chat.ID, workerID2) - - call := trapReconnect.MustWait(ctx) - require.Equal(t, 500*time.Millisecond, call.Duration, - "retry state must reset after status change; got grown delay %v", call.Duration) - call.MustRelease(ctx) -} - -// TestRelayBackoffRespectsContextCancel is a regression guard: the -// reconnect timer must respect ctx cancellation promptly. -func TestRelayBackoffRespectsContextCancel(t *testing.T) { - t.Parallel() - - db, ps := dbtestutil.NewDB(t) - workerID := uuid.New() - subscriberID := uuid.New() - - dialer := func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) ( - []codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error, - ) { - return nil, nil, nil, &entchatd.RelayDialError{ - HTTPStatus: http.StatusBadGateway, - Err: io.EOF, - } - } - - mclk := quartz.NewMock(t) - trapReconnect := mclk.Trap().NewTimer("reconnect") - defer trapReconnect.Close() - - subscriber := newTestServer(t, db, ps, subscriberID, dialer, mclk) - - ctx := testutil.Context(t, testutil.WaitLong) - user, org, model := seedChatDependencies(t, db) - chat := seedWaitingChat(t, db, org.ID, user, model, "relay-cancel") - - subCtx, subCancel := context.WithCancel(ctx) - _, events, cancel, ok := subscriber.Subscribe(subCtx, chat.ID, nil, 0) - require.True(t, ok) - t.Cleanup(cancel) - - setChatRunningAndPublish(ctx, t, db, ps, chat.ID, workerID) - - // Wait for the first reconnect timer to arm. - call := trapReconnect.MustWait(ctx) - call.MustRelease(ctx) - - // Cancel the subscriber context. The events channel should - // close promptly (the merge goroutine's select exits on - // ctx.Done). - subCancel() - - done := make(chan struct{}) - go func() { - defer close(done) - for { - if _, open := <-events; !open { - return - } - } - }() - select { - case <-done: - case <-time.After(testutil.WaitShort): - t.Fatal("events channel did not close after ctx cancel") - } -} - -// TestDialRelayReal401 exercises the real dialRelay path against an -// httptest server that returns 401 on the stream endpoint. It -// validates that the websocket library's handshake failure -// propagates through as *RelayDialError with HTTPStatus == 401. -// -// This is the one test that uses the real coder/websocket library -// on the failure path - a safety net against library upgrades -// silently breaking status capture. -func TestDialRelayReal401(t *testing.T) { - t.Parallel() - - // An httptest server that 401s every request on the stream - // endpoint. Any other path gets a 404. - srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - if !streamPathRE.MatchString(r.URL.Path) { - http.NotFound(rw, r) - return - } - rw.Header().Set("Content-Type", "application/json") - rw.WriteHeader(http.StatusUnauthorized) - _, _ = rw.Write([]byte(`{"message":"unauthorized"}`)) - })) - t.Cleanup(srv.Close) - - db, _ := dbtestutil.NewDB(t) - workerID := uuid.New() - subscriberID := uuid.New() - - // Wire real config (no DialerFn override) so dialRelay runs - // end-to-end against the httptest server. Seeding a waiting - // chat (below) keeps Subscribe's initial synchronous dial a - // no-op; we then push a running status notification to the - // merge loop so it invokes dialRelay via the async path, where - // the 401 tear-down logic lives. - cfg := entchatd.MultiReplicaSubscribeConfig{ - ResolveReplicaAddress: func(_ context.Context, _ uuid.UUID) (string, bool) { - return srv.URL, true - }, - ReplicaHTTPClient: srv.Client(), - ReplicaIDFn: func() uuid.UUID { return subscriberID }, - } - subscribeFn := entchatd.NewMultiReplicaSubscribeFn(cfg) - - ctx := testutil.Context(t, testutil.WaitMedium) - user, org, model := seedChatDependencies(t, db) - // Seed a waiting chat - no sync dial - then push a running - // status notification to trigger the async dial via the real - // dialRelay path. - chat := seedWaitingChat(t, db, org.ID, user, model, "relay-real-401") - - statusCh := make(chan osschatd.StatusNotification, 1) - evs := subscribeFn(ctx, osschatd.SubscribeFnParams{ - ChatID: chat.ID, - Chat: chat, - WorkerID: subscriberID, - StatusNotifications: statusCh, - RequestHeader: http.Header{codersdk.SessionTokenHeader: {"test-token"}}, - DB: db, - Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}), - }) - - statusCh <- osschatd.StatusNotification{ - Status: database.ChatStatusRunning, - WorkerID: workerID, - } - - // Wait for a terminal error event. On a real 401 handshake, - // the classifier flags it unrecoverable → one dial, then - // error event, then channel close. - var errEvent *codersdk.ChatStreamEvent - deadline := time.After(testutil.WaitMedium) -waitErr: - for { - select { - case ev, open := <-evs: - if !open { - break waitErr - } - if ev.Type == codersdk.ChatStreamEventTypeError { - errEvent = &ev - } - case <-deadline: - break waitErr - } - } - - require.NotNil(t, errEvent, "expected terminal error event from real 401 dial") - require.NotNil(t, errEvent.Error) - require.Contains(t, errEvent.Error.Message, "relay authentication failed") - require.Contains(t, errEvent.Error.Message, "401") -} - -// streamPathRE matches the chat stream endpoint path built by -// buildRelayURL. Compiled at package scope so the httptest handler -// below doesn't pay regexp.Compile per request. -var streamPathRE = regexp.MustCompile( - `^/api/experimental/chats/[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}/stream$`, -) diff --git a/enterprise/coderd/x/chatd/chatd_test.go b/enterprise/coderd/x/chatd/chatd_test.go index a4c48ad4d98b8..9dfb2e361f054 100644 --- a/enterprise/coderd/x/chatd/chatd_test.go +++ b/enterprise/coderd/x/chatd/chatd_test.go @@ -2,1626 +2,151 @@ package chatd_test import ( "context" - "database/sql" - "encoding/json" - "fmt" - "math" "net/http" "net/http/httptest" - "sync/atomic" "testing" - "time" "github.com/google/uuid" - "github.com/sqlc-dev/pqtype" "github.com/stretchr/testify/require" - "golang.org/x/xerrors" - "cdr.dev/slog/v3/sloggers/slogtest" - "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/dbgen" - "github.com/coder/coder/v2/coderd/database/dbtestutil" - dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub" - coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" osschatd "github.com/coder/coder/v2/coderd/x/chatd" - "github.com/coder/coder/v2/coderd/x/chatd/chattest" "github.com/coder/coder/v2/codersdk" entchatd "github.com/coder/coder/v2/enterprise/coderd/x/chatd" - "github.com/coder/coder/v2/testutil" - "github.com/coder/quartz" + "github.com/coder/websocket" ) -const skipLegacyChatStream = "chatd refactor: remove in PR 4" - -func chatLastErrorMessage(raw pqtype.NullRawMessage) string { - if !raw.Valid { - return "" - } - - var payload codersdk.ChatError - if err := json.Unmarshal(raw.RawMessage, &payload); err == nil && payload.Message != "" { - return payload.Message - } - return string(raw.RawMessage) -} - -func newTestServer( - t *testing.T, - db database.Store, - ps dbpubsub.Pubsub, - replicaID uuid.UUID, - dialer func( - ctx context.Context, - chatID uuid.UUID, - workerID uuid.UUID, - requestHeader http.Header, - ) ( - []codersdk.ChatStreamEvent, - <-chan codersdk.ChatStreamEvent, - func(), - error, - ), - clock quartz.Clock, -) *osschatd.Server { - t.Helper() - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) - server := osschatd.New(osschatd.Config{ - Logger: logger, - Database: db, - ReplicaID: replicaID, - Pubsub: ps, - SubscribeFn: entchatd.NewMultiReplicaSubscribeFn(entchatd.MultiReplicaSubscribeConfig{DialerFn: dialer, Clock: clock}), - PendingChatAcquireInterval: testutil.WaitSuperLong, - }) - server.Start() - t.Cleanup(func() { - require.NoError(t, server.Close()) - }) - return server -} - -func newActiveWorkerServer( - t *testing.T, - db database.Store, - ps dbpubsub.Pubsub, - replicaID uuid.UUID, -) *osschatd.Server { - t.Helper() - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) - server := osschatd.New(osschatd.Config{ - Logger: logger, - Database: db, - ReplicaID: replicaID, - Pubsub: ps, - PendingChatAcquireInterval: 10 * time.Millisecond, - InFlightChatStaleAfter: testutil.WaitSuperLong, - }) - server.Start() - t.Cleanup(func() { - require.NoError(t, server.Close()) - }) - return server -} - -// seedChatDependencies creates a user, organization, and chat model -// config in the database for use in relay tests. -func seedChatDependencies( - t *testing.T, - db database.Store, -) (database.User, database.Organization, database.ChatModelConfig) { - t.Helper() - - safetyNet := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - rw.Header().Set("Content-Type", "application/json") - rw.WriteHeader(http.StatusInternalServerError) - _, _ = rw.Write([]byte(`{"error":{"message":"unexpected OpenAI request in chatd relay test safety net"}}`)) - })) - t.Cleanup(safetyNet.Close) - - user := dbgen.User(t, db, database.User{}) - org := dbgen.Organization(t, db, database.Organization{}) - dbgen.OrganizationMember(t, db, database.OrganizationMember{ - UserID: user.ID, - OrganizationID: org.ID, - }) - provider := dbgen.AIProvider(t, db, database.AIProvider{ - Type: database.AiProviderTypeOpenai, - Name: "test-" + uuid.NewString(), - BaseUrl: safetyNet.URL, - }) - dbgen.AIProviderKey(t, db, database.AIProviderKey{ - ProviderID: provider.ID, - }) - model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{ - Provider: "openai", - AIProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true}, - CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, - UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, - IsDefault: true, - }) - return user, org, model +type fakePartsSession struct { + parts chan osschatd.StreamPart } -func seedWaitingChat( - t *testing.T, - db database.Store, - orgID uuid.UUID, - user database.User, - model database.ChatModelConfig, - title string, -) database.Chat { - t.Helper() - - chat := dbgen.Chat(t, db, database.Chat{ - OrganizationID: orgID, - OwnerID: user.ID, - LastModelConfigID: model.ID, - Title: title, - }) - return chat -} - -func seedRemoteRunningChat( - ctx context.Context, - t *testing.T, - db database.Store, - orgID uuid.UUID, - user database.User, - model database.ChatModelConfig, - workerID uuid.UUID, - title string, -) database.Chat { - t.Helper() - - chat := seedWaitingChat(t, db, orgID, user, model, title) - now := time.Now() - chat, err := db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ - ID: chat.ID, - Status: database.ChatStatusRunning, - WorkerID: uuid.NullUUID{UUID: workerID, Valid: true}, - StartedAt: sql.NullTime{Time: now, Valid: true}, - HeartbeatAt: sql.NullTime{Time: now, Valid: true}, - }) - require.NoError(t, err) - return chat -} - -func setOpenAIProviderBaseURL( - ctx context.Context, - t *testing.T, - db database.Store, - baseURL string, -) { - t.Helper() - - providers, err := db.GetAIProviders(ctx, database.GetAIProvidersParams{IncludeDisabled: true}) - require.NoError(t, err) - for _, provider := range providers { - if provider.Type != database.AiProviderTypeOpenai { - continue - } - _, err = db.UpdateAIProvider(ctx, database.UpdateAIProviderParams{ - ID: provider.ID, - DisplayName: provider.DisplayName, - Enabled: provider.Enabled, - BaseUrl: baseURL, - Settings: provider.Settings, - SettingsKeyID: provider.SettingsKeyID, - }) - require.NoError(t, err) - return - } - require.Fail(t, "openai provider not found") +func newFakePartsSession() *fakePartsSession { + return &fakePartsSession{parts: make(chan osschatd.StreamPart)} } -func TestSubscribeRelayReconnectsOnDrop(t *testing.T) { - t.Parallel() - - db, ps := dbtestutil.NewDB(t) - workerID := uuid.New() - subscriberID := uuid.New() - - var callCount atomic.Int32 - - provider := func(ctx context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) ( - []codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error, - ) { - call := callCount.Add(1) - ch := make(chan codersdk.ChatStreamEvent, 10) - if call == 1 { - // First relay: send a part then close to simulate a drop. - ch <- codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessagePart, - MessagePart: &codersdk.ChatStreamMessagePart{ - Role: "assistant", - Part: codersdk.ChatMessageText("first-relay"), - }, - } - close(ch) - } else { - // Second relay: send a different part, keep open. - ch <- codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessagePart, - MessagePart: &codersdk.ChatStreamMessagePart{ - Role: "assistant", - Part: codersdk.ChatMessageText("second-relay"), - }, - } - // Don't close — keep alive so the subscriber stays connected. - } - return nil, ch, func() {}, nil - } - - mclk := quartz.NewMock(t) - // Trap the reconnect timer so we can fire it deterministically - // instead of waiting real time. - trapReconnect := mclk.Trap().NewTimer("reconnect") - defer trapReconnect.Close() - - subscriber := newTestServer(t, db, ps, subscriberID, provider, mclk) - - ctx := testutil.Context(t, testutil.WaitLong) - user, org, model := seedChatDependencies(t, db) - - chat := seedRemoteRunningChat(ctx, t, db, org.ID, user, model, workerID, "relay-reconnect") - - _, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0) - require.True(t, ok) - t.Cleanup(cancel) - - // Should get the first relay part. - require.Eventually(t, func() bool { - select { - case event := <-events: - if event.Type == codersdk.ChatStreamEventTypeMessagePart && - event.MessagePart != nil && - event.MessagePart.Part.Text == "first-relay" { - return true - } - return false - default: - return false - } - }, testutil.WaitMedium, testutil.IntervalFast) - - // Wait for the reconnect timer to be created after the relay - // drop, then advance the mock clock to fire it immediately. - trapReconnect.MustWait(ctx).MustRelease(ctx) - mclk.Advance(500 * time.Millisecond).MustWait(ctx) - - // After the first relay closes, the reconnection should deliver - // the second relay part. - require.Eventually(t, func() bool { - select { - case event := <-events: - if event.Type == codersdk.ChatStreamEventTypeMessagePart && - event.MessagePart != nil && - event.MessagePart.Part.Text == "second-relay" { - return true - } - return false - default: - return false - } - }, testutil.WaitMedium, testutil.IntervalFast) - - require.GreaterOrEqual(t, int(callCount.Load()), 2) +func (*fakePartsSession) SelectEpisode(context.Context, int64, int64) error { return nil } +func (s *fakePartsSession) Parts() <-chan osschatd.StreamPart { return s.parts } +func (s *fakePartsSession) Close() error { + close(s.parts) + return nil } -func TestSubscribeRelayAsyncDoesNotBlock(t *testing.T) { +func TestRelayDialErrorIsUnrecoverable(t *testing.T) { t.Parallel() - db, ps := dbtestutil.NewDB(t) - workerID := uuid.New() - subscriberID := uuid.New() - - dialStarted := make(chan struct{}) - dialContinue := make(chan struct{}) - - provider := func(ctx context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) ( - []codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error, - ) { - // Signal that the dial has started, then block until released. - select { - case <-dialStarted: - default: - close(dialStarted) - } - select { - case <-dialContinue: - case <-ctx.Done(): - return nil, nil, nil, ctx.Err() - } - ch := make(chan codersdk.ChatStreamEvent, 10) - return nil, ch, func() {}, nil - } - - subscriber := newTestServer(t, db, ps, subscriberID, provider, nil) - - ctx := testutil.Context(t, testutil.WaitLong) - user, org, model := seedChatDependencies(t, db) - - // Seed a waiting chat so Subscribe does not trigger a synchronous - // relay. - chat := seedWaitingChat(t, db, org.ID, user, model, "relay-async-nonblock") - - // Subscribe before the chat is marked running so the relay opens - // via pubsub notification (openRelayAsync path). - _, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0) - require.True(t, ok) - t.Cleanup(cancel) - - // Now mark the chat as running on a remote worker. This publishes - // a status notification which triggers openRelayAsync on the - // subscriber. - notify := coderdpubsub.ChatStreamNotifyMessage{ - Status: string(database.ChatStatusRunning), - WorkerID: workerID.String(), - } - payload, err := json.Marshal(notify) - require.NoError(t, err) - err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chat.ID), payload) - require.NoError(t, err) - - // Wait for the relay dial to actually start (blocking in the - // provider). - select { - case <-dialStarted: - case <-ctx.Done(): - t.Fatal("timed out waiting for relay dial to start") - } - - // While the relay is still dialing (provider is blocked), publish - // another status change. If openRelayAsync blocked the select loop - // this event would never arrive. - statusNotify := coderdpubsub.ChatStreamNotifyMessage{ - Status: string(database.ChatStatusWaiting), - } - statusPayload, err := json.Marshal(statusNotify) - require.NoError(t, err) - err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chat.ID), statusPayload) - require.NoError(t, err) - - // The waiting status event should arrive promptly despite the - // relay still dialing. - require.Eventually(t, func() bool { - select { - case event := <-events: - return event.Type == codersdk.ChatStreamEventTypeStatus && - event.Status != nil && - event.Status.Status == codersdk.ChatStatusWaiting - default: - return false - } - }, testutil.WaitShort, testutil.IntervalFast) - - // Unblock the relay dial so the test can clean up. - close(dialContinue) -} - -func TestSubscribeRelaySnapshotDelivered(t *testing.T) { - t.Parallel() - - db, ps := dbtestutil.NewDB(t) - workerID := uuid.New() - subscriberID := uuid.New() - - provider := func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) ( - []codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error, - ) { - // Return a non-empty snapshot with two parts. - snapshot := []codersdk.ChatStreamEvent{ - { - Type: codersdk.ChatStreamEventTypeMessagePart, - MessagePart: &codersdk.ChatStreamMessagePart{ - Role: "assistant", - Part: codersdk.ChatMessageText("snap-one"), - }, - }, - { - Type: codersdk.ChatStreamEventTypeMessagePart, - MessagePart: &codersdk.ChatStreamMessagePart{ - Role: "assistant", - Part: codersdk.ChatMessageText("snap-two"), - }, - }, - } - ch := make(chan codersdk.ChatStreamEvent, 10) - // Also send a live part after the snapshot. - ch <- codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessagePart, - MessagePart: &codersdk.ChatStreamMessagePart{ - Role: "assistant", - Part: codersdk.ChatMessageText("live-part"), - }, - } - return snapshot, ch, func() {}, nil - } - - subscriber := newTestServer(t, db, ps, subscriberID, provider, nil) - - ctx := testutil.Context(t, testutil.WaitLong) - user, org, model := seedChatDependencies(t, db) - - chat := seedRemoteRunningChat(ctx, t, db, org.ID, user, model, workerID, "relay-snapshot") - staleChat := chat - staleChat.Status = database.ChatStatusWaiting - staleChat.WorkerID = uuid.NullUUID{} - staleChat.StartedAt = sql.NullTime{} - staleChat.HeartbeatAt = sql.NullTime{} - - initialSnapshot, events, cancel, ok := subscriber.SubscribeAuthorized(ctx, staleChat, nil, 0) - require.True(t, ok) - t.Cleanup(cancel) - - // The relay snapshot parts are forwarded through the events - // channel by the enterprise SubscribeFn. Collect them along - // with the live part. - var receivedTexts []string - require.Eventually(t, func() bool { - select { - case event := <-events: - if event.Type == codersdk.ChatStreamEventTypeMessagePart && - event.MessagePart != nil { - receivedTexts = append(receivedTexts, event.MessagePart.Part.Text) - } - // We expect snap-one, snap-two, and live-part. - return len(receivedTexts) >= 3 - default: - return false - } - }, testutil.WaitMedium, testutil.IntervalFast) - - require.Equal(t, []string{"snap-one", "snap-two", "live-part"}, receivedTexts) - - // The initial snapshot should contain the refreshed running status, - // not the stale waiting status passed into SubscribeAuthorized. - var snapshotStatus codersdk.ChatStatus - for _, event := range initialSnapshot { - if event.Type == codersdk.ChatStreamEventTypeStatus && event.Status != nil { - snapshotStatus = event.Status.Status - } + cases := []struct { + name string + status int + want bool + }{ + {"unauthorized", http.StatusUnauthorized, true}, + {"forbidden", http.StatusForbidden, true}, + {"internal_server", http.StatusInternalServerError, false}, + {"bad_gateway", http.StatusBadGateway, false}, + {"pre_response", 0, false}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + err := &entchatd.RelayDialError{HTTPStatus: tc.status, Err: context.Canceled} + require.Equal(t, tc.want, err.IsUnrecoverable()) + }) } - require.Equal(t, codersdk.ChatStatusRunning, snapshotStatus) } -func TestSubscribeRetryEventAcrossInstances(t *testing.T) { +func TestStreamPartsDialerUsesConfiguredDialer(t *testing.T) { t.Parallel() - t.Skip(skipLegacyChatStream) - db, ps := dbtestutil.NewDB(t) + chatID := uuid.New() workerID := uuid.New() - subscriberID := uuid.New() + headers := http.Header{codersdk.SessionTokenHeader: {"token-value"}} + wantSession := newFakePartsSession() - var streamCalls atomic.Int32 - firstStreamStarted := make(chan struct{}) - allowFirstFailure := make(chan struct{}) - openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { - if !req.Stream { - return chattest.OpenAINonStreamingResponse("retry-across-instances") - } - if streamCalls.Add(1) == 1 { - select { - case <-firstStreamStarted: - default: - close(firstStreamStarted) - } - <-allowFirstFailure - return chattest.OpenAIRateLimitResponse() - } - return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("retry", " complete")...) + var gotInput osschatd.StreamPartsDialInput + dialer := entchatd.NewStreamPartsDialer(entchatd.StreamPartsDialerConfig{ + DialerFn: func(_ context.Context, input osschatd.StreamPartsDialInput) (osschatd.StreamPartsSession, error) { + gotInput = input + return wantSession, nil + }, }) - worker := newActiveWorkerServer(t, db, ps, workerID) - subscriber := newTestServer(t, db, ps, subscriberID, func( - ctx context.Context, - chatID uuid.UUID, - targetWorkerID uuid.UUID, - requestHeader http.Header, - ) ( - []codersdk.ChatStreamEvent, - <-chan codersdk.ChatStreamEvent, - func(), - error, - ) { - if targetWorkerID != workerID { - return nil, nil, nil, xerrors.Errorf("unexpected relay target %s", targetWorkerID) - } - snapshot, events, cancel, ok := worker.Subscribe(ctx, chatID, requestHeader, math.MaxInt64) - if !ok { - return nil, nil, nil, xerrors.New("worker subscribe failed") - } - return snapshot, events, cancel, nil - }, nil) - - ctx := testutil.Context(t, testutil.WaitLong) - user, org, model := seedChatDependencies(t, db) - setOpenAIProviderBaseURL(ctx, t, db, openAIURL) - - chat, err := worker.CreateChat(ctx, osschatd.CreateOptions{ - OrganizationID: org.ID, - OwnerID: user.ID, - Title: "retry-across-instances", - ModelConfigID: model.ID, - InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")}, - }) - require.NoError(t, err) - - require.Eventually(t, func() bool { - fromDB, dbErr := db.GetChatByID(ctx, chat.ID) - if dbErr != nil { - return false - } - return fromDB.Status == database.ChatStatusRunning && - fromDB.WorkerID.Valid && fromDB.WorkerID.UUID == workerID - }, testutil.WaitMedium, testutil.IntervalFast) - - select { - case <-firstStreamStarted: - case <-ctx.Done(): - t.Fatal("timed out waiting for first streaming attempt") - } - - _, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0) - require.True(t, ok) - defer cancel() - - close(allowFirstFailure) - - var retryEvent *codersdk.ChatStreamRetry - var waitingSeen bool - var waitingBeforeRetry bool - var assistantMessageBeforeRetry bool - require.Eventually(t, func() bool { - select { - case event, ok := <-events: - if !ok { - return false - } - switch event.Type { - case codersdk.ChatStreamEventTypeRetry: - retryEvent = event.Retry - case codersdk.ChatStreamEventTypeMessage: - if event.Message != nil && event.Message.Role == codersdk.ChatMessageRoleAssistant { - if retryEvent == nil { - assistantMessageBeforeRetry = true - } - } - case codersdk.ChatStreamEventTypeStatus: - if event.Status != nil && event.Status.Status == codersdk.ChatStatusWaiting { - if retryEvent == nil { - waitingBeforeRetry = true - } - waitingSeen = true - } - } - return retryEvent != nil && waitingSeen - default: - return false - } - }, testutil.WaitLong, testutil.IntervalFast) - - require.NotNil(t, retryEvent) - require.Equal(t, 1, retryEvent.Attempt) - require.Greater(t, retryEvent.DelayMs, int64(0)) - require.Equal(t, codersdk.ChatErrorKindRateLimit, retryEvent.Kind) - require.Equal(t, "openai", retryEvent.Provider) - require.Equal(t, 429, retryEvent.StatusCode) - require.Contains(t, retryEvent.Error, "rate limiting requests") - require.False(t, assistantMessageBeforeRetry) - require.False(t, waitingBeforeRetry) - require.GreaterOrEqual(t, streamCalls.Load(), int32(2)) -} - -// TestSubscribeRelayStaleDialDiscardedAfterInterrupt verifies that when a -// user interrupts a streaming chat and sends a new message (which gets -// picked up by a different replica), an in-flight relay dial to the -// OLD replica is canceled/discarded and the relay connects to the -// NEW replica correctly. -func TestSubscribeRelayStaleDialDiscardedAfterInterrupt(t *testing.T) { - t.Parallel() - - db, ps := dbtestutil.NewDB(t) - oldWorkerID := uuid.New() - newWorkerID := uuid.New() - subscriberID := uuid.New() - - // Gate to hold the first dial until we're ready. - firstDialStarted := make(chan struct{}) - releaseFirstDial := make(chan struct{}) - - var callCount atomic.Int32 - - provider := func(ctx context.Context, _ uuid.UUID, workerID uuid.UUID, _ http.Header) ( - []codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error, - ) { - call := callCount.Add(1) - ch := make(chan codersdk.ChatStreamEvent, 10) - if call == 1 { - // First dial (to old worker): signal that we started, - // then block until released or context canceled. - close(firstDialStarted) - select { - case <-releaseFirstDial: - case <-ctx.Done(): - return nil, nil, nil, ctx.Err() - } - // If we get here after being released (not canceled), - // return a stale part — this should be discarded. - ch <- codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessagePart, - MessagePart: &codersdk.ChatStreamMessagePart{ - Role: "assistant", - Part: codersdk.ChatMessageText("stale-part"), - }, - } - close(ch) - return nil, ch, func() {}, nil - } - // Second dial (to new worker): return a valid part. - ch <- codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessagePart, - MessagePart: &codersdk.ChatStreamMessagePart{ - Role: "assistant", - Part: codersdk.ChatMessageText("new-worker-part"), - }, - } - return nil, ch, func() {}, nil - } - - subscriber := newTestServer(t, db, ps, subscriberID, provider, nil) - - ctx := testutil.Context(t, testutil.WaitLong) - user, org, model := seedChatDependencies(t, db) - - // Seed the chat in waiting state so Subscribe does not try an initial - // relay. - chat := seedWaitingChat(t, db, org.ID, user, model, "stale-dial-test") - - // Subscribe while chat is in "waiting" state — no relay opened. - _, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0) - require.True(t, ok) - t.Cleanup(cancel) - - // Now simulate the chat being picked up by the OLD worker via pubsub. - // This triggers openRelayAsync in the merge loop. - _, err := db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ - ID: chat.ID, - Status: database.ChatStatusRunning, - WorkerID: uuid.NullUUID{UUID: oldWorkerID, Valid: true}, - StartedAt: sql.NullTime{Time: time.Now(), Valid: true}, - HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true}, - }) - require.NoError(t, err) - oldRunningNotify := coderdpubsub.ChatStreamNotifyMessage{ - Status: string(database.ChatStatusRunning), - WorkerID: oldWorkerID.String(), - } - oldRunningPayload, err := json.Marshal(oldRunningNotify) - require.NoError(t, err) - err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chat.ID), oldRunningPayload) - require.NoError(t, err) - - // Wait for the first dial goroutine to start (it's blocked in the provider). - select { - case <-firstDialStarted: - case <-ctx.Done(): - t.Fatal("timed out waiting for first dial to start") - } - - // Simulate interrupt: chat goes to "waiting". - _, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ - ID: chat.ID, - Status: database.ChatStatusWaiting, - }) - require.NoError(t, err) - waitingNotify := coderdpubsub.ChatStreamNotifyMessage{ - Status: string(database.ChatStatusWaiting), - } - waitingPayload, err := json.Marshal(waitingNotify) - require.NoError(t, err) - err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chat.ID), waitingPayload) - require.NoError(t, err) - - // Wait for the merge loop to process the waiting notification - // and emit the status event before publishing the new running - // notification. This avoids time.Sleep (banned by project - // policy) and provides a deterministic sync point. - require.Eventually(t, func() bool { - select { - case event := <-events: - return event.Type == codersdk.ChatStreamEventTypeStatus && - event.Status != nil && - event.Status.Status == codersdk.ChatStatusWaiting - default: - return false - } - }, testutil.WaitMedium, testutil.IntervalFast) - - // Now the chat transitions to running on the NEW worker. - _, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ - ID: chat.ID, - Status: database.ChatStatusRunning, - WorkerID: uuid.NullUUID{UUID: newWorkerID, Valid: true}, - StartedAt: sql.NullTime{Time: time.Now(), Valid: true}, - HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true}, - }) - require.NoError(t, err) - runningNotify := coderdpubsub.ChatStreamNotifyMessage{ - Status: string(database.ChatStatusRunning), - WorkerID: newWorkerID.String(), - } - runningPayload, err := json.Marshal(runningNotify) - require.NoError(t, err) - err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chat.ID), runningPayload) - require.NoError(t, err) - - // Now release the first dial (if it wasn't already canceled). - close(releaseFirstDial) - - // The subscriber should receive parts from the NEW worker, not the stale one. - require.Eventually(t, func() bool { - select { - case event := <-events: - if event.Type == codersdk.ChatStreamEventTypeMessagePart && - event.MessagePart != nil && - event.MessagePart.Part.Text == "new-worker-part" { - return true - } - // If we get the stale part, the bug is present. - if event.Type == codersdk.ChatStreamEventTypeMessagePart && - event.MessagePart != nil && - event.MessagePart.Part.Text == "stale-part" { - t.Fatal("received stale part from old worker — relay did not cancel in-flight dial") - } - return false - default: - return false - } - }, testutil.WaitMedium, testutil.IntervalFast) - - // Drain the events channel for a while to ensure no late-arriving - // stale part sneaks in after the require.Eventually above returned. - // This closes the timing gap where "stale-part" could arrive after - // "new-worker-part" was already consumed. - require.Never(t, func() bool { - select { - case event := <-events: - return event.Type == codersdk.ChatStreamEventTypeMessagePart && - event.MessagePart != nil && - event.MessagePart.Part.Text == "stale-part" - default: - return false - } - }, 2*time.Second, testutil.IntervalFast) -} - -// TestSubscribeCancelDuringInFlightDial verifies that calling the -// subscription's cancel function while a relay dial goroutine is -// still blocking in the provider causes the provider's context to -// be canceled and the goroutine to return cleanly. -func TestSubscribeCancelDuringInFlightDial(t *testing.T) { - t.Parallel() - - db, ps := dbtestutil.NewDB(t) - workerID := uuid.New() - subscriberID := uuid.New() - - dialStarted := make(chan struct{}) - dialExited := make(chan struct{}) - - provider := func(ctx context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) ( - []codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error, - ) { - // Signal the dial has started, then block until the context - // is canceled. - close(dialStarted) - <-ctx.Done() - close(dialExited) - return nil, nil, nil, ctx.Err() - } - - subscriber := newTestServer(t, db, ps, subscriberID, provider, nil) - - ctx := testutil.Context(t, testutil.WaitLong) - user, org, model := seedChatDependencies(t, db) - - // Seed the chat in waiting state so Subscribe does not open a - // synchronous relay. - chat := seedWaitingChat(t, db, org.ID, user, model, "cancel-inflight-dial") - - _, _, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0) - require.True(t, ok) - - // Publish a running notification to trigger openRelayAsync. - notify := coderdpubsub.ChatStreamNotifyMessage{ - Status: string(database.ChatStatusRunning), - WorkerID: workerID.String(), - } - payload, err := json.Marshal(notify) - require.NoError(t, err) - err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chat.ID), payload) - require.NoError(t, err) - - // Wait for the dial goroutine to block inside the provider. - select { - case <-dialStarted: - case <-ctx.Done(): - t.Fatal("timed out waiting for dial to start") - } - - // Cancel the subscription while the dial is still in-flight. - cancel() - - // The provider context must be canceled, causing the goroutine - // to return cleanly. - require.Eventually(t, func() bool { - select { - case <-dialExited: - return true - default: - return false - } - }, testutil.WaitMedium, testutil.IntervalFast) -} - -// TestSubscribeRelayRunningToRunningSwitch verifies that when a chat -// transitions directly from running(workerA) to running(workerB) -// without an intermediate waiting state, the relay switches to the -// new worker and discards parts from the old one. -func TestSubscribeRelayRunningToRunningSwitch(t *testing.T) { - t.Parallel() - - db, ps := dbtestutil.NewDB(t) - workerA := uuid.New() - workerB := uuid.New() - subscriberID := uuid.New() - - // Gate to hold workerA's dial until we verify cancellation. - dialAStarted := make(chan struct{}) - dialAExited := make(chan struct{}) - - var callCount atomic.Int32 - - provider := func(ctx context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) ( - []codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error, - ) { - call := callCount.Add(1) - if call == 1 { - // First dial (to workerA): signal that we started, - // then block until the context is canceled. - close(dialAStarted) - <-ctx.Done() - close(dialAExited) - return nil, nil, nil, ctx.Err() - } - // Second dial (to workerB): return a valid part. - ch := make(chan codersdk.ChatStreamEvent, 10) - ch <- codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessagePart, - MessagePart: &codersdk.ChatStreamMessagePart{ - Role: "assistant", - Part: codersdk.ChatMessageText("worker-b-part"), - }, - } - return nil, ch, func() {}, nil - } - - subscriber := newTestServer(t, db, ps, subscriberID, provider, nil) - - ctx := testutil.Context(t, testutil.WaitLong) - user, org, model := seedChatDependencies(t, db) - - // Seed the chat in waiting state so Subscribe does not open a relay. - chat := seedWaitingChat(t, db, org.ID, user, model, "running-to-running") - - _, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0) - require.True(t, ok) - t.Cleanup(cancel) - - // Transition to running on workerA. - notifyA := coderdpubsub.ChatStreamNotifyMessage{ - Status: string(database.ChatStatusRunning), - WorkerID: workerA.String(), - } - payloadA, err := json.Marshal(notifyA) - require.NoError(t, err) - err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chat.ID), payloadA) - require.NoError(t, err) - - // Wait for the workerA dial goroutine to block inside the - // provider before publishing the workerB notification. - select { - case <-dialAStarted: - case <-ctx.Done(): - t.Fatal("timed out waiting for workerA dial to start") - } - - // Immediately transition to running on workerB (no waiting in - // between). This should cancel workerA's in-flight dial. - notifyB := coderdpubsub.ChatStreamNotifyMessage{ - Status: string(database.ChatStatusRunning), - WorkerID: workerB.String(), - } - payloadB, err := json.Marshal(notifyB) - require.NoError(t, err) - err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chat.ID), payloadB) - require.NoError(t, err) - - // Verify that the relay canceled workerA's stale dial. - require.Eventually(t, func() bool { - select { - case <-dialAExited: - return true - default: - return false - } - }, testutil.WaitMedium, testutil.IntervalFast) - - // We should receive the part from workerB. - require.Eventually(t, func() bool { - select { - case event := <-events: - if event.Type == codersdk.ChatStreamEventTypeMessagePart && - event.MessagePart != nil && - event.MessagePart.Part.Text == "worker-b-part" { - return true - } - return false - default: - return false - } - }, testutil.WaitMedium, testutil.IntervalFast) - - require.Equal(t, 2, int(callCount.Load())) -} - -// TestSubscribeRelayFailedDialRetries verifies that when an async relay -// dial fails (returns an error), the merge loop schedules a reconnect -// timer and eventually re-dials successfully. This exercises the -// result.parts == nil path and the scheduleRelayReconnect() logic. -func TestSubscribeRelayFailedDialRetries(t *testing.T) { - t.Parallel() - - db, ps := dbtestutil.NewDB(t) - remoteWorkerID := uuid.New() - subscriberID := uuid.New() - - var callCount atomic.Int32 - - provider := func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) ( - []codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error, - ) { - call := callCount.Add(1) - if call == 1 { - // First dial: fail with an error to trigger - // scheduleRelayReconnect via the result.parts == nil path. - return nil, nil, nil, xerrors.New("transient dial failure") - } - // Second dial: succeed and return a part. - ch := make(chan codersdk.ChatStreamEvent, 10) - ch <- codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessagePart, - MessagePart: &codersdk.ChatStreamMessagePart{ - Role: "assistant", - Part: codersdk.ChatMessageText("retry-success"), - }, - } - return nil, ch, func() {}, nil - } - - mclk := quartz.NewMock(t) - // Trap the reconnect timer so we can fire it deterministically. - trapReconnect := mclk.Trap().NewTimer("reconnect") - defer trapReconnect.Close() - - subscriber := newTestServer(t, db, ps, subscriberID, provider, mclk) - - ctx := testutil.Context(t, testutil.WaitLong) - user, org, model := seedChatDependencies(t, db) - - // Seed the chat in waiting state so Subscribe does not open a - // synchronous relay dial. - chat := seedWaitingChat(t, db, org.ID, user, model, "failed-dial-retry") - - _, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0) - require.True(t, ok) - t.Cleanup(cancel) - - // Now mark the chat as running on the remote worker in the DB. - // The reconnect timer calls params.DB.GetChatByID to check if - // the chat is still running on a remote worker, so this must be - // set before we advance the clock. - _, err := db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ - ID: chat.ID, - Status: database.ChatStatusRunning, - WorkerID: uuid.NullUUID{UUID: remoteWorkerID, Valid: true}, - StartedAt: sql.NullTime{Time: time.Now(), Valid: true}, - HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true}, + session, err := dialer(context.Background(), osschatd.StreamPartsDialInput{ + ChatID: chatID, + WorkerID: workerID, + RequestHeader: headers, }) require.NoError(t, err) - - // Publish a running notification with a remote workerID to - // trigger openRelayAsync. The first dial will fail, causing - // scheduleRelayReconnect to be called. - notify := coderdpubsub.ChatStreamNotifyMessage{ - Status: string(database.ChatStatusRunning), - WorkerID: remoteWorkerID.String(), - } - payload, err := json.Marshal(notify) - require.NoError(t, err) - err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chat.ID), payload) - require.NoError(t, err) - - // Wait for the reconnect timer to be created (after the failed - // dial), then advance the mock clock to fire it. - trapReconnect.MustWait(ctx).MustRelease(ctx) - mclk.Advance(500 * time.Millisecond).MustWait(ctx) - - // The merge loop re-checks the DB, sees the chat is still - // running on the remote worker, and dials again. The second - // dial succeeds. - require.Eventually(t, func() bool { - select { - case event := <-events: - if event.Type == codersdk.ChatStreamEventTypeMessagePart && - event.MessagePart != nil && - event.MessagePart.Part.Text == "retry-success" { - return true - } - return false - default: - return false - } - }, testutil.WaitMedium, testutil.IntervalFast) - - require.GreaterOrEqual(t, int(callCount.Load()), 2) + require.Same(t, wantSession, session) + require.Equal(t, chatID, gotInput.ChatID) + require.Equal(t, workerID, gotInput.WorkerID) + require.Equal(t, "token-value", gotInput.RequestHeader.Get(codersdk.SessionTokenHeader)) } -// TestSubscribeRunningLocalWorkerClosesRelay verifies that when a chat -// is running on a remote worker and a pubsub notification arrives -// saying the local worker (subscriberID) now owns the chat, the -// existing relay is closed and no new dial is started (the local -// worker serves directly without relaying). -func TestSubscribeRunningLocalWorkerClosesRelay(t *testing.T) { +func TestStreamPartsDialerDialsPartsEndpoint(t *testing.T) { t.Parallel() - db, ps := dbtestutil.NewDB(t) - remoteWorkerID := uuid.New() - subscriberID := uuid.New() - - var callCount atomic.Int32 - - provider := func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) ( - []codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error, - ) { - call := callCount.Add(1) - ch := make(chan codersdk.ChatStreamEvent, 10) - if call == 1 { - // Initial synchronous dial to the remote worker. - ch <- codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessagePart, - MessagePart: &codersdk.ChatStreamMessagePart{ - Role: "assistant", - Part: codersdk.ChatMessageText("remote-part"), - }, - } - // Keep channel open so the relay stays active. - } - return nil, ch, func() {}, nil - } - - subscriber := newTestServer(t, db, ps, subscriberID, provider, nil) - - ctx := testutil.Context(t, testutil.WaitLong) - user, org, model := seedChatDependencies(t, db) - - chat := seedRemoteRunningChat( - ctx, - t, - db, - org.ID, - user, - model, - remoteWorkerID, - "local-worker-closes-relay", - ) - - _, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0) - require.True(t, ok) - t.Cleanup(cancel) - - // Consume the remote-part from the initial relay. - require.Eventually(t, func() bool { - select { - case event := <-events: - if event.Type == codersdk.ChatStreamEventTypeMessagePart && - event.MessagePart != nil && - event.MessagePart.Part.Text == "remote-part" { - return true - } - return false - default: - return false - } - }, testutil.WaitMedium, testutil.IntervalFast) - - // Notify that the LOCAL worker now owns the chat. This should - // close the relay without opening a new one. - notify := coderdpubsub.ChatStreamNotifyMessage{ - Status: string(database.ChatStatusRunning), - WorkerID: subscriberID.String(), - } - payload, err := json.Marshal(notify) - require.NoError(t, err) - err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chat.ID), payload) - require.NoError(t, err) - - // Give the system time to process the notification. No additional - // dial should happen — only the initial synchronous one. - require.Never(t, func() bool { - return int(callCount.Load()) > 1 - }, 2*time.Second, testutil.IntervalFast) - - require.Equal(t, 1, int(callCount.Load()), - "only the initial synchronous dial should have happened") -} - -// TestSubscribeRelayMultipleReconnects verifies that the reconnect -// loop handles multiple consecutive relay drops, proving it is -// robust across repeated iterations — not just the single reconnect -// already covered by TestSubscribeRelayReconnectsOnDrop. -func TestSubscribeRelayMultipleReconnects(t *testing.T) { - t.Parallel() - - db, ps := dbtestutil.NewDB(t) - workerID := uuid.New() - subscriberID := uuid.New() - - var callCount atomic.Int32 - - provider := func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ http.Header) ( - []codersdk.ChatStreamEvent, <-chan codersdk.ChatStreamEvent, func(), error, - ) { - call := callCount.Add(1) - ch := make(chan codersdk.ChatStreamEvent, 10) - part := codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessagePart, - MessagePart: &codersdk.ChatStreamMessagePart{ - Role: "assistant", - Part: codersdk.ChatMessagePart{ - Type: codersdk.ChatMessagePartTypeText, - Text: fmt.Sprintf("relay-%d", call), - }, - }, - } - ch <- part - if call <= 2 { - // First two dials: close channel to simulate relay - // drop. This triggers scheduleRelayReconnect. - close(ch) - } - // Third dial: keep channel open. - return nil, ch, func() {}, nil - } - - mclk := quartz.NewMock(t) - // Trap the reconnect timer so we can fire both reconnects - // deterministically. - trapReconnect := mclk.Trap().NewTimer("reconnect") - defer trapReconnect.Close() - - subscriber := newTestServer(t, db, ps, subscriberID, provider, mclk) - - ctx := testutil.Context(t, testutil.WaitLong) - user, org, model := seedChatDependencies(t, db) - - chat := seedRemoteRunningChat( - ctx, - t, - db, - org.ID, - user, - model, - workerID, - "multiple-reconnects", - ) - - _, events, cancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0) - require.True(t, ok) - t.Cleanup(cancel) - - // Helper to consume a specific relay part. - consumePart := func(text string) { - t.Helper() - require.Eventually(t, func() bool { - select { - case event := <-events: - if event.Type == codersdk.ChatStreamEventTypeMessagePart && - event.MessagePart != nil && - event.MessagePart.Part.Text == text { - return true - } - return false - default: - return false - } - }, testutil.WaitMedium, testutil.IntervalFast) - } - - // First relay: consumed immediately (synchronous dial). - consumePart("relay-1") - - // First relay drops → reconnect timer created. Advance clock - // to fire it. - trapReconnect.MustWait(ctx).MustRelease(ctx) - mclk.Advance(500 * time.Millisecond).MustWait(ctx) - - // Second relay part. - consumePart("relay-2") - - // Second relay drops → another reconnect timer. Advance again. - trapReconnect.MustWait(ctx).MustRelease(ctx) - mclk.Advance(500 * time.Millisecond).MustWait(ctx) - - // Third relay part (channel stays open). - consumePart("relay-3") - require.GreaterOrEqual(t, int(callCount.Load()), 3) -} - -// TestSubscribeRelayDialCanceledOnFastCompletion verifies that a -// subscriber on a remote replica still sees the committed assistant -// response when the worker completes faster than the relay dial. -// -// Scenario: -// 1. Subscriber subscribes to a chat while it's in waiting state (no relay). -// 2. User sends a message → chat becomes pending → worker picks it up. -// 3. Subscriber receives status=running via pubsub → enterprise opens relay async. -// 4. Worker completes quickly → publishes committed message + status=waiting. -// 5. Subscriber receives status=waiting → enterprise cancels the in-progress relay dial. -// 6. Even though the relay never delivered streaming parts, the -// committed assistant message arrives via pubsub so the user -// does not need to refresh to see the response. -// -// Streaming parts for committed turns are intentionally NOT replayed -// via the relay: they would duplicate the durable message on the -// user's screen. The buffer retains in-progress parts only; once an -// assistant turn commits, the parts that built it are claimed by -// the durable message ID and dropped from new buffer snapshots. -func TestSubscribeRelayDialCanceledOnFastCompletion(t *testing.T) { - t.Parallel() - t.Skip(skipLegacyChatStream) - - db, ps := dbtestutil.NewDB(t) + chatID := uuid.New() workerID := uuid.New() - subscriberID := uuid.New() - - var dialAttempted atomic.Bool - - // Gate: closed when the worker finishes processing. - workerDone := make(chan struct{}) - - openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { - if !req.Stream { - return chattest.OpenAINonStreamingResponse("fast-completion-relay-race") - } - return chattest.OpenAIStreamingResponse( - chattest.OpenAITextChunks("hello ", "world ", "from ", "the ", "worker")..., - ) - }) + replicaID := uuid.New() + received := make(chan http.Header, 1) + + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + require.Equal(t, "/api/experimental/chats/"+chatID.String()+"/stream/parts", r.URL.Path) + require.Empty(t, r.URL.RawQuery) + received <- r.Header.Clone() + conn, err := websocket.Accept(rw, r, nil) + require.NoError(t, err) + _ = conn.Close(websocket.StatusNormalClosure, "") + })) + t.Cleanup(server.Close) - // Worker server with a 1-hour acquire interval so it only processes - // when explicitly woken by SendMessage's signalWake. - workerLogger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) - worker := osschatd.New(osschatd.Config{ - Logger: workerLogger, - Database: db, - ReplicaID: workerID, - Pubsub: ps, - PendingChatAcquireInterval: time.Hour, - InFlightChatStaleAfter: testutil.WaitSuperLong, - }) - worker.Start() - t.Cleanup(func() { - require.NoError(t, worker.Close()) + dialer := entchatd.NewStreamPartsDialer(entchatd.StreamPartsDialerConfig{ + ResolveReplicaAddress: func(_ context.Context, gotWorker uuid.UUID) (string, bool) { + require.Equal(t, workerID, gotWorker) + return server.URL, true + }, + ReplicaHTTPClient: server.Client(), + ReplicaIDFn: func() uuid.UUID { return replicaID }, }) - // Subscriber's relay dialer blocks until the worker finishes, - // simulating a slow relay dial (network latency between replicas). - // After the worker completes, the dialer connects to the worker - // to retrieve buffered parts from the retained buffer. - subscriber := newTestServer(t, db, ps, subscriberID, func( - ctx context.Context, - chatID uuid.UUID, - targetWorkerID uuid.UUID, - requestHeader http.Header, - ) ( - []codersdk.ChatStreamEvent, - <-chan codersdk.ChatStreamEvent, - func(), - error, - ) { - dialAttempted.Store(true) - // Block until the worker finishes processing, simulating - // a slow relay dial. - select { - case <-workerDone: - case <-ctx.Done(): - return nil, nil, nil, ctx.Err() - } - // Connect to the worker. The buffer is retained for a - // grace period after processing, so the relay session - // can complete (control events, status updates) even - // though every part has been claimed by its durable - // message and the snapshot is empty. - snapshot, relayEvents, cancel, ok := worker.Subscribe(ctx, chatID, requestHeader, math.MaxInt64) - if !ok { - return nil, nil, nil, xerrors.New("worker subscribe failed") - } - return snapshot, relayEvents, cancel, nil - }, nil) - - ctx := testutil.Context(t, testutil.WaitLong) - user, org, model := seedChatDependencies(t, db) - setOpenAIProviderBaseURL(ctx, t, db, openAIURL) - - // Create the chat in waiting state so the subscriber sees it - // before the worker picks it up (avoids the synchronous relay - // path in Subscribe). - chat := seedWaitingChat(t, db, org.ID, user, model, "fast-completion-relay-race") - - // Subscribe from the subscriber replica while the chat is idle. - // No relay is opened because the chat is in waiting state. - _, events, subCancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0) - require.True(t, ok) - defer subCancel() - - // Send a message via the worker server to transition the chat to - // pending and wake the worker's processing loop. - _, err := worker.SendMessage(ctx, osschatd.SendMessageOptions{ - ChatID: chat.ID, - CreatedBy: user.ID, - Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")}, + session, err := dialer(context.Background(), osschatd.StreamPartsDialInput{ + ChatID: chatID, + WorkerID: workerID, + RequestHeader: http.Header{ + codersdk.SessionTokenHeader: {"session-token"}, + }, }) require.NoError(t, err) + require.NotNil(t, session) + require.NoError(t, session.Close()) - // Wait for the worker to fully process the chat. - require.Eventually(t, func() bool { - fromDB, dbErr := db.GetChatByID(ctx, chat.ID) - if dbErr != nil { - return false - } - return fromDB.Status == database.ChatStatusWaiting - }, testutil.WaitMedium, testutil.IntervalFast) - - // Release the relay dial now that the worker is done. - close(workerDone) - - // Collect events that arrived at the subscriber. The committed - // assistant message is guaranteed to arrive via pubsub even when - // the relay dial races worker completion; streaming parts are - // best-effort and are not asserted here because the buffer drops - // already-committed parts to prevent duplicate UI rendering. - var committedAssistantMsgs int - - require.Eventually(t, func() bool { - select { - case event := <-events: - if event.Type == codersdk.ChatStreamEventTypeMessage && - event.Message != nil && - event.Message.Role == codersdk.ChatMessageRoleAssistant { - committedAssistantMsgs++ - } - return committedAssistantMsgs > 0 - default: - return false - } - }, testutil.WaitLong, testutil.IntervalFast) - - // The committed assistant message arrives via pubsub → DB query - // (durable path). - require.Equal(t, 1, committedAssistantMsgs, - "committed assistant message should arrive via pubsub durable path") - - // The relay dial was attempted when status=running arrived. - require.True(t, dialAttempted.Load(), - "relay dial should have been attempted when status changed to running") + headers := <-received + require.Equal(t, "session-token", headers.Get(codersdk.SessionTokenHeader)) + require.Equal(t, replicaID.String(), headers.Get(entchatd.RelaySourceHeader)) } -// TestSubscribeRelayEstablishedMidStream demonstrates that when the -// relay is established while the worker is still streaming, the -// subscriber receives buffered parts via the relay snapshot and live -// parts through the relay channel. -// -// This is the complementary test to TestSubscribeRelayDialCanceledOnFastCompletion: -// it shows the relay mechanism works correctly when timing is favorable -// (relay connects before the worker finishes), contrasting with the race -// condition where the relay is too slow. -func TestSubscribeRelayEstablishedMidStream(t *testing.T) { +func TestStreamPartsDialerClassifiesHTTPFailures(t *testing.T) { t.Parallel() - // TODO(CODAGT-353): Re-enable this test after the chatd notification flow - // refactor gives workers enough causal information to distinguish stale - // control NOTIFY messages from real interrupts. The current design reuses - // the same status notification shape for wake-only and interrupt intents, - // so a stale NOTIFY can cancel a new processChat run. - t.Skip("skipped until chatd notification flow refactor handles stale control notifications") - db, ps := dbtestutil.NewDB(t) + chatID := uuid.New() workerID := uuid.New() - subscriberID := uuid.New() - - // Gate: worker blocks after first streaming request until we - // release it. This gives the relay time to establish. - firstChunkEmitted := make(chan struct{}) - continueStreaming := make(chan struct{}) - - openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { - if !req.Stream { - return chattest.OpenAINonStreamingResponse("mid-stream-relay") - } - // Signal that the first streaming request was received, - // then block until released. - select { - case <-firstChunkEmitted: - default: - close(firstChunkEmitted) - } - <-continueStreaming - return chattest.OpenAIStreamingResponse( - chattest.OpenAITextChunks("continued ", "response")..., - ) - }) + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) { + http.Error(rw, "nope", http.StatusUnauthorized) + })) + t.Cleanup(server.Close) - // Worker with a short fallback poll interval. The primary - // trigger is signalWake() from SendMessage, but under heavy - // CI load the wake goroutine may be delayed. A short poll - // ensures the worker always picks up the pending chat. - workerLogger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) - worker := osschatd.New(osschatd.Config{ - Logger: workerLogger, - Database: db, - ReplicaID: workerID, - Pubsub: ps, - PendingChatAcquireInterval: time.Second, - InFlightChatStaleAfter: testutil.WaitSuperLong, - }) - worker.Start() - t.Cleanup(func() { - require.NoError(t, worker.Close()) + dialer := entchatd.NewStreamPartsDialer(entchatd.StreamPartsDialerConfig{ + ResolveReplicaAddress: func(context.Context, uuid.UUID) (string, bool) { return server.URL, true }, + ReplicaHTTPClient: server.Client(), + ReplicaIDFn: uuid.New, }) - // Subscriber's dialer connects to the worker with no delay. - // This simulates a relay that succeeds promptly. - subscriber := newTestServer(t, db, ps, subscriberID, func( - ctx context.Context, - chatID uuid.UUID, - targetWorkerID uuid.UUID, - requestHeader http.Header, - ) ( - []codersdk.ChatStreamEvent, - <-chan codersdk.ChatStreamEvent, - func(), - error, - ) { - if targetWorkerID != workerID { - return nil, nil, nil, xerrors.Errorf("unexpected relay target %s", targetWorkerID) - } - snapshot, relayEvents, cancel, ok := worker.Subscribe(ctx, chatID, requestHeader, math.MaxInt64) - if !ok { - return nil, nil, nil, xerrors.New("worker subscribe failed") - } - return snapshot, relayEvents, cancel, nil - }, nil) - - // Use WaitSuperLong so the test survives heavy CI contention. - // The worker pipeline (model resolution, message loading, LLM - // call) involves multiple DB round-trips that can be slow under - // load. - ctx := testutil.Context(t, testutil.WaitSuperLong) - user, org, model := seedChatDependencies(t, db) - setOpenAIProviderBaseURL(ctx, t, db, openAIURL) - - // Create the chat in waiting state. - chat := seedWaitingChat(t, db, org.ID, user, model, "mid-stream-relay") - - // Subscribe from the subscriber replica while the chat is idle. - _, events, subCancel, ok := subscriber.Subscribe(ctx, chat.ID, nil, 0) - require.True(t, ok) - defer subCancel() - - // Send a message to make the chat pending and wake the worker. - _, err := worker.SendMessage(ctx, osschatd.SendMessageOptions{ - ChatID: chat.ID, - CreatedBy: user.ID, - Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")}, + session, err := dialer(context.Background(), osschatd.StreamPartsDialInput{ + ChatID: chatID, + WorkerID: workerID, }) - require.NoError(t, err) - - // Wait for the worker to reach the LLM (first streaming - // request). Also poll the chat status so we fail fast with a - // clear message if the worker errors out instead of timing - // out silently. - ticker := time.NewTicker(250 * time.Millisecond) - defer ticker.Stop() -waitForStream: - for { - select { - case <-firstChunkEmitted: - break waitForStream - case <-ticker.C: - currentChat, dbErr := db.GetChatByID(ctx, chat.ID) - if dbErr == nil && currentChat.Status == database.ChatStatusError { - t.Fatalf("worker failed to process chat: status=%s last_error=%s", - currentChat.Status, chatLastErrorMessage(currentChat.LastError)) - } - case <-ctx.Done(): - // Dump the final chat status for debugging. - currentChat, dbErr := db.GetChatByID(context.Background(), chat.ID) - if dbErr == nil { - t.Fatalf("timed out waiting for worker to start streaming (chat status=%s, last_error=%q)", - currentChat.Status, chatLastErrorMessage(currentChat.LastError)) - } - t.Fatal("timed out waiting for worker to start streaming") - } - } - - // Wait for the subscriber to receive the running status, which - // triggers the relay. Because the dialer is non-blocking, the - // relay establishes promptly. - require.Eventually(t, func() bool { - select { - case event := <-events: - return event.Type == codersdk.ChatStreamEventTypeStatus && - event.Status != nil && - event.Status.Status == codersdk.ChatStatusRunning - default: - return false - } - }, testutil.WaitMedium, testutil.IntervalFast) - - // Now release the worker to continue streaming. - close(continueStreaming) - - // Wait for the worker to complete. - require.Eventually(t, func() bool { - fromDB, dbErr := db.GetChatByID(ctx, chat.ID) - if dbErr != nil { - return false - } - return fromDB.Status == database.ChatStatusWaiting - }, testutil.WaitMedium, testutil.IntervalFast) - - // Collect remaining events. - var messageParts []string - var hasCommittedMsg bool - - require.Eventually(t, func() bool { - select { - case event := <-events: - switch event.Type { - case codersdk.ChatStreamEventTypeMessagePart: - if event.MessagePart != nil { - messageParts = append(messageParts, event.MessagePart.Part.Text) - } - case codersdk.ChatStreamEventTypeMessage: - if event.Message != nil && event.Message.Role == codersdk.ChatMessageRoleAssistant { - hasCommittedMsg = true - } - } - return hasCommittedMsg - default: - return false - } - }, testutil.WaitLong, testutil.IntervalFast) - - // The committed message arrives via pubsub. - require.True(t, hasCommittedMsg, - "committed assistant message should arrive") - - // When the relay is established mid-stream, streaming parts - // SHOULD be received through the relay. This contrasts with - // TestSubscribeRelayDialCanceledOnFastCompletion where no parts - // arrive because the relay is never established. - require.NotEmpty(t, messageParts, - "streaming parts should be received when relay establishes while worker is still streaming") + require.Nil(t, session) + var dialErr *entchatd.RelayDialError + require.ErrorAs(t, err, &dialErr) + require.Equal(t, http.StatusUnauthorized, dialErr.HTTPStatus) + require.True(t, dialErr.IsUnrecoverable()) } diff --git a/scripts/metricsdocgen/generated_metrics b/scripts/metricsdocgen/generated_metrics index da019143dfc87..76d25ef341ade 100644 --- a/scripts/metricsdocgen/generated_metrics +++ b/scripts/metricsdocgen/generated_metrics @@ -250,21 +250,9 @@ coderd_chatd_steps_total{provider="",model=""} 0 # HELP coderd_chatd_stream_buffer_dropped_total Number of chat stream buffer events dropped due to the per-chat buffer cap. # TYPE coderd_chatd_stream_buffer_dropped_total counter coderd_chatd_stream_buffer_dropped_total 0 -# HELP coderd_chatd_stream_buffer_events Sum of current buffer lengths across all chat streams. -# TYPE coderd_chatd_stream_buffer_events gauge -coderd_chatd_stream_buffer_events 0 -# HELP coderd_chatd_stream_buffer_size_max Maximum current buffer length across all chat streams. -# TYPE coderd_chatd_stream_buffer_size_max gauge -coderd_chatd_stream_buffer_size_max 0 # HELP coderd_chatd_stream_retries_total Total LLM stream retries. # TYPE coderd_chatd_stream_retries_total counter coderd_chatd_stream_retries_total{provider="",model="",kind="",chain_broken=""} 0 -# HELP coderd_chatd_stream_subscribers Current number of chat stream subscribers across all chat streams. -# TYPE coderd_chatd_stream_subscribers gauge -coderd_chatd_stream_subscribers 0 -# HELP coderd_chatd_streams_active Current number of chat stream state entries (in-flight plus retained). -# TYPE coderd_chatd_streams_active gauge -coderd_chatd_streams_active 0 # HELP coderd_chatd_tool_errors_total Total tool calls that returned an error result. # TYPE coderd_chatd_tool_errors_total counter coderd_chatd_tool_errors_total{provider="",model="",tool_name=""} 0 diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 8f07b5563a5ae..30bd4f5afd6f3 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -2798,8 +2798,10 @@ export interface ChatStreamEvent { export type ChatStreamEventType = | "action_required" | "error" + | "history_reset" | "message" | "message_part" + | "preview_reset" | "queue_update" | "retry" | "status"; @@ -2807,8 +2809,10 @@ export type ChatStreamEventType = export const ChatStreamEventTypes: ChatStreamEventType[] = [ "action_required", "error", + "history_reset", "message", "message_part", + "preview_reset", "queue_update", "retry", "status", @@ -2821,6 +2825,9 @@ export const ChatStreamEventTypes: ChatStreamEventType[] = [ export interface ChatStreamMessagePart { readonly role?: ChatMessageRole; readonly part: ChatMessagePart; + readonly history_version?: number; + readonly generation_attempt?: number; + readonly seq?: number; } // From codersdk/chats.go