Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 90 additions & 32 deletions coderd/x/chatd/messagepartbuffer/message_part_buffer.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,25 @@
// Package messagepartbuffer stores the transient message-part stream that a
// chat worker emits before those parts are committed to durable chat history.
//
// Chat generation has two consumers with different timing. Stream endpoints
// need to forward parts immediately, while interruption handling may need to
// recover the partial assistant or tool message and commit it. Buffer groups
// parts by an episode key that includes the chat, history version, and
// generation attempt so stale workers and late subscribers do not mix parts
// from different generations.
//
// Episodes are intentionally in-memory. They are closed when a generation
// attempt ends, then retained briefly so stream subscribers and interruption
// cleanup can drain the final parts. The cleanup loop removes closed episodes
Comment thread
hugodutka marked this conversation as resolved.
// after the retention window. Never-created placeholders are removed during
// subscriber teardown, when the last early subscriber leaves.
package messagepartbuffer

import (
"container/heap"
"context"
"encoding/json"
"slices"
"sync"
"time"

Expand All @@ -18,7 +34,6 @@ const (
defaultMaxEpisodeBytes = int64(1024 * 1024)
defaultClosedEpisodeRetention = 15 * time.Second
defaultSubscriberSendTimeout = 10 * time.Second
defaultSubscriberChannelSize = 16
)

var (
Expand Down Expand Up @@ -67,7 +82,6 @@ type Options struct {
MaxEpisodeBytes int64
ClosedEpisodeRetention time.Duration
SubscriberSendTimeout time.Duration
SubscriberChannelSize int
Clock quartz.Clock
}

Expand Down Expand Up @@ -148,47 +162,51 @@ func New(options Options) *Buffer {
if options.SubscriberSendTimeout <= 0 {
options.SubscriberSendTimeout = defaultSubscriberSendTimeout
}
if options.SubscriberChannelSize <= 0 {
options.SubscriberChannelSize = defaultSubscriberChannelSize
}
if options.Clock == nil {
options.Clock = quartz.NewReal()
}
buffer := &Buffer{
opts: options,
episodes: make(map[Key]*episodeState),
done: make(chan struct{}),
// done is unbuffered because it's only ever closed - never sent on.
Comment thread
hugodutka marked this conversation as resolved.
done: make(chan struct{}),
}
buffer.startCleanupLoop()
return buffer
}

// CreateEpisode creates a new episode.
//
// Subscribers may attach before an episode is created. Creating the episode
Comment thread
hugodutka marked this conversation as resolved.
// makes it eligible to receive parts; the first AddPart wakes early subscribers.
func (b *Buffer) CreateEpisode(key Key) error {
b.mu.Lock()
defer b.mu.Unlock()
if b.closed {
return ErrMessagePartBufferClosed
}
b.gcClosedEpisodesLocked(b.opts.Clock.Now("message-part-buffer", "create"))
episode := b.episodeLocked(key)
episode := b.getOrCreateEpisodeLocked(key)
if episode.created {
return ErrEpisodeExists
}
episode.created = true
episode.markCreated()
return nil
}

// AddPart appends a part to an existing episode.
//
// Parts receive contiguous sequence numbers so stream endpoints can detect
// stale or broken episode subscriptions before forwarding data to clients.
Comment thread
hugodutka marked this conversation as resolved.
func (b *Buffer) AddPart(key Key, role codersdk.ChatMessageRole, part codersdk.ChatMessagePart) error {
b.mu.Lock()
defer b.mu.Unlock()
if b.closed {
return ErrMessagePartBufferClosed
}
episode := b.episodes[key]
if episode == nil || !episode.created {
return ErrEpisodeNotFound
episode, err := b.getEpisodeLocked(key)
if err != nil {
return err
}
if episode.closed {
return ErrEpisodeClosed
Expand All @@ -207,60 +225,71 @@ func (b *Buffer) AddPart(key Key, role codersdk.ChatMessageRole, part codersdk.C
}
episode.parts = append(episode.parts, buffered)
episode.bytes += sizeBytes
for subscriber := range episode.subscribers {
notifySubscriber(subscriber)
}
episode.notifySubscribers()
return nil
}

// CloseEpisode marks an episode closed and closes its subscribers.
// CloseEpisode marks an episode closed.
//
// Closing creates the episode if it did not exist yet. This lets interruption
// cleanup converge when a worker exits before it publishes any parts.
func (b *Buffer) CloseEpisode(key Key) error {
b.mu.Lock()
defer b.mu.Unlock()
if b.closed {
return ErrMessagePartBufferClosed
}
episode := b.episodeLocked(key)
Comment thread
hugodutka marked this conversation as resolved.
episode.created = true
if episode.closed {
episode := b.getOrCreateEpisodeLocked(key)
if !episode.close(b.opts.Clock.Now("message-part-buffer", "close")) {
return nil
}
episode.closed = true
episode.closedAt = b.opts.Clock.Now("message-part-buffer", "close")
b.queueClosedEpisodeLocked(key, episode)
for subscriber := range episode.subscribers {
notifySubscriber(subscriber)
Comment thread
hugodutka marked this conversation as resolved.
}
episode.notifySubscribers()
return nil
}

// GetParts returns a snapshot of buffered parts for an episode.
//
// The returned slice is detached from the buffer so callers can process it
// without holding the buffer lock.
func (b *Buffer) GetParts(key Key) ([]Part, error) {
b.mu.Lock()
defer b.mu.Unlock()
if b.closed {
return nil, ErrMessagePartBufferClosed
}
b.gcClosedEpisodesLocked(b.opts.Clock.Now("message-part-buffer", "get"))
episode := b.episodes[key]
if episode == nil || !episode.created {
return nil, ErrEpisodeNotFound
episode, err := b.getEpisodeLocked(key)
if err != nil {
return nil, err
}
return append([]Part(nil), episode.parts...), nil
return slices.Clone(episode.parts), nil
}

// SubscribeToEpisode replays existing parts and streams new parts.
//
// Subscribers may attach before CreateEpisode is called. In that case the
Comment thread
hugodutka marked this conversation as resolved.
// subscription stays idle until the first part added, closure, cancellation,
// or buffer shutdown. The returned cancel function is idempotent.
func (b *Buffer) SubscribeToEpisode(ctx context.Context, key Key) (<-chan Part, func(), error) {
b.mu.Lock()
if b.closed {
b.mu.Unlock()
return nil, nil, ErrMessagePartBufferClosed
}
episode := b.episodeLocked(key)
episode := b.getOrCreateEpisodeLocked(key)
subscriber := &episodeSubscriber{
out: make(chan Part),
// out is unbuffered so the delivery goroutine only advances once the
Comment thread
hugodutka marked this conversation as resolved.
// subscriber has accepted each part. The send timeout bounds how long
// an unresponsive subscriber can keep its episode retained.
out: make(chan Part),
// notifyCh is a one-slot wakeup channel. Additional wakeups can be
// coalesced because the delivery goroutine copies all available parts
// each time it wakes.
notifyCh: make(chan struct{}, 1),
stopCh: make(chan struct{}),
// stopCh is unbuffered because stop only closes it. Closing does not
// block and every select that observes it treats it as cancellation.
stopCh: make(chan struct{}),
}
if episode.subscribers == nil {
episode.subscribers = make(map[*episodeSubscriber]struct{})
Expand Down Expand Up @@ -358,7 +387,7 @@ func (b *Buffer) queueClosedEpisodeLocked(key Key, episode *episodeState) {
heap.Push(&b.closedEpisodes, item)
}

func (b *Buffer) episodeLocked(key Key) *episodeState {
func (b *Buffer) getOrCreateEpisodeLocked(key Key) *episodeState {
episode := b.episodes[key]
if episode != nil {
return episode
Expand All @@ -368,6 +397,14 @@ func (b *Buffer) episodeLocked(key Key) *episodeState {
return episode
}

func (b *Buffer) getEpisodeLocked(key Key) (*episodeState, error) {
episode := b.episodes[key]
if episode == nil || !episode.created {
return nil, ErrEpisodeNotFound
}
return episode, nil
}

func (b *Buffer) subscriberParts(key Key, subscriber *episodeSubscriber) (parts []Part, closed bool, ok bool) {
b.mu.Lock()
defer b.mu.Unlock()
Expand All @@ -384,7 +421,7 @@ func (b *Buffer) subscriberParts(key Key, subscriber *episodeSubscriber) (parts
if subscriber.next > len(episode.parts) {
return nil, false, false
}
parts = append([]Part(nil), episode.parts[subscriber.next:]...)
parts = slices.Clone(episode.parts[subscriber.next:])
subscriber.next = len(episode.parts)
return parts, episode.closed && subscriber.next == len(episode.parts), true
}
Expand Down Expand Up @@ -481,6 +518,27 @@ func notifySubscriber(subscriber *episodeSubscriber) {
}
}

func (e *episodeState) markCreated() {
e.created = true
}

// close marks the episode closed and returns false if it was already closed.
func (e *episodeState) close(now time.Time) bool {
Comment thread
hugodutka marked this conversation as resolved.
Comment thread
hugodutka marked this conversation as resolved.
e.markCreated()
if e.closed {
return false
}
e.closed = true
e.closedAt = now
return true
}

func (e *episodeState) notifySubscribers() {
for subscriber := range e.subscribers {
notifySubscriber(subscriber)
}
}

func (s *episodeSubscriber) stop() {
s.stopOnce.Do(func() { close(s.stopCh) })
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ func TestBuffer_SlowSubscriberClosed(t *testing.T) {
func TestBuffer_BurstyOutputDoesNotCloseSubscriberBeforeSendTimeout(t *testing.T) {
t.Parallel()

buffer := messagepartbuffer.New(messagepartbuffer.Options{SubscriberChannelSize: 1})
buffer := messagepartbuffer.New(messagepartbuffer.Options{})
defer buffer.Close()
key := testEpisodeKey()
require.NoError(t, buffer.CreateEpisode(key))
Expand Down
Loading