From 903f4be6594bfae32eb9837f06660d23ff251d56 Mon Sep 17 00:00:00 2001 From: Hugo Dutka Date: Thu, 18 Jun 2026 10:11:10 +0000 Subject: [PATCH] fix: address message part buffer review items --- .../messagepartbuffer/message_part_buffer.go | 122 +++++++++++++----- .../message_part_buffer_test.go | 2 +- 2 files changed, 91 insertions(+), 33 deletions(-) diff --git a/coderd/x/chatd/messagepartbuffer/message_part_buffer.go b/coderd/x/chatd/messagepartbuffer/message_part_buffer.go index a41c91c293c75..9b14c3287e400 100644 --- a/coderd/x/chatd/messagepartbuffer/message_part_buffer.go +++ b/coderd/x/chatd/messagepartbuffer/message_part_buffer.go @@ -1,9 +1,25 @@ +// Package messagepartbuffer stores the transient message-part stream that a +// chat worker emits before those parts are committed to durable chat history. +// +// Chat generation has two consumers with different timing. Stream endpoints +// need to forward parts immediately, while interruption handling may need to +// recover the partial assistant or tool message and commit it. Buffer groups +// parts by an episode key that includes the chat, history version, and +// generation attempt so stale workers and late subscribers do not mix parts +// from different generations. +// +// Episodes are intentionally in-memory. They are closed when a generation +// attempt ends, then retained briefly so stream subscribers and interruption +// cleanup can drain the final parts. The cleanup loop removes closed episodes +// after the retention window. Never-created placeholders are removed during +// subscriber teardown, when the last early subscriber leaves. package messagepartbuffer import ( "container/heap" "context" "encoding/json" + "slices" "sync" "time" @@ -18,7 +34,6 @@ const ( defaultMaxEpisodeBytes = int64(1024 * 1024) defaultClosedEpisodeRetention = 15 * time.Second defaultSubscriberSendTimeout = 10 * time.Second - defaultSubscriberChannelSize = 16 ) var ( @@ -67,7 +82,6 @@ type Options struct { MaxEpisodeBytes int64 ClosedEpisodeRetention time.Duration SubscriberSendTimeout time.Duration - SubscriberChannelSize int Clock quartz.Clock } @@ -148,22 +162,23 @@ func New(options Options) *Buffer { if options.SubscriberSendTimeout <= 0 { options.SubscriberSendTimeout = defaultSubscriberSendTimeout } - if options.SubscriberChannelSize <= 0 { - options.SubscriberChannelSize = defaultSubscriberChannelSize - } if options.Clock == nil { options.Clock = quartz.NewReal() } buffer := &Buffer{ opts: options, episodes: make(map[Key]*episodeState), - done: make(chan struct{}), + // done is unbuffered because it's only ever closed - never sent on. + done: make(chan struct{}), } buffer.startCleanupLoop() return buffer } // CreateEpisode creates a new episode. +// +// Subscribers may attach before an episode is created. Creating the episode +// makes it eligible to receive parts; the first AddPart wakes early subscribers. func (b *Buffer) CreateEpisode(key Key) error { b.mu.Lock() defer b.mu.Unlock() @@ -171,24 +186,27 @@ func (b *Buffer) CreateEpisode(key Key) error { return ErrMessagePartBufferClosed } b.gcClosedEpisodesLocked(b.opts.Clock.Now("message-part-buffer", "create")) - episode := b.episodeLocked(key) + episode := b.getOrCreateEpisodeLocked(key) if episode.created { return ErrEpisodeExists } - episode.created = true + episode.markCreated() return nil } // AddPart appends a part to an existing episode. +// +// Parts receive contiguous sequence numbers so stream endpoints can detect +// stale or broken episode subscriptions before forwarding data to clients. func (b *Buffer) AddPart(key Key, role codersdk.ChatMessageRole, part codersdk.ChatMessagePart) error { b.mu.Lock() defer b.mu.Unlock() if b.closed { return ErrMessagePartBufferClosed } - episode := b.episodes[key] - if episode == nil || !episode.created { - return ErrEpisodeNotFound + episode, err := b.getEpisodeLocked(key) + if err != nil { + return err } if episode.closed { return ErrEpisodeClosed @@ -207,34 +225,33 @@ func (b *Buffer) AddPart(key Key, role codersdk.ChatMessageRole, part codersdk.C } episode.parts = append(episode.parts, buffered) episode.bytes += sizeBytes - for subscriber := range episode.subscribers { - notifySubscriber(subscriber) - } + episode.notifySubscribers() return nil } -// CloseEpisode marks an episode closed and closes its subscribers. +// CloseEpisode marks an episode closed. +// +// Closing creates the episode if it did not exist yet. This lets interruption +// cleanup converge when a worker exits before it publishes any parts. func (b *Buffer) CloseEpisode(key Key) error { b.mu.Lock() defer b.mu.Unlock() if b.closed { return ErrMessagePartBufferClosed } - episode := b.episodeLocked(key) - episode.created = true - if episode.closed { + episode := b.getOrCreateEpisodeLocked(key) + if !episode.close(b.opts.Clock.Now("message-part-buffer", "close")) { return nil } - episode.closed = true - episode.closedAt = b.opts.Clock.Now("message-part-buffer", "close") b.queueClosedEpisodeLocked(key, episode) - for subscriber := range episode.subscribers { - notifySubscriber(subscriber) - } + episode.notifySubscribers() return nil } // GetParts returns a snapshot of buffered parts for an episode. +// +// The returned slice is detached from the buffer so callers can process it +// without holding the buffer lock. func (b *Buffer) GetParts(key Key) ([]Part, error) { b.mu.Lock() defer b.mu.Unlock() @@ -242,25 +259,37 @@ func (b *Buffer) GetParts(key Key) ([]Part, error) { return nil, ErrMessagePartBufferClosed } b.gcClosedEpisodesLocked(b.opts.Clock.Now("message-part-buffer", "get")) - episode := b.episodes[key] - if episode == nil || !episode.created { - return nil, ErrEpisodeNotFound + episode, err := b.getEpisodeLocked(key) + if err != nil { + return nil, err } - return append([]Part(nil), episode.parts...), nil + return slices.Clone(episode.parts), nil } // SubscribeToEpisode replays existing parts and streams new parts. +// +// Subscribers may attach before CreateEpisode is called. In that case the +// subscription stays idle until the first part added, closure, cancellation, +// or buffer shutdown. The returned cancel function is idempotent. func (b *Buffer) SubscribeToEpisode(ctx context.Context, key Key) (<-chan Part, func(), error) { b.mu.Lock() if b.closed { b.mu.Unlock() return nil, nil, ErrMessagePartBufferClosed } - episode := b.episodeLocked(key) + episode := b.getOrCreateEpisodeLocked(key) subscriber := &episodeSubscriber{ - out: make(chan Part), + // out is unbuffered so the delivery goroutine only advances once the + // subscriber has accepted each part. The send timeout bounds how long + // an unresponsive subscriber can keep its episode retained. + out: make(chan Part), + // notifyCh is a one-slot wakeup channel. Additional wakeups can be + // coalesced because the delivery goroutine copies all available parts + // each time it wakes. notifyCh: make(chan struct{}, 1), - stopCh: make(chan struct{}), + // stopCh is unbuffered because stop only closes it. Closing does not + // block and every select that observes it treats it as cancellation. + stopCh: make(chan struct{}), } if episode.subscribers == nil { episode.subscribers = make(map[*episodeSubscriber]struct{}) @@ -358,7 +387,7 @@ func (b *Buffer) queueClosedEpisodeLocked(key Key, episode *episodeState) { heap.Push(&b.closedEpisodes, item) } -func (b *Buffer) episodeLocked(key Key) *episodeState { +func (b *Buffer) getOrCreateEpisodeLocked(key Key) *episodeState { episode := b.episodes[key] if episode != nil { return episode @@ -368,6 +397,14 @@ func (b *Buffer) episodeLocked(key Key) *episodeState { return episode } +func (b *Buffer) getEpisodeLocked(key Key) (*episodeState, error) { + episode := b.episodes[key] + if episode == nil || !episode.created { + return nil, ErrEpisodeNotFound + } + return episode, nil +} + func (b *Buffer) subscriberParts(key Key, subscriber *episodeSubscriber) (parts []Part, closed bool, ok bool) { b.mu.Lock() defer b.mu.Unlock() @@ -384,7 +421,7 @@ func (b *Buffer) subscriberParts(key Key, subscriber *episodeSubscriber) (parts if subscriber.next > len(episode.parts) { return nil, false, false } - parts = append([]Part(nil), episode.parts[subscriber.next:]...) + parts = slices.Clone(episode.parts[subscriber.next:]) subscriber.next = len(episode.parts) return parts, episode.closed && subscriber.next == len(episode.parts), true } @@ -481,6 +518,27 @@ func notifySubscriber(subscriber *episodeSubscriber) { } } +func (e *episodeState) markCreated() { + e.created = true +} + +// close marks the episode closed and returns false if it was already closed. +func (e *episodeState) close(now time.Time) bool { + e.markCreated() + if e.closed { + return false + } + e.closed = true + e.closedAt = now + return true +} + +func (e *episodeState) notifySubscribers() { + for subscriber := range e.subscribers { + notifySubscriber(subscriber) + } +} + func (s *episodeSubscriber) stop() { s.stopOnce.Do(func() { close(s.stopCh) }) } diff --git a/coderd/x/chatd/messagepartbuffer/message_part_buffer_test.go b/coderd/x/chatd/messagepartbuffer/message_part_buffer_test.go index 475f2dfa1b118..46ff27e5f3565 100644 --- a/coderd/x/chatd/messagepartbuffer/message_part_buffer_test.go +++ b/coderd/x/chatd/messagepartbuffer/message_part_buffer_test.go @@ -291,7 +291,7 @@ func TestBuffer_SlowSubscriberClosed(t *testing.T) { func TestBuffer_BurstyOutputDoesNotCloseSubscriberBeforeSendTimeout(t *testing.T) { t.Parallel() - buffer := messagepartbuffer.New(messagepartbuffer.Options{SubscriberChannelSize: 1}) + buffer := messagepartbuffer.New(messagepartbuffer.Options{}) defer buffer.Close() key := testEpisodeKey() require.NoError(t, buffer.CreateEpisode(key))