From 6410011425527e716cb45eba370f92bb380aea39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Wed, 17 Jun 2026 16:58:15 +0000 Subject: [PATCH] feat(enterprise/coderd): add /api/v2/aibridge/serve endpoint Add the DRPC-over-WebSocket endpoint that standalone AI Gateway replicas connect to. The handler validates the requested API version, upgrades to a WebSocket, sets up a yamux session, and serves the Recorder, MCPConfigurator, and Authorizer services, mirroring the embedded CreateInMemoryAIBridgeServer and the external provisioner daemon serve transport. While a session is active it records last_used_at for the authenticating gateway key so operators can safely rotate keys. The route is registered outside the /aibridge catch-all so its overload middleware does not apply, and is gated by FeatureAIBridge plus gateway key authentication. A shared aibridgedserver.Register helper now backs both the embedded and standalone service registration. Part of AIGOV-308. Generated with Coder Agents. --- coderd/aibridged.go | 14 +- coderd/aibridgedserver/register.go | 25 ++++ coderd/apidoc/docs.go | 19 +++ coderd/apidoc/swagger.json | 17 +++ docs/reference/api/enterprise.md | 20 +++ enterprise/coderd/aibridgeserve.go | 187 ++++++++++++++++++++++++ enterprise/coderd/aibridgeserve_test.go | 174 ++++++++++++++++++++++ enterprise/coderd/coderd.go | 16 ++ 8 files changed, 460 insertions(+), 12 deletions(-) create mode 100644 coderd/aibridgedserver/register.go create mode 100644 enterprise/coderd/aibridgeserve.go create mode 100644 enterprise/coderd/aibridgeserve_test.go diff --git a/coderd/aibridged.go b/coderd/aibridged.go index f448be39d07ed..6f469f09f71bd 100644 --- a/coderd/aibridged.go +++ b/coderd/aibridged.go @@ -6,7 +6,6 @@ import ( "io" "net/http" - "golang.org/x/xerrors" "storj.io/drpc/drpcmux" "storj.io/drpc/drpcserver" @@ -72,17 +71,8 @@ func (api *API) CreateInMemoryAIBridgeServer(dialCtx context.Context) (client ai if err != nil { return nil, err } - err = aibridgedproto.DRPCRegisterRecorder(mux, srv) - if err != nil { - return nil, xerrors.Errorf("register recorder service: %w", err) - } - err = aibridgedproto.DRPCRegisterMCPConfigurator(mux, srv) - if err != nil { - return nil, xerrors.Errorf("register MCP configurator service: %w", err) - } - err = aibridgedproto.DRPCRegisterAuthorizer(mux, srv) - if err != nil { - return nil, xerrors.Errorf("register key validator service: %w", err) + if err := aibridgedserver.Register(mux, srv); err != nil { + return nil, err } server := drpcserver.NewWithOptions(&tracing.DRPCHandler{Handler: mux}, drpcserver.Options{ diff --git a/coderd/aibridgedserver/register.go b/coderd/aibridgedserver/register.go new file mode 100644 index 0000000000000..31292ec1fe238 --- /dev/null +++ b/coderd/aibridgedserver/register.go @@ -0,0 +1,25 @@ +package aibridgedserver + +import ( + "golang.org/x/xerrors" + "storj.io/drpc/drpcmux" + + "github.com/coder/coder/v2/coderd/aibridged/proto" +) + +// Register registers the Recorder, MCPConfigurator, and Authorizer DRPC +// services backed by srv onto mux. It is shared by the embedded in-memory +// server and the standalone /api/v2/aibridge/serve WebSocket handler so both +// expose an identical service set. +func Register(mux *drpcmux.Mux, srv *Server) error { + if err := proto.DRPCRegisterRecorder(mux, srv); err != nil { + return xerrors.Errorf("register recorder service: %w", err) + } + if err := proto.DRPCRegisterMCPConfigurator(mux, srv); err != nil { + return xerrors.Errorf("register MCP configurator service: %w", err) + } + if err := proto.DRPCRegisterAuthorizer(mux, srv); err != nil { + return xerrors.Errorf("register key validator service: %w", err) + } + return nil +} diff --git a/coderd/apidoc/docs.go b/coderd/apidoc/docs.go index a017de9c2406d..2bfd0833f5dd3 100644 --- a/coderd/apidoc/docs.go +++ b/coderd/apidoc/docs.go @@ -64,6 +64,25 @@ const docTemplate = `{ } } }, + "/aibridge/serve": { + "get": { + "tags": [ + "Enterprise" + ], + "summary": "AI Gateway serve", + "operationId": "ai-gateway-serve", + "responses": { + "101": { + "description": "Switching Protocols" + } + }, + "security": [ + { + "CoderSessionToken": [] + } + ] + } + }, "/api/experimental/chats": { "get": { "description": "Experimental: this endpoint is subject to change.", diff --git a/coderd/apidoc/swagger.json b/coderd/apidoc/swagger.json index bcdee7377d5a3..8721dbdd0f03c 100644 --- a/coderd/apidoc/swagger.json +++ b/coderd/apidoc/swagger.json @@ -49,6 +49,23 @@ } } }, + "/aibridge/serve": { + "get": { + "tags": ["Enterprise"], + "summary": "AI Gateway serve", + "operationId": "ai-gateway-serve", + "responses": { + "101": { + "description": "Switching Protocols" + } + }, + "security": [ + { + "CoderSessionToken": [] + } + ] + } + }, "/api/experimental/chats": { "get": { "description": "Experimental: this endpoint is subject to change.", diff --git a/docs/reference/api/enterprise.md b/docs/reference/api/enterprise.md index c2d193aa326e7..49216d5b787df 100644 --- a/docs/reference/api/enterprise.md +++ b/docs/reference/api/enterprise.md @@ -84,6 +84,26 @@ curl -X GET http://coder-server:8080/.well-known/oauth-protected-resource \ |--------|---------------------------------------------------------|-------------|------------------------------------------------------------------------------------------------| | 200 | [OK](https://tools.ietf.org/html/rfc7231#section-6.3.1) | OK | [codersdk.OAuth2ProtectedResourceMetadata](schemas.md#codersdkoauth2protectedresourcemetadata) | +## AI Gateway serve + +### Code samples + +```shell +# Example request using curl +curl -X GET http://coder-server:8080/aibridge/serve \ + -H 'Coder-Session-Token: API_KEY' +``` + +`GET /aibridge/serve` + +### Responses + +| Status | Meaning | Description | Schema | +|--------|--------------------------------------------------------------------------|---------------------|--------| +| 101 | [Switching Protocols](https://tools.ietf.org/html/rfc7231#section-6.2.2) | Switching Protocols | | + +To perform this operation, you must be authenticated. [Learn more](authentication.md). + ## List AI Gateway keys ### Code samples diff --git a/enterprise/coderd/aibridgeserve.go b/enterprise/coderd/aibridgeserve.go new file mode 100644 index 0000000000000..294430b869c1c --- /dev/null +++ b/enterprise/coderd/aibridgeserve.go @@ -0,0 +1,187 @@ +package coderd + +import ( + "context" + "database/sql" + "io" + "net/http" + "time" + + "github.com/google/uuid" + "github.com/hashicorp/yamux" + "golang.org/x/xerrors" + "storj.io/drpc/drpcmux" + "storj.io/drpc/drpcserver" + + "cdr.dev/slog/v3" + aibridgedproto "github.com/coder/coder/v2/coderd/aibridged/proto" + "github.com/coder/coder/v2/coderd/aibridgedserver" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/httpapi" + "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/tracing" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/codersdk/drpcsdk" + "github.com/coder/websocket" +) + +// aiGatewayKeyLastUsedInterval is how often an active DRPC session refreshes +// last_used_at for its authenticating key. A key is considered active in the UI +// if last_used_at is within the last few minutes, so operators can wait for a +// key to fall out of the active window before deleting it during rotation. +const aiGatewayKeyLastUsedInterval = 60 * time.Second + +// aiBridgeServe upgrades the connection to a WebSocket and serves the aibridged +// DRPC services (Recorder, MCPConfigurator, Authorizer) to a remote standalone +// AI Gateway replica, mirroring CreateInMemoryAIBridgeServer for the embedded +// case and provisionerDaemonServe for the transport. Authentication and license +// entitlement are enforced by middleware on the route. +// +// @Summary AI Gateway serve +// @ID ai-gateway-serve +// @Security CoderSessionToken +// @Tags Enterprise +// @Success 101 +// @Router /aibridge/serve [get] +func (api *API) aiBridgeServe(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + apiVersion := "1.0" + if qv := r.URL.Query().Get("version"); qv != "" { + apiVersion = qv + } + if err := aibridgedproto.CurrentVersion.Validate(apiVersion); err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Incompatible or unparsable version", + Validations: []codersdk.ValidationError{ + {Field: "version", Detail: err.Error()}, + }, + }) + return + } + + // X-Coder-Build-Version is used for observability only, not compatibility. + buildVersion := r.Header.Get(codersdk.BuildVersionHeader) + logger := api.Logger.Named("aibridge-serve").With( + slog.F("gateway_api_version", apiVersion), + slog.F("gateway_build_version", buildVersion), + ) + + // Track the websocket so API shutdown waits for it to close. + api.AGPL.WebsocketWaitMutex.Lock() + api.AGPL.WebsocketWaitGroup.Add(1) + api.AGPL.WebsocketWaitMutex.Unlock() + defer api.AGPL.WebsocketWaitGroup.Done() + + conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{ + // Need to disable compression to avoid a data-race. + CompressionMode: websocket.CompressionDisabled, + }) + if err != nil { + if !xerrors.Is(err, context.Canceled) { + logger.Error(ctx, "accept aibridge websocket conn", slog.Error(err)) + } + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Internal error accepting websocket connection.", + Detail: err.Error(), + }) + return + } + // Align with the frame size of yamux. + conn.SetReadLimit(256 * 1024) + + // Multiplexes the incoming connection using yamux, allowing multiple DRPC + // calls to occur over the same connection. + config := yamux.DefaultConfig() + config.LogOutput = io.Discard + ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageBinary) + defer wsNetConn.Close() + session, err := yamux.Server(wsNetConn, config) + if err != nil { + _ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("multiplex server: %s", err)) + return + } + + srvCtx, srvCancel := context.WithCancel(ctx) + defer srvCancel() + + // Record liveness for the authenticating key while the session is open. + if key, ok := httpmw.AIGatewayKeyAuthOptional(r); ok { + go api.trackAIGatewayKeyUsage(srvCtx, key.ID) + } + + mux := drpcmux.New() + srv, err := aibridgedserver.NewServer( + srvCtx, + api.Database, + logger.Named("aibridgedserver"), + api.AccessURL.String(), + api.DeploymentValues.AI.BridgeConfig, + api.ExternalAuthConfigs, + api.AGPL.Experiments, + api.AGPL.AISeatTracker, + ) + if err != nil { + if !xerrors.Is(err, context.Canceled) { + logger.Error(ctx, "create aibridge server", slog.Error(err)) + } + _ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("create aibridge server: %s", err)) + return + } + if err := aibridgedserver.Register(mux, srv); err != nil { + _ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("register aibridge services: %s", err)) + return + } + + server := drpcserver.NewWithOptions(&tracing.DRPCHandler{Handler: mux}, + drpcserver.Options{ + Manager: drpcsdk.DefaultDRPCOptions(nil), + Log: func(err error) { + if xerrors.Is(err, io.EOF) { + return + } + logger.Debug(srvCtx, "drpc server error", slog.Error(err)) + }, + }, + ) + + logger.Info(ctx, "standalone aibridge connected") + err = server.Serve(srvCtx, session) + srvCancel() + logger.Info(ctx, "standalone aibridge disconnected", slog.Error(err)) + if err != nil && !xerrors.Is(err, io.EOF) { + _ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("serve: %s", err)) + return + } + _ = conn.Close(websocket.StatusGoingAway, "") +} + +// trackAIGatewayKeyUsage refreshes last_used_at for keyID until ctx is +// canceled. It records usage immediately on connect, then on a fixed interval. +func (api *API) trackAIGatewayKeyUsage(ctx context.Context, keyID uuid.UUID) { + update := func() { + // nolint:gocritic // Recording AI Gateway key liveness is an internal system write. + err := api.Database.UpdateAIGatewayKeyLastUsedAt(dbauthz.AsSystemRestricted(ctx), database.UpdateAIGatewayKeyLastUsedAtParams{ + ID: keyID, + LastUsedAt: sql.NullTime{Time: dbtime.Now(), Valid: true}, + }) + if err != nil && !xerrors.Is(err, context.Canceled) { + api.Logger.Debug(ctx, "update aibridge gateway key last used", slog.Error(err), slog.F("key_id", keyID)) + } + } + + update() + + ticker := time.NewTicker(aiGatewayKeyLastUsedInterval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + update() + } + } +} diff --git a/enterprise/coderd/aibridgeserve_test.go b/enterprise/coderd/aibridgeserve_test.go new file mode 100644 index 0000000000000..1286a34f75fbb --- /dev/null +++ b/enterprise/coderd/aibridgeserve_test.go @@ -0,0 +1,174 @@ +package coderd_test + +import ( + "context" + "io" + "net/http" + "testing" + + "github.com/hashicorp/yamux" + "github.com/stretchr/testify/require" + + aibridgedproto "github.com/coder/coder/v2/coderd/aibridged/proto" + "github.com/coder/coder/v2/coderd/coderdtest" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/codersdk/drpcsdk" + "github.com/coder/coder/v2/enterprise/coderd/coderdenttest" + "github.com/coder/coder/v2/enterprise/coderd/license" + "github.com/coder/coder/v2/testutil" + "github.com/coder/serpent" + "github.com/coder/websocket" +) + +// dialAIBridgeServe dials /api/v2/aibridge/serve, authenticating with the given +// gateway key and API version. On a successful WebSocket upgrade it returns a +// yamux session and http.StatusSwitchingProtocols. Otherwise it returns a nil +// session and the HTTP status code coderd responded with. +func dialAIBridgeServe(ctx context.Context, t *testing.T, client *codersdk.Client, key, version string) (*yamux.Session, int) { + t.Helper() + + serverURL, err := client.URL.Parse("/api/v2/aibridge/serve") + require.NoError(t, err) + query := serverURL.Query() + if version != "" { + query.Set("version", version) + } + serverURL.RawQuery = query.Encode() + + headers := http.Header{} + if key != "" { + headers.Set(codersdk.AIGatewayKeyHeader, key) + } + + conn, res, err := websocket.Dial(ctx, serverURL.String(), &websocket.DialOptions{ + HTTPClient: &http.Client{Transport: client.HTTPClient.Transport}, + CompressionMode: websocket.CompressionDisabled, + HTTPHeader: headers, + }) + if err != nil { + statusCode := 0 + if res != nil { + statusCode = res.StatusCode + _ = res.Body.Close() + } + return nil, statusCode + } + conn.SetReadLimit(256 * 1024) + + cfg := yamux.DefaultConfig() + cfg.LogOutput = io.Discard + _, wsNetConn := codersdk.WebsocketNetConn(context.Background(), conn, websocket.MessageBinary) + session, err := yamux.Client(wsNetConn, cfg) + require.NoError(t, err) + t.Cleanup(func() { + _ = session.Close() + _ = wsNetConn.Close() + _ = conn.Close(websocket.StatusNormalClosure, "") + }) + return session, http.StatusSwitchingProtocols +} + +func TestAIBridgeServe(t *testing.T) { + t.Parallel() + + t.Run("Success", func(t *testing.T) { + t.Parallel() + client, firstUser := coderdenttest.New(t, aibridgeOpts(t)) + ctx := testutil.Context(t, testutil.WaitLong) + + //nolint:gocritic // Owner role is irrelevant for gateway key management here. + created, err := client.CreateAIGatewayKey(ctx, codersdk.CreateAIGatewayKeyRequest{Name: "serve-success"}) + require.NoError(t, err) + + session, status := dialAIBridgeServe(ctx, t, client, created.Key, aibridgedproto.CurrentVersion.String()) + require.Equal(t, http.StatusSwitchingProtocols, status) + require.NotNil(t, session) + + // The Authorizer service should be served and authorize the owner's + // session token, exercising a full DRPC round trip over the WebSocket. + authorizer := aibridgedproto.NewDRPCAuthorizerClient(drpcsdk.MultiplexedConn(session)) + resp, err := authorizer.IsAuthorized(ctx, &aibridgedproto.IsAuthorizedRequest{ + Key: client.SessionToken(), + }) + require.NoError(t, err) + require.Equal(t, firstUser.UserID.String(), resp.GetOwnerId()) + + // The session records liveness for the authenticating key. + require.Eventually(t, func() bool { + //nolint:gocritic // Owner role is irrelevant for gateway key management here. + keys, err := client.ListAIGatewayKeys(ctx) + if err != nil { + return false + } + for _, k := range keys { + if k.ID == created.ID { + return k.LastUsedAt != nil + } + } + return false + }, testutil.WaitMedium, testutil.IntervalFast) + }) + + t.Run("MissingKey", func(t *testing.T) { + t.Parallel() + client, _ := coderdenttest.New(t, aibridgeOpts(t)) + ctx := testutil.Context(t, testutil.WaitLong) + + _, status := dialAIBridgeServe(ctx, t, client, "", aibridgedproto.CurrentVersion.String()) + require.Equal(t, http.StatusUnauthorized, status) + }) + + t.Run("InvalidKey", func(t *testing.T) { + t.Parallel() + client, _ := coderdenttest.New(t, aibridgeOpts(t)) + ctx := testutil.Context(t, testutil.WaitLong) + + _, status := dialAIBridgeServe(ctx, t, client, "not-a-real-key", aibridgedproto.CurrentVersion.String()) + require.Equal(t, http.StatusUnauthorized, status) + }) + + t.Run("RevokedKey", func(t *testing.T) { + t.Parallel() + client, _ := coderdenttest.New(t, aibridgeOpts(t)) + ctx := testutil.Context(t, testutil.WaitLong) + + //nolint:gocritic // Owner role is irrelevant for gateway key management here. + created, err := client.CreateAIGatewayKey(ctx, codersdk.CreateAIGatewayKeyRequest{Name: "serve-revoked"}) + require.NoError(t, err) + //nolint:gocritic // Owner role is irrelevant for gateway key management here. + require.NoError(t, client.DeleteAIGatewayKey(ctx, created.ID)) + + _, status := dialAIBridgeServe(ctx, t, client, created.Key, aibridgedproto.CurrentVersion.String()) + require.Equal(t, http.StatusUnauthorized, status) + }) + + t.Run("IncompatibleVersion", func(t *testing.T) { + t.Parallel() + client, _ := coderdenttest.New(t, aibridgeOpts(t)) + ctx := testutil.Context(t, testutil.WaitLong) + + //nolint:gocritic // Owner role is irrelevant for gateway key management here. + created, err := client.CreateAIGatewayKey(ctx, codersdk.CreateAIGatewayKeyRequest{Name: "serve-badversion"}) + require.NoError(t, err) + + _, status := dialAIBridgeServe(ctx, t, client, created.Key, "999.0") + require.Equal(t, http.StatusBadRequest, status) + }) + + t.Run("MissingEntitlement", func(t *testing.T) { + t.Parallel() + // Enable the bridge config but do not grant the FeatureAIBridge license. + dv := coderdtest.DeploymentValues(t) + dv.AI.BridgeConfig.Enabled = serpent.Bool(true) + client, _ := coderdenttest.New(t, &coderdenttest.Options{ + Options: &coderdtest.Options{DeploymentValues: dv}, + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{}, + }, + }) + ctx := testutil.Context(t, testutil.WaitLong) + + _, status := dialAIBridgeServe(ctx, t, client, "any-key", aibridgedproto.CurrentVersion.String()) + require.Equal(t, http.StatusForbidden, status) + }) +} diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 40d1e7f0979d8..c4656d39c7d58 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -311,6 +311,22 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { }) }) + // /aibridge/serve is the DRPC-over-WebSocket endpoint that standalone AI + // Gateway replicas connect to. It authenticates with a gateway key instead + // of a user session, and deliberately sits outside the /aibridge catch-all + // so the catch-all's overload middleware does not apply to it. + api.AGPL.APIHandler.Group(func(r chi.Router) { + r.Route("/aibridge/serve", func(r chi.Router) { + r.Use( + api.RequireFeatureMW(codersdk.FeatureAIBridge), + httpmw.ExtractAIGatewayKeyAuthenticated(httpmw.ExtractAIGatewayKeyConfig{ + DB: api.Database, + }), + ) + r.Get("/", api.aiBridgeServe) + }) + }) + api.AGPL.APIHandler.Group(func(r chi.Router) { r.Get("/entitlements", api.serveEntitlements) // /regions overrides the AGPL /regions endpoint