Skip to content

Commit dbf8b7f

Browse files
committed
feat: track ai seat usage
1 parent fe47143 commit dbf8b7f

10 files changed

Lines changed: 286 additions & 7 deletions

File tree

coderd/aiseats/aiseats.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// Package aiseats is the AGPL version the package.
2+
// The actual implementation is in `enterprise/aiseats`.
3+
package aiseats
4+
5+
import (
6+
"context"
7+
8+
"github.com/google/uuid"
9+
10+
"github.com/coder/coder/v2/coderd/database"
11+
)
12+
13+
// Reason describes what AI event consumed the seat.
14+
type Reason interface {
15+
isReason()
16+
}
17+
18+
type reason struct {
19+
eventType database.AiSeatUsageReason
20+
description string
21+
}
22+
23+
func (reason) isReason() {}
24+
25+
// ReasonValues extracts storage values from a Reason.
26+
func ReasonValues(r Reason) (database.AiSeatUsageReason, string, bool) {
27+
rr, ok := r.(reason)
28+
if !ok {
29+
return "", "", false
30+
}
31+
return rr.eventType, rr.description, true
32+
}
33+
34+
// ReasonAIBridge constructs a reason for usage originating from AI Bridge.
35+
func ReasonAIBridge(description string) Reason {
36+
return reason{eventType: database.AiSeatUsageReasonAibridge, description: description}
37+
}
38+
39+
// ReasonTask constructs a reason for usage originating from tasks.
40+
func ReasonTask(description string) Reason {
41+
return reason{eventType: database.AiSeatUsageReasonTask, description: description}
42+
}
43+
44+
// SeatTracker records AI seat consumption state.
45+
type SeatTracker interface {
46+
// RecordUsage does not return an error to prevent blocking the user from using
47+
// AI features. This method is used to record usage, not enforce it.
48+
RecordUsage(ctx context.Context, userID uuid.UUID, reason Reason)
49+
}
50+
51+
// Noop is an AGPL seat tracker that does nothing.
52+
type Noop struct{}
53+
54+
func (Noop) RecordUsage(context.Context, uuid.UUID, Reason) {}

coderd/coderd.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ import (
4444
"github.com/coder/coder/v2/buildinfo"
4545
"github.com/coder/coder/v2/coderd/agentapi"
4646
"github.com/coder/coder/v2/coderd/agentapi/metadatabatcher"
47+
"github.com/coder/coder/v2/coderd/aiseats"
4748
_ "github.com/coder/coder/v2/coderd/apidoc" // Used for swagger docs.
4849
"github.com/coder/coder/v2/coderd/appearance"
4950
"github.com/coder/coder/v2/coderd/audit"
@@ -629,6 +630,8 @@ func New(options *Options) *API {
629630
),
630631
dbRolluper: options.DatabaseRolluper,
631632
}
633+
api.AISeatTracker = aiseats.Noop{}
634+
632635
api.WorkspaceAppsProvider = workspaceapps.NewDBTokenProvider(
633636
ctx,
634637
options.Logger.Named("workspaceapps"),
@@ -2015,6 +2018,8 @@ type API struct {
20152018
dbRolluper *dbrollup.Rolluper
20162019
// chatDaemon handles background processing of pending chats.
20172020
chatDaemon *chatd.Server
2021+
// AISeatTracker records AI seat usage.
2022+
AISeatTracker aiseats.SeatTracker
20182023
// gitSyncWorker refreshes stale chat diff statuses in the
20192024
// background.
20202025
gitSyncWorker *gitsync.Worker
@@ -2218,6 +2223,7 @@ func (api *API) CreateInMemoryTaggedProvisionerDaemon(dialCtx context.Context, n
22182223
provisionerdserver.Options{
22192224
OIDCConfig: api.OIDCConfig,
22202225
ExternalAuthConfigs: api.ExternalAuthConfigs,
2226+
AISeatTracker: api.AISeatTracker,
22212227
Clock: api.Clock,
22222228
HeartbeatFn: options.heartbeatFn,
22232229
},

coderd/provisionerdserver/provisionerdserver.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828
protobuf "google.golang.org/protobuf/proto"
2929

3030
"cdr.dev/slog/v3"
31+
"github.com/coder/coder/v2/coderd/aiseats"
3132
"github.com/coder/coder/v2/coderd/apikey"
3233
"github.com/coder/coder/v2/coderd/audit"
3334
"github.com/coder/coder/v2/coderd/database"
@@ -76,6 +77,7 @@ const (
7677
type Options struct {
7778
OIDCConfig promoauth.OAuth2Config
7879
ExternalAuthConfigs []*externalauth.Config
80+
AISeatTracker aiseats.SeatTracker
7981

8082
// Clock for testing
8183
Clock quartz.Clock
@@ -120,6 +122,7 @@ type server struct {
120122
NotificationsEnqueuer notifications.Enqueuer
121123
PrebuildsOrchestrator *atomic.Pointer[prebuilds.ReconciliationOrchestrator]
122124
UsageInserter *atomic.Pointer[usage.Inserter]
125+
AISeatTracker aiseats.SeatTracker
123126
Experiments codersdk.Experiments
124127

125128
OIDCConfig promoauth.OAuth2Config
@@ -215,6 +218,9 @@ func NewServer(
215218
if err := tags.Valid(); err != nil {
216219
return nil, xerrors.Errorf("invalid tags: %w", err)
217220
}
221+
if options.AISeatTracker == nil {
222+
options.AISeatTracker = aiseats.Noop{}
223+
}
218224
if options.AcquireJobLongPollDur == 0 {
219225
options.AcquireJobLongPollDur = DefaultAcquireJobLongPollDur
220226
}
@@ -253,6 +259,7 @@ func NewServer(
253259
heartbeatFn: options.HeartbeatFn,
254260
PrebuildsOrchestrator: prebuildsOrchestrator,
255261
UsageInserter: usageInserter,
262+
AISeatTracker: options.AISeatTracker,
256263
metrics: metrics,
257264
Experiments: experiments,
258265
}
@@ -2417,6 +2424,12 @@ func (s *server) completeWorkspaceBuildJob(ctx context.Context, job database.Pro
24172424
})
24182425
}
24192426

2427+
// Record AI seat usage for successful task workspace builds.
2428+
if workspaceBuild.Transition == database.WorkspaceTransitionStart && workspace.TaskID.Valid {
2429+
s.AISeatTracker.RecordUsage(ctx, workspace.OwnerID,
2430+
aiseats.ReasonTask("task workspace build succeeded"))
2431+
}
2432+
24202433
if s.PrebuildsOrchestrator != nil && input.PrebuiltWorkspaceBuildStage == sdkproto.PrebuiltWorkspaceBuildStage_CLAIM {
24212434
// Track resource replacements, if there are any.
24222435
orchestrator := s.PrebuildsOrchestrator.Load()

enterprise/aibridgedserver/aibridgedserver.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"google.golang.org/protobuf/types/known/structpb"
1616

1717
"cdr.dev/slog/v3"
18+
"github.com/coder/coder/v2/coderd/aiseats"
1819
"github.com/coder/coder/v2/coderd/apikey"
1920
"github.com/coder/coder/v2/coderd/database"
2021
"github.com/coder/coder/v2/coderd/database/dbauthz"
@@ -81,10 +82,12 @@ type Server struct {
8182

8283
coderMCPConfig *proto.MCPServerConfig // may be nil if not available
8384
structuredLogging bool
85+
aiSeatTracker aiseats.SeatTracker
8486
}
8587

8688
func NewServer(lifecycleCtx context.Context, store store, logger slog.Logger, accessURL string,
8789
bridgeCfg codersdk.AIBridgeConfig, externalAuthConfigs []*externalauth.Config, experiments codersdk.Experiments,
90+
aiSeatTracker aiseats.SeatTracker,
8891
) (*Server, error) {
8992
eac := make(map[string]*externalauth.Config, len(externalAuthConfigs))
9093

@@ -102,6 +105,7 @@ func NewServer(lifecycleCtx context.Context, store store, logger slog.Logger, ac
102105
logger: logger,
103106
externalAuthConfigs: eac,
104107
structuredLogging: bridgeCfg.StructuredLogging.Value(),
108+
aiSeatTracker: aiSeatTracker,
105109
}
106110

107111
if bridgeCfg.InjectCoderMCPTools {
@@ -183,6 +187,9 @@ func (s *Server) RecordInterception(ctx context.Context, in *proto.RecordInterce
183187
return nil, xerrors.Errorf("start interception: %w", err)
184188
}
185189

190+
// Make the reason something human-readable.
191+
reason := aiseats.ReasonAIBridge("provider=" + in.Provider + ", model=" + in.Model)
192+
s.aiSeatTracker.RecordUsage(ctx, initID, reason)
186193
return &proto.RecordInterceptionResponse{}, nil
187194
}
188195

enterprise/aibridgedserver/aibridgedserver_test.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424

2525
"cdr.dev/slog/v3"
2626
"cdr.dev/slog/v3/sloggers/slogjson"
27+
"github.com/coder/coder/v2/coderd/aiseats"
2728
"github.com/coder/coder/v2/coderd/apikey"
2829
"github.com/coder/coder/v2/coderd/database"
2930
"github.com/coder/coder/v2/coderd/database/dbgen"
@@ -176,7 +177,7 @@ func TestAuthorization(t *testing.T) {
176177
tc.mocksFn(db, apiKey, user)
177178
}
178179

179-
srv, err := aibridgedserver.NewServer(t.Context(), db, logger, "/", codersdk.AIBridgeConfig{}, nil, requiredExperiments)
180+
srv, err := aibridgedserver.NewServer(t.Context(), db, logger, "/", codersdk.AIBridgeConfig{}, nil, requiredExperiments, aiseats.Noop{})
180181
require.NoError(t, err)
181182
require.NotNil(t, srv)
182183

@@ -268,7 +269,7 @@ func TestGetMCPServerConfigs(t *testing.T) {
268269
accessURL := "https://my-cool-deployment.com"
269270
srv, err := aibridgedserver.NewServer(t.Context(), db, logger, accessURL, codersdk.AIBridgeConfig{
270271
InjectCoderMCPTools: serpent.Bool(!tc.disableCoderMCPInjection),
271-
}, tc.externalAuthConfigs, tc.experiments)
272+
}, tc.externalAuthConfigs, tc.experiments, aiseats.Noop{})
272273
require.NoError(t, err)
273274
require.NotNil(t, srv)
274275

@@ -318,7 +319,7 @@ func TestGetMCPServerAccessTokensBatch(t *testing.T) {
318319
{
319320
ID: "3",
320321
},
321-
}, requiredExperiments)
322+
}, requiredExperiments, aiseats.Noop{})
322323
require.NoError(t, err)
323324
require.NotNil(t, srv)
324325

@@ -1014,7 +1015,7 @@ func testRecordMethod[Req any, Resp any](
10141015
}
10151016

10161017
ctx := testutil.Context(t, testutil.WaitLong)
1017-
srv, err := aibridgedserver.NewServer(ctx, db, logger, "/", codersdk.AIBridgeConfig{}, nil, requiredExperiments)
1018+
srv, err := aibridgedserver.NewServer(ctx, db, logger, "/", codersdk.AIBridgeConfig{}, nil, requiredExperiments, aiseats.Noop{})
10181019
require.NoError(t, err)
10191020

10201021
resp, err := callMethod(srv, ctx, tc.request)
@@ -1309,7 +1310,7 @@ func TestStructuredLogging(t *testing.T) {
13091310
ctx := testutil.Context(t, testutil.WaitLong)
13101311
srv, err := aibridgedserver.NewServer(ctx, db, logger, "/", codersdk.AIBridgeConfig{
13111312
StructuredLogging: serpent.Bool(tc.structuredLogging),
1312-
}, nil, requiredExperiments)
1313+
}, nil, requiredExperiments, aiseats.Noop{})
13131314
require.NoError(t, err)
13141315

13151316
err = tc.recordFn(srv, ctx, interceptionID)
@@ -1351,7 +1352,7 @@ func TestInferredThreadsByToolCalls(t *testing.T) {
13511352

13521353
user := dbgen.User(t, db, database.User{})
13531354

1354-
srv, err := aibridgedserver.NewServer(ctx, db, logger, "/", codersdk.AIBridgeConfig{}, nil, requiredExperiments)
1355+
srv, err := aibridgedserver.NewServer(ctx, db, logger, "/", codersdk.AIBridgeConfig{}, nil, requiredExperiments, aiseats.Noop{})
13551356
require.NoError(t, err)
13561357

13571358
aID := uuid.New()

enterprise/aiseats/tracker.go

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
package aiseats
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"sync"
7+
"time"
8+
9+
"github.com/google/uuid"
10+
11+
"cdr.dev/slog/v3"
12+
agplaiseats "github.com/coder/coder/v2/coderd/aiseats"
13+
"github.com/coder/coder/v2/coderd/database"
14+
"github.com/coder/quartz"
15+
)
16+
17+
type store interface {
18+
UpsertAISeatState(ctx context.Context, arg database.UpsertAISeatStateParams) error
19+
}
20+
21+
// throttleInterval is the minimum time between DB writes for the same user. This
22+
// is to prevent ai seat tracking from consuming more db resources.
23+
//
24+
// These events are not critical to be recorded in real time, so we can afford to
25+
// skip almost all of them. The first write is the most important, as it
26+
// indicates a seat is consumed. Subsequent writes are purely informative and has
27+
// no functional impact.
28+
const (
29+
throttleInterval = 6 * time.Hour
30+
// failedRetryInterval exists to prevent a transient error from causing no
31+
// usage to be recorded. Still debounce.
32+
failedRetryInterval = 30 * time.Minute
33+
)
34+
35+
// SeatTracker records current AI seat state for users.
36+
type SeatTracker struct {
37+
db store
38+
logger slog.Logger
39+
clock quartz.Clock
40+
41+
mu sync.RWMutex
42+
retryAfter map[uuid.UUID]time.Time
43+
}
44+
45+
func New(db store, logger slog.Logger, clock quartz.Clock) *SeatTracker {
46+
if clock == nil {
47+
clock = quartz.NewReal()
48+
}
49+
return &SeatTracker{db: db, logger: logger, clock: clock, retryAfter: make(map[uuid.UUID]time.Time)}
50+
}
51+
52+
// skipRecord returns true when the user is still in the retry cooldown
53+
// window and we should skip a DB write attempt.
54+
func (t *SeatTracker) skipRecord(userID uuid.UUID, now time.Time) bool {
55+
t.mu.RLock()
56+
defer t.mu.RUnlock()
57+
58+
retryAfter, ok := t.retryAfter[userID]
59+
return ok && now.Before(retryAfter)
60+
}
61+
62+
// recordThrottle sets the next time when DB writes for this user are allowed.
63+
func (t *SeatTracker) recordThrottle(userID uuid.UUID, now time.Time, d time.Duration) {
64+
t.mu.Lock()
65+
defer t.mu.Unlock()
66+
t.retryAfter[userID] = now.Add(d)
67+
}
68+
69+
// RecordUsage will record the AI seat usage for the user. There is a race condition between
70+
// checking if the user should be recorded or throttled and actually recording. This is fine, as
71+
// it just means we record the usage twice.
72+
// The throttle just exists to prevent excessive database queries.
73+
func (t *SeatTracker) RecordUsage(ctx context.Context, userID uuid.UUID, reason agplaiseats.Reason) {
74+
now := t.clock.Now()
75+
if t.skipRecord(userID, now) {
76+
return
77+
}
78+
79+
eventType, description, ok := agplaiseats.ReasonValues(reason)
80+
if !ok {
81+
t.logger.Warn(ctx, "invalid AI seat usage reason", slog.F("user_id", userID), slog.F("reason_type", fmt.Sprintf("%T", reason)))
82+
return
83+
}
84+
85+
err := t.db.UpsertAISeatState(ctx, database.UpsertAISeatStateParams{
86+
UserID: userID,
87+
FirstUsedAt: now,
88+
LastEventType: eventType,
89+
LastEventDescription: description,
90+
})
91+
if err != nil {
92+
t.logger.Warn(ctx, "upsert AI seat state", slog.Error(err), slog.F("user_id", userID), slog.F("event_type", eventType))
93+
t.recordThrottle(userID, now, failedRetryInterval)
94+
return
95+
}
96+
97+
t.recordThrottle(userID, now, throttleInterval)
98+
}

0 commit comments

Comments
 (0)