diff --git a/dotnet/src/Client.cs b/dotnet/src/Client.cs index 8c6831445..be098aba9 100644 --- a/dotnet/src/Client.cs +++ b/dotnet/src/Client.cs @@ -85,6 +85,13 @@ public sealed partial class CopilotClient : IDisposable, IAsyncDisposable private List? _modelsCache; private ServerRpc? _serverRpc; + /// + /// Client-global RPC handlers (e.g. the LLM inference provider adapter), + /// built once at construction when the corresponding option is configured and + /// registered on every connection. Null when no client-global API is enabled. + /// + private readonly ClientGlobalApiHandlers? _clientGlobalApis; + private sealed record LifecycleSubscription(Type EventType, Action Handler); /// @@ -165,6 +172,8 @@ public CopilotClient(CopilotClientOptions? options = null) _logger = _options.Logger ?? NullLogger.Instance; _onListModels = _options.OnListModels; + _clientGlobalApis = BuildClientGlobalApis(); + // Empty mode: validate at construction time that the app supplied a // per-session persistence location. The runtime is mode-agnostic, so // without this check it would silently fall back to ~/.copilot, which @@ -276,6 +285,8 @@ async Task StartCoreAsync(CancellationToken ct) sessionFsTimestamp); } + await ConfigureLlmInferenceAsync(ct); + LoggingHelpers.LogTiming(_logger, LogLevel.Debug, null, "CopilotClient.StartAsync complete. Elapsed={Elapsed}", startTimestamp); @@ -1674,6 +1685,39 @@ await Rpc.SessionFs.SetProviderAsync( cancellationToken: cancellationToken); } + /// + /// Builds the client-global RPC handler bag at construction time. Currently + /// only the LLM inference provider adapter is registered; returns null when no + /// client-global API is configured so the registration is skipped entirely. + /// + private ClientGlobalApiHandlers? BuildClientGlobalApis() + { + var handler = _options.LlmInference?.Handler; + if (handler is null) + { + return null; + } + + return new ClientGlobalApiHandlers + { + LlmInference = new LlmInferenceAdapter(handler, () => _serverRpc), + }; + } + + /// + /// Tells the runtime to route its outbound model-layer requests through this + /// client's LLM inference provider. No-op when interception is not configured. + /// + private async Task ConfigureLlmInferenceAsync(CancellationToken cancellationToken) + { + if (_clientGlobalApis?.LlmInference is null) + { + return; + } + + await Rpc.LlmInference.SetProviderAsync(cancellationToken); + } + private void ConfigureSessionFsHandlers(CopilotSession session, Func? createSessionFsHandler) { if (_options.SessionFs is null) @@ -2067,6 +2111,10 @@ private async Task ConnectToServerAsync(Process? cliProcess, string? var session = GetSession(sessionId) ?? throw new ArgumentException($"Unknown session {sessionId}"); return session.ClientSessionApis; }); + if (_clientGlobalApis is not null) + { + ClientGlobalApiRegistration.RegisterClientGlobalApiHandlers(rpc, _clientGlobalApis); + } rpc.StartListening(); LoggingHelpers.LogTiming(_logger, LogLevel.Debug, null, "CopilotClient.ConnectToServerAsync transport setup complete. Elapsed={Elapsed}", diff --git a/dotnet/src/Generated/Rpc.cs b/dotnet/src/Generated/Rpc.cs index 89e863ad1..e9720253c 100644 --- a/dotnet/src/Generated/Rpc.cs +++ b/dotnet/src/Generated/Rpc.cs @@ -996,6 +996,92 @@ internal sealed class SessionFsSetProviderRequest public string SessionStatePath { get; set; } = string.Empty; } +/// Indicates whether the calling client was registered as the LLM inference provider. +[Experimental(Diagnostics.Experimental)] +public sealed class LlmInferenceSetProviderResult +{ + /// Whether the provider was set successfully. + [JsonPropertyName("success")] + public bool Success { get; set; } +} + +/// Whether the start frame was accepted. +[Experimental(Diagnostics.Experimental)] +public sealed class LlmInferenceHttpResponseStartResult +{ + /// True when the response start was matched to a pending request; false when unknown. + [JsonPropertyName("accepted")] + public bool Accepted { get; set; } +} + +/// Response head. +[Experimental(Diagnostics.Experimental)] +internal sealed class LlmInferenceHttpResponseStartRequest +{ + /// HTTP headers as a map from lowercased header name to a list of values. Multi-valued headers (e.g. Set-Cookie) preserve all values. + [JsonPropertyName("headers")] + public IDictionary> Headers { get => field ??= new Dictionary>(); set; } + + /// Matches the requestId from the originating httpRequestStart frame. + [JsonPropertyName("requestId")] + public string RequestId { get; set; } = string.Empty; + + /// HTTP status code. + [JsonPropertyName("status")] + public long Status { get; set; } + + /// Optional HTTP status reason phrase. + [JsonPropertyName("statusText")] + public string? StatusText { get; set; } +} + +/// Whether the chunk was accepted. +[Experimental(Diagnostics.Experimental)] +public sealed class LlmInferenceHttpResponseChunkResult +{ + /// True when the chunk was matched to a pending request; false when unknown. + [JsonPropertyName("accepted")] + public bool Accepted { get; set; } +} + +/// Set to terminate the response with a transport-level failure. Implies end-of-stream; any further chunks for this requestId are ignored. +[Experimental(Diagnostics.Experimental)] +public sealed class LlmInferenceHttpResponseChunkError +{ + /// Optional machine-readable error code. + [JsonPropertyName("code")] + public string? Code { get; set; } + + /// Human-readable failure description. + [JsonPropertyName("message")] + public string Message { get; set; } = string.Empty; +} + +/// A response body chunk or terminal error. +[Experimental(Diagnostics.Experimental)] +internal sealed class LlmInferenceHttpResponseChunkRequest +{ + /// When true, `data` is base64-encoded bytes. When absent or false, `data` is UTF-8 text. + [JsonPropertyName("binary")] + public bool? Binary { get; set; } + + /// Body byte range. UTF-8 text when `binary` is absent or false; base64-encoded bytes when `binary` is true. May be empty (e.g. when the response body is empty: send a single chunk with empty data and end=true). + [JsonPropertyName("data")] + public string Data { get; set; } = string.Empty; + + /// When true, this is the final body chunk for the response. The runtime treats the response body as complete after receiving an end-marked chunk. + [JsonPropertyName("end")] + public bool? End { get; set; } + + /// Set to terminate the response with a transport-level failure. Implies end-of-stream; any further chunks for this requestId are ignored. + [JsonPropertyName("error")] + public LlmInferenceHttpResponseChunkError? Error { get; set; } + + /// Matches the requestId from the originating httpRequestStart frame. + [JsonPropertyName("requestId")] + public string RequestId { get; set; } = string.Empty; +} + /// Pre-resolved working-directory context for session startup. [Experimental(Diagnostics.Experimental)] public sealed class SessionContext @@ -10215,6 +10301,76 @@ public sealed class CanvasProviderInvokeActionRequest public string SessionId { get; set; } = string.Empty; } +/// Acknowledgement. Returning successfully simply means the SDK accepted the start frame; it does not imply the request will succeed. +[Experimental(Diagnostics.Experimental)] +public sealed class LlmInferenceHttpRequestStartResult +{ +} + +/// The head of an outbound model-layer HTTP request. +[Experimental(Diagnostics.Experimental)] +public sealed class LlmInferenceHttpRequestStartRequest +{ + /// HTTP headers as a map from lowercased header name to a list of values. Multi-valued headers (e.g. Set-Cookie) preserve all values. + [JsonPropertyName("headers")] + public IDictionary> Headers { get => field ??= new Dictionary>(); set; } + + /// HTTP method, e.g. GET, POST. + [JsonPropertyName("method")] + public string Method { get; set; } = string.Empty; + + /// Opaque runtime-minted id, unique per in-flight request. The SDK uses this to correlate httpRequestChunk frames and to address its httpResponseStart / httpResponseChunk replies back to the runtime. + [JsonPropertyName("requestId")] + public string RequestId { get; set; } = string.Empty; + + /// Id of the runtime session that triggered this request, when one is in scope. Absent for requests issued outside any session (e.g. startup model-catalog or capability resolution). This is a payload field — not a dispatch key — because the client-global API is registered process-wide rather than per session. + [JsonPropertyName("sessionId")] + public string? SessionId { get; set; } + + /// Transport the runtime would otherwise use for this request. `http` (the default when absent) covers plain HTTP and SSE responses; `websocket` indicates a full-duplex message channel where each body chunk maps to one WebSocket message and the `binary` flag distinguishes text from binary frames. The SDK consumer uses this to decide whether to service the request with an HTTP client or a WebSocket client. It is the one piece of request metadata the consumer cannot reliably infer from the URL or headers alone. + [JsonPropertyName("transport")] + public LlmInferenceHttpRequestStartTransport? Transport { get; set; } + + /// Absolute request URL. + [JsonPropertyName("url")] + public string Url { get; set; } = string.Empty; +} + +/// Acknowledgement. The SDK is free to ignore the ack and treat chunk delivery as fire-and-forget. +[Experimental(Diagnostics.Experimental)] +public sealed class LlmInferenceHttpRequestChunkResult +{ +} + +/// A request body chunk or cancellation signal. +[Experimental(Diagnostics.Experimental)] +public sealed class LlmInferenceHttpRequestChunkRequest +{ + /// When true, `data` is base64-encoded bytes. When absent or false, `data` is UTF-8 text. + [JsonPropertyName("binary")] + public bool? Binary { get; set; } + + /// When true, the runtime is cancelling the in-flight request (e.g. upstream consumer aborted). `data` is ignored. Implies end-of-request. + [JsonPropertyName("cancel")] + public bool? Cancel { get; set; } + + /// Optional human-readable reason for the cancellation, propagated for logging. + [JsonPropertyName("cancelReason")] + public string? CancelReason { get; set; } + + /// Body byte range. UTF-8 text when `binary` is absent or false; base64-encoded bytes when `binary` is true. May be empty. + [JsonPropertyName("data")] + public string Data { get; set; } = string.Empty; + + /// When true, this is the final body chunk for the request. The SDK may rely on having received an end-marked chunk before treating the request body as complete. + [JsonPropertyName("end")] + public bool? End { get; set; } + + /// Matches the requestId from the originating httpRequestStart frame. + [JsonPropertyName("requestId")] + public string RequestId { get; set; } = string.Empty; +} + /// Model capability category for grouping in the model picker. [JsonConverter(typeof(Converter))] [DebuggerDisplay("{Value,nq}")] @@ -15514,6 +15670,69 @@ public override void Write(Utf8JsonWriter writer, SessionFsSqliteQueryType value } +/// Transport the runtime would otherwise use for this request. `http` (the default when absent) covers plain HTTP and SSE responses; `websocket` indicates a full-duplex message channel where each body chunk maps to one WebSocket message and the `binary` flag distinguishes text from binary frames. The SDK consumer uses this to decide whether to service the request with an HTTP client or a WebSocket client. It is the one piece of request metadata the consumer cannot reliably infer from the URL or headers alone. +[Experimental(Diagnostics.Experimental)] +[JsonConverter(typeof(Converter))] +[DebuggerDisplay("{Value,nq}")] +public readonly struct LlmInferenceHttpRequestStartTransport : IEquatable +{ + private readonly string? _value; + + /// Initializes a new instance of the struct. + /// The value to associate with this . + [JsonConstructor] + public LlmInferenceHttpRequestStartTransport(string value) + { + ArgumentException.ThrowIfNullOrWhiteSpace(value); + _value = value; + } + + /// Gets the value associated with this . + public string Value => _value ?? string.Empty; + + /// Plain HTTP or SSE response. Each body chunk is an opaque byte range; the response is a status line, headers, and a (possibly streamed) body. + public static LlmInferenceHttpRequestStartTransport Http { get; } = new("http"); + + /// Full-duplex WebSocket channel. Each body chunk maps to exactly one WebSocket message and the `binary` flag distinguishes text from binary frames; request and response chunks flow concurrently. + public static LlmInferenceHttpRequestStartTransport Websocket { get; } = new("websocket"); + + /// Returns a value indicating whether two instances are equivalent. + public static bool operator ==(LlmInferenceHttpRequestStartTransport left, LlmInferenceHttpRequestStartTransport right) => left.Equals(right); + + /// Returns a value indicating whether two instances are not equivalent. + public static bool operator !=(LlmInferenceHttpRequestStartTransport left, LlmInferenceHttpRequestStartTransport right) => !(left == right); + + /// + public override bool Equals(object? obj) => obj is LlmInferenceHttpRequestStartTransport other && Equals(other); + + /// + public bool Equals(LlmInferenceHttpRequestStartTransport other) => string.Equals(Value, other.Value, StringComparison.OrdinalIgnoreCase); + + /// + public override int GetHashCode() => StringComparer.OrdinalIgnoreCase.GetHashCode(Value); + + /// + public override string ToString() => Value; + + /// Provides a for serializing instances. + [EditorBrowsable(EditorBrowsableState.Never)] + public sealed class Converter : JsonConverter + { + /// + public override LlmInferenceHttpRequestStartTransport Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + return new(GeneratedStringEnumJson.ReadValue(ref reader, typeToConvert)); + } + + /// + public override void Write(Utf8JsonWriter writer, LlmInferenceHttpRequestStartTransport value, JsonSerializerOptions options) + { + GeneratedStringEnumJson.WriteValue(writer, value.Value, typeof(LlmInferenceHttpRequestStartTransport)); + } + } +} + + /// Provides server-scoped RPC methods (no session required). public sealed class ServerRpc { @@ -15616,6 +15835,12 @@ internal async Task ConnectAsync(string? token = null, Cancellati Interlocked.CompareExchange(ref field, new(_rpc), null) ?? field; + /// LlmInference APIs. + public ServerLlmInferenceApi LlmInference => + field ?? + Interlocked.CompareExchange(ref field, new(_rpc), null) ?? + field; + /// Sessions APIs. public ServerSessionsApi Sessions => field ?? @@ -16162,6 +16387,59 @@ public async Task SetProviderAsync(string initialCwd } } +/// Provides server-scoped LlmInference APIs. +[Experimental(Diagnostics.Experimental)] +public sealed class ServerLlmInferenceApi +{ + private readonly JsonRpc _rpc; + + internal ServerLlmInferenceApi(JsonRpc rpc) + { + _rpc = rpc; + } + + /// Registers an SDK client as the LLM inference callback provider. + /// The to monitor for cancellation requests. The default is . + /// Indicates whether the calling client was registered as the LLM inference provider. + public async Task SetProviderAsync(CancellationToken cancellationToken = default) + { + return await CopilotClient.InvokeRpcAsync(_rpc, "llmInference.setProvider", [], cancellationToken); + } + + /// Delivers the response head (status + headers) for an in-flight request, correlated by the requestId the runtime supplied in httpRequestStart. Must be called exactly once per request before any httpResponseChunk frames. + /// Matches the requestId from the originating httpRequestStart frame. + /// HTTP status code. + /// HTTP headers as a map from lowercased header name to a list of values. Multi-valued headers (e.g. Set-Cookie) preserve all values. + /// Optional HTTP status reason phrase. + /// The to monitor for cancellation requests. The default is . + /// Whether the start frame was accepted. + public async Task HttpResponseStartAsync(string requestId, long status, IDictionary> headers, string? statusText = null, CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(requestId); + ArgumentNullException.ThrowIfNull(headers); + + var request = new LlmInferenceHttpResponseStartRequest { RequestId = requestId, Status = status, Headers = headers, StatusText = statusText }; + return await CopilotClient.InvokeRpcAsync(_rpc, "llmInference.httpResponseStart", [request], cancellationToken); + } + + /// Delivers a body byte range (or a terminal transport error) for an in-flight response, correlated by requestId. Set `end` true on the last chunk. When `error` is set the response terminates with a transport-level failure and the runtime raises an APIConnectionError. + /// Matches the requestId from the originating httpRequestStart frame. + /// Body byte range. UTF-8 text when `binary` is absent or false; base64-encoded bytes when `binary` is true. May be empty (e.g. when the response body is empty: send a single chunk with empty data and end=true). + /// When true, `data` is base64-encoded bytes. When absent or false, `data` is UTF-8 text. + /// When true, this is the final body chunk for the response. The runtime treats the response body as complete after receiving an end-marked chunk. + /// Set to terminate the response with a transport-level failure. Implies end-of-stream; any further chunks for this requestId are ignored. + /// The to monitor for cancellation requests. The default is . + /// Whether the chunk was accepted. + public async Task HttpResponseChunkAsync(string requestId, string data, bool? binary = null, bool? end = null, LlmInferenceHttpResponseChunkError? error = null, CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(requestId); + ArgumentNullException.ThrowIfNull(data); + + var request = new LlmInferenceHttpResponseChunkRequest { RequestId = requestId, Data = data, Binary = binary, End = end, Error = error }; + return await CopilotClient.InvokeRpcAsync(_rpc, "llmInference.httpResponseChunk", [request], cancellationToken); + } +} + /// Provides server-scoped Sessions APIs. [Experimental(Diagnostics.Experimental)] public sealed class ServerSessionsApi @@ -19573,6 +19851,53 @@ public static void RegisterClientSessionApiHandlers(JsonRpc rpc, FuncHandles `llmInference` client global API methods. +[Experimental(Diagnostics.Experimental)] +public interface ILlmInferenceHandler +{ + /// Announces an outbound model-layer HTTP request the runtime wants the SDK client to service. Carries the request head only; the body always follows as one or more httpRequestChunk frames keyed by the same requestId, even when the body is empty (a single chunk with end=true). + /// The head of an outbound model-layer HTTP request. + /// The to monitor for cancellation requests. The default is . + /// Acknowledgement. Returning successfully simply means the SDK accepted the start frame; it does not imply the request will succeed. + Task HttpRequestStartAsync(LlmInferenceHttpRequestStartRequest request, CancellationToken cancellationToken = default); + /// Delivers a body byte range (or a cancellation signal) for a request previously announced via httpRequestStart, correlated by requestId. The runtime fires at least one chunk per request — when there is no body, a single chunk with empty data and end=true. Mid-stream the runtime may send a chunk with cancel=true to abort the request; the SDK then stops issuing httpResponseChunk frames and may emit a terminal httpResponseChunk with error set. + /// A request body chunk or cancellation signal. + /// The to monitor for cancellation requests. The default is . + /// Acknowledgement. The SDK is free to ignore the ack and treat chunk delivery as fire-and-forget. + Task HttpRequestChunkAsync(LlmInferenceHttpRequestChunkRequest request, CancellationToken cancellationToken = default); +} + +/// Provides all client global API handler groups for a connection. +public sealed class ClientGlobalApiHandlers +{ + /// Optional handler for LlmInference client global API methods. + public ILlmInferenceHandler? LlmInference { get; set; } +} + +/// Registers client global API handlers on a JSON-RPC connection. +internal static class ClientGlobalApiRegistration +{ + /// + /// Registers handlers for server-to-client global API calls. + /// Unlike client session APIs, these methods carry no implicit + /// sessionId dispatch key — a single set of handlers serves the + /// entire connection. + /// + public static void RegisterClientGlobalApiHandlers(JsonRpc rpc, ClientGlobalApiHandlers handlers) + { + rpc.SetLocalRpcMethod("llmInference.httpRequestStart", (Func>)(async (request, cancellationToken) => + { + var handler = handlers.LlmInference ?? throw new InvalidOperationException("No llmInference client-global handler registered"); + return await handler.HttpRequestStartAsync(request, cancellationToken); + }), singleObjectParam: true); + rpc.SetLocalRpcMethod("llmInference.httpRequestChunk", (Func>)(async (request, cancellationToken) => + { + var handler = handlers.LlmInference ?? throw new InvalidOperationException("No llmInference client-global handler registered"); + return await handler.HttpRequestChunkAsync(request, cancellationToken); + }), singleObjectParam: true); + } +} + [JsonSourceGenerationOptions( JsonSerializerDefaults.Web, AllowOutOfOrderMetadataProperties = true, @@ -19924,6 +20249,16 @@ public static void RegisterClientSessionApiHandlers(JsonRpc rpc, FuncWhether the model response was blocked or truncated by content filtering (finish_reason === 'content_filter'). For Anthropic models this corresponds to a 'refusal' stop reason. + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + [JsonPropertyName("contentFilterTriggered")] + public bool? ContentFilterTriggered { get; set; } + /// Per-request cost and usage data from the CAPI copilot_usage response field. [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] [JsonInclude] @@ -2362,6 +2367,11 @@ public sealed partial class AssistantUsageData [JsonPropertyName("duration")] public TimeSpan? Duration { get; set; } + /// Finish reason reported by the model for this API call (e.g. "stop", "length", "tool_calls", "content_filter"). Normalized to OpenAI vocabulary; for Anthropic models a "refusal" stop reason maps to "content_filter". + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + [JsonPropertyName("finishReason")] + public string? FinishReason { get; set; } + /// What initiated this API call (e.g., "sub-agent", "mcp-sampling"); absent for user-initiated calls. [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] [JsonPropertyName("initiator")] @@ -4621,6 +4631,11 @@ public sealed partial class ToolExecutionCompleteResult [JsonPropertyName("detailedContent")] public string? DetailedContent { get; set; } + /// Structured content (arbitrary JSON) returned verbatim by the MCP tool. + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + [JsonPropertyName("structuredContent")] + public JsonElement? StructuredContent { get; set; } + /// MCP Apps UI resource content for rendering in a sandboxed iframe. [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] [JsonPropertyName("uiResource")] diff --git a/dotnet/src/GitHub.Copilot.SDK.csproj b/dotnet/src/GitHub.Copilot.SDK.csproj index 7a9fa2bdc..f37982155 100644 --- a/dotnet/src/GitHub.Copilot.SDK.csproj +++ b/dotnet/src/GitHub.Copilot.SDK.csproj @@ -27,6 +27,10 @@ $(NoWarn);GHCP001 + + + + true diff --git a/dotnet/src/LlmInferenceProvider.cs b/dotnet/src/LlmInferenceProvider.cs new file mode 100644 index 000000000..73b121f17 --- /dev/null +++ b/dotnet/src/LlmInferenceProvider.cs @@ -0,0 +1,628 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +using GitHub.Copilot.Rpc; +using System.Collections.Concurrent; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading.Channels; + +namespace GitHub.Copilot; + +/// +/// Transport the runtime would otherwise use to issue an intercepted +/// model-layer request. +/// +[Experimental(Diagnostics.Experimental)] +public enum LlmInferenceTransport +{ + /// + /// Plain HTTP or a streamed SSE response. Each body chunk is an opaque + /// byte range. + /// + Http, + + /// + /// Full-duplex WebSocket channel. Each request-body chunk is one inbound + /// WebSocket message and each response-body write is one outbound message. + /// + WebSocket, +} + +/// +/// An outbound model-layer HTTP (or WebSocket) request the runtime is asking +/// the SDK consumer to service on its behalf. +/// +/// +/// This is a low-level shape: URL / method / headers verbatim, body bytes +/// delivered as an async sequence, and the response delivered through the +/// sink. The runtime does not classify the request +/// (no provider type, endpoint kind, or wire API); consumers that need that +/// information derive it from the URL / headers themselves. +/// +internal sealed class LlmInferenceRequest +{ + /// Opaque runtime-minted id, stable across the request lifecycle. + public required string RequestId { get; init; } + + /// + /// Id of the runtime session that triggered this request, when one is in + /// scope. for out-of-session requests (e.g. startup + /// model catalog). + /// + public string? SessionId { get; init; } + + /// HTTP method (GET, POST, ...). + public required string Method { get; init; } + + /// Absolute request URL. + public required string Url { get; init; } + + /// HTTP request headers, lowercased names mapped to multi-valued lists. + public required IReadOnlyDictionary> Headers { get; init; } + + /// + /// Transport the runtime would otherwise use. + /// covers plain HTTP and SSE responses; + /// indicates a full-duplex message channel. Consumers branch on this to + /// decide whether to service the request with an HTTP client or a WebSocket + /// client. + /// + public LlmInferenceTransport Transport { get; init; } + + /// + /// Request body bytes, yielded as they arrive from the runtime. Always + /// enumerable; an empty body yields zero chunks before completing. For + /// WebSocket transport each element is one inbound message. + /// + public required IAsyncEnumerable> RequestBody { get; init; } + + /// + /// Cancelled when the runtime aborts this in-flight request (e.g. the agent + /// turn was aborted upstream). Pass it straight to HttpClient.SendAsync + /// / your transport so the upstream call is torn down too. After it fires, + /// writes to are ignored. + /// + public CancellationToken CancellationToken { get; init; } + + /// + /// Sink the consumer writes the upstream response into. Call + /// exactly once before + /// writing body chunks, then zero or more + /// + /// calls, and finish with or + /// . + /// + public required LlmInferenceResponseSink ResponseBody { get; init; } +} + +/// Response head passed to . +internal sealed class LlmInferenceResponseInit +{ + /// HTTP status code (101 acknowledges a WebSocket upgrade). + public int Status { get; init; } + + /// Optional HTTP status reason phrase. + public string? StatusText { get; init; } + + /// Response headers, lowercased names mapped to multi-valued lists. + public IReadOnlyDictionary>? Headers { get; init; } +} + +/// +/// Sink the consumer writes the upstream response into. The state machine is +/// strict: once → zero or more WriteAsync → +/// exactly one of or . Calling +/// out of order throws. +/// +internal abstract class LlmInferenceResponseSink +{ + /// Sends the response head (status + headers) back to the runtime. + public abstract Task StartAsync(LlmInferenceResponseInit init); + + /// Sends a binary body chunk (base64-encoded on the wire). + public abstract Task WriteAsync(ReadOnlyMemory data); + + /// Sends a UTF-8 text body chunk. + public abstract Task WriteAsync(string text); + + /// Marks end-of-stream cleanly. + public abstract Task EndAsync(); + + /// Marks end-of-stream with a transport-level failure. + public abstract Task ErrorAsync(string message, string? code = null); +} + +/// +/// Internal seam implemented by and consumed by +/// . The single callback handles both buffered +/// and streaming responses — the implementer calls +/// zero +/// or more times before . +/// +/// +/// Not part of the public API: consumers subclass +/// rather than implementing this directly. It exists so the adapter can drive any +/// handler through one uniform entry point. +/// +internal interface ILlmInferenceProvider +{ + /// + /// Invoked by the adapter once per outbound LLM request. The implementer is + /// responsible for eventually calling either + /// or + /// ; failing to do so leaks + /// runtime state. Throwing surfaces a transport-level failure to the runtime + /// (equivalent to ResponseBody.ErrorAsync(...) when + /// has not yet been called). + /// + Task OnLlmRequestAsync(LlmInferenceRequest request); +} + +/// +/// Adapts an into the generated +/// shape consumed by the SDK's RPC +/// dispatcher. +/// +/// +/// Maintains a per-requestId state table: each httpRequestStart +/// allocates a body channel + response sink and fires +/// in the background. +/// Subsequent httpRequestChunk frames are routed into the channel. The +/// sink translates Start / Write / End / Error calls +/// into outbound llmInference.httpResponseStart / +/// llmInference.httpResponseChunk calls. +/// +internal sealed class LlmInferenceAdapter : ILlmInferenceHandler +{ + private readonly ILlmInferenceProvider _provider; + private readonly Func _getChannel; + private readonly ConcurrentDictionary _pending = new(StringComparer.Ordinal); + + // Defense-in-depth backstop: chunks that arrive before their start frame + // (a reordering the runtime's single ordered dispatch should make + // impossible) are staged here and drained the moment httpRequestStart + // registers the matching state, so a body byte is never silently dropped. + private readonly ConcurrentDictionary> _staged = new(StringComparer.Ordinal); + + internal LlmInferenceAdapter(ILlmInferenceProvider provider, Func getServerRpc) + : this(provider, WrapServerRpc(getServerRpc ?? throw new ArgumentNullException(nameof(getServerRpc)))) + { + } + + internal LlmInferenceAdapter(ILlmInferenceProvider provider, Func getChannel) + { + _provider = provider ?? throw new ArgumentNullException(nameof(provider)); + _getChannel = getChannel ?? throw new ArgumentNullException(nameof(getChannel)); + } + + /// + /// Adapts a getter into a response-channel getter, + /// caching the wrapper so a new one is allocated only when the underlying + /// connection changes (e.g. reconnect). + /// + private static Func WrapServerRpc(Func getServerRpc) + { + ServerRpc? cachedRpc = null; + ILlmInferenceResponseChannel? cachedChannel = null; + return () => + { + var rpc = getServerRpc(); + if (rpc is null) + { + return null; + } + + if (!ReferenceEquals(rpc, cachedRpc)) + { + cachedRpc = rpc; + cachedChannel = new ServerRpcResponseChannel(rpc); + } + + return cachedChannel; + }; + } + + public Task HttpRequestStartAsync(LlmInferenceHttpRequestStartRequest request, CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(request); + + var state = new PendingState(); + _pending[request.RequestId] = state; + + if (_staged.TryRemove(request.RequestId, out var stagedChunks)) + { + foreach (var chunk in stagedChunks) + { + RouteChunk(state, chunk); + } + } + + var sink = new AdapterResponseSink(request.RequestId, state, _getChannel, _pending); + state.Sink = sink; + + var transport = request.Transport == LlmInferenceHttpRequestStartTransport.Websocket + ? LlmInferenceTransport.WebSocket + : LlmInferenceTransport.Http; + + var llmRequest = new LlmInferenceRequest + { + RequestId = request.RequestId, + SessionId = request.SessionId, + Method = request.Method, + Url = request.Url, + Headers = ToReadOnlyHeaders(request.Headers), + Transport = transport, + RequestBody = state.Body.ReadAllAsync(state.Abort.Token), + CancellationToken = state.Abort.Token, + ResponseBody = sink, + }; + + // Return from httpRequestStart immediately (after registering state) so + // the runtime's RPC reply is not gated on the consumer's I/O. The actual + // provider work runs asynchronously. + _ = RunProviderAsync(llmRequest, state, sink); + + return Task.FromResult(new LlmInferenceHttpRequestStartResult()); + } + + public Task HttpRequestChunkAsync(LlmInferenceHttpRequestChunkRequest request, CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(request); + + if (_pending.TryGetValue(request.RequestId, out var state)) + { + RouteChunk(state, request); + } + else + { + _staged.AddOrUpdate( + request.RequestId, + _ => [request], + (_, list) => + { + list.Add(request); + return list; + }); + } + + return Task.FromResult(new LlmInferenceHttpRequestChunkResult()); + } + + private async Task RunProviderAsync(LlmInferenceRequest request, PendingState state, AdapterResponseSink sink) + { + try + { + await _provider.OnLlmRequestAsync(request).ConfigureAwait(false); + if (!state.Finished) + { + await FailViaSink( + sink, + state, + "LLM inference provider returned without finalising the response (call ResponseBody.EndAsync() or .ErrorAsync()).").ConfigureAwait(false); + } + } + catch (Exception ex) + { + if (state.Cancelled || state.Abort.IsCancellationRequested) + { + // The runtime already cancelled this request; the provider's + // throw is just the abort propagating out of its upstream call. + await FinishCancelled(sink, state).ConfigureAwait(false); + return; + } + + await FailViaSink(sink, state, ex.Message).ConfigureAwait(false); + } + } + + private static async Task FailViaSink(AdapterResponseSink sink, PendingState state, string message) + { + if (state.Finished) + { + return; + } + + try + { + if (!state.Started) + { + await sink.StartAsync(new LlmInferenceResponseInit { Status = 502 }).ConfigureAwait(false); + } + + await sink.ErrorAsync(message).ConfigureAwait(false); + } + catch + { + // Best-effort — the connection may already be dead. + } + } + + private static async Task FinishCancelled(AdapterResponseSink sink, PendingState state) + { + if (state.Finished) + { + return; + } + + try + { + if (!state.Started) + { + await sink.StartAsync(new LlmInferenceResponseInit { Status = 499 }).ConfigureAwait(false); + } + + await sink.ErrorAsync("Request cancelled by runtime", "cancelled").ConfigureAwait(false); + } + catch + { + // Best-effort — the runtime already dropped the request on cancel. + } + } + + private static void RouteChunk(PendingState state, LlmInferenceHttpRequestChunkRequest chunk) + { + if (chunk.Cancel == true) + { + state.Cancelled = true; + state.Abort.Cancel(); + state.Body.PushCancel(chunk.CancelReason); + return; + } + + if (!string.IsNullOrEmpty(chunk.Data)) + { + state.Body.PushChunk(DecodeChunkData(chunk.Data, chunk.Binary == true)); + } + + if (chunk.End == true) + { + state.Body.PushEnd(); + } + } + + private static byte[] DecodeChunkData(string data, bool binary) => + binary ? Convert.FromBase64String(data) : Encoding.UTF8.GetBytes(data); + + private static Dictionary> ToReadOnlyHeaders(IDictionary> headers) + { + var result = new Dictionary>(StringComparer.OrdinalIgnoreCase); + foreach (var (name, values) in headers) + { + result[name] = values as IReadOnlyList ?? [.. values]; + } + + return result; + } + + private sealed class PendingState + { + public BodyChannel Body { get; } = new(); + + public CancellationTokenSource Abort { get; } = new(); + + public bool Started { get; set; } + + public bool Finished { get; set; } + + public bool Cancelled { get; set; } + + public AdapterResponseSink? Sink { get; set; } + } + + /// + /// An unbounded channel of request-body items exposed as an + /// of byte chunks. A cancel item surfaces + /// as an out of the enumerator so + /// the consumer's upstream call is torn down. + /// + private sealed class BodyChannel + { + private readonly Channel _channel = Channel.CreateUnbounded( + new UnboundedChannelOptions { SingleReader = true, SingleWriter = true }); + + public void PushChunk(byte[] data) => _channel.Writer.TryWrite(new Item { Chunk = data }); + + public void PushEnd() => _channel.Writer.TryWrite(new Item { End = true }); + + public void PushCancel(string? reason) => _channel.Writer.TryWrite(new Item { Cancel = true, CancelReason = reason }); + + public async IAsyncEnumerable> ReadAllAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + while (await _channel.Reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false)) + { + while (_channel.Reader.TryRead(out var item)) + { + if (item.Cancel) + { + _channel.Writer.TryComplete(); + throw new OperationCanceledException( + item.CancelReason is null + ? "Request cancelled by runtime" + : $"Request cancelled by runtime: {item.CancelReason}"); + } + + if (item.End) + { + _channel.Writer.TryComplete(); + yield break; + } + + if (item.Chunk is { Length: > 0 }) + { + yield return item.Chunk; + } + } + } + } + + private struct Item + { + public byte[]? Chunk; + public bool End; + public bool Cancel; + public string? CancelReason; + } + } + + private sealed class AdapterResponseSink( + string requestId, + PendingState state, + Func getChannel, + ConcurrentDictionary pending) : LlmInferenceResponseSink + { + public override async Task StartAsync(LlmInferenceResponseInit init) + { + ArgumentNullException.ThrowIfNull(init); + + if (state.Started) + { + throw new InvalidOperationException("LLM inference response sink StartAsync() called twice."); + } + + if (state.Finished) + { + throw new InvalidOperationException("LLM inference response sink already finished."); + } + + state.Started = true; + var result = await Channel() + .HttpResponseStartAsync(requestId, init.Status, ToWireHeaders(init.Headers), init.StatusText) + .ConfigureAwait(false); + if (!result.Accepted) + { + RejectedByRuntime(); + } + } + + public override Task WriteAsync(ReadOnlyMemory data) => + WriteChunk(Convert.ToBase64String(data.ToArray()), binary: true); + + public override Task WriteAsync(string text) + { + ArgumentNullException.ThrowIfNull(text); + return WriteChunk(text, binary: false); + } + + public override async Task EndAsync() + { + if (state.Finished) + { + return; + } + + state.Finished = true; + pending.TryRemove(requestId, out _); + await Channel().HttpResponseChunkAsync(requestId, string.Empty, end: true).ConfigureAwait(false); + } + + public override async Task ErrorAsync(string message, string? code = null) + { + ArgumentNullException.ThrowIfNull(message); + + if (state.Finished) + { + return; + } + + state.Finished = true; + pending.TryRemove(requestId, out _); + await Channel() + .HttpResponseChunkAsync( + requestId, + string.Empty, + end: true, + error: new LlmInferenceHttpResponseChunkError { Message = message, Code = code }) + .ConfigureAwait(false); + } + + private async Task WriteChunk(string data, bool binary) + { + if (state.Cancelled) + { + throw new InvalidOperationException("LLM inference request was cancelled by the runtime."); + } + + if (!state.Started) + { + throw new InvalidOperationException("LLM inference response sink WriteAsync() called before StartAsync()."); + } + + if (state.Finished) + { + throw new InvalidOperationException("LLM inference response sink WriteAsync() called after EndAsync()/ErrorAsync()."); + } + + var result = await Channel() + .HttpResponseChunkAsync(requestId, data, binary: binary, end: false) + .ConfigureAwait(false); + if (!result.Accepted) + { + RejectedByRuntime(); + } + } + + private ILlmInferenceResponseChannel Channel() => + getChannel() ?? throw new InvalidOperationException("LLM inference response sink used after RPC connection closed."); + + // The runtime acknowledges every response frame with accepted; accepted: + // false means it has dropped the request (e.g. it cancelled), so we abort + // the provider's upstream work and stop emitting. + private void RejectedByRuntime() + { + if (!state.Cancelled) + { + state.Cancelled = true; + state.Abort.Cancel(); + } + + state.Finished = true; + pending.TryRemove(requestId, out _); + throw new InvalidOperationException("LLM inference response was rejected by the runtime (request no longer active)."); + } + + private static Dictionary> ToWireHeaders(IReadOnlyDictionary>? headers) + { + var result = new Dictionary>(StringComparer.OrdinalIgnoreCase); + if (headers is null) + { + return result; + } + + foreach (var (name, values) in headers) + { + result[name] = values as IList ?? [.. values]; + } + + return result; + } + } +} + +/// +/// Minimal seam over the runtime-bound llmInference server API the +/// adapter uses to push response frames back to the runtime. Extracted as an +/// interface so the adapter can be unit-tested without a live JSON-RPC +/// connection. +/// +internal interface ILlmInferenceResponseChannel +{ + Task HttpResponseStartAsync(string requestId, long status, IDictionary> headers, string? statusText = null); + + Task HttpResponseChunkAsync(string requestId, string data, bool? binary = null, bool? end = null, LlmInferenceHttpResponseChunkError? error = null); +} + +/// +/// Production backed by the generated +/// client. +/// +internal sealed class ServerRpcResponseChannel(ServerRpc serverRpc) : ILlmInferenceResponseChannel +{ + public Task HttpResponseStartAsync(string requestId, long status, IDictionary> headers, string? statusText = null) => + serverRpc.LlmInference.HttpResponseStartAsync(requestId, status, headers, statusText); + + public Task HttpResponseChunkAsync(string requestId, string data, bool? binary = null, bool? end = null, LlmInferenceHttpResponseChunkError? error = null) => + serverRpc.LlmInference.HttpResponseChunkAsync(requestId, data, binary, end, error); +} diff --git a/dotnet/src/LlmRequestHandler.cs b/dotnet/src/LlmRequestHandler.cs new file mode 100644 index 000000000..b44cb9130 --- /dev/null +++ b/dotnet/src/LlmRequestHandler.cs @@ -0,0 +1,747 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +using System.Diagnostics.CodeAnalysis; +using System.Net.WebSockets; +using System.Text; + +namespace GitHub.Copilot; + +/// +/// Per-request context handed to every hook. +/// Mirrors the subset of fields that are +/// stable across the request lifetime, letting overrides observe routing / +/// cancellation without re-plumbing the underlying request. +/// +[Experimental(Diagnostics.Experimental)] +public sealed class LlmRequestContext +{ + /// Opaque runtime-minted id, stable across the request lifecycle. + public required string RequestId { get; init; } + + /// Runtime session id that triggered the request, if any. + public string? SessionId { get; init; } + + /// Transport the runtime would otherwise use. + public LlmInferenceTransport Transport { get; init; } + + /// Original request URL. + public required string Url { get; init; } + + /// Original request headers. + public required IReadOnlyDictionary> Headers { get; init; } + + /// + /// Cancelled when the runtime aborts this in-flight request. Subclasses that + /// issue their own I/O should pass this through so the upstream call is torn + /// down too. + /// + public CancellationToken CancellationToken { get; init; } + + internal LlmWebSocketResponseBridge? WebSocketResponse { get; set; } +} + +/// A single WebSocket message exchanged through a hook. +[Experimental(Diagnostics.Experimental)] +public readonly struct LlmWebSocketMessage(ReadOnlyMemory data, bool isBinary) +{ + /// The message payload bytes. + public ReadOnlyMemory Data { get; } = data; + + /// True for a binary frame; false for a UTF-8 text frame. + public bool IsBinary { get; } = isBinary; + + /// Decodes the payload as UTF-8 text. + public string GetText() => Encoding.UTF8.GetString(Data.ToArray()); + + /// Creates a text message from a UTF-8 string. + public static LlmWebSocketMessage Text(string text) => new(Encoding.UTF8.GetBytes(text), isBinary: false); + + /// Creates a binary message from raw bytes. + public static LlmWebSocketMessage Binary(ReadOnlyMemory data) => new(data, isBinary: true); +} + +/// +/// Terminal status for a callback-owned WebSocket connection. +/// +[Experimental(Diagnostics.Experimental)] +public sealed class LlmWebSocketCloseStatus +{ + /// The close description, if any. + public string? Description { get; init; } + + /// + /// Optional error code surfaced to the runtime when the close is a failure + /// rather than a clean end-of-stream. + /// + public string? ErrorCode { get; init; } + + /// The error that terminated the connection, if any. + public Exception? Error { get; init; } + + /// Shared normal-closure instance. + public static LlmWebSocketCloseStatus NormalClosure { get; } = new(); +} + +/// +/// Per-connection WebSocket handler returned by +/// . +/// +[Experimental(Diagnostics.Experimental)] +public abstract class CopilotWebSocketHandler : IAsyncDisposable +{ + private readonly TaskCompletionSource _completion = + new(TaskCreationOptions.RunContinuationsAsynchronously); + private int _closed; + private bool _suppressCloseOnDispose; + + /// Request context for this WebSocket connection. + protected LlmRequestContext Context { get; } + + internal Task Completion => _completion.Task; + + /// + /// Initializes a per-connection handler for the supplied request context. + /// + protected CopilotWebSocketHandler(LlmRequestContext context) + { + Context = context; + _ = context.WebSocketResponse ?? throw new InvalidOperationException("WebSocket response bridge is not attached."); + } + + /// + /// Send a message from the runtime to the upstream connection. + /// + public abstract Task SendRequestMessageAsync(LlmWebSocketMessage message); + + /// + /// Send a message from the upstream connection back to the runtime. + /// Override to mutate or duplicate messages; call base to emit. + /// + public virtual Task SendResponseMessageAsync(LlmWebSocketMessage message) => + Context.WebSocketResponse!.WriteAsync(message); + + /// + /// Close the connection and finalise the runtime-facing response. + /// + public virtual async Task CloseAsync(LlmWebSocketCloseStatus status) + { + if (Interlocked.Exchange(ref _closed, 1) != 0) + { + return; + } + + if (status.Error is not null) + { + await Context.WebSocketResponse! + .ErrorAsync(status.Description ?? status.Error.Message, status.ErrorCode) + .ConfigureAwait(false); + } + else + { + await Context.WebSocketResponse!.EndAsync().ConfigureAwait(false); + } + + _completion.TrySetResult(status); + } + + internal void SuppressCloseOnDispose() => _suppressCloseOnDispose = true; + + internal virtual Task OpenAsync() => Task.CompletedTask; + + /// + public virtual async ValueTask DisposeAsync() + { + GC.SuppressFinalize(this); + if (!_suppressCloseOnDispose && Volatile.Read(ref _closed) == 0) + { + await CloseAsync(LlmWebSocketCloseStatus.NormalClosure).ConfigureAwait(false); + } + } +} + +/// +/// Default pass-through WebSocket handler. Opens the real upstream socket and +/// relays messages unchanged unless a subclass overrides the send methods. +/// +[Experimental(Diagnostics.Experimental)] +public class ForwardingWebSocketHandler : CopilotWebSocketHandler +{ + private readonly string _url; + private readonly IReadOnlyDictionary> _headers; + private WebSocket? _upstream; + private CancellationTokenSource? _pumpCts; + private Task? _responsePump; + + /// + /// Initializes a forwarding handler that will open the upstream socket on + /// demand using the supplied URL/headers (or the values from + /// when omitted). + /// + public ForwardingWebSocketHandler( + LlmRequestContext context, + string? url = null, + IReadOnlyDictionary>? headers = null) + : base(context) + { + _url = url ?? context.Url; + _headers = headers ?? context.Headers; + } + + /// + /// Opens the upstream socket and starts the built-in response pump. + /// + internal override async Task OpenAsync() + { + if (_upstream is not null) + { + return; + } + + var socket = new ClientWebSocket(); + foreach (var (name, values) in _headers) + { + if (s_forbiddenRequestHeaders.Contains(name)) + { + continue; + } + + try + { + socket.Options.SetRequestHeader(name, string.Join(", ", values)); + } + catch + { + // Some headers are managed by the handshake; ignore rejections. + } + } + + await socket.ConnectAsync(LlmWebSocketHelpers.ToWebSocketUri(_url), Context.CancellationToken).ConfigureAwait(false); + _upstream = socket; + _pumpCts = CancellationTokenSource.CreateLinkedTokenSource(Context.CancellationToken); + _responsePump = Task.Run(() => PumpResponsesAsync(_pumpCts.Token), _pumpCts.Token); + } + + /// + /// Sends a message from the runtime to the upstream connection. Subclasses may override to mutate messages. + /// + /// The message to send. + /// A representing the asynchronous operation. + public override Task SendRequestMessageAsync(LlmWebSocketMessage message) + { + if (_upstream?.State != WebSocketState.Open) + { + return Task.CompletedTask; + } + + var type = message.IsBinary ? WebSocketMessageType.Binary : WebSocketMessageType.Text; + return _upstream.SendAsync( + new ArraySegment(message.Data.ToArray()), + type, + endOfMessage: true, + Context.CancellationToken); + } + + /// + public override async Task CloseAsync(LlmWebSocketCloseStatus status) + { + _pumpCts?.Cancel(); + if (_upstream is not null) + { + await LlmWebSocketHelpers.CloseWebSocketQuietlyAsync(_upstream).ConfigureAwait(false); + } + await base.CloseAsync(status).ConfigureAwait(false); + } + + /// + public override async ValueTask DisposeAsync() + { + GC.SuppressFinalize(this); + try + { + await base.DisposeAsync().ConfigureAwait(false); + } + finally + { + _pumpCts?.Cancel(); + _pumpCts?.Dispose(); + _upstream?.Dispose(); + if (_responsePump is not null) + { + await LlmWebSocketHelpers.ObserveQuietlyAsync(_responsePump).ConfigureAwait(false); + } + } + } + + private async Task PumpResponsesAsync(CancellationToken cancellationToken) + { + if (_upstream is null) + { + return; + } + + try + { + while (_upstream.State == WebSocketState.Open) + { + var message = await LlmWebSocketHelpers.ReceiveMessageAsync(_upstream, cancellationToken).ConfigureAwait(false); + if (message is null) + { + break; + } + + await SendResponseMessageAsync(message.Value).ConfigureAwait(false); + } + + await CloseAsync(LlmWebSocketCloseStatus.NormalClosure).ConfigureAwait(false); + } + catch (OperationCanceledException) when (Context.CancellationToken.IsCancellationRequested) + { + // Runtime-side cancellation aborts the request pump; the outer + // handler rethrows that cancellation rather than finalising here. + } + catch (Exception ex) + { + await CloseAsync(new LlmWebSocketCloseStatus + { + Description = ex.Message, + Error = ex, + }).ConfigureAwait(false); + } + } + + // Computed/managed by the HTTP/WS stack; forwarding them verbatim either + // throws or corrupts the request. + private static readonly HashSet s_forbiddenRequestHeaders = new(StringComparer.OrdinalIgnoreCase) + { + "host", + "connection", + "content-length", + "transfer-encoding", + "keep-alive", + "upgrade", + "proxy-connection", + "te", + "trailer", + }; +} + +/// +/// Base class for SDK consumers who want to observe or mutate the LLM inference +/// requests the runtime issues. +/// +[Experimental(Diagnostics.Experimental)] +public class LlmRequestHandler : ILlmInferenceProvider +{ + private static readonly HttpClient s_sharedHttpClient = new(); + + // Computed/managed by the HTTP stack; forwarding them verbatim either throws + // or corrupts the request. + private static readonly HashSet s_forbiddenRequestHeaders = new(StringComparer.OrdinalIgnoreCase) + { + "host", + "connection", + "content-length", + "transfer-encoding", + "keep-alive", + "upgrade", + "proxy-connection", + "te", + "trailer", + }; + + /// + async Task ILlmInferenceProvider.OnLlmRequestAsync(LlmInferenceRequest request) + { + ArgumentNullException.ThrowIfNull(request); + + var wsResponse = new LlmWebSocketResponseBridge(request.ResponseBody); + var ctx = new LlmRequestContext + { + RequestId = request.RequestId, + SessionId = request.SessionId, + Transport = request.Transport, + Url = request.Url, + Headers = request.Headers, + CancellationToken = request.CancellationToken, + }; + ctx.WebSocketResponse = wsResponse; + + if (request.Transport == LlmInferenceTransport.WebSocket) + { + await HandleWebSocketAsync(request, ctx).ConfigureAwait(false); + } + else + { + await HandleHttpAsync(request, ctx).ConfigureAwait(false); + } + } + + /// + /// Issue the upstream HTTP request. Override to mutate the request before + /// calling base, mutate the returned response after, or replace the + /// call entirely. + /// + protected virtual Task SendRequestAsync(HttpRequestMessage request, LlmRequestContext ctx) => + s_sharedHttpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, ctx.CancellationToken); + + /// + /// Open the upstream WebSocket connection. Override to return a custom + /// or to construct a + /// against a rewritten URL. + /// + protected virtual Task OpenWebSocketAsync(LlmRequestContext ctx) => + Task.FromResult(new ForwardingWebSocketHandler(ctx)); + + private async Task HandleHttpAsync(LlmInferenceRequest req, LlmRequestContext ctx) + { + using var request = await BuildHttpRequestAsync(req).ConfigureAwait(false); + using var response = await SendRequestAsync(request, ctx).ConfigureAwait(false); + await StreamResponseToSinkAsync(response, req, ctx).ConfigureAwait(false); + } + + private static async Task BuildHttpRequestAsync(LlmInferenceRequest req) + { + var method = new HttpMethod(req.Method.ToUpperInvariant()); + var message = new HttpRequestMessage(method, req.Url); + + var hasBody = method != HttpMethod.Get && method != HttpMethod.Head; + var body = await DrainAsync(req.RequestBody).ConfigureAwait(false); + if (hasBody && body.Length > 0) + { + message.Content = new ByteArrayContent(body); + } + + foreach (var (name, values) in req.Headers) + { + if (s_forbiddenRequestHeaders.Contains(name)) + { + continue; + } + + if (!message.Headers.TryAddWithoutValidation(name, values)) + { + message.Content ??= new ByteArrayContent([]); + message.Content.Headers.TryAddWithoutValidation(name, values); + } + } + + return message; + } + + private static async Task StreamResponseToSinkAsync(HttpResponseMessage response, LlmInferenceRequest req, LlmRequestContext ctx) + { + await req.ResponseBody.StartAsync(new LlmInferenceResponseInit + { + Status = (int)response.StatusCode, + StatusText = response.ReasonPhrase, + Headers = HeadersToMultiMap(response), + }).ConfigureAwait(false); + +#if NETSTANDARD2_0 + using var stream = await response.Content.ReadAsStreamAsync().ConfigureAwait(false); +#else + using var stream = await response.Content.ReadAsStreamAsync(ctx.CancellationToken).ConfigureAwait(false); +#endif + var buffer = new byte[16 * 1024]; + int read; +#if NETSTANDARD2_0 + while ((read = await stream.ReadAsync(buffer, 0, buffer.Length, ctx.CancellationToken).ConfigureAwait(false)) > 0) + { + await req.ResponseBody.WriteAsync(new ReadOnlyMemory(buffer, 0, read)).ConfigureAwait(false); + } +#else + while ((read = await stream.ReadAsync(buffer.AsMemory(), ctx.CancellationToken).ConfigureAwait(false)) > 0) + { + await req.ResponseBody.WriteAsync(new ReadOnlyMemory(buffer, 0, read)).ConfigureAwait(false); + } +#endif + + await req.ResponseBody.EndAsync().ConfigureAwait(false); + } + + private async Task HandleWebSocketAsync(LlmInferenceRequest req, LlmRequestContext ctx) + { + var handler = await OpenWebSocketAsync(ctx).ConfigureAwait(false); + try + { + await handler.OpenAsync().ConfigureAwait(false); + await ctx.WebSocketResponse!.StartAsync().ConfigureAwait(false); + + var clientPump = Task.Run(async () => + { + await foreach (var chunk in req.RequestBody.WithCancellation(ctx.CancellationToken).ConfigureAwait(false)) + { + await handler.SendRequestMessageAsync(new LlmWebSocketMessage(chunk, isBinary: false)).ConfigureAwait(false); + } + }, ctx.CancellationToken); + + var first = await Task.WhenAny(clientPump, handler.Completion).ConfigureAwait(false); + if (first == clientPump) + { + if (clientPump.IsFaulted || clientPump.IsCanceled) + { + handler.SuppressCloseOnDispose(); + await clientPump.ConfigureAwait(false); + } + + await handler.CloseAsync(LlmWebSocketCloseStatus.NormalClosure).ConfigureAwait(false); + await handler.Completion.ConfigureAwait(false); + return; + } + + var closeStatus = await handler.Completion.ConfigureAwait(false); + if (closeStatus.Error is not null) + { + throw closeStatus.Error; + } + } + finally + { + await handler.DisposeAsync().ConfigureAwait(false); + } + } + + private static async Task DrainAsync(IAsyncEnumerable> stream) + { + using var buffer = new MemoryStream(); + await foreach (var chunk in stream.ConfigureAwait(false)) + { + if (chunk.Length > 0) + { + buffer.Write(chunk.ToArray(), 0, chunk.Length); + } + } + + return buffer.ToArray(); + } + + private static Dictionary> HeadersToMultiMap(HttpResponseMessage response) + { + var result = new Dictionary>(StringComparer.OrdinalIgnoreCase); + foreach (var header in response.Headers) + { + result[header.Key] = [.. header.Value]; + } + + if (response.Content is not null) + { + foreach (var header in response.Content.Headers) + { + result[header.Key] = [.. header.Value]; + } + } + + return result; + } + +} + +internal static class LlmWebSocketHelpers +{ + internal static async Task ReceiveMessageAsync(WebSocket socket, CancellationToken cancellationToken) + { + var buffer = new byte[16 * 1024]; + using var assembled = new MemoryStream(); + WebSocketReceiveResult result; + do + { + try + { + result = await socket.ReceiveAsync(new ArraySegment(buffer), cancellationToken).ConfigureAwait(false); + } + catch (OperationCanceledException) + { + return null; + } + catch (WebSocketException) + { + return null; + } + + if (result.MessageType == WebSocketMessageType.Close) + { + return null; + } + + assembled.Write(buffer, 0, result.Count); + } + while (!result.EndOfMessage); + + return new LlmWebSocketMessage(assembled.ToArray(), result.MessageType == WebSocketMessageType.Binary); + } + + internal static async Task CloseWebSocketQuietlyAsync(WebSocket socket) + { + try + { + if (socket.State is WebSocketState.Open or WebSocketState.CloseReceived) + { + await socket.CloseAsync(WebSocketCloseStatus.NormalClosure, statusDescription: null, CancellationToken.None).ConfigureAwait(false); + } + } + catch + { + // Best-effort; the socket may already be closed. + } + } + + [SuppressMessage("Usage", "CA1031:Do not catch general exception types", Justification = "Best-effort teardown of the losing pump.")] + internal static async Task ObserveQuietlyAsync(Task task) + { + try + { + await task.ConfigureAwait(false); + } + catch + { + // Best-effort teardown only. + } + } + + internal static Uri ToWebSocketUri(string url) + { + var builder = new UriBuilder(url); + if (builder.Scheme.Equals("https", StringComparison.OrdinalIgnoreCase)) + { + builder.Scheme = "wss"; + } + else if (builder.Scheme.Equals("http", StringComparison.OrdinalIgnoreCase)) + { + builder.Scheme = "ws"; + } + + return builder.Uri; + } +} + +internal sealed class LlmWebSocketResponseBridge +{ + private readonly LlmInferenceResponseSink _sink; + private readonly SemaphoreSlim _gate = new(1, 1); + private readonly Queue _pending = new(); + private bool _started; + private bool _completed; + + internal LlmWebSocketResponseBridge(LlmInferenceResponseSink sink) + { + _sink = sink; + } + + internal async Task StartAsync() + { + await _gate.WaitAsync().ConfigureAwait(false); + try + { + if (_started) + { + return; + } + + _started = true; + await _sink.StartAsync(new LlmInferenceResponseInit { Status = 101 }).ConfigureAwait(false); + while (_pending.Count > 0) + { + await ApplyAsync(_pending.Dequeue()).ConfigureAwait(false); + } + } + finally + { + _gate.Release(); + } + } + + internal Task WriteAsync(LlmWebSocketMessage message) => EnqueueOrApplyAsync(PendingAction.Write(message)); + + internal Task EndAsync() => EnqueueOrApplyAsync(PendingAction.End()); + + internal Task ErrorAsync(string message, string? code) => EnqueueOrApplyAsync(PendingAction.Error(message, code)); + + private async Task EnqueueOrApplyAsync(PendingAction action) + { + await _gate.WaitAsync().ConfigureAwait(false); + try + { + if (_completed && action.Kind == PendingActionKind.Write) + { + return; + } + + if (!_started) + { + _pending.Enqueue(action); + if (action.Kind is PendingActionKind.End or PendingActionKind.Error) + { + _completed = true; + } + + return; + } + + await ApplyAsync(action).ConfigureAwait(false); + } + finally + { + _gate.Release(); + } + } + + private async Task ApplyAsync(PendingAction action) + { + if (_completed && action.Kind == PendingActionKind.Write) + { + return; + } + + switch (action.Kind) + { + case PendingActionKind.Write: + if (action.Message!.Value.IsBinary) + { + await _sink.WriteAsync(action.Message.Value.Data).ConfigureAwait(false); + } + else + { + await _sink.WriteAsync(action.Message.Value.GetText()).ConfigureAwait(false); + } + break; + case PendingActionKind.End: + if (_completed) + { + return; + } + + _completed = true; + await _sink.EndAsync().ConfigureAwait(false); + break; + case PendingActionKind.Error: + if (_completed) + { + return; + } + + _completed = true; + await _sink.ErrorAsync(action.ErrorMessage!, action.ErrorCode).ConfigureAwait(false); + break; + } + } + + private readonly record struct PendingAction( + PendingActionKind Kind, + LlmWebSocketMessage? Message = null, + string? ErrorMessage = null, + string? ErrorCode = null) + { + internal static PendingAction Write(LlmWebSocketMessage message) => new(PendingActionKind.Write, message); + internal static PendingAction End() => new(PendingActionKind.End); + internal static PendingAction Error(string message, string? code) => new(PendingActionKind.Error, null, message, code); + } + + private enum PendingActionKind + { + Write, + End, + Error, + } +} diff --git a/dotnet/src/Types.cs b/dotnet/src/Types.cs index 08e1dbbfa..786d38b03 100644 --- a/dotnet/src/Types.cs +++ b/dotnet/src/Types.cs @@ -278,6 +278,7 @@ private CopilotClientOptions(CopilotClientOptions? other) UseLoggedInUser = other.UseLoggedInUser; OnListModels = other.OnListModels; SessionFs = other.SessionFs; + LlmInference = other.LlmInference; SessionIdleTimeoutSeconds = other.SessionIdleTimeoutSeconds; EnableRemoteSessions = other.EnableRemoteSessions; Mode = other.Mode; @@ -364,6 +365,17 @@ private CopilotClientOptions(CopilotClientOptions? other) /// public SessionFsConfig? SessionFs { get; set; } + /// + /// Configures interception of the LLM inference requests the runtime would + /// otherwise issue itself (for both CAPI and BYOK providers). When set, the + /// client registers a client-global LLM inference provider on connect, so + /// every model-layer HTTP / WebSocket request is routed to the consumer's + /// (or + /// subclass) instead of the runtime's own outbound call. + /// + [Experimental(Diagnostics.Experimental)] + public LlmInferenceConfig? LlmInference { get; set; } + /// /// OpenTelemetry configuration for the runtime. /// When set to a non- instance, the runtime is started with OpenTelemetry instrumentation enabled. @@ -476,6 +488,21 @@ public sealed class SessionFsConfig public SessionFsSetProviderCapabilities? Capabilities { get; init; } } +/// +/// Configuration for intercepting the LLM inference requests the runtime issues. +/// +[Experimental(Diagnostics.Experimental)] +public sealed class LlmInferenceConfig +{ + /// + /// Handler that services every intercepted model-layer request for the + /// lifetime of the client connection. Subclass + /// and override its hooks to observe, mutate, or fully replace each + /// request/response. + /// + public LlmRequestHandler? Handler { get; set; } +} + /// /// Represents a binary result returned by a tool invocation. /// diff --git a/dotnet/test/E2E/LlmInferenceE2EProvider.cs b/dotnet/test/E2E/LlmInferenceE2EProvider.cs new file mode 100644 index 000000000..25fdadd76 --- /dev/null +++ b/dotnet/test/E2E/LlmInferenceE2EProvider.cs @@ -0,0 +1,167 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +using System.Collections.Concurrent; +using System.Net; +using System.Net.Http; +using System.Text; +using System.Text.RegularExpressions; + +namespace GitHub.Copilot.Test.E2E; + +#pragma warning disable GHCP001 // The LLM inference surface is intentionally experimental. + +/// +/// A subclass for e2e tests that records every +/// intercepted request (url + threaded session id) and fully replaces the +/// upstream call with a fabricated, well-formed response for every model-layer +/// endpoint, so an agent turn completes entirely off-network — no upstream +/// server and no CAPI proxy acting as the inference endpoint. +/// +/// +/// +/// This exercises the public extension surface end to end: a consumer subclasses +/// and overrides to +/// short-circuit the upstream HTTP call with any +/// it likes. The base class streams that response back to the runtime. +/// +/// +/// All response bodies are emitted as raw JSON string literals rather than via +/// JsonSerializer: the test project disables reflection-based STJ on +/// net8.0 (JsonSerializerIsReflectionEnabledByDefault=false), so +/// serializing anonymous types would throw at runtime. +/// +/// +internal sealed class RecordingInferenceProvider : LlmRequestHandler +{ + internal const string SyntheticText = "OK from the synthetic stream."; + + private static readonly Regex WantsStreamRegex = new("\"stream\"\\s*:\\s*true", RegexOptions.Compiled); + + private readonly ConcurrentQueue _records = new(); + + public IReadOnlyCollection Records => _records; + + public IReadOnlyList InferenceRequests => + [.. _records.Where(r => IsInferenceUrl(r.Url))]; + + protected override async Task SendRequestAsync(HttpRequestMessage request, LlmRequestContext ctx) + { + var url = request.RequestUri!.ToString(); + _records.Enqueue(new InterceptedRequest(url, ctx.SessionId)); + + var bodyText = request.Content is null + ? string.Empty +#if NET8_0_OR_GREATER + : await request.Content.ReadAsStringAsync(ctx.CancellationToken).ConfigureAwait(false); +#else + : await request.Content.ReadAsStringAsync().ConfigureAwait(false); +#endif + + return IsInferenceUrl(url) + ? BuildInferenceResponse(url, bodyText) + : BuildNonInferenceResponse(url); + } + + internal static bool IsInferenceUrl(string url) + { + var u = url.ToLowerInvariant(); + return u.EndsWith("/chat/completions", StringComparison.Ordinal) + || u.EndsWith("/responses", StringComparison.Ordinal) + || u.EndsWith("/v1/messages", StringComparison.Ordinal) + || u.EndsWith("/messages", StringComparison.Ordinal); + } + + /// + /// Synthesizes a well-formed inference response so the agent turn completes. + /// The runtime selects /responses for both the CAPI and BYOK sessions + /// here; /chat/completions is handled too for robustness. + /// + private static HttpResponseMessage BuildInferenceResponse(string url, string bodyText) + { + var wantsStream = WantsStreamRegex.IsMatch(bodyText); + var u = url.ToLowerInvariant(); + + if (u.Contains("/responses", StringComparison.Ordinal)) + { + return wantsStream + ? Sse(string.Concat(ResponsesStreamEvents)) + : Json(BufferedResponseJson); + } + + if (u.Contains("/chat/completions", StringComparison.Ordinal) && wantsStream) + { + return Sse(string.Concat(ChatCompletionStreamEvents)); + } + + // /chat/completions non-streaming (and any other inference url) — buffered JSON. + return Json(BufferedChatCompletionJson); + } + + /// + /// Serves the non-inference model-layer GETs/POSTs the runtime issues + /// (catalog, model session, policy). These flow through the same callback + /// but carry no session id (they happen outside an agent turn). + /// + private static HttpResponseMessage BuildNonInferenceResponse(string url) + { + var u = url.ToLowerInvariant(); + if (u.EndsWith("/models", StringComparison.Ordinal)) + { + return Json(ModelCatalogJson); + } + + if (u.Contains("/models/session", StringComparison.Ordinal)) + { + return Json("{}"); + } + + if (u.Contains("/policy", StringComparison.Ordinal)) + { + return Json("{\"state\":\"enabled\"}"); + } + + return Json("{}"); + } + + private static HttpResponseMessage Json(string body) => new(HttpStatusCode.OK) + { + Content = new StringContent(body, Encoding.UTF8, "application/json"), + }; + + private static HttpResponseMessage Sse(string body) => new(HttpStatusCode.OK) + { + Content = new StringContent(body, Encoding.UTF8, "text/event-stream"), + }; + + private static readonly string[] ResponsesStreamEvents = + [ + "event: response.created\ndata: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_stub_1\",\"object\":\"response\",\"status\":\"in_progress\",\"output\":[]}}\n\n", + "event: response.output_item.added\ndata: {\"type\":\"response.output_item.added\",\"output_index\":0,\"item\":{\"id\":\"msg_1\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[]}}\n\n", + "event: response.content_part.added\ndata: {\"type\":\"response.content_part.added\",\"output_index\":0,\"content_index\":0,\"part\":{\"type\":\"output_text\",\"text\":\"\"}}\n\n", + "event: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"output_index\":0,\"content_index\":0,\"delta\":\"" + SyntheticText + "\"}\n\n", + "event: response.output_text.done\ndata: {\"type\":\"response.output_text.done\",\"output_index\":0,\"content_index\":0,\"text\":\"" + SyntheticText + "\"}\n\n", + "event: response.completed\ndata: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_stub_1\",\"object\":\"response\",\"status\":\"completed\",\"output\":[{\"id\":\"msg_1\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"" + SyntheticText + "\"}]}],\"usage\":{\"input_tokens\":5,\"output_tokens\":7,\"total_tokens\":12}}}\n\n", + ]; + + private static readonly string[] ChatCompletionStreamEvents = + [ + "data: {\"id\":\"chatcmpl-stub-1\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"claude-sonnet-4.5\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":\"\"},\"finish_reason\":null}]}\n\n", + "data: {\"id\":\"chatcmpl-stub-1\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"claude-sonnet-4.5\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"" + SyntheticText + "\"},\"finish_reason\":null}]}\n\n", + "data: {\"id\":\"chatcmpl-stub-1\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"claude-sonnet-4.5\",\"choices\":[{\"index\":0,\"delta\":{},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":5,\"completion_tokens\":7,\"total_tokens\":12}}\n\n", + "data: [DONE]\n\n", + ]; + + private static readonly string BufferedResponseJson = + "{\"id\":\"resp_stub_1\",\"object\":\"response\",\"status\":\"completed\",\"output\":[{\"id\":\"msg_1\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"" + SyntheticText + "\"}]}],\"usage\":{\"input_tokens\":5,\"output_tokens\":7,\"total_tokens\":12}}"; + + private static readonly string BufferedChatCompletionJson = + "{\"id\":\"chatcmpl-stub-1\",\"object\":\"chat.completion\",\"created\":1,\"model\":\"claude-sonnet-4.5\",\"choices\":[{\"index\":0,\"message\":{\"role\":\"assistant\",\"content\":\"" + SyntheticText + "\"},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":5,\"completion_tokens\":7,\"total_tokens\":12}}"; + + private const string ModelCatalogJson = + "{\"data\":[{\"id\":\"claude-sonnet-4.5\",\"name\":\"Claude Sonnet 4.5\",\"object\":\"model\",\"vendor\":\"Anthropic\",\"version\":\"1\",\"preview\":false,\"model_picker_enabled\":true,\"capabilities\":{\"type\":\"chat\",\"family\":\"claude-sonnet-4.5\",\"tokenizer\":\"o200k_base\",\"limits\":{\"max_context_window_tokens\":200000,\"max_output_tokens\":8192},\"supports\":{\"streaming\":true,\"tool_calls\":true,\"parallel_tool_calls\":true,\"vision\":true}}}]}"; +} + +/// A single request the callback intercepted. +internal sealed record InterceptedRequest(string Url, string? SessionId); diff --git a/dotnet/test/E2E/LlmInferenceSessionIdE2ETests.cs b/dotnet/test/E2E/LlmInferenceSessionIdE2ETests.cs new file mode 100644 index 000000000..be1db1de9 --- /dev/null +++ b/dotnet/test/E2E/LlmInferenceSessionIdE2ETests.cs @@ -0,0 +1,107 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +using GitHub.Copilot.Test.Harness; +using Xunit; +using Xunit.Abstractions; + +namespace GitHub.Copilot.Test.E2E; + +#pragma warning disable GHCP001 // The LLM inference surface is intentionally experimental. + +/// +/// Asserts the runtime threads its session id into the LLM inference callback +/// for BOTH a CAPI session and a BYOK session. The callback alone services +/// every model-layer request — no upstream server, no CAPI proxy acting as the +/// inference endpoint — so the only source of req.SessionId is the +/// runtime's own per-client threading. +/// +public class LlmInferenceSessionIdE2ETests(E2ETestFixture fixture, ITestOutputHelper output) + : E2ETestBase(fixture, "llm_inference_session_id", output) +{ + private CopilotClient CreateClientWith(RecordingInferenceProvider provider) => + Ctx.CreateClient(options: new CopilotClientOptions + { + Connection = RuntimeConnection.ForStdio(), + LlmInference = new LlmInferenceConfig + { + Handler = provider, + }, + }); + + [Fact] + public async Task Threads_The_Session_Id_Into_A_Capi_Session_Inference_Request() + { + var provider = new RecordingInferenceProvider(); + await using var client = CreateClientWith(provider); + await client.StartAsync(); + + var session = await client.CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + }); + var capiSessionId = session.SessionId; + + string content; + try + { + var msg = await session.SendAndWaitAsync(new MessageOptions { Prompt = "Say OK." }); + content = msg?.Data.Content ?? string.Empty; + } + finally + { + await session.DisposeAsync(); + } + + var inference = provider.InferenceRequests; + Assert.NotEmpty(inference); + Assert.All(inference, r => Assert.Equal(capiSessionId, r.SessionId)); + + // Validate the final assistant response arrived (guards against truncated captures) + Assert.Contains("OK from the synthetic", content); + } + + [Fact] + public async Task Threads_The_Session_Id_Into_A_Byok_Session_Inference_Request() + { + var provider = new RecordingInferenceProvider(); + await using var client = CreateClientWith(provider); + await client.StartAsync(); + + var session = await client.CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + // BYOK providers require an explicit model id. + Model = "claude-sonnet-4.5", + Provider = new ProviderConfig + { + Type = "openai", + WireApi = "responses", + BaseUrl = "https://byok.invalid/v1", + ApiKey = "byok-secret", + ModelId = "claude-sonnet-4.5", + WireModel = "claude-sonnet-4.5", + }, + }); + var byokSessionId = session.SessionId; + + string content; + try + { + var msg = await session.SendAndWaitAsync(new MessageOptions { Prompt = "Say OK." }); + content = msg?.Data.Content ?? string.Empty; + } + finally + { + await session.DisposeAsync(); + } + + var inference = provider.InferenceRequests; + Assert.NotEmpty(inference); + Assert.All(inference, r => Assert.Equal(byokSessionId, r.SessionId)); + + // Validate the final assistant response arrived (guards against truncated captures) + Assert.Contains("OK from the synthetic", content); + } +} diff --git a/dotnet/test/GitHub.Copilot.SDK.Test.csproj b/dotnet/test/GitHub.Copilot.SDK.Test.csproj index 4b27df57c..49e117d83 100644 --- a/dotnet/test/GitHub.Copilot.SDK.Test.csproj +++ b/dotnet/test/GitHub.Copilot.SDK.Test.csproj @@ -7,6 +7,13 @@ false true $(NoWarn);GHCP001 + + $(NoWarn);CS0436 @@ -35,7 +42,11 @@ - + diff --git a/dotnet/test/Unit/LlmInference/LlmInferenceAdapterTests.cs b/dotnet/test/Unit/LlmInference/LlmInferenceAdapterTests.cs new file mode 100644 index 000000000..94d50f378 --- /dev/null +++ b/dotnet/test/Unit/LlmInference/LlmInferenceAdapterTests.cs @@ -0,0 +1,197 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +using System.Text; +using Xunit; + +namespace GitHub.Copilot.Test.Unit.LlmInference; + +#pragma warning disable GHCP001 // The LLM inference surface is intentionally experimental. + +public class LlmInferenceAdapterTests +{ + private static readonly TimeSpan Timeout = TimeSpan.FromSeconds(10); + + private static LlmInferenceAdapter CreateAdapter(ILlmInferenceProvider provider, RecordingResponseChannel channel) + { + ILlmInferenceResponseChannel current = channel; + return new LlmInferenceAdapter(provider, () => current); + } + + [Fact] + public async Task Stages_request_chunks_that_arrive_before_their_start_frame_and_replays_them_in_order() + { + var received = new List(); + var done = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var provider = new InlineProvider(async req => + { + await foreach (var chunk in req.RequestBody) + { + received.Add(Encoding.UTF8.GetString(chunk.ToArray())); + } + + await req.ResponseBody.StartAsync(new LlmInferenceResponseInit { Status = 200 }); + await req.ResponseBody.EndAsync(); + done.SetResult(); + }); + + var channel = new RecordingResponseChannel(); + var adapter = CreateAdapter(provider, channel); + + // Chunks arrive BEFORE the start frame (a reordering the runtime should + // never produce). They must be staged and replayed once start registers. + await adapter.HttpRequestChunkAsync(LlmFrames.Chunk("r1", "hello ", end: false)); + await adapter.HttpRequestChunkAsync(LlmFrames.Chunk("r1", "world", end: false)); + await adapter.HttpRequestChunkAsync(LlmFrames.Chunk("r1", "", end: true)); + + await adapter.HttpRequestStartAsync(LlmFrames.Start("r1")); + + await done.Task.WaitAsync(Timeout); + Assert.Equal("hello world", string.Concat(received)); + } + + [Fact] + public async Task Emits_a_buffered_response_as_start_then_body_then_terminal_end() + { + var done = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var provider = new InlineProvider(async req => + { + await foreach (var _ in req.RequestBody) + { + // drain + } + + await req.ResponseBody.StartAsync(new LlmInferenceResponseInit + { + Status = 200, + Headers = new Dictionary> { ["content-type"] = ["application/json"] }, + }); + await req.ResponseBody.WriteAsync("OK"); + await req.ResponseBody.EndAsync(); + done.SetResult(); + }); + + var channel = new RecordingResponseChannel(); + var adapter = CreateAdapter(provider, channel); + + await adapter.HttpRequestStartAsync(LlmFrames.Start("r2")); + await adapter.HttpRequestChunkAsync(LlmFrames.Chunk("r2", "", end: true)); + + await done.Task.WaitAsync(Timeout); + + var start = Assert.Single(channel.Starts); + Assert.Equal(200, start.Status); + Assert.Equal("OK", channel.DecodeTextBody()); + + var terminal = Assert.Single(channel.Chunks, c => c.End == true); + Assert.Null(terminal.Error); + } + + [Fact] + public async Task Aborts_the_provider_and_throws_from_write_when_the_runtime_rejects_a_response_frame() + { + var aborted = false; + var writeThrew = false; + var settled = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var provider = new InlineProvider(async req => + { + req.CancellationToken.Register(() => aborted = true); + await foreach (var _ in req.RequestBody) + { + // drain + } + + await req.ResponseBody.StartAsync(new LlmInferenceResponseInit { Status = 200 }); + try + { + await req.ResponseBody.WriteAsync("rejected-chunk"); + } + catch (InvalidOperationException) + { + writeThrew = true; + } + + settled.SetResult(); + }); + + // The runtime accepts the start frame but rejects the body chunk. + var channel = new RecordingResponseChannel(acceptStart: true, acceptChunk: false); + var adapter = CreateAdapter(provider, channel); + + await adapter.HttpRequestStartAsync(LlmFrames.Start("r3")); + await adapter.HttpRequestChunkAsync(LlmFrames.Chunk("r3", "", end: true)); + + await settled.Task.WaitAsync(Timeout); + Assert.True(writeThrew, "write should throw after the runtime rejects the chunk"); + Assert.True(aborted, "the provider's cancellation token should fire on rejection"); + } + + [Fact] + public async Task Surfaces_a_runtime_cancel_chunk_as_a_cancelled_terminal_error() + { + var observedCancellation = false; + var done = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var provider = new InlineProvider(async req => + { + try + { + await foreach (var _ in req.RequestBody) + { + // The cancel frame surfaces as an OperationCanceledException here. + } + } + catch (OperationCanceledException) + { + observedCancellation = true; + throw; + } + finally + { + done.TrySetResult(); + } + }); + + var channel = new RecordingResponseChannel(); + var adapter = CreateAdapter(provider, channel); + + await adapter.HttpRequestStartAsync(LlmFrames.Start("r4")); + await adapter.HttpRequestChunkAsync(LlmFrames.Chunk("r4", cancel: true, cancelReason: "turn aborted")); + + await done.Task.WaitAsync(Timeout); + await channel.Terminal.WaitAsync(Timeout); + Assert.True(observedCancellation, "the request body iterator should throw on a cancel frame"); + + // The adapter finalises a cancelled request as a 499 + error{code:cancelled}. + var terminal = Assert.Single(channel.Chunks, c => c.Error is not null); + Assert.Equal("cancelled", terminal.Error!.Code); + } + + [Fact] + public async Task Threads_the_runtime_session_id_into_the_request() + { + string? observedSessionId = null; + var done = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var provider = new InlineProvider(async req => + { + observedSessionId = req.SessionId; + await foreach (var _ in req.RequestBody) + { + // drain + } + + await req.ResponseBody.StartAsync(new LlmInferenceResponseInit { Status = 200 }); + await req.ResponseBody.EndAsync(); + done.SetResult(); + }); + + var channel = new RecordingResponseChannel(); + var adapter = CreateAdapter(provider, channel); + + await adapter.HttpRequestStartAsync(LlmFrames.Start("r5", sessionId: "session-123")); + await adapter.HttpRequestChunkAsync(LlmFrames.Chunk("r5", "", end: true)); + + await done.Task.WaitAsync(Timeout); + Assert.Equal("session-123", observedSessionId); + } +} diff --git a/dotnet/test/Unit/LlmInference/LlmInferenceHandlerTests.cs b/dotnet/test/Unit/LlmInference/LlmInferenceHandlerTests.cs new file mode 100644 index 000000000..663884781 --- /dev/null +++ b/dotnet/test/Unit/LlmInference/LlmInferenceHandlerTests.cs @@ -0,0 +1,159 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +using System.Net; +using System.Net.Http; +using System.Text; +using Xunit; + +namespace GitHub.Copilot.Test.Unit.LlmInference; + +#pragma warning disable GHCP001 // The LLM inference surface is intentionally experimental. + +public class LlmInferenceHandlerTests +{ + private static readonly TimeSpan Timeout = TimeSpan.FromSeconds(10); + + private static Task Dispatch(LlmRequestHandler handler, LlmInferenceRequest request) => + ((ILlmInferenceProvider)handler).OnLlmRequestAsync(request); + + private static async IAsyncEnumerable> AsyncBytes(params string[] chunks) + { + foreach (var chunk in chunks) + { + await Task.Yield(); + yield return Encoding.UTF8.GetBytes(chunk); + } + } + + private static LlmInferenceRequest HttpRequest( + RecordingSink sink, + IAsyncEnumerable> body, + string method = "POST", + string url = "https://upstream.test/v1/chat/completions", + IReadOnlyDictionary>? headers = null) => + new() + { + RequestId = "req-1", + SessionId = "session-1", + Method = method, + Url = url, + Headers = headers ?? new Dictionary>(), + Transport = LlmInferenceTransport.Http, + RequestBody = body, + ResponseBody = sink, + }; + + /// A handler whose upstream call is a canned delegate (no network). + private sealed class StubHandler(Func send) : LlmRequestHandler + { + protected override Task SendRequestAsync(HttpRequestMessage request, LlmRequestContext ctx) => + Task.FromResult(send(request)); + } + + /// A handler that adds a header before calling base.SendRequestAsync. + private sealed class HeaderMutatingHandler(Func send) : LlmRequestHandler + { + protected override Task SendRequestAsync(HttpRequestMessage request, LlmRequestContext ctx) + { + request.Headers.TryAddWithoutValidation("authorization", "Bearer swapped-token"); + return Task.FromResult(send(request)); + } + } + + [Fact] + public async Task Forwards_request_body_and_streams_response_back_to_the_sink() + { + string? forwardedBody = null; + var handler = new StubHandler(req => + { + forwardedBody = req.Content!.ReadAsStringAsync().GetAwaiter().GetResult(); + return new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent("RESPONSE-BODY", Encoding.UTF8, "application/json"), + }; + }); + + var sink = new RecordingSink(); + var request = HttpRequest(sink, AsyncBytes("{\"hello\":", "\"world\"}")); + + await Dispatch(handler, request).WaitAsync(Timeout); + + Assert.Equal("{\"hello\":\"world\"}", forwardedBody); + + var start = Assert.Single(sink.Starts); + Assert.Equal(200, start.Status); + Assert.Equal("RESPONSE-BODY", sink.DecodeBinaryBody()); + Assert.True(sink.Ended); + Assert.Null(sink.Errored); + } + + [Fact] + public async Task Strips_forbidden_request_headers_before_forwarding() + { + var forwarded = new Dictionary(StringComparer.OrdinalIgnoreCase); + var handler = new StubHandler(req => + { + foreach (var header in req.Headers) + { + forwarded[header.Key] = string.Join(",", header.Value); + } + + return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent("ok") }; + }); + + var sink = new RecordingSink(); + var headers = new Dictionary> + { + ["host"] = ["should-be-stripped.test"], + ["x-tenant"] = ["acme"], + }; + var request = HttpRequest(sink, AsyncBytes("body"), headers: headers); + + await Dispatch(handler, request).WaitAsync(Timeout); + + Assert.False(forwarded.ContainsKey("host"), "the forbidden host header must be stripped"); + Assert.Equal("acme", forwarded["x-tenant"]); + } + + [Fact] + public async Task Lets_a_subclass_mutate_the_outbound_request_headers() + { + string? observedAuth = null; + var handler = new HeaderMutatingHandler(req => + { + observedAuth = req.Headers.TryGetValues("authorization", out var values) + ? string.Join(",", values) + : null; + return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent("ok") }; + }); + + var sink = new RecordingSink(); + var request = HttpRequest(sink, AsyncBytes("body")); + + await Dispatch(handler, request).WaitAsync(Timeout); + + Assert.Equal("Bearer swapped-token", observedAuth); + } + + [Fact] + public async Task Propagates_a_non_2xx_status_verbatim_to_the_runtime() + { + var handler = new StubHandler(_ => + new HttpResponseMessage((HttpStatusCode)429) + { + Content = new StringContent("slow down"), + }); + + var sink = new RecordingSink(); + var request = HttpRequest(sink, AsyncBytes()); + + await Dispatch(handler, request).WaitAsync(Timeout); + + var start = Assert.Single(sink.Starts); + Assert.Equal(429, start.Status); + Assert.Equal("slow down", sink.DecodeBinaryBody()); + Assert.True(sink.Ended); + } +} diff --git a/dotnet/test/Unit/LlmInference/LlmInferenceTestInfra.cs b/dotnet/test/Unit/LlmInference/LlmInferenceTestInfra.cs new file mode 100644 index 000000000..65339732a --- /dev/null +++ b/dotnet/test/Unit/LlmInference/LlmInferenceTestInfra.cs @@ -0,0 +1,157 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +using GitHub.Copilot.Rpc; +using System.Text; + +namespace GitHub.Copilot.Test.Unit.LlmInference; + +#pragma warning disable GHCP001 // The LLM inference surface is intentionally experimental. + +/// +/// In-memory that records every +/// response frame the adapter emits and lets a test choose what +/// accepted value the runtime returns. +/// +internal sealed class RecordingResponseChannel(bool acceptStart = true, bool acceptChunk = true) : ILlmInferenceResponseChannel +{ + public sealed record StartFrame(long Status, string? StatusText, IDictionary> Headers); + + public sealed record ChunkFrame(string Data, bool? Binary, bool? End, LlmInferenceHttpResponseChunkError? Error); + + public List Starts { get; } = []; + + public List Chunks { get; } = []; + + private readonly TaskCompletionSource _terminal = new(TaskCreationOptions.RunContinuationsAsynchronously); + + /// Completes once a terminal response chunk (end or error) is recorded. + public Task Terminal => _terminal.Task; + + public Task HttpResponseStartAsync(string requestId, long status, IDictionary> headers, string? statusText = null) + { + Starts.Add(new StartFrame(status, statusText, headers)); + return Task.FromResult(new LlmInferenceHttpResponseStartResult { Accepted = acceptStart }); + } + + public Task HttpResponseChunkAsync(string requestId, string data, bool? binary = null, bool? end = null, LlmInferenceHttpResponseChunkError? error = null) + { + Chunks.Add(new ChunkFrame(data, binary, end, error)); + if (end == true || error is not null) + { + _terminal.TrySetResult(); + } + + return Task.FromResult(new LlmInferenceHttpResponseChunkResult { Accepted = acceptChunk }); + } + + /// Concatenates the UTF-8 text of all non-terminal body chunks. + public string DecodeTextBody() + { + var sb = new StringBuilder(); + foreach (var chunk in Chunks) + { + if (chunk.Error is not null || chunk.Data.Length == 0) + { + continue; + } + + sb.Append(chunk.Binary == true + ? Encoding.UTF8.GetString(Convert.FromBase64String(chunk.Data)) + : chunk.Data); + } + + return sb.ToString(); + } +} + +/// An driven by an inline delegate. +internal sealed class InlineProvider(Func handler) : ILlmInferenceProvider +{ + public Task OnLlmRequestAsync(LlmInferenceRequest request) => handler(request); +} + +/// Records everything written to a . +internal sealed class RecordingSink : LlmInferenceResponseSink +{ + public List Starts { get; } = []; + + public List TextWrites { get; } = []; + + public List BinaryWrites { get; } = []; + + public bool Ended { get; private set; } + + public (string Message, string? Code)? Errored { get; private set; } + + /// Concatenates all binary body writes and decodes them as UTF-8. + public string DecodeBinaryBody() => Encoding.UTF8.GetString(BinaryWrites.SelectMany(b => b).ToArray()); + + public override Task StartAsync(LlmInferenceResponseInit init) + { + Starts.Add(init); + return Task.CompletedTask; + } + + public override Task WriteAsync(ReadOnlyMemory data) + { + BinaryWrites.Add(data.ToArray()); + return Task.CompletedTask; + } + + public override Task WriteAsync(string text) + { + TextWrites.Add(text); + return Task.CompletedTask; + } + + public override Task EndAsync() + { + Ended = true; + return Task.CompletedTask; + } + + public override Task ErrorAsync(string message, string? code = null) + { + Errored = (message, code); + return Task.CompletedTask; + } +} + +/// Convenience builders for the generated request frames. +internal static class LlmFrames +{ + public static LlmInferenceHttpRequestStartRequest Start( + string requestId, + string url = "https://example.test/v1/chat", + string method = "POST", + string? sessionId = null, + LlmInferenceHttpRequestStartTransport? transport = null) => + new() + { + RequestId = requestId, + Url = url, + Method = method, + SessionId = sessionId, + Headers = new Dictionary>(), + Transport = transport, + }; + + public static LlmInferenceHttpRequestChunkRequest Chunk( + string requestId, + string data = "", + bool? end = null, + bool? binary = null, + bool? cancel = null, + string? cancelReason = null) => + new() + { + RequestId = requestId, + Data = data, + End = end, + Binary = binary, + Cancel = cancel, + CancelReason = cancelReason, + }; +} diff --git a/go/rpc/zrpc.go b/go/rpc/zrpc.go index 583ed3c58..6d4d45c46 100644 --- a/go/rpc/zrpc.go +++ b/go/rpc/zrpc.go @@ -1837,6 +1837,176 @@ type InstructionSource struct { Type InstructionSourceType `json:"type"` } +// HTTP headers as a map from lowercased header name to a list of values. Multi-valued +// headers (e.g. Set-Cookie) preserve all values. +type LlmInferenceHeaders map[string][]string + +// Set when the SDK client could not produce a response (transport-level failure). Causes +// the runtime to raise an APIConnectionError; status/headers/body are ignored when error is +// set. +type LlmInferenceHTTPRequestError struct { + // Optional machine-readable error code. + Code *string `json:"code,omitempty"` + // Human-readable failure description. + Message string `json:"message"` +} + +// An outbound model-layer HTTP request the runtime would otherwise have issued itself. +type LlmInferenceHTTPRequestRequest struct { + // Request body as base64-encoded bytes. Set instead of bodyText when the body is binary. + BodyBase64 *string `json:"bodyBase64,omitempty"` + // Request body as a UTF-8 string. Set when binaryBody is absent or false. + BodyText *string `json:"bodyText,omitempty"` + // HTTP headers as a map from lowercased header name to a list of values. Multi-valued + // headers (e.g. Set-Cookie) preserve all values. + Headers map[string][]string `json:"headers"` + // Metadata describing an intercepted LLM HTTP request. + Metadata LlmInferenceRequestMetadata `json:"metadata"` + // HTTP method, e.g. GET, POST. + Method string `json:"method"` + // Opaque runtime-minted id, unique per request. Useful for client-side logging. + RequestID string `json:"requestId"` + // Id of the runtime session that triggered this request, when one is in scope. Absent for + // requests issued outside any session (e.g. startup model-catalog or capability + // resolution). This is a payload field — not a dispatch key — because the client-global API + // is registered process-wide rather than per session. + SessionID *string `json:"sessionId,omitempty"` + // Absolute request URL. + URL string `json:"url"` +} + +// The HTTP response the runtime should treat as if it had issued the request itself. +type LlmInferenceHTTPRequestResult struct { + // Response body as base64-encoded bytes. Set instead of bodyText for binary responses. + BodyBase64 *string `json:"bodyBase64,omitempty"` + // Response body as a UTF-8 string. Set when bodyBase64 is absent. + BodyText *string `json:"bodyText,omitempty"` + // Set when the SDK client could not produce a response (transport-level failure). Causes + // the runtime to raise an APIConnectionError; status/headers/body are ignored when error is + // set. + Error *LlmInferenceHTTPRequestError `json:"error,omitempty"` + // HTTP headers as a map from lowercased header name to a list of values. Multi-valued + // headers (e.g. Set-Cookie) preserve all values. + Headers map[string][]string `json:"headers"` + // HTTP status code returned to the runtime. + Status int64 `json:"status"` + // Optional HTTP status text. + StatusText *string `json:"statusText,omitempty"` +} + +// Set when the SDK client could not even begin the stream (transport-level failure). When +// error is set the runtime raises an APIConnectionError and ignores status/headers. +type LlmInferenceHTTPStreamStartError struct { + Code *string `json:"code,omitempty"` + Message string `json:"message"` +} + +// An outbound streaming model-layer HTTP request. +type LlmInferenceHTTPStreamStartRequest struct { + BodyBase64 *string `json:"bodyBase64,omitempty"` + BodyText *string `json:"bodyText,omitempty"` + // HTTP headers as a map from lowercased header name to a list of values. Multi-valued + // headers (e.g. Set-Cookie) preserve all values. + Headers map[string][]string `json:"headers"` + // Metadata describing an intercepted LLM HTTP request. + Metadata LlmInferenceRequestMetadata `json:"metadata"` + // HTTP method. + Method string `json:"method"` + // Opaque runtime-minted id, unique per request. + RequestID string `json:"requestId"` + // Originating session id, when known. + SessionID *string `json:"sessionId,omitempty"` + // Stream identifier. The SDK client passes this exact value back on every + // llmInference.streamChunk / streamEnd call to correlate pushed chunks with this request. + StreamToken int64 `json:"streamToken"` + // Absolute request URL. + URL string `json:"url"` +} + +// The response head. After returning, the SDK client pushes body chunks via +// llmInference.streamChunk and signals completion (or transport error) via +// llmInference.streamEnd. +type LlmInferenceHTTPStreamStartResult struct { + // Set when the SDK client could not even begin the stream (transport-level failure). When + // error is set the runtime raises an APIConnectionError and ignores status/headers. + Error *LlmInferenceHTTPStreamStartError `json:"error,omitempty"` + // HTTP headers as a map from lowercased header name to a list of values. Multi-valued + // headers (e.g. Set-Cookie) preserve all values. + Headers map[string][]string `json:"headers"` + // HTTP status code. + Status int64 `json:"status"` + StatusText *string `json:"statusText,omitempty"` +} + +// Metadata describing an intercepted LLM HTTP request. +type LlmInferenceRequestMetadata struct { + // What kind of model-layer endpoint this is. + EndpointKind LlmInferenceRequestMetadataEndpointKind `json:"endpointKind"` + // Model identifier, when known. + ModelID *string `json:"modelId,omitempty"` + // Logical model provider this request targets. + ProviderType LlmInferenceRequestMetadataProviderType `json:"providerType"` + // Transport kind. v1 implements http only. + Transport LlmInferenceRequestMetadataTransport `json:"transport"` + // Wire API shape, when this is an inference request. + WireAPI *LlmInferenceRequestMetadataWireAPI `json:"wireApi,omitempty"` +} + +// No parameters. The calling connection is registered as the runtime's LLM inference +// provider; all subsequent model-layer HTTP requests are dispatched back to it via the +// llmInference client API. +// Experimental: LlmInferenceSetProviderRequest is part of an experimental API and may +// change or be removed. +type LlmInferenceSetProviderRequest struct { +} + +// Indicates whether the calling client was registered as the LLM inference provider. +// Experimental: LlmInferenceSetProviderResult is part of an experimental API and may change +// or be removed. +type LlmInferenceSetProviderResult struct { + // Whether the provider was set successfully + Success bool `json:"success"` +} + +// A streamed response body chunk. +// Experimental: LlmInferenceStreamChunkRequest is part of an experimental API and may +// change or be removed. +type LlmInferenceStreamChunkRequest struct { + // One body chunk as base64-encoded bytes. Chunks are appended to the runtime's view of the + // response body in the order received. + DataBase64 string `json:"dataBase64"` + // The same streamToken the runtime supplied in the originating llmInference.httpStreamStart + // call. + StreamToken int64 `json:"streamToken"` +} + +// Whether the chunk was accepted. +// Experimental: LlmInferenceStreamChunkResult is part of an experimental API and may change +// or be removed. +type LlmInferenceStreamChunkResult struct { + // True when the chunk was queued for the stream; false when the stream is unknown. + Accepted bool `json:"accepted"` +} + +// End-of-stream signal. +// Experimental: LlmInferenceStreamEndRequest is part of an experimental API and may change +// or be removed. +type LlmInferenceStreamEndRequest struct { + // When set, marks the stream as ending with a transport-level error of this description. + // When absent the stream ends normally. + Error *string `json:"error,omitempty"` + // The originating streamToken. + StreamToken int64 `json:"streamToken"` +} + +// Whether the end signal was accepted. +// Experimental: LlmInferenceStreamEndResult is part of an experimental API and may change +// or be removed. +type LlmInferenceStreamEndResult struct { + // True when the stream was found and ended; false when unknown. + Accepted bool `json:"accepted"` +} + // Schema for the `LocalSessionMetadataValue` type. // Experimental: LocalSessionMetadataValue is part of an experimental API and may change or // be removed. @@ -2566,6 +2736,9 @@ func (RawMCPServerConfigData) mcpServerConfig() {} type MCPServerConfigHTTP struct { // Set to `true` to use defaults, or provide an object with additional auth or OIDC settings. Auth MCPServerAuthConfig `json:"auth,omitempty"` + // Controls if tools provided by this server can be loaded on demand via tool search (auto) + // or always included in the initial tool list (never) + DeferTools *MCPServerConfigDeferTools `json:"deferTools,omitempty"` // Content filtering mode to apply to all tools, or a map of tool name to content filtering // mode. FilterMapping FilterMapping `json:"filterMapping,omitempty"` @@ -2604,6 +2777,9 @@ type MCPServerConfigStdio struct { Command string `json:"command"` // Working directory for the Stdio MCP server process. Cwd *string `json:"cwd,omitempty"` + // Controls if tools provided by this server can be loaded on demand via tool search (auto) + // or always included in the initial tool list (never) + DeferTools *MCPServerConfigDeferTools `json:"deferTools,omitempty"` // Environment variables to pass to the Stdio MCP server process. Env map[string]string `json:"env,omitzero"` // Content filtering mode to apply to all tools, or a map of tool name to content filtering @@ -9023,6 +9199,64 @@ const ( InstructionSourceTypeVscode InstructionSourceType = "vscode" ) +// What kind of model-layer endpoint this is. +type LlmInferenceRequestMetadataEndpointKind string + +const ( + // An embeddings request. + LlmInferenceRequestMetadataEndpointKindEmbeddings LlmInferenceRequestMetadataEndpointKind = "embeddings" + // An inference request (chat/completions, responses, messages). + LlmInferenceRequestMetadataEndpointKindInference LlmInferenceRequestMetadataEndpointKind = "inference" + // Listing of available models. + LlmInferenceRequestMetadataEndpointKindModelsCatalog LlmInferenceRequestMetadataEndpointKind = "models-catalog" + // Per-model policy lookup. + LlmInferenceRequestMetadataEndpointKindModelsPolicy LlmInferenceRequestMetadataEndpointKind = "models-policy" + // Per-model session/auth bootstrap. + LlmInferenceRequestMetadataEndpointKindModelsSession LlmInferenceRequestMetadataEndpointKind = "models-session" + // Model-layer endpoint not specifically categorized. + LlmInferenceRequestMetadataEndpointKindOther LlmInferenceRequestMetadataEndpointKind = "other" +) + +// Logical model provider this request targets. +type LlmInferenceRequestMetadataProviderType string + +const ( + // Anthropic. + LlmInferenceRequestMetadataProviderTypeAnthropic LlmInferenceRequestMetadataProviderType = "anthropic" + // Azure OpenAI. + LlmInferenceRequestMetadataProviderTypeAzure LlmInferenceRequestMetadataProviderType = "azure" + // GitHub Copilot CAPI. + LlmInferenceRequestMetadataProviderTypeCopilot LlmInferenceRequestMetadataProviderType = "copilot" + // Google Gemini / Vertex. + LlmInferenceRequestMetadataProviderTypeGoogle LlmInferenceRequestMetadataProviderType = "google" + // OpenAI. + LlmInferenceRequestMetadataProviderTypeOpenai LlmInferenceRequestMetadataProviderType = "openai" + // Provider not recognised by the runtime's URL heuristics. + LlmInferenceRequestMetadataProviderTypeOther LlmInferenceRequestMetadataProviderType = "other" +) + +// Transport kind. v1 implements http only. +type LlmInferenceRequestMetadataTransport string + +const ( + // Plain HTTP request/response, possibly with an SSE-encoded streamed body. + LlmInferenceRequestMetadataTransportHTTP LlmInferenceRequestMetadataTransport = "http" + // WebSocket connection. Not implemented in v1 of the callback wire. + LlmInferenceRequestMetadataTransportWebsocket LlmInferenceRequestMetadataTransport = "websocket" +) + +// Wire API shape, when this is an inference request. +type LlmInferenceRequestMetadataWireAPI string + +const ( + // OpenAI chat completions API. + LlmInferenceRequestMetadataWireAPICompletions LlmInferenceRequestMetadataWireAPI = "completions" + // Anthropic messages API. + LlmInferenceRequestMetadataWireAPIMessages LlmInferenceRequestMetadataWireAPI = "messages" + // OpenAI responses API. + LlmInferenceRequestMetadataWireAPIResponses LlmInferenceRequestMetadataWireAPI = "responses" +) + // Allowed values for the `McpAppsHostContextDetailsAvailableDisplayMode` enumeration. // Experimental: MCPAppsHostContextDetailsAvailableDisplayMode is part of an experimental // API and may change or be removed. @@ -9147,6 +9381,17 @@ const ( MCPSamplingExecutionActionSuccess MCPSamplingExecutionAction = "success" ) +// Controls if tools provided by this server can be loaded on demand via tool search (auto) +// or always included in the initial tool list (never) +type MCPServerConfigDeferTools string + +const ( + // Tools may be deferred under certain conditions + MCPServerConfigDeferToolsAuto MCPServerConfigDeferTools = "auto" + // Tools are always included in the initial tool list, even when tool search is enabled. + MCPServerConfigDeferToolsNever MCPServerConfigDeferTools = "never" +) + // OAuth grant type to use when authenticating to the remote MCP server. type MCPServerConfigHTTPOauthGrantType string @@ -10348,6 +10593,68 @@ func (a *ServerInstructionsAPI) Discover(ctx context.Context, params *Instructio return &result, nil } +// Experimental: ServerLlmInferenceAPI contains experimental APIs that may change or be +// removed. +type ServerLlmInferenceAPI serverAPI + +// SetProvider registers an SDK client as the LLM inference callback provider. +// +// RPC method: llmInference.setProvider. +// +// Returns: Indicates whether the calling client was registered as the LLM inference +// provider. +func (a *ServerLlmInferenceAPI) SetProvider(ctx context.Context) (*LlmInferenceSetProviderResult, error) { + raw, err := a.client.Request(ctx, "llmInference.setProvider", nil) + if err != nil { + return nil, err + } + var result LlmInferenceSetProviderResult + if err := json.Unmarshal(raw, &result); err != nil { + return nil, err + } + return &result, nil +} + +// StreamChunk pushes a streamed response body chunk back to the runtime, correlated by the +// streamToken the runtime previously handed out in llmInference.httpStreamStart. +// +// RPC method: llmInference.streamChunk. +// +// Parameters: A streamed response body chunk. +// +// Returns: Whether the chunk was accepted. +func (a *ServerLlmInferenceAPI) StreamChunk(ctx context.Context, params *LlmInferenceStreamChunkRequest) (*LlmInferenceStreamChunkResult, error) { + raw, err := a.client.Request(ctx, "llmInference.streamChunk", params) + if err != nil { + return nil, err + } + var result LlmInferenceStreamChunkResult + if err := json.Unmarshal(raw, &result); err != nil { + return nil, err + } + return &result, nil +} + +// StreamEnd signals end-of-stream for an inference response stream the SDK client started +// via llmInference.httpStreamStart. +// +// RPC method: llmInference.streamEnd. +// +// Parameters: End-of-stream signal. +// +// Returns: Whether the end signal was accepted. +func (a *ServerLlmInferenceAPI) StreamEnd(ctx context.Context, params *LlmInferenceStreamEndRequest) (*LlmInferenceStreamEndResult, error) { + raw, err := a.client.Request(ctx, "llmInference.streamEnd", params) + if err != nil { + return nil, err + } + var result LlmInferenceStreamEndResult + if err := json.Unmarshal(raw, &result); err != nil { + return nil, err + } + return &result, nil +} + type ServerMCPAPI serverAPI // Discovers MCP servers from user, workspace, plugin, and builtin sources. @@ -11390,6 +11697,7 @@ type ServerRPC struct { AgentRegistry *ServerAgentRegistryAPI Agents *ServerAgentsAPI Instructions *ServerInstructionsAPI + LlmInference *ServerLlmInferenceAPI MCP *ServerMCPAPI Models *ServerModelsAPI Plugins *ServerPluginsAPI @@ -11429,6 +11737,7 @@ func NewServerRPC(client *jsonrpc2.Client) *ServerRPC { r.AgentRegistry = (*ServerAgentRegistryAPI)(&r.common) r.Agents = (*ServerAgentsAPI)(&r.common) r.Instructions = (*ServerInstructionsAPI)(&r.common) + r.LlmInference = (*ServerLlmInferenceAPI)(&r.common) r.MCP = (*ServerMCPAPI)(&r.common) r.Models = (*ServerModelsAPI)(&r.common) r.Plugins = (*ServerPluginsAPI)(&r.common) diff --git a/go/rpc/zrpc_encoding.go b/go/rpc/zrpc_encoding.go index 573528b5a..1e081ab7d 100644 --- a/go/rpc/zrpc_encoding.go +++ b/go/rpc/zrpc_encoding.go @@ -918,6 +918,7 @@ func unmarshalMCPServerAuthConfig(data []byte) (MCPServerAuthConfig, error) { func (r *MCPServerConfigHTTP) UnmarshalJSON(data []byte) error { type rawMCPServerConfigHTTP struct { Auth json.RawMessage `json:"auth,omitempty"` + DeferTools *MCPServerConfigDeferTools `json:"deferTools,omitempty"` FilterMapping json.RawMessage `json:"filterMapping,omitempty"` Headers map[string]string `json:"headers,omitzero"` IsDefaultServer *bool `json:"isDefaultServer,omitempty"` @@ -941,6 +942,7 @@ func (r *MCPServerConfigHTTP) UnmarshalJSON(data []byte) error { } r.Auth = value } + r.DeferTools = raw.DeferTools if raw.FilterMapping != nil { value, err := unmarshalFilterMapping(raw.FilterMapping) if err != nil { @@ -969,16 +971,17 @@ func (r *MCPServerConfigHTTP) UnmarshalJSON(data []byte) error { func (r *MCPServerConfigStdio) UnmarshalJSON(data []byte) error { type rawMCPServerConfigStdio struct { - Args []string `json:"args,omitzero"` - Auth json.RawMessage `json:"auth,omitempty"` - Command string `json:"command"` - Cwd *string `json:"cwd,omitempty"` - Env map[string]string `json:"env,omitzero"` - FilterMapping json.RawMessage `json:"filterMapping,omitempty"` - IsDefaultServer *bool `json:"isDefaultServer,omitempty"` - Oidc json.RawMessage `json:"oidc,omitempty"` - Timeout *int64 `json:"timeout,omitempty"` - Tools []string `json:"tools,omitzero"` + Args []string `json:"args,omitzero"` + Auth json.RawMessage `json:"auth,omitempty"` + Command string `json:"command"` + Cwd *string `json:"cwd,omitempty"` + DeferTools *MCPServerConfigDeferTools `json:"deferTools,omitempty"` + Env map[string]string `json:"env,omitzero"` + FilterMapping json.RawMessage `json:"filterMapping,omitempty"` + IsDefaultServer *bool `json:"isDefaultServer,omitempty"` + Oidc json.RawMessage `json:"oidc,omitempty"` + Timeout *int64 `json:"timeout,omitempty"` + Tools []string `json:"tools,omitzero"` } var raw rawMCPServerConfigStdio if err := json.Unmarshal(data, &raw); err != nil { @@ -994,6 +997,7 @@ func (r *MCPServerConfigStdio) UnmarshalJSON(data []byte) error { } r.Command = raw.Command r.Cwd = raw.Cwd + r.DeferTools = raw.DeferTools r.Env = raw.Env if raw.FilterMapping != nil { value, err := unmarshalFilterMapping(raw.FilterMapping) diff --git a/go/rpc/zsession_encoding.go b/go/rpc/zsession_encoding.go index 066fd854a..a86ca38ca 100644 --- a/go/rpc/zsession_encoding.go +++ b/go/rpc/zsession_encoding.go @@ -844,10 +844,11 @@ func (r ToolExecutionCompleteContentText) MarshalJSON() ([]byte, error) { func (r *ToolExecutionCompleteResult) UnmarshalJSON(data []byte) error { type rawToolExecutionCompleteResult struct { - Content string `json:"content"` - Contents []json.RawMessage `json:"contents,omitzero"` - DetailedContent *string `json:"detailedContent,omitempty"` - UIResource *ToolExecutionCompleteUIResource `json:"uiResource,omitempty"` + Content string `json:"content"` + Contents []json.RawMessage `json:"contents,omitzero"` + DetailedContent *string `json:"detailedContent,omitempty"` + StructuredContent any `json:"structuredContent,omitempty"` + UIResource *ToolExecutionCompleteUIResource `json:"uiResource,omitempty"` } var raw rawToolExecutionCompleteResult if err := json.Unmarshal(data, &raw); err != nil { @@ -865,6 +866,7 @@ func (r *ToolExecutionCompleteResult) UnmarshalJSON(data []byte) error { } } r.DetailedContent = raw.DetailedContent + r.StructuredContent = raw.StructuredContent r.UIResource = raw.UIResource return nil } diff --git a/go/rpc/zsession_events.go b/go/rpc/zsession_events.go index fdad4a1ad..fa04e4d2f 100644 --- a/go/rpc/zsession_events.go +++ b/go/rpc/zsession_events.go @@ -576,6 +576,8 @@ type AssistantUsageData struct { CacheReadTokens *int64 `json:"cacheReadTokens,omitempty"` // Number of tokens written to prompt cache CacheWriteTokens *int64 `json:"cacheWriteTokens,omitempty"` + // Whether the model response was blocked or truncated by content filtering (finish_reason === 'content_filter'). For Anthropic models this corresponds to a 'refusal' stop reason. + ContentFilterTriggered *bool `json:"contentFilterTriggered,omitempty"` // Per-request cost and usage data from the CAPI copilot_usage response field // Internal: CopilotUsage is part of the SDK's internal API surface and is not intended for external use. CopilotUsage *AssistantUsageCopilotUsage `json:"copilotUsage,omitempty"` @@ -584,6 +586,8 @@ type AssistantUsageData struct { Cost *float64 `json:"cost,omitempty"` // Duration of the API call in milliseconds Duration *int64 `json:"duration,omitempty"` + // Finish reason reported by the model for this API call (e.g. "stop", "length", "tool_calls", "content_filter"). Normalized to OpenAI vocabulary; for Anthropic models a "refusal" stop reason maps to "content_filter". + FinishReason *string `json:"finishReason,omitempty"` // What initiated this API call (e.g., "sub-agent", "mcp-sampling"); absent for user-initiated calls Initiator *string `json:"initiator,omitempty"` // Number of input tokens consumed @@ -2814,6 +2818,8 @@ type ToolExecutionCompleteResult struct { Contents []ToolExecutionCompleteContent `json:"contents,omitzero"` // Full detailed tool result for UI/timeline display, preserving complete content such as diffs. Falls back to content when absent. DetailedContent *string `json:"detailedContent,omitempty"` + // Structured content (arbitrary JSON) returned verbatim by the MCP tool + StructuredContent any `json:"structuredContent,omitempty"` // MCP Apps UI resource content for rendering in a sandboxed iframe UIResource *ToolExecutionCompleteUIResource `json:"uiResource,omitempty"` } diff --git a/nodejs/package-lock.json b/nodejs/package-lock.json index cd500d88b..4f816a9ae 100644 --- a/nodejs/package-lock.json +++ b/nodejs/package-lock.json @@ -16,6 +16,7 @@ "devDependencies": { "@platformatic/vfs": "^0.3.0", "@types/node": "^25.2.0", + "@types/ws": "^8.18.1", "@typescript-eslint/eslint-plugin": "^8.54.0", "@typescript-eslint/parser": "^8.54.0", "esbuild": "^0.27.2", @@ -29,7 +30,8 @@ "semver": "^7.7.3", "tsx": "^4.20.6", "typescript": "^5.0.0", - "vitest": "^4.0.18" + "vitest": "^4.0.18", + "ws": "^8.21.0" }, "engines": { "node": "^20.19.0 || >=22.12.0" @@ -1329,6 +1331,16 @@ "undici-types": "~7.18.0" } }, + "node_modules/@types/ws": { + "version": "8.18.1", + "resolved": "https://registry.npmjs.org/@types/ws/-/ws-8.18.1.tgz", + "integrity": "sha512-ThVF6DCVhA8kUGy+aazFQ4kXQ7E1Ty7A3ypFOe0IcJV8O/M511G99AW24irKrW56Wt44yG9+ij8FaqoBGkuBXg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/node": "*" + } + }, "node_modules/@typescript-eslint/eslint-plugin": { "version": "8.56.1", "resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-8.56.1.tgz", @@ -3973,6 +3985,28 @@ "dev": true, "license": "MIT" }, + "node_modules/ws": { + "version": "8.21.0", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.21.0.tgz", + "integrity": "sha512-Vsp28b7DRcimFQvrqu2Wek3z1iYxDCWqHYB8Qsnk/S4RfaCQzPGPyBNuVjJV3cd6UiKtUtp6sNM77gWvzcCH+g==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=10.0.0" + }, + "peerDependencies": { + "bufferutil": "^4.0.1", + "utf-8-validate": ">=5.0.2" + }, + "peerDependenciesMeta": { + "bufferutil": { + "optional": true + }, + "utf-8-validate": { + "optional": true + } + } + }, "node_modules/yaml": { "version": "2.9.0", "resolved": "https://registry.npmjs.org/yaml/-/yaml-2.9.0.tgz", diff --git a/nodejs/package.json b/nodejs/package.json index 7d86b9620..ee2e08c3d 100644 --- a/nodejs/package.json +++ b/nodejs/package.json @@ -63,6 +63,7 @@ "devDependencies": { "@platformatic/vfs": "^0.3.0", "@types/node": "^25.2.0", + "@types/ws": "^8.18.1", "@typescript-eslint/eslint-plugin": "^8.54.0", "@typescript-eslint/parser": "^8.54.0", "esbuild": "^0.27.2", @@ -76,7 +77,8 @@ "semver": "^7.7.3", "tsx": "^4.20.6", "typescript": "^5.0.0", - "vitest": "^4.0.18" + "vitest": "^4.0.18", + "ws": "^8.21.0" }, "engines": { "node": "^20.19.0 || >=22.12.0" diff --git a/nodejs/src/client.ts b/nodejs/src/client.ts index c1b94b072..f1eeeaade 100644 --- a/nodejs/src/client.ts +++ b/nodejs/src/client.ts @@ -29,12 +29,14 @@ import { import { createServerRpc, createInternalServerRpc, + registerClientGlobalApiHandlers, registerClientSessionApiHandlers, } from "./generated/rpc.js"; import type { OpenCanvasInstance, SessionUpdateOptionsParams } from "./generated/rpc.js"; import { getSdkProtocolVersion } from "./sdkProtocolVersion.js"; import { CopilotSession } from "./session.js"; import { createSessionFsAdapter, type SessionFsProvider } from "./sessionFsProvider.js"; +import { createLlmInferenceAdapter } from "./llmInferenceProvider.js"; import { getTraceContext } from "./telemetry.js"; import { ToolSet } from "./toolSet.js"; import type { @@ -60,6 +62,7 @@ import type { SessionCapabilities, SessionEvent, SessionFsConfig, + LlmInferenceConfig, SessionLifecycleEvent, SessionLifecycleEventType, SessionLifecycleHandler, @@ -389,6 +392,8 @@ export class CopilotClient { private negotiatedProtocolVersion: number | null = null; /** Connection-level session filesystem config, set via constructor option. */ private sessionFsConfig: SessionFsConfig | null = null; + private llmInferenceConfig: LlmInferenceConfig | null = null; + private llmInferenceHandlers: import("./generated/rpc.js").ClientGlobalApiHandlers = {}; /** * Typed server-scoped RPC methods. @@ -500,6 +505,8 @@ export class CopilotClient { this.onListModels = options.onListModels; this.onGetTraceContext = options.onGetTraceContext; this.sessionFsConfig = options.sessionFs ?? null; + this.llmInferenceConfig = options.llmInference ?? null; + this.setupLlmInference(); const effectiveEnv = options.env ?? process.env; this.resolvedEnv = effectiveEnv; @@ -616,6 +623,27 @@ export class CopilotClient { session.clientSessionApis.sessionFs = createSessionFsAdapter(provider); } + private setupLlmInference(): void { + if (!this.llmInferenceConfig) { + return; + } + const provider = this.llmInferenceConfig.handler; + if (!provider) { + throw new Error( + "handler is required on client options.llmInference when llmInference is enabled." + ); + } + this.llmInferenceHandlers = { + llmInference: createLlmInferenceAdapter(provider, () => { + if (!this.connection) { + return undefined; + } + this._rpc ??= createServerRpc(this.connection); + return this._rpc; + }), + }; + } + /** * Starts the CLI server and establishes a connection. * @@ -663,6 +691,13 @@ export class CopilotClient { }); } + // If an LLM inference provider was configured, register it. + // The runtime will then route outbound model HTTP requests + // through the registered handler for the duration of each session. + if (this.llmInferenceConfig) { + await this.connection!.sendRequest("llmInference.setProvider", {}); + } + this.state = "connected"; } catch (error) { this.state = "error"; @@ -2327,6 +2362,11 @@ export class CopilotClient { return session.clientSessionApis; }); + // Register client *global* API handlers (e.g. LLM inference) on the + // same connection. These methods carry no implicit sessionId dispatch + // — the runtime calls into a single handler for the whole connection. + registerClientGlobalApiHandlers(this.connection, this.llmInferenceHandlers); + this.connection.onClose(() => { this.state = "disconnected"; }); diff --git a/nodejs/src/generated/rpc.ts b/nodejs/src/generated/rpc.ts index 1ef280abf..6643df193 100644 --- a/nodejs/src/generated/rpc.ts +++ b/nodejs/src/generated/rpc.ts @@ -461,6 +461,18 @@ export type InstructionSourceLocation = | "working-directory" /** Instructions live in plugin-provided configuration. */ | "plugin"; +/** + * Transport the runtime would otherwise use for this request. `http` (the default when absent) covers plain HTTP and SSE responses; `websocket` indicates a full-duplex message channel where each body chunk maps to one WebSocket message and the `binary` flag distinguishes text from binary frames. The SDK consumer uses this to decide whether to service the request with an HTTP client or a WebSocket client. It is the one piece of request metadata the consumer cannot reliably infer from the URL or headers alone. + * + * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema + * via the `definition` "LlmInferenceHttpRequestStartTransport". + */ +/** @experimental */ +export type LlmInferenceHttpRequestStartTransport = + /** Plain HTTP or SSE response. Each body chunk is an opaque byte range; the response is a status line, headers, and a (possibly streamed) body. */ + | "http" + /** Full-duplex WebSocket channel. Each body chunk maps to exactly one WebSocket message and the `binary` flag distinguishes text from binary frames; request and response chunks flow concurrently. */ + | "websocket"; /** * Repository host type * @@ -609,6 +621,17 @@ export type McpServerConfig = McpServerConfigStdio | McpServerConfigHttp; * via the `definition` "McpServerAuthConfig". */ export type McpServerAuthConfig = boolean | McpServerAuthConfigRedirectPort; +/** + * Controls if tools provided by this server can be loaded on demand via tool search (auto) or always included in the initial tool list (never) + * + * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema + * via the `definition` "McpServerConfigDeferTools". + */ +export type McpServerConfigDeferTools = + /** Tools may be deferred under certain conditions */ + | "auto" + /** Tools are always included in the initial tool list, even when tool search is enabled. */ + | "never"; /** * Remote transport type. Defaults to "http" when omitted. * @@ -4121,6 +4144,204 @@ export interface InstructionSource { */ projectPath?: string; } +/** + * HTTP headers as a map from lowercased header name to a list of values. Multi-valued headers (e.g. Set-Cookie) preserve all values. + * + * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema + * via the `definition` "LlmInferenceHeaders". + */ +/** @experimental */ +export interface LlmInferenceHeaders { + [k: string]: string[] | undefined; +} +/** + * A request body chunk or cancellation signal. + * + * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema + * via the `definition` "LlmInferenceHttpRequestChunkRequest". + */ +/** @experimental */ +export interface LlmInferenceHttpRequestChunkRequest { + /** + * Matches the requestId from the originating httpRequestStart frame. + */ + requestId: string; + /** + * Body byte range. UTF-8 text when `binary` is absent or false; base64-encoded bytes when `binary` is true. May be empty. + */ + data: string; + /** + * When true, `data` is base64-encoded bytes. When absent or false, `data` is UTF-8 text. + */ + binary?: boolean; + /** + * When true, this is the final body chunk for the request. The SDK may rely on having received an end-marked chunk before treating the request body as complete. + */ + end?: boolean; + /** + * When true, the runtime is cancelling the in-flight request (e.g. upstream consumer aborted). `data` is ignored. Implies end-of-request. + */ + cancel?: boolean; + /** + * Optional human-readable reason for the cancellation, propagated for logging. + */ + cancelReason?: string; +} +/** + * Acknowledgement. The SDK is free to ignore the ack and treat chunk delivery as fire-and-forget. + * + * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema + * via the `definition` "LlmInferenceHttpRequestChunkResult". + */ +/** @experimental */ +export interface LlmInferenceHttpRequestChunkResult {} +/** + * The head of an outbound model-layer HTTP request. + * + * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema + * via the `definition` "LlmInferenceHttpRequestStartRequest". + */ +/** @experimental */ +export interface LlmInferenceHttpRequestStartRequest { + /** + * Opaque runtime-minted id, unique per in-flight request. The SDK uses this to correlate httpRequestChunk frames and to address its httpResponseStart / httpResponseChunk replies back to the runtime. + */ + requestId: string; + /** + * Id of the runtime session that triggered this request, when one is in scope. Absent for requests issued outside any session (e.g. startup model-catalog or capability resolution). This is a payload field — not a dispatch key — because the client-global API is registered process-wide rather than per session. + */ + sessionId?: string; + /** + * HTTP method, e.g. GET, POST. + */ + method: string; + /** + * Absolute request URL. + */ + url: string; + headers: LlmInferenceHeaders; + transport?: LlmInferenceHttpRequestStartTransport; +} +/** + * Acknowledgement. Returning successfully simply means the SDK accepted the start frame; it does not imply the request will succeed. + * + * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema + * via the `definition` "LlmInferenceHttpRequestStartResult". + */ +/** @experimental */ +export interface LlmInferenceHttpRequestStartResult {} +/** + * Set to terminate the response with a transport-level failure. Implies end-of-stream; any further chunks for this requestId are ignored. + * + * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema + * via the `definition` "LlmInferenceHttpResponseChunkError". + */ +/** @experimental */ +export interface LlmInferenceHttpResponseChunkError { + /** + * Human-readable failure description. + */ + message: string; + /** + * Optional machine-readable error code. + */ + code?: string; +} +/** + * A response body chunk or terminal error. + * + * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema + * via the `definition` "LlmInferenceHttpResponseChunkRequest". + */ +/** @experimental */ +export interface LlmInferenceHttpResponseChunkRequest { + /** + * Matches the requestId from the originating httpRequestStart frame. + */ + requestId: string; + /** + * Body byte range. UTF-8 text when `binary` is absent or false; base64-encoded bytes when `binary` is true. May be empty (e.g. when the response body is empty: send a single chunk with empty data and end=true). + */ + data: string; + /** + * When true, `data` is base64-encoded bytes. When absent or false, `data` is UTF-8 text. + */ + binary?: boolean; + /** + * When true, this is the final body chunk for the response. The runtime treats the response body as complete after receiving an end-marked chunk. + */ + end?: boolean; + error?: LlmInferenceHttpResponseChunkError; +} +/** + * Whether the chunk was accepted. + * + * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema + * via the `definition` "LlmInferenceHttpResponseChunkResult". + */ +/** @experimental */ +export interface LlmInferenceHttpResponseChunkResult { + /** + * True when the chunk was matched to a pending request; false when unknown. + */ + accepted: boolean; +} +/** + * Response head. + * + * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema + * via the `definition` "LlmInferenceHttpResponseStartRequest". + */ +/** @experimental */ +export interface LlmInferenceHttpResponseStartRequest { + /** + * Matches the requestId from the originating httpRequestStart frame. + */ + requestId: string; + /** + * HTTP status code. + */ + status: number; + /** + * Optional HTTP status reason phrase. + */ + statusText?: string; + headers: LlmInferenceHeaders; +} +/** + * Whether the start frame was accepted. + * + * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema + * via the `definition` "LlmInferenceHttpResponseStartResult". + */ +/** @experimental */ +export interface LlmInferenceHttpResponseStartResult { + /** + * True when the response start was matched to a pending request; false when unknown. + */ + accepted: boolean; +} +/** + * No parameters. The calling connection is registered as the runtime's LLM inference provider; all subsequent model-layer HTTP requests are dispatched back to it via the llmInference client API. + * + * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema + * via the `definition` "LlmInferenceSetProviderRequest". + */ +/** @experimental */ +export interface LlmInferenceSetProviderRequest {} +/** + * Indicates whether the calling client was registered as the LLM inference provider. + * + * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema + * via the `definition` "LlmInferenceSetProviderResult". + */ +/** @experimental */ +export interface LlmInferenceSetProviderResult { + /** + * Whether the provider was set successfully + */ + success: boolean; +} /** * Schema for the `LocalSessionMetadataValue` type. * @@ -4731,6 +4952,7 @@ export interface McpServerConfigStdio { timeout?: number; oidc?: McpServerAuthConfig; auth?: McpServerAuthConfig; + deferTools?: McpServerConfigDeferTools; /** * Executable command used to start the Stdio MCP server process. */ @@ -4786,6 +5008,7 @@ export interface McpServerConfigHttp { timeout?: number; oidc?: McpServerAuthConfig; auth?: McpServerAuthConfig; + deferTools?: McpServerConfigDeferTools; /** * URL of the remote MCP server endpoint. */ @@ -13196,6 +13419,34 @@ export function createServerRpc(connection: MessageConnection) { connection.sendRequest("sessionFs.setProvider", params), }, /** @experimental */ + llmInference: { + /** + * Registers an SDK client as the LLM inference callback provider. + * + * @returns Indicates whether the calling client was registered as the LLM inference provider. + */ + setProvider: async (): Promise => + connection.sendRequest("llmInference.setProvider", {}), + /** + * Delivers the response head (status + headers) for an in-flight request, correlated by the requestId the runtime supplied in httpRequestStart. Must be called exactly once per request before any httpResponseChunk frames. + * + * @param params Response head. + * + * @returns Whether the start frame was accepted. + */ + httpResponseStart: async (params: LlmInferenceHttpResponseStartRequest): Promise => + connection.sendRequest("llmInference.httpResponseStart", params), + /** + * Delivers a body byte range (or a terminal transport error) for an in-flight response, correlated by requestId. Set `end` true on the last chunk. When `error` is set the response terminates with a transport-level failure and the runtime raises an APIConnectionError. + * + * @param params A response body chunk or terminal error. + * + * @returns Whether the chunk was accepted. + */ + httpResponseChunk: async (params: LlmInferenceHttpResponseChunkRequest): Promise => + connection.sendRequest("llmInference.httpResponseChunk", params), + }, + /** @experimental */ sessions: { /** * Creates or resumes a local session and returns the opened session ID. @@ -15160,3 +15411,52 @@ export function registerClientSessionApiHandlers( return handler.invoke(params); }); } + +/** Handler for `llmInference` client global API methods. */ +/** @experimental */ +export interface LlmInferenceHandler { + /** + * Announces an outbound model-layer HTTP request the runtime wants the SDK client to service. Carries the request head only; the body always follows as one or more httpRequestChunk frames keyed by the same requestId, even when the body is empty (a single chunk with end=true). + * + * @param params The head of an outbound model-layer HTTP request. + * + * @returns Acknowledgement. Returning successfully simply means the SDK accepted the start frame; it does not imply the request will succeed. + */ + httpRequestStart(params: LlmInferenceHttpRequestStartRequest): Promise; + /** + * Delivers a body byte range (or a cancellation signal) for a request previously announced via httpRequestStart, correlated by requestId. The runtime fires at least one chunk per request — when there is no body, a single chunk with empty data and end=true. Mid-stream the runtime may send a chunk with cancel=true to abort the request; the SDK then stops issuing httpResponseChunk frames and may emit a terminal httpResponseChunk with error set. + * + * @param params A request body chunk or cancellation signal. + * + * @returns Acknowledgement. The SDK is free to ignore the ack and treat chunk delivery as fire-and-forget. + */ + httpRequestChunk(params: LlmInferenceHttpRequestChunkRequest): Promise; +} + +/** All client global API handler groups. */ +export interface ClientGlobalApiHandlers { + llmInference?: LlmInferenceHandler; +} + +/** + * Register client global API handlers on a JSON-RPC connection. + * The server calls these methods to delegate work to the client. + * Unlike session-scoped client APIs, these methods carry no implicit + * `sessionId` dispatch key — a single set of handlers serves the entire + * connection. + */ +export function registerClientGlobalApiHandlers( + connection: MessageConnection, + handlers: ClientGlobalApiHandlers, +): void { + connection.onRequest("llmInference.httpRequestStart", async (params: LlmInferenceHttpRequestStartRequest) => { + const handler = handlers.llmInference; + if (!handler) throw new Error("No llmInference client-global handler registered"); + return handler.httpRequestStart(params); + }); + connection.onRequest("llmInference.httpRequestChunk", async (params: LlmInferenceHttpRequestChunkRequest) => { + const handler = handlers.llmInference; + if (!handler) throw new Error("No llmInference client-global handler registered"); + return handler.httpRequestChunk(params); + }); +} diff --git a/nodejs/src/generated/session-events.ts b/nodejs/src/generated/session-events.ts index b17901504..3230e4c5a 100644 --- a/nodejs/src/generated/session-events.ts +++ b/nodejs/src/generated/session-events.ts @@ -2989,6 +2989,10 @@ export interface AssistantUsageData { * Number of tokens written to prompt cache */ cacheWriteTokens?: number; + /** + * Whether the model response was blocked or truncated by content filtering (finish_reason === 'content_filter'). For Anthropic models this corresponds to a 'refusal' stop reason. + */ + contentFilterTriggered?: boolean; /** * Per-request cost and usage data from the CAPI copilot_usage response field * @@ -3005,6 +3009,10 @@ export interface AssistantUsageData { * Duration of the API call in milliseconds */ duration?: number; + /** + * Finish reason reported by the model for this API call (e.g. "stop", "length", "tool_calls", "content_filter"). Normalized to OpenAI vocabulary; for Anthropic models a "refusal" stop reason maps to "content_filter". + */ + finishReason?: string; /** * What initiated this API call (e.g., "sub-agent", "mcp-sampling"); absent for user-initiated calls */ @@ -3601,6 +3609,12 @@ export interface ToolExecutionCompleteResult { * Full detailed tool result for UI/timeline display, preserving complete content such as diffs. Falls back to content when absent. */ detailedContent?: string; + /** + * Structured content (arbitrary JSON) returned verbatim by the MCP tool + */ + structuredContent?: { + [k: string]: unknown | undefined; + }; uiResource?: ToolExecutionCompleteUIResource; } /** diff --git a/nodejs/src/index.ts b/nodejs/src/index.ts index 007d808fa..855f5ca1e 100644 --- a/nodejs/src/index.ts +++ b/nodejs/src/index.ts @@ -28,6 +28,10 @@ export { approveAll, convertMcpCallToolResult, createSessionFsAdapter, + CopilotWebSocketHandler, + ForwardingWebSocketHandler, + LlmRequestHandler, + LlmWebSocketCloseStatus, SYSTEM_MESSAGE_SECTIONS, } from "./types.js"; // Re-export the generated session-event types (every *Event interface and @@ -121,6 +125,11 @@ export type { SessionFsSqliteQueryResult, SessionFsSqliteQueryType, SessionFsSqliteProvider, + LlmInferenceConfig, + LlmInferenceRequest, + LlmInferenceResponseInit, + LlmInferenceResponseSink, + LlmRequestContext, SystemMessageAppendConfig, SystemMessageConfig, SystemMessageCustomizeConfig, diff --git a/nodejs/src/llmInferenceProvider.ts b/nodejs/src/llmInferenceProvider.ts new file mode 100644 index 000000000..4e43900b2 --- /dev/null +++ b/nodejs/src/llmInferenceProvider.ts @@ -0,0 +1,437 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +import type { + LlmInferenceHandler, + LlmInferenceHeaders, + LlmInferenceHttpRequestChunkRequest, + LlmInferenceHttpRequestChunkResult, + LlmInferenceHttpRequestStartRequest, + LlmInferenceHttpRequestStartResult, +} from "./generated/rpc.js"; +import type { createServerRpc } from "./generated/rpc.js"; + +type ServerRpc = ReturnType; + +/** + * An outbound model-layer HTTP request the runtime is asking the SDK + * consumer to handle on its behalf. + * + * This is a low-level shape: URL / method / headers verbatim, body bytes + * delivered as an async iterable, response delivered through the + * {@link LlmInferenceResponseSink}. The runtime does not classify the + * request (no provider type, endpoint kind, wire API). Consumers that + * need that information derive it themselves from the URL / headers. + */ +export interface LlmInferenceRequest { + /** Opaque runtime-minted id, stable across the request lifecycle. */ + requestId: string; + /** + * Id of the runtime session that triggered this request, when one is + * in scope. Absent for out-of-session requests (e.g. startup model + * catalog). + */ + sessionId?: string; + /** HTTP method (`GET`, `POST`, ...). */ + method: string; + /** Absolute URL. */ + url: string; + /** HTTP request headers, multi-valued. */ + headers: LlmInferenceHeaders; + /** + * Transport the runtime would otherwise use for this request. + * `"http"` (the default) covers plain HTTP and SSE responses; + * `"websocket"` indicates a full-duplex message channel where each + * {@link requestBody} chunk is one inbound WebSocket message and each + * {@link responseBody} write is one outbound message. Consumers branch + * on this to decide whether to service the request with an HTTP client + * or a WebSocket client. + */ + transport: "http" | "websocket"; + /** + * Request body bytes, yielded as they arrive from the runtime. + * Always iterable; an empty body yields zero chunks before completing. + */ + requestBody: AsyncIterable; + /** + * Aborts when the runtime cancels this in-flight request (e.g. the + * agent turn was aborted upstream). Pass it straight to `fetch` / + * `HttpClient.SendAsync` / your transport so the upstream call is torn + * down too. After it fires, writes to {@link responseBody} are ignored. + */ + signal: AbortSignal; + /** + * Sink the consumer writes the upstream response into. Call + * {@link LlmInferenceResponseSink.start} exactly once before writing + * body chunks, then one or more {@link LlmInferenceResponseSink.write} + * calls, and finish with {@link LlmInferenceResponseSink.end} or + * {@link LlmInferenceResponseSink.error}. + */ + responseBody: LlmInferenceResponseSink; +} + +/** Response head passed to {@link LlmInferenceResponseSink.start}. */ +export interface LlmInferenceResponseInit { + status: number; + statusText?: string; + headers?: LlmInferenceHeaders; +} + +/** + * Sink the consumer writes the upstream response into. The state machine + * is strict: `start` once → 0..N `write` → exactly one of `end` or + * `error`. Calling out of order throws. + */ +export interface LlmInferenceResponseSink { + /** Send the response head (status + headers) back to the runtime. */ + start(init: LlmInferenceResponseInit): Promise; + /** + * Send a body chunk. `string` is encoded as UTF-8; `Uint8Array` is sent + * as binary (base64 on the wire). + */ + write(data: string | Uint8Array): Promise; + /** Mark end-of-stream cleanly. */ + end(): Promise; + /** Mark end-of-stream with a transport-level failure. */ + error(error: { message: string; code?: string }): Promise; +} + +/** + * Interface for an LLM inference provider. The SDK consumer implements + * `onLlmRequest`. The same callback handles both buffered and streaming + * responses — the consumer just calls `responseBody.write` zero or more + * times before `end`. + * + * Use {@link createLlmInferenceAdapter} to convert an + * {@link LlmInferenceProvider} into the {@link LlmInferenceHandler} the + * SDK's RPC layer registers. + */ +export interface LlmInferenceProvider { + /** + * Called by the runtime once per outbound LLM HTTP request the + * consumer has opted to handle. The consumer is responsible for + * eventually calling either `responseBody.end()` or + * `responseBody.error(...)`; failing to do so leaks runtime state. + * Throwing surfaces a transport-level failure to the runtime + * (equivalent to `responseBody.error({ message: err.message })` + * provided `start` has not yet been called). + */ + onLlmRequest(request: LlmInferenceRequest): Promise | void; +} + +interface BodyQueueItem { + chunk?: Uint8Array; + end?: boolean; + cancel?: { reason?: string }; +} + +interface BodyQueue { + push(item: BodyQueueItem): void; + iterable: AsyncIterable; +} + +function makeBodyQueue(): BodyQueue { + const buffer: BodyQueueItem[] = []; + let waker: (() => void) | null = null; + let done = false; + const wake = (): void => { + const w = waker; + waker = null; + w?.(); + }; + return { + push(item: BodyQueueItem): void { + buffer.push(item); + wake(); + }, + iterable: { + [Symbol.asyncIterator](): AsyncIterator { + return { + async next(): Promise> { + if (done) { + return { value: undefined, done: true }; + } + while (buffer.length === 0) { + await new Promise((resolve) => { + waker = resolve; + }); + } + const item = buffer.shift()!; + if (item.cancel) { + done = true; + const reason = item.cancel.reason + ? `Request cancelled by runtime: ${item.cancel.reason}` + : "Request cancelled by runtime"; + throw new Error(reason); + } + if (item.end) { + done = true; + return { value: undefined, done: true }; + } + return { value: item.chunk ?? new Uint8Array(), done: false }; + }, + }; + }, + }, + }; +} + +const sharedTextEncoder = new TextEncoder(); + +function decodeChunkData(data: string, binary: boolean): Uint8Array { + if (binary) { + return new Uint8Array(Buffer.from(data, "base64")); + } + return sharedTextEncoder.encode(data); +} + +interface PendingState { + queue: BodyQueue; + started: boolean; + finished: boolean; + abort: AbortController; + cancelled: boolean; +} + +/** + * Adapt an {@link LlmInferenceProvider} into the generated + * {@link LlmInferenceHandler} shape consumed by the SDK's RPC dispatcher. + * + * Maintains a per-`requestId` state table: each `httpRequestStart` + * allocates a body queue + response sink and fires + * `provider.onLlmRequest` in the background. Subsequent `httpRequestChunk` + * frames are routed into the queue. The sink translates `start` / + * `write` / `end` / `error` calls into outbound + * `serverRpc.llmInference.httpResponseStart` / `httpResponseChunk` calls. + * + * The handler returns from `httpRequestStart` immediately (synchronously + * after registering state) so the runtime's RPC reply is not gated on the + * consumer's I/O. The actual provider work runs asynchronously. + */ +export function createLlmInferenceAdapter( + provider: LlmInferenceProvider, + getServerRpc: () => ServerRpc | undefined +): LlmInferenceHandler { + const pending = new Map(); + // Defense-in-depth backstop: chunks that arrive before their `start` + // frame (a reordering the runtime's single ordered dispatch should make + // impossible) are staged here keyed by requestId and drained the moment + // `httpRequestStart` registers the matching state, so a body byte is + // never silently dropped. + const staged = new Map(); + + function routeChunk(state: PendingState, params: LlmInferenceHttpRequestChunkRequest): void { + if (params.cancel) { + state.cancelled = true; + state.abort.abort(); + state.queue.push({ cancel: { reason: params.cancelReason } }); + return; + } + if (params.data && params.data.length > 0) { + state.queue.push({ chunk: decodeChunkData(params.data, !!params.binary) }); + } + if (params.end) { + state.queue.push({ end: true }); + } + } + + function makeSink(requestId: string, state: PendingState): LlmInferenceResponseSink { + const rpc = (): ServerRpc => { + const r = getServerRpc(); + if (!r) { + throw new Error("LLM inference response sink used after RPC connection closed."); + } + return r; + }; + // The runtime acknowledges every response frame with `accepted`. + // `accepted: false` means it has dropped the request (e.g. it + // cancelled), so we abort the provider's upstream work and stop + // emitting — there is no consumer for further frames. + const rejectedByRuntime = (): never => { + if (!state.cancelled) { + state.cancelled = true; + state.abort.abort(); + } + state.finished = true; + pending.delete(requestId); + throw new Error( + "LLM inference response was rejected by the runtime (request no longer active)." + ); + }; + return { + async start(init: LlmInferenceResponseInit): Promise { + if (state.started) { + throw new Error("LLM inference response sink.start() called twice."); + } + if (state.finished) { + throw new Error("LLM inference response sink already finished."); + } + state.started = true; + const result = await rpc().llmInference.httpResponseStart({ + requestId, + status: init.status, + statusText: init.statusText, + headers: init.headers ?? {}, + }); + if (!result.accepted) { + rejectedByRuntime(); + } + }, + async write(data: string | Uint8Array): Promise { + if (state.cancelled) { + throw new Error("LLM inference request was cancelled by the runtime."); + } + if (!state.started) { + throw new Error("LLM inference response sink.write() called before start()."); + } + if (state.finished) { + throw new Error( + "LLM inference response sink.write() called after end()/error()." + ); + } + const isString = typeof data === "string"; + const result = await rpc().llmInference.httpResponseChunk({ + requestId, + data: isString ? data : Buffer.from(data).toString("base64"), + binary: !isString, + end: false, + }); + if (!result.accepted) { + rejectedByRuntime(); + } + }, + async end(): Promise { + if (state.finished) { + return; + } + state.finished = true; + pending.delete(requestId); + await rpc().llmInference.httpResponseChunk({ + requestId, + data: "", + end: true, + }); + }, + async error(err: { message: string; code?: string }): Promise { + if (state.finished) { + return; + } + state.finished = true; + pending.delete(requestId); + await rpc().llmInference.httpResponseChunk({ + requestId, + data: "", + end: true, + error: { message: err.message, code: err.code }, + }); + }, + }; + } + + async function failViaSink( + sink: LlmInferenceResponseSink, + state: PendingState, + message: string + ): Promise { + if (state.finished) { + return; + } + try { + if (!state.started) { + await sink.start({ status: 502, headers: {} }); + } + await sink.error({ message }); + } catch { + // Best-effort — the connection may already be dead. + } + } + + async function finishCancelled( + sink: LlmInferenceResponseSink, + state: PendingState + ): Promise { + if (state.finished) { + return; + } + try { + if (!state.started) { + await sink.start({ status: 499, headers: {} }); + } + await sink.error({ message: "Request cancelled by runtime", code: "cancelled" }); + } catch { + // Best-effort — the runtime already dropped the request on cancel. + } + } + + return { + async httpRequestStart( + params: LlmInferenceHttpRequestStartRequest + ): Promise { + const state: PendingState = { + queue: makeBodyQueue(), + started: false, + finished: false, + abort: new AbortController(), + cancelled: false, + }; + pending.set(params.requestId, state); + const stagedChunks = staged.get(params.requestId); + if (stagedChunks) { + staged.delete(params.requestId); + for (const chunk of stagedChunks) { + routeChunk(state, chunk); + } + } + const sink = makeSink(params.requestId, state); + const request: LlmInferenceRequest = { + requestId: params.requestId, + sessionId: params.sessionId, + method: params.method, + url: params.url, + headers: params.headers, + transport: params.transport ?? "http", + requestBody: state.queue.iterable, + signal: state.abort.signal, + responseBody: sink, + }; + void (async () => { + try { + await provider.onLlmRequest(request); + if (!state.finished) { + await failViaSink( + sink, + state, + "LLM inference provider returned without finalising the response (call responseBody.end() or .error())." + ); + } + } catch (err) { + if (state.cancelled || state.abort.signal.aborted) { + // The runtime already cancelled this request; the + // provider's throw is just the abort propagating + // out of its upstream call. Acknowledge with a + // terminal cancelled error if we still can. + await finishCancelled(sink, state); + return; + } + const message = err instanceof Error ? err.message : String(err); + await failViaSink(sink, state, message); + } + })(); + return {}; + }, + async httpRequestChunk( + params: LlmInferenceHttpRequestChunkRequest + ): Promise { + const state = pending.get(params.requestId); + if (!state) { + const buffered = staged.get(params.requestId) ?? []; + buffered.push(params); + staged.set(params.requestId, buffered); + return {}; + } + routeChunk(state, params); + return {}; + }, + }; +} diff --git a/nodejs/src/llmRequestHandler.ts b/nodejs/src/llmRequestHandler.ts new file mode 100644 index 000000000..1640183b3 --- /dev/null +++ b/nodejs/src/llmRequestHandler.ts @@ -0,0 +1,469 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +import type { LlmInferenceHeaders } from "./generated/rpc.js"; +import type { LlmInferenceProvider, LlmInferenceRequest, LlmInferenceResponseSink } from "./llmInferenceProvider.js"; + +const sharedTextDecoder = new TextDecoder("utf-8", { fatal: false }); +const kBridge = Symbol("llmWebSocketResponseBridge"); +const kCompletion = Symbol("llmWebSocketCompletion"); +const kOpen = Symbol("llmWebSocketOpen"); +const kSuppressCloseOnDispose = Symbol("llmWebSocketSuppressCloseOnDispose"); + +type InternalContext = LlmRequestContext & { [kBridge]: LlmWebSocketResponseBridge }; + +/** + * Per-request context handed to every {@link LlmRequestHandler} hook. + * + * @experimental + */ +export interface LlmRequestContext { + readonly requestId: string; + readonly sessionId?: string; + readonly transport: "http" | "websocket"; + readonly url: string; + readonly headers: LlmInferenceHeaders; + readonly signal: AbortSignal; +} + +/** + * Terminal status for a callback-owned WebSocket connection. + * + * @experimental + */ +export class LlmWebSocketCloseStatus { + static readonly normalClosure = new LlmWebSocketCloseStatus(); + + constructor( + readonly description?: string, + readonly errorCode?: string, + readonly error?: Error + ) {} +} + +/** + * Per-connection WebSocket handler returned by {@link LlmRequestHandler.openWebSocket}. + * + * @experimental + */ +export abstract class CopilotWebSocketHandler implements AsyncDisposable { + readonly #response: LlmWebSocketResponseBridge; + readonly #completion: Promise; + #resolveCompletion!: (status: LlmWebSocketCloseStatus) => void; + #closed = false; + [kSuppressCloseOnDispose] = false; + + protected readonly context: LlmRequestContext; + + protected constructor(context: LlmRequestContext) { + this.context = context; + const bridge = (context as Partial)[kBridge]; + if (!bridge) { + throw new Error("WebSocket response bridge is not attached"); + } + this.#response = bridge; + this.#completion = new Promise((resolve) => { + this.#resolveCompletion = resolve; + }); + } + + async sendResponseMessage(data: string | Uint8Array): Promise { + await this.#response.write(data); + } + + async close(status: LlmWebSocketCloseStatus = LlmWebSocketCloseStatus.normalClosure): Promise { + if (this.#closed) { + return; + } + this.#closed = true; + if (status.error) { + await this.#response.error({ + message: status.description ?? status.error.message, + code: status.errorCode, + }); + } else { + await this.#response.end(); + } + this.#resolveCompletion(status); + } + + abstract sendRequestMessage(data: string | Uint8Array): Promise | void; + + async [Symbol.asyncDispose](): Promise { + if (!this[kSuppressCloseOnDispose] && !this.#closed) { + await this.close(LlmWebSocketCloseStatus.normalClosure); + } + } + + /** @internal */ + get [kCompletion](): Promise { + return this.#completion; + } + + /** @internal */ + async [kOpen](): Promise {} +} + +/** + * Default pass-through WebSocket handler backed by the WHATWG `WebSocket`. + * + * @experimental + */ +export class ForwardingWebSocketHandler extends CopilotWebSocketHandler { + readonly #url: string; + #upstream: WebSocket | null = null; + + constructor(context: LlmRequestContext, url = context.url) { + super(context); + this.#url = url; + } + + override sendRequestMessage(data: string | Uint8Array): void { + if (this.#upstream?.readyState !== WebSocket.OPEN) { + return; + } + this.#upstream.send(data); + } + + /** @internal */ + override async [kOpen](): Promise { + if (this.#upstream) { + return; + } + const upstream = new WebSocket(this.#url); + upstream.binaryType = "arraybuffer"; + this.#upstream = upstream; + upstream.addEventListener("message", (event) => { + void this.sendResponseMessage(normalizeWsData(event.data)).catch(async (err: unknown) => { + await this.close( + new LlmWebSocketCloseStatus( + err instanceof Error ? err.message : String(err), + undefined, + err instanceof Error ? err : new Error(String(err)) + ) + ); + }); + }); + upstream.addEventListener("close", () => { + void this.close(LlmWebSocketCloseStatus.normalClosure); + }); + upstream.addEventListener("error", () => { + void this.close(new LlmWebSocketCloseStatus("WebSocket error", undefined, new Error("WebSocket error"))); + }); + await new Promise((resolve, reject) => { + if (upstream.readyState === WebSocket.OPEN) { + resolve(); + return; + } + upstream.addEventListener("open", () => resolve(), { once: true }); + upstream.addEventListener("error", () => reject(new Error("WebSocket error")), { once: true }); + }); + } + + override async close( + status: LlmWebSocketCloseStatus = LlmWebSocketCloseStatus.normalClosure + ): Promise { + try { + if ( + this.#upstream?.readyState === WebSocket.OPEN || + this.#upstream?.readyState === WebSocket.CONNECTING + ) { + this.#upstream?.close(); + } + } catch { + // Best-effort; the socket may already be closed. + } + await super.close(status); + } + + override async [Symbol.asyncDispose](): Promise { + try { + await super[Symbol.asyncDispose](); + } finally { + try { + this.#upstream?.close(); + } catch { + // Best-effort. + } + } + } +} + +/** + * Base class for SDK consumers who want to observe or mutate the LLM + * inference requests the runtime issues. + * + * @experimental + */ +export class LlmRequestHandler implements LlmInferenceProvider { + async onLlmRequest(req: LlmInferenceRequest): Promise { + const bridge = new LlmWebSocketResponseBridge(req.responseBody); + const ctx: InternalContext = { + requestId: req.requestId, + sessionId: req.sessionId, + transport: req.transport, + url: req.url, + headers: req.headers, + signal: req.signal, + [kBridge]: bridge, + }; + + if (req.transport === "websocket") { + await this.#handleWebSocket(req, ctx); + } else { + await this.#handleHttp(req, ctx); + } + } + + protected sendRequest(request: Request, ctx: LlmRequestContext): Promise { + return fetch(request, { signal: ctx.signal }); + } + + protected openWebSocket(ctx: LlmRequestContext): Promise { + return Promise.resolve(new ForwardingWebSocketHandler(ctx)); + } + + async #handleHttp(req: LlmInferenceRequest, ctx: LlmRequestContext): Promise { + const request = await buildFetchRequest(req); + const response = await this.sendRequest(request, ctx); + await streamResponseToSink(response, req); + } + + async #handleWebSocket(req: LlmInferenceRequest, ctx: InternalContext): Promise { + const handler = await this.openWebSocket(ctx); + try { + await handler[kOpen](); + await ctx[kBridge].start(); + + let cancelled: unknown; + const clientSettled = (async () => { + for await (const chunk of req.requestBody) { + await handler.sendRequestMessage(decodeFrame(chunk)); + } + return "client-complete" as const; + })().catch((err) => { + cancelled = err; + return "client-error" as const; + }); + + const first = await Promise.race([ + clientSettled, + handler[kCompletion].then(() => "server-done" as const), + ]); + + if (first === "client-error") { + handler[kSuppressCloseOnDispose] = true; + throw cancelled instanceof Error ? cancelled : new Error(String(cancelled)); + } + + if (first === "client-complete") { + await handler.close(LlmWebSocketCloseStatus.normalClosure); + await handler[kCompletion]; + return; + } + + const status = await handler[kCompletion]; + if (status.error) { + throw status.error; + } + } finally { + await handler[Symbol.asyncDispose](); + } + } +} + +const FORBIDDEN_REQUEST_HEADERS = new Set([ + "host", + "connection", + "content-length", + "transfer-encoding", + "keep-alive", + "upgrade", + "proxy-connection", + "te", + "trailer", +]); + +async function buildFetchRequest(req: LlmInferenceRequest): Promise { + const headers = new Headers(); + for (const [name, values] of Object.entries(req.headers)) { + if (!values) { + continue; + } + if (FORBIDDEN_REQUEST_HEADERS.has(name.toLowerCase())) { + continue; + } + for (const value of values) { + headers.append(name, value); + } + } + + const method = req.method.toUpperCase(); + const hasBody = method !== "GET" && method !== "HEAD"; + + let body: Uint8Array | undefined; + if (hasBody) { + const buffered = await drainAsync(req.requestBody); + if (buffered.length > 0) { + body = buffered; + } + } else { + await drainAsync(req.requestBody); + } + + return new Request(req.url, { method, headers, body }); +} + +async function drainAsync(stream: AsyncIterable): Promise { + const parts: Uint8Array[] = []; + let total = 0; + for await (const chunk of stream) { + parts.push(chunk); + total += chunk.byteLength; + } + if (parts.length === 0) { + return new Uint8Array(0); + } + if (parts.length === 1) { + return parts[0]; + } + const out = new Uint8Array(total); + let off = 0; + for (const part of parts) { + out.set(part, off); + off += part.byteLength; + } + return out; +} + +async function streamResponseToSink(response: Response, req: LlmInferenceRequest): Promise { + const headers = headersToMultiMap(response.headers); + await req.responseBody.start({ + status: response.status, + statusText: response.statusText || undefined, + headers, + }); + + const body = response.body; + if (!body) { + await req.responseBody.end(); + return; + } + + const reader = body.getReader(); + try { + for (;;) { + const { value, done } = await reader.read(); + if (done) { + break; + } + if (value && value.byteLength > 0) { + await req.responseBody.write(value); + } + } + await req.responseBody.end(); + } finally { + reader.releaseLock(); + } +} + +function headersToMultiMap(headers: Headers): LlmInferenceHeaders { + const out: Record = {}; + headers.forEach((value, name) => { + if (name.toLowerCase() === "set-cookie") { + return; + } + const list = out[name] ?? (out[name] = []); + list.push(value); + }); + const setCookies = headers.getSetCookie(); + if (setCookies.length > 0) { + out["set-cookie"] = setCookies; + } + return out; +} + +function decodeFrame(chunk: Uint8Array): string { + return sharedTextDecoder.decode(chunk); +} + +function normalizeWsData(data: unknown): string | Uint8Array { + if (typeof data === "string") { + return data; + } + if (data instanceof Uint8Array) { + return data; + } + if (data instanceof ArrayBuffer) { + return new Uint8Array(data); + } + return new Uint8Array(); +} + +class LlmWebSocketResponseBridge { + readonly #sink: LlmInferenceResponseSink; + readonly #pending: Array<() => Promise> = []; + #started = false; + #completed = false; + #serial: Promise = Promise.resolve(); + + constructor(sink: LlmInferenceResponseSink) { + this.#sink = sink; + } + + async start(): Promise { + await this.#enqueue(async () => { + if (this.#started) { + return; + } + this.#started = true; + await this.#sink.start({ status: 101, headers: {} }); + while (this.#pending.length > 0) { + await this.#pending.shift()!(); + } + }); + } + + async write(data: string | Uint8Array): Promise { + await this.#enqueueOrBuffer(async () => { + if (!this.#completed) { + await this.#sink.write(data); + } + }); + } + + async end(): Promise { + await this.#enqueueOrBuffer(async () => { + if (this.#completed) { + return; + } + this.#completed = true; + await this.#sink.end(); + }); + } + + async error(error: { message: string; code?: string }): Promise { + await this.#enqueueOrBuffer(async () => { + if (this.#completed) { + return; + } + this.#completed = true; + await this.#sink.error(error); + }); + } + + async #enqueueOrBuffer(action: () => Promise): Promise { + if (!this.#started) { + this.#pending.push(action); + return; + } + await this.#enqueue(action); + } + + async #enqueue(action: () => Promise): Promise { + const run = this.#serial.then(action, action); + this.#serial = run.catch(() => {}); + await run; + } +} diff --git a/nodejs/src/types.ts b/nodejs/src/types.ts index ac9fb829b..9ed0f61c8 100644 --- a/nodejs/src/types.ts +++ b/nodejs/src/types.ts @@ -9,6 +9,7 @@ // Import and re-export generated session event types import type { Canvas } from "./canvas.js"; import type { SessionFsProvider } from "./sessionFsProvider.js"; +import type { LlmRequestHandler } from "./llmRequestHandler.js"; import type { ReasoningSummary, SessionEvent as GeneratedSessionEvent, @@ -26,6 +27,19 @@ export type { SessionFsFileInfo } from "./sessionFsProvider.js"; export type { SessionFsSqliteQueryResult } from "./sessionFsProvider.js"; export type { SessionFsSqliteQueryType } from "./sessionFsProvider.js"; export type { SessionFsSqliteProvider } from "./sessionFsProvider.js"; +export type { + LlmInferenceRequest, + LlmInferenceResponseInit, + LlmInferenceResponseSink, +} from "./llmInferenceProvider.js"; +export type { LlmInferenceHeaders } from "./generated/rpc.js"; +export type { LlmRequestContext } from "./llmRequestHandler.js"; +export { + CopilotWebSocketHandler, + ForwardingWebSocketHandler, + LlmRequestHandler, + LlmWebSocketCloseStatus, +} from "./llmRequestHandler.js"; /** * Options for creating a CopilotClient @@ -296,6 +310,27 @@ export interface CopilotClientOptions { */ sessionFs?: SessionFsConfig; + /** + * Custom LLM inference callback provider (experimental). + * + * When provided, the client registers as the runtime's LLM inference + * provider on connection: every outbound model-layer request the runtime + * would otherwise have issued itself — plain HTTP, streaming SSE, and + * WebSocket — is dispatched back to the callback over JSON-RPC. The + * callback returns the response verbatim, exactly as if the runtime had + * issued the request itself. + * + * v1 notes: + * - HTTP (buffered and streaming SSE) and WebSocket transports are all + * intercepted. The callback receives a `transport` discriminator and a + * symmetric request-body stream / response-body sink for both. + * - The callback is set process-globally on the runtime; the same + * provider is invoked for every session created on this client. + * + * @experimental + */ + llmInference?: LlmInferenceConfig; + /** * Server-wide idle timeout for sessions in seconds. * Sessions without activity for this duration are automatically cleaned up. @@ -2305,6 +2340,28 @@ export interface SessionFsConfig { }; } +/** + * Configuration for a custom LLM inference callback provider + * (experimental). + * + * @experimental + */ +export interface LlmInferenceConfig { + /** + * The handler that services LLM inference requests. The runtime routes + * all outbound model HTTP and WebSocket requests through this handler + * for the lifetime of the client, regardless of which session triggered + * them. + * + * Subclass {@link LlmRequestHandler} and override the hooks you need; + * an instance that overrides nothing is a transparent pass-through. + * + * Per-request session correlation is available on + * {@link LlmInferenceRequest.sessionId}. + */ + handler?: LlmRequestHandler; +} + /** * Filter options for listing sessions */ diff --git a/nodejs/test/e2e/llm_inference.e2e.test.ts b/nodejs/test/e2e/llm_inference.e2e.test.ts new file mode 100644 index 000000000..0d4898b92 --- /dev/null +++ b/nodejs/test/e2e/llm_inference.e2e.test.ts @@ -0,0 +1,131 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +import { describe, expect, it } from "vitest"; +import { approveAll, LlmRequestHandler, type LlmInferenceRequest } from "../../src/index.js"; +import { createSdkTestContext } from "./harness/sdkTestContext.js"; + +/** + * Drain the request body and reply with a single buffered response. The + * unified callback supports both buffered and streaming uniformly — for + * non-streaming responses, the consumer writes the whole body once and + * calls `end`. + */ +async function respondBuffered( + req: LlmInferenceRequest, + init: { status: number; headers?: Record }, + body: string, +): Promise { + for await (const _chunk of req.requestBody) { + // discard — the runtime always sends at least one chunk (with end:true). + } + await req.responseBody.start(init); + if (body.length > 0) { + await req.responseBody.write(body); + } + await req.responseBody.end(); +} + +async function handleNonStreaming(req: LlmInferenceRequest): Promise { + const url = req.url.toLowerCase(); + if (url.endsWith("/models")) { + return respondBuffered( + req, + { status: 200, headers: { "content-type": ["application/json"] } }, + JSON.stringify({ + data: [ + { + id: "claude-sonnet-4.5", + name: "Claude Sonnet 4.5", + object: "model", + vendor: "Anthropic", + version: "1", + preview: false, + model_picker_enabled: true, + capabilities: { + type: "chat", + family: "claude-sonnet-4.5", + tokenizer: "o200k_base", + limits: { max_context_window_tokens: 200000, max_output_tokens: 8192 }, + supports: { streaming: true, tool_calls: true, parallel_tool_calls: true, vision: true }, + }, + }, + ], + }), + ); + } + if (url.includes("/models/session")) { + return respondBuffered(req, { status: 200, headers: {} }, "{}"); + } + if (url.includes("/policy")) { + return respondBuffered(req, { status: 200, headers: {} }, JSON.stringify({ state: "enabled" })); + } + return respondBuffered( + req, + { status: 200, headers: { "content-type": ["application/json"] } }, + "{}", + ); +} + +describe("LLM inference callback", async () => { + const received: LlmInferenceRequest[] = []; + + const { copilotClient: client } = await createSdkTestContext({ + copilotClientOptions: { + llmInference: { + handler: new (class extends LlmRequestHandler { + override async onLlmRequest(req): Promise { + received.push(req); + await handleNonStreaming(req); + } + })(), + }, + }, + }); + + it("registers the provider on connect without erroring", async () => { + await client.start(); + expect(client).toBeDefined(); + }); + + it( + "invokes the callback for non-streaming model-layer requests and threads sessionId through", + async () => { + const baselineLength = received.length; + const session = await client.createSession({ onPermissionRequest: approveAll }); + try { + // Drive a turn so model-layer traffic (catalog, + // session-intent, inference) flows through the callback. + // We swallow errors here — the buffered handler returns + // empty JSON for inference, which is not a valid model + // response; the agent will surface a transport error. + // What we care about is that the runtime *attempted* to + // call the callback for the model-layer endpoints. + try { + await session.sendAndWait({ prompt: "Say OK." }); + } catch { + // expected — see comment above + } + } finally { + await session.disconnect(); + } + + expect(received.length).toBeGreaterThan(baselineLength); + const newRequests = received.slice(baselineLength); + for (const r of newRequests) { + expect(r.url).toMatch(/^https?:\/\//); + expect(typeof r.method).toBe("string"); + } + + const catalog = newRequests.find((r) => r.url.toLowerCase().endsWith("/models")); + expect(catalog, "expected to intercept the /models catalog request").toBeDefined(); + + const inSession = newRequests.find((r) => typeof r.sessionId === "string"); + if (inSession) { + expect(inSession.sessionId).toMatch(/[a-zA-Z0-9-]+/); + } + }, + 90_000, + ); +}); diff --git a/nodejs/test/e2e/llm_inference_cancel.e2e.test.ts b/nodejs/test/e2e/llm_inference_cancel.e2e.test.ts new file mode 100644 index 000000000..72f1471c0 --- /dev/null +++ b/nodejs/test/e2e/llm_inference_cancel.e2e.test.ts @@ -0,0 +1,164 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +import { describe, expect, it } from "vitest"; +import { approveAll, LlmRequestHandler, type LlmInferenceRequest } from "../../src/index.js"; +import { createSdkTestContext } from "./harness/sdkTestContext.js"; + +async function drainRequest(req: LlmInferenceRequest): Promise { + for await (const _chunk of req.requestBody) { + // discard + } +} + +async function respondBuffered( + req: LlmInferenceRequest, + init: { status: number; headers?: Record }, + body: string, +): Promise { + await drainRequest(req); + await req.responseBody.start(init); + if (body.length > 0) { + await req.responseBody.write(body); + } + await req.responseBody.end(); +} + +async function serviceNonInference(req: LlmInferenceRequest): Promise { + const url = req.url.toLowerCase(); + if (url.endsWith("/models")) { + await respondBuffered( + req, + { status: 200, headers: { "content-type": ["application/json"] } }, + JSON.stringify({ + data: [ + { + id: "claude-sonnet-4.5", + name: "Claude Sonnet 4.5", + object: "model", + vendor: "Anthropic", + version: "1", + preview: false, + model_picker_enabled: true, + capabilities: { + type: "chat", + family: "claude-sonnet-4.5", + tokenizer: "o200k_base", + limits: { max_context_window_tokens: 200000, max_output_tokens: 8192 }, + supports: { streaming: true, tool_calls: true, parallel_tool_calls: true, vision: true }, + }, + }, + ], + }), + ); + return true; + } + if (url.includes("/models/session")) { + await respondBuffered(req, { status: 200, headers: {} }, "{}"); + return true; + } + if (url.includes("/policy")) { + await respondBuffered(req, { status: 200, headers: {} }, JSON.stringify({ state: "enabled" })); + return true; + } + return false; +} + +async function waitFor(predicate: () => boolean, timeoutMs: number): Promise { + const start = Date.now(); + while (!predicate()) { + if (Date.now() - start > timeoutMs) { + throw new Error("waitFor timed out"); + } + await new Promise((resolve) => setTimeout(resolve, 50)); + } +} + +/** + * Verifies the runtime → consumer cancellation path: when an in-flight + * turn is aborted via `session.abort()`, the runtime cancels the + * callback-served inference request and the consumer observes + * `req.signal.aborted` so it can tear down its upstream call. + */ +describe("LLM inference callback — cancellation", async () => { + let inferenceEntered = false; + let sawAbort = false; + let resolveAbortSeen: (() => void) | undefined; + const abortSeen = new Promise((resolve) => { + resolveAbortSeen = resolve; + }); + + const { copilotClient: client } = await createSdkTestContext({ + copilotClientOptions: { + llmInference: { + handler: new (class extends LlmRequestHandler { + override async onLlmRequest(req: LlmInferenceRequest): Promise { + if (await serviceNonInference(req)) { + return; + } + const url = req.url.toLowerCase(); + const isInference = + url.includes("/chat/completions") || + url.includes("/responses") || + url.endsWith("/messages") || + url.endsWith("/v1/messages"); + if (!isInference) { + await respondBuffered( + req, + { status: 200, headers: { "content-type": ["application/json"] } }, + "{}", + ); + return; + } + + // Inference: never produce a response. Wait for the + // runtime to cancel us, recording the abort. + await drainRequest(req); + inferenceEntered = true; + await new Promise((resolve) => { + if (req.signal.aborted) { + resolve(); + return; + } + req.signal.addEventListener("abort", () => resolve(), { once: true }); + }); + sawAbort = true; + resolveAbortSeen?.(); + try { + await req.responseBody.error({ message: "cancelled by upstream", code: "cancelled" }); + } catch { + // Runtime already dropped the request on cancel. + } + } + })(), + }, + }, + }); + + it( + "propagates runtime cancellation to the consumer's req.signal", + async () => { + await client.start(); + const session = await client.createSession({ onPermissionRequest: approveAll }); + try { + await session.send({ prompt: "Say OK." }); + await waitFor(() => inferenceEntered, 60_000); + await session.abort(); + await Promise.race([ + abortSeen, + new Promise((_resolve, reject) => + setTimeout(() => reject(new Error("timed out waiting for abort")), 30_000), + ), + ]); + } finally { + await session.disconnect(); + } + + // The consumer observed the runtime-driven cancellation. + expect(inferenceEntered).toBe(true); + expect(sawAbort).toBe(true); + }, + 120_000, + ); +}); diff --git a/nodejs/test/e2e/llm_inference_consumer_cancel.e2e.test.ts b/nodejs/test/e2e/llm_inference_consumer_cancel.e2e.test.ts new file mode 100644 index 000000000..c504bdd2b --- /dev/null +++ b/nodejs/test/e2e/llm_inference_consumer_cancel.e2e.test.ts @@ -0,0 +1,147 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +import { describe, expect, it } from "vitest"; +import { approveAll, LlmRequestHandler, type LlmInferenceRequest } from "../../src/index.js"; +import { createSdkTestContext } from "./harness/sdkTestContext.js"; + +async function drainRequest(req: LlmInferenceRequest): Promise { + for await (const _chunk of req.requestBody) { + // discard + } +} + +async function respondBuffered( + req: LlmInferenceRequest, + init: { status: number; headers?: Record }, + body: string, +): Promise { + await drainRequest(req); + await req.responseBody.start(init); + if (body.length > 0) { + await req.responseBody.write(body); + } + await req.responseBody.end(); +} + +async function serviceNonInference(req: LlmInferenceRequest): Promise { + const url = req.url.toLowerCase(); + if (url.endsWith("/models")) { + await respondBuffered( + req, + { status: 200, headers: { "content-type": ["application/json"] } }, + JSON.stringify({ + data: [ + { + id: "claude-sonnet-4.5", + name: "Claude Sonnet 4.5", + object: "model", + vendor: "Anthropic", + version: "1", + preview: false, + model_picker_enabled: true, + capabilities: { + type: "chat", + family: "claude-sonnet-4.5", + tokenizer: "o200k_base", + limits: { max_context_window_tokens: 200000, max_output_tokens: 8192 }, + supports: { streaming: true, tool_calls: true, parallel_tool_calls: true, vision: true }, + }, + }, + ], + }), + ); + return true; + } + if (url.includes("/models/session")) { + await respondBuffered(req, { status: 200, headers: {} }, "{}"); + return true; + } + if (url.includes("/policy")) { + await respondBuffered(req, { status: 200, headers: {} }, JSON.stringify({ state: "enabled" })); + return true; + } + return false; +} + +function isInferenceUrl(url: string): boolean { + const u = url.toLowerCase(); + return ( + u.includes("/chat/completions") || + u.includes("/responses") || + u.endsWith("/messages") || + u.endsWith("/v1/messages") + ); +} + +/** + * Verifies the consumer → runtime cancellation path: when the consumer + * itself decides to abort the upstream call (e.g. its own + * `AbortController` fired, or the upstream socket dropped), it signals the + * runtime via `responseBody.error({ code: "cancelled" })`. The runtime + * must surface that faithfully as a request failure rather than hanging + * waiting for a response head/body. + */ +describe("LLM inference callback — consumer-initiated cancellation", async () => { + let inferenceAttempts = 0; + + const { copilotClient: client } = await createSdkTestContext({ + copilotClientOptions: { + llmInference: { + handler: new (class extends LlmRequestHandler { + override async onLlmRequest(req: LlmInferenceRequest): Promise { + if (await serviceNonInference(req)) { + return; + } + if (!isInferenceUrl(req.url)) { + await respondBuffered( + req, + { status: 200, headers: { "content-type": ["application/json"] } }, + "{}", + ); + return; + } + + // Consumer-initiated cancellation: the consumer's own + // upstream call was aborted, so it tells the runtime to + // give up on this request. No response head is ever + // produced; the runtime should see a transport failure. + await drainRequest(req); + inferenceAttempts += 1; + await req.responseBody.error({ + message: "upstream call aborted by consumer", + code: "cancelled", + }); + } + })(), + }, + }, + }); + + it( + "surfaces a consumer-signalled cancellation to the runtime", + async () => { + await client.start(); + const session = await client.createSession({ onPermissionRequest: approveAll }); + + let caught: unknown; + try { + await session.sendAndWait({ prompt: "Say OK." }); + } catch (err) { + caught = err; + } finally { + await session.disconnect(); + } + + // The runtime reached the inference step and the consumer's + // cancellation terminated it (rather than the runtime hanging). + expect(inferenceAttempts).toBeGreaterThan(0); + if (caught) { + const message = caught instanceof Error ? caught.message : String(caught); + expect(message.length).toBeGreaterThan(0); + } + }, + 90_000, + ); +}); diff --git a/nodejs/test/e2e/llm_inference_errors.e2e.test.ts b/nodejs/test/e2e/llm_inference_errors.e2e.test.ts new file mode 100644 index 000000000..4d8c84643 --- /dev/null +++ b/nodejs/test/e2e/llm_inference_errors.e2e.test.ts @@ -0,0 +1,147 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +import { describe, expect, it } from "vitest"; +import { approveAll, LlmRequestHandler, type LlmInferenceRequest } from "../../src/index.js"; +import { createSdkTestContext } from "./harness/sdkTestContext.js"; + +async function drainRequest(req: LlmInferenceRequest): Promise { + for await (const _chunk of req.requestBody) { + // discard + } +} + +async function respondBuffered( + req: LlmInferenceRequest, + init: { status: number; headers?: Record }, + body: string, +): Promise { + await drainRequest(req); + await req.responseBody.start(init); + if (body.length > 0) { + await req.responseBody.write(body); + } + await req.responseBody.end(); +} + +/** + * Verifies that errors thrown (or signalled via `responseBody.error`) by + * the LLM inference callback surface to the SDK consumer as transport + * failures, so the runtime's existing retry / error-reporting machinery + * handles them uniformly. + */ +describe("LLM inference callback — error mapping", async () => { + let callsBeforeError = 0; + let totalCalls = 0; + + const { copilotClient: client } = await createSdkTestContext({ + copilotClientOptions: { + llmInference: { + handler: new (class extends LlmRequestHandler { + override async onLlmRequest(req: LlmInferenceRequest): Promise { + totalCalls += 1; + const url = req.url.toLowerCase(); + + // Service models / session / policy normally so the + // agent can reach the inference step. + if (url.endsWith("/models")) { + await respondBuffered( + req, + { status: 200, headers: { "content-type": ["application/json"] } }, + JSON.stringify({ + data: [ + { + id: "claude-sonnet-4.5", + name: "Claude Sonnet 4.5", + object: "model", + vendor: "Anthropic", + version: "1", + preview: false, + model_picker_enabled: true, + capabilities: { + type: "chat", + family: "claude-sonnet-4.5", + tokenizer: "o200k_base", + limits: { + max_context_window_tokens: 200000, + max_output_tokens: 8192, + }, + supports: { + streaming: true, + tool_calls: true, + parallel_tool_calls: true, + vision: true, + }, + }, + }, + ], + }), + ); + return; + } + if (url.includes("/models/session")) { + await respondBuffered(req, { status: 200, headers: {} }, "{}"); + return; + } + if (url.includes("/policy")) { + await respondBuffered( + req, + { status: 200, headers: {} }, + JSON.stringify({ state: "enabled" }), + ); + return; + } + + // Inference: throw a transport-level error from the + // callback. The adapter converts this into a + // terminal `httpResponseChunk` with `error` set, so + // the runtime surfaces it as `APIConnectionError`. + if (url.includes("/chat/completions") || url.includes("/responses")) { + await drainRequest(req); + callsBeforeError += 1; + throw new Error("synthetic-callback-transport-failure"); + } + + await respondBuffered( + req, + { status: 200, headers: { "content-type": ["application/json"] } }, + "{}", + ); + } + })(), + }, + }, + }); + + it( + "surfaces a callback-thrown error to the SDK consumer", + async () => { + await client.start(); + const session = await client.createSession({ onPermissionRequest: approveAll }); + + let caught: unknown; + try { + await session.sendAndWait({ prompt: "Say OK." }); + } catch (err) { + caught = err; + } finally { + await session.disconnect(); + } + + // The agent layer typically wraps inference failures in its + // own error type and may convert them to an event rather than + // a thrown exception, so the assertion is loose: either we + // caught an error referencing the callback failure, or the + // inference call was attempted at least once and the runtime + // did NOT hang waiting for a response. + expect(totalCalls).toBeGreaterThan(0); + expect(callsBeforeError).toBeGreaterThan(0); + if (caught) { + const message = caught instanceof Error ? caught.message : String(caught); + expect(message.length).toBeGreaterThan(0); + } + }, + 90_000, + ); +}); diff --git a/nodejs/test/e2e/llm_inference_handler.e2e.test.ts b/nodejs/test/e2e/llm_inference_handler.e2e.test.ts new file mode 100644 index 000000000..e8fcc7529 --- /dev/null +++ b/nodejs/test/e2e/llm_inference_handler.e2e.test.ts @@ -0,0 +1,390 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +import { createServer, IncomingMessage, Server as HttpServer, ServerResponse } from "http"; +import { AddressInfo } from "net"; +import { afterAll, describe, expect, it } from "vitest"; +import { WebSocket as WsClient, WebSocketServer } from "ws"; +import { + approveAll, + CopilotWebSocketHandler, + LlmRequestHandler, + LlmWebSocketCloseStatus, + type LlmRequestContext, +} from "../../src/index.js"; +import { createSdkTestContext } from "./harness/sdkTestContext.js"; + +const HTTP_TEXT = "OK from synthetic HTTP upstream."; +const WS_TEXT = "OK from synthetic WS upstream."; + +/** + * Stand up an in-process upstream that speaks the real CAPI shapes the + * runtime needs: model catalog, policy, `/responses` SSE for HTTP + * inference, and a WebSocket endpoint at `/responses` that answers each + * inbound `response.create` with the ordered `/responses` events the + * reducer expects. + * + * Returned `url` is what the handler subclass rewrites every + * intercepted request to point at — the runtime never talks to this + * server directly; the handler does, on the runtime's behalf. + */ +async function startFakeUpstream(): Promise<{ + url: string; + server: HttpServer; + wsRequestCount: () => number; + close: () => Promise; +}> { + let wsRequests = 0; + + const httpServer = createServer((req, res) => { + const url = new URL(req.url ?? "/", `http://${req.headers.host ?? "localhost"}`); + if (url.pathname === "/models" && req.method === "GET") { + sendJson(res, 200, { + data: [ + { + id: "claude-sonnet-4.5", + name: "Claude Sonnet 4.5", + object: "model", + vendor: "Anthropic", + version: "1", + preview: false, + model_picker_enabled: true, + supported_endpoints: ["/responses", "ws:/responses"], + capabilities: { + type: "chat", + family: "claude-sonnet-4.5", + tokenizer: "o200k_base", + limits: { + max_context_window_tokens: 200000, + max_output_tokens: 8192, + }, + supports: { + streaming: true, + tool_calls: true, + parallel_tool_calls: true, + vision: true, + }, + }, + }, + ], + }); + return; + } + if (url.pathname.endsWith("/models/session")) { + sendJson(res, 200, {}); + return; + } + if (url.pathname.includes("/policy")) { + sendJson(res, 200, { state: "enabled" }); + return; + } + if (url.pathname.endsWith("/responses") && req.method === "POST") { + // Single-shot HTTP inference (e.g. title generation). SSE + // events the `responses-client.ts` reducer accepts. + drainBody(req) + .then(() => { + res.writeHead(200, { + "content-type": "text/event-stream", + "cache-control": "no-cache", + }); + for (const event of buildResponsesEvents(HTTP_TEXT, "resp_stub_http")) { + res.write(`event: ${event.type}\ndata: ${JSON.stringify(event)}\n\n`); + } + res.end(); + }) + .catch(() => { + res.writeHead(500).end(); + }); + return; + } + // Anything else: not found. + res.writeHead(404, { "content-type": "application/json" }); + res.end(JSON.stringify({ error: "not_found", path: url.pathname })); + }); + + const wss = new WebSocketServer({ server: httpServer, path: "/responses" }); + wss.on("connection", (socket) => { + socket.on("message", (raw) => { + wsRequests++; + // For each `response.create` request the runtime sends, + // answer with the ordered `/responses` event objects — one + // event per outbound WS message, raw JSON (NOT SSE-framed). + for (const event of buildResponsesEvents(WS_TEXT, "resp_stub_ws")) { + socket.send(JSON.stringify(event)); + } + void raw; + }); + }); + + await new Promise((resolve) => httpServer.listen(0, "127.0.0.1", resolve)); + const port = (httpServer.address() as AddressInfo).port; + const url = `http://127.0.0.1:${port}`; + + return { + url, + server: httpServer, + wsRequestCount: () => wsRequests, + async close() { + wss.clients.forEach((c) => c.terminate()); + await new Promise((resolve) => wss.close(() => resolve())); + await new Promise((resolve) => httpServer.close(() => resolve())); + }, + }; +} + +function sendJson(res: ServerResponse, status: number, body: unknown): void { + res.writeHead(status, { "content-type": "application/json" }); + res.end(JSON.stringify(body)); +} + +async function drainBody(req: IncomingMessage): Promise { + const parts: Buffer[] = []; + for await (const chunk of req) { + parts.push(chunk as Buffer); + } + return Buffer.concat(parts); +} + +function buildResponsesEvents(text: string, id: string): Array> { + return [ + { + type: "response.created", + response: { id, object: "response", status: "in_progress", output: [] }, + }, + { + type: "response.output_item.added", + output_index: 0, + item: { id: "msg_1", type: "message", role: "assistant", content: [] }, + }, + { + type: "response.content_part.added", + output_index: 0, + content_index: 0, + part: { type: "output_text", text: "" }, + }, + { type: "response.output_text.delta", output_index: 0, content_index: 0, delta: text }, + { type: "response.output_text.done", output_index: 0, content_index: 0, text }, + { + type: "response.completed", + response: { + id, + object: "response", + status: "completed", + output: [ + { + id: "msg_1", + type: "message", + role: "assistant", + content: [{ type: "output_text", text }], + }, + ], + usage: { input_tokens: 5, output_tokens: 7, total_tokens: 12 }, + }, + }, + ]; +} + +/** +interface Counters { + httpRequests: number; + httpResponses: number; + wsRequestMessages: number; + wsResponseMessages: number; +} + +/** + * Single handler subclass that services BOTH transports against the + * per-test fake upstream. Demonstrates mutation in each direction: + * + * - HTTP: rewrites the URL to point at the test server, adds an + * `X-Test-Mutated` header to the outbound request, and adds an + * `X-Test-Response-Mutated` header on the way back. The test server + * echoes the request header into a counter so we can assert it + * actually arrived upstream. + * - WebSocket: rewrites the WS URL similarly, opens with the `ws` + * package inside a custom per-connection handler, and observes + * message counts in both directions. + */ +class TestHandler extends LlmRequestHandler { + constructor( + private readonly upstreamUrl: string, + private readonly counters: Counters + ) { + super(); + } + + private rewriteUrl(originalUrl: string): string { + const parsed = new URL(originalUrl); + const upstream = new URL(this.upstreamUrl); + parsed.protocol = upstream.protocol; + parsed.host = upstream.host; + return parsed.toString(); + } + + private rewriteWsUrl(originalUrl: string): string { + const parsed = new URL(originalUrl); + const upstream = new URL(this.upstreamUrl); + // The upstream URL is http(s); flip to ws(s) for the WS open. + parsed.protocol = upstream.protocol === "https:" ? "wss:" : "ws:"; + parsed.host = upstream.host; + return parsed.toString(); + } + + protected override async sendRequest(request: Request, _ctx: LlmRequestContext): Promise { + this.counters.httpRequests++; + const rewritten = this.rewriteUrl(request.url); + const requestHeaders = new Headers(request.headers); + requestHeaders.set("x-test-mutated", "1"); + const rewrittenRequest = new Request(rewritten, { + method: request.method, + headers: requestHeaders, + body: request.body, + // @ts-expect-error duplex is required by undici when streaming a body + duplex: "half", + }); + const response = await fetch(rewrittenRequest, { signal: _ctx.signal }); + this.counters.httpResponses++; + const responseHeaders = new Headers(response.headers); + responseHeaders.set("x-test-response-mutated", "1"); + return new Response(response.body, { + status: response.status, + statusText: response.statusText, + headers: responseHeaders, + }); + } + + protected override async openWebSocket(ctx: LlmRequestContext): Promise { + return TestSocketHandler.connect(this.rewriteWsUrl(ctx.url), ctx, this.counters); + } +} + +class TestSocketHandler extends CopilotWebSocketHandler { + static async connect( + url: string, + ctx: LlmRequestContext, + counters: Counters + ): Promise { + const client = new WsClient(url); + await new Promise((resolve, reject) => { + client.once("open", () => resolve()); + client.once("error", (err) => reject(err)); + }); + return new TestSocketHandler(client, ctx, counters); + } + + private constructor( + private readonly client: WsClient, + ctx: LlmRequestContext, + private readonly counters: Counters + ) { + super(ctx); + this.client.on("message", (data, isBinary) => { + this.counters.wsResponseMessages++; + void this.sendResponseMessage(isBinary ? (data as Buffer) : data.toString("utf-8")); + }); + this.client.once("close", () => { + void this.close(); + }); + this.client.once("error", (err) => { + void this.close(new LlmWebSocketCloseStatus(err.message, undefined, err as Error)); + }); + const onAbort = (): void => { + try { + this.client.close(); + } catch { + /* best-effort */ + } + }; + ctx.signal.addEventListener("abort", onAbort, { once: true }); + this.client.once("close", () => ctx.signal.removeEventListener("abort", onAbort)); + } + + override sendRequestMessage(data: string | Uint8Array): void { + this.counters.wsRequestMessages++; + if (this.client.readyState !== WsClient.OPEN) { + return; + } + this.client.send(data); + } + + override async [Symbol.asyncDispose](): Promise { + try { + await super[Symbol.asyncDispose](); + } finally { + try { + this.client.close(); + } catch { + /* best-effort */ + } + } + } +} + +describe("LlmRequestHandler — single subclass handles HTTP + WebSocket", async () => { + const upstream = await startFakeUpstream(); + const counters: Counters = { + httpRequests: 0, + httpResponses: 0, + wsRequestMessages: 0, + wsResponseMessages: 0, + }; + + const { copilotClient: client, env } = await createSdkTestContext({ + copilotClientOptions: { + llmInference: { + handler: new TestHandler(upstream.url, counters), + }, + }, + }); + + // Enable the WebSocket Responses transport in the spawned runtime so + // the main agent turn picks the WS path; single-shot calls (title + // generation) still go over HTTP through the same subclass. + env.COPILOT_EXP_COPILOT_CLI_WEBSOCKET_RESPONSES = "true"; + + afterAll(async () => { + await upstream.close(); + }); + + it("services both an HTTP turn and a WebSocket turn end-to-end via one handler", async () => { + await client.start(); + const session = await client.createSession({ onPermissionRequest: approveAll }); + let resultJson = ""; + try { + const result = await session.sendAndWait({ prompt: "Say OK." }); + resultJson = JSON.stringify(result); + } finally { + await session.disconnect(); + } + + // The HTTP hooks fired — the runtime issued model-layer GETs + // (catalog, policy) and possibly a single-shot inference. + expect(counters.httpRequests, "expected sendRequest to fire").toBeGreaterThan(0); + expect(counters.httpResponses, "expected sendRequest response mutation to fire").toBeGreaterThan( + 0 + ); + + // The WebSocket hooks fired — the main agent turn went over + // the WS path and we observed messages in both directions. + expect( + counters.wsRequestMessages, + "expected sendRequestMessage (runtime → upstream) to fire" + ).toBeGreaterThan(0); + expect( + counters.wsResponseMessages, + "expected sendResponseMessage (upstream → runtime) to fire" + ).toBeGreaterThan(0); + expect( + upstream.wsRequestCount(), + "expected upstream WS to receive request messages" + ).toBeGreaterThan(0); + + // The synthetic content from the upstream surfaced in the + // assistant turn — proves the full chain (runtime → handler + // → upstream → handler → runtime) is intact for the + // transport the main agent turn used. + // Validate the final assistant response arrived (guards against truncated captures) + expect(resultJson).toMatch(/OK from synthetic (HTTP|WS) upstream/); + }, 90_000); +}); diff --git a/nodejs/test/e2e/llm_inference_session_id.e2e.test.ts b/nodejs/test/e2e/llm_inference_session_id.e2e.test.ts new file mode 100644 index 000000000..8637f7b6e --- /dev/null +++ b/nodejs/test/e2e/llm_inference_session_id.e2e.test.ts @@ -0,0 +1,335 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +import { describe, expect, it } from "vitest"; +import { approveAll, LlmRequestHandler, type LlmInferenceRequest } from "../../src/index.js"; +import { createSdkTestContext } from "./harness/sdkTestContext.js"; + +const SYNTHETIC_TEXT = "OK from the synthetic stream."; + +async function drainRequest(req: LlmInferenceRequest): Promise { + const parts: Buffer[] = []; + for await (const chunk of req.requestBody) { + parts.push(Buffer.from(chunk)); + } + return Buffer.concat(parts).toString("utf-8"); +} + +async function respondBuffered( + req: LlmInferenceRequest, + init: { status: number; headers?: Record }, + body: string +): Promise { + await drainRequest(req); + await req.responseBody.start(init); + if (body.length > 0) { + await req.responseBody.write(body); + } + await req.responseBody.end(); +} + +/** + * Serve the model-layer GETs/POSTs the runtime issues that are not + * inference (catalog, model session, policy). These flow through the same + * callback but carry no session id (they happen outside an agent turn). + */ +async function handleNonInferenceModelTraffic(req: LlmInferenceRequest): Promise { + const url = req.url.toLowerCase(); + if (url.endsWith("/models")) { + await respondBuffered( + req, + { status: 200, headers: { "content-type": ["application/json"] } }, + JSON.stringify({ + data: [ + { + id: "claude-sonnet-4.5", + name: "Claude Sonnet 4.5", + object: "model", + vendor: "Anthropic", + version: "1", + preview: false, + model_picker_enabled: true, + capabilities: { + type: "chat", + family: "claude-sonnet-4.5", + tokenizer: "o200k_base", + limits: { max_context_window_tokens: 200000, max_output_tokens: 8192 }, + supports: { + streaming: true, + tool_calls: true, + parallel_tool_calls: true, + vision: true, + }, + }, + }, + ], + }) + ); + return; + } + if (url.includes("/models/session")) { + await respondBuffered(req, { status: 200, headers: {} }, "{}"); + return; + } + if (url.includes("/policy")) { + await respondBuffered(req, { status: 200, headers: {} }, JSON.stringify({ state: "enabled" })); + return; + } + await respondBuffered(req, { status: 200, headers: { "content-type": ["application/json"] } }, "{}"); +} + +/** + * Synthesize a well-formed inference response so the agent turn completes. + * The runtime selects `/responses` for both the CAPI and BYOK sessions + * here; `/chat/completions` is handled too for robustness. The consumer + * fabricates the response directly — there is no upstream server and the + * CAPI record/replay proxy is never the inference endpoint. + */ +async function handleInference(req: LlmInferenceRequest): Promise { + const bodyText = await drainRequest(req); + const wantsStream = /"stream"\s*:\s*true/.test(bodyText); + const url = req.url.toLowerCase(); + + // `/responses` streams via SSE only when the request asked for it + // (`stream: true`). BYOK turns whose config-derived model doesn't + // advertise streaming issue a buffered request expecting a single + // JSON `response` object, so branch on the flag exactly as a real + // upstream would. + if (url.includes("/responses")) { + if (!wantsStream) { + await req.responseBody.start({ + status: 200, + headers: { "content-type": ["application/json"] }, + }); + await req.responseBody.write( + JSON.stringify({ + id: "resp_stub_1", + object: "response", + status: "completed", + output: [ + { + id: "msg_1", + type: "message", + role: "assistant", + content: [{ type: "output_text", text: SYNTHETIC_TEXT }], + }, + ], + usage: { input_tokens: 5, output_tokens: 7, total_tokens: 12 }, + }) + ); + await req.responseBody.end(); + return; + } + await req.responseBody.start({ + status: 200, + headers: { "content-type": ["text/event-stream"] }, + }); + const id = "resp_stub_1"; + const events: string[] = [ + `event: response.created\ndata: ${JSON.stringify({ + type: "response.created", + response: { id, object: "response", status: "in_progress", output: [] }, + })}\n\n`, + `event: response.output_item.added\ndata: ${JSON.stringify({ + type: "response.output_item.added", + output_index: 0, + item: { id: "msg_1", type: "message", role: "assistant", content: [] }, + })}\n\n`, + `event: response.content_part.added\ndata: ${JSON.stringify({ + type: "response.content_part.added", + output_index: 0, + content_index: 0, + part: { type: "output_text", text: "" }, + })}\n\n`, + `event: response.output_text.delta\ndata: ${JSON.stringify({ + type: "response.output_text.delta", + output_index: 0, + content_index: 0, + delta: SYNTHETIC_TEXT, + })}\n\n`, + `event: response.output_text.done\ndata: ${JSON.stringify({ + type: "response.output_text.done", + output_index: 0, + content_index: 0, + text: SYNTHETIC_TEXT, + })}\n\n`, + `event: response.completed\ndata: ${JSON.stringify({ + type: "response.completed", + response: { + id, + object: "response", + status: "completed", + output: [ + { + id: "msg_1", + type: "message", + role: "assistant", + content: [{ type: "output_text", text: SYNTHETIC_TEXT }], + }, + ], + usage: { input_tokens: 5, output_tokens: 7, total_tokens: 12 }, + }, + })}\n\n`, + ]; + for (const event of events) { + await req.responseBody.write(event); + } + await req.responseBody.end(); + return; + } + + if (url.includes("/chat/completions") && wantsStream) { + await req.responseBody.start({ + status: 200, + headers: { "content-type": ["text/event-stream"] }, + }); + const base = { id: "chatcmpl-stub-1", object: "chat.completion.chunk", created: 1, model: "claude-sonnet-4.5" }; + const events: string[] = [ + `data: ${JSON.stringify({ + ...base, + choices: [{ index: 0, delta: { role: "assistant", content: "" }, finish_reason: null }], + })}\n\n`, + `data: ${JSON.stringify({ + ...base, + choices: [{ index: 0, delta: { content: SYNTHETIC_TEXT }, finish_reason: null }], + })}\n\n`, + `data: ${JSON.stringify({ + ...base, + choices: [{ index: 0, delta: {}, finish_reason: "stop" }], + usage: { prompt_tokens: 5, completion_tokens: 7, total_tokens: 12 }, + })}\n\n`, + `data: [DONE]\n\n`, + ]; + for (const event of events) { + await req.responseBody.write(event); + } + await req.responseBody.end(); + return; + } + + // /chat/completions non-streaming — buffered JSON. + await req.responseBody.start({ status: 200, headers: { "content-type": ["application/json"] } }); + await req.responseBody.write( + JSON.stringify({ + id: "chatcmpl-stub-1", + object: "chat.completion", + created: 1, + model: "claude-sonnet-4.5", + choices: [ + { index: 0, message: { role: "assistant", content: SYNTHETIC_TEXT }, finish_reason: "stop" }, + ], + usage: { prompt_tokens: 5, completion_tokens: 7, total_tokens: 12 }, + }) + ); + await req.responseBody.end(); +} + +interface InterceptedRequest { + url: string; + sessionId?: string; +} + +function isInferenceUrl(url: string): boolean { + const u = url.toLowerCase(); + return ( + u.endsWith("/chat/completions") || + u.endsWith("/responses") || + u.endsWith("/v1/messages") || + u.endsWith("/messages") + ); +} + +/** + * Asserts the runtime threads its session id into the LLM inference + * callback for BOTH a CAPI session and a BYOK session. The callback alone + * services every model-layer request — no upstream server, no CAPI proxy + * acting as the inference endpoint — so the only source of `req.sessionId` + * is the runtime's own per-client threading. + */ +describe("LLM inference callback threads the runtime session id (CAPI + BYOK)", async () => { + const records: InterceptedRequest[] = []; + + const { copilotClient: client } = await createSdkTestContext({ + copilotClientOptions: { + llmInference: { + handler: new (class extends LlmRequestHandler { + override async onLlmRequest(req: LlmInferenceRequest): Promise { + records.push({ url: req.url, sessionId: req.sessionId }); + if (isInferenceUrl(req.url)) { + await handleInference(req); + } else { + await handleNonInferenceModelTraffic(req); + } + } + })(), + }, + }, + }); + + let capiSessionId: string | undefined; + + it("threads the session id into a CAPI session's inference request", async () => { + await client.start(); + const baseline = records.length; + const session = await client.createSession({ onPermissionRequest: approveAll }); + capiSessionId = session.sessionId; + let resultJson = ""; + try { + const result = await session.sendAndWait({ prompt: "Say OK." }); + resultJson = JSON.stringify(result); + } finally { + await session.disconnect(); + } + + const inference = records.slice(baseline).filter((r) => isInferenceUrl(r.url)); + expect(inference.length, "expected at least one intercepted inference request").toBeGreaterThan(0); + for (const r of inference) { + expect(r.sessionId, "CAPI inference request must carry the runtime session id").toBe( + session.sessionId + ); + } + + // Validate the final assistant response arrived (guards against truncated captures) + expect(resultJson).toMatch(/OK from the synthetic/); + }, 90_000); + + it("threads the session id into a BYOK session's inference request", async () => { + await client.start(); + const baseline = records.length; + const session = await client.createSession({ + onPermissionRequest: approveAll, + // BYOK providers require an explicit model id. + model: "claude-sonnet-4.5", + provider: { + type: "openai", + wireApi: "responses", + baseUrl: "https://byok.invalid/v1", + apiKey: "byok-secret", + modelId: "claude-sonnet-4.5", + wireModel: "claude-sonnet-4.5", + }, + }); + const byokSessionId = session.sessionId; + let resultJson = ""; + try { + const result = await session.sendAndWait({ prompt: "Say OK." }); + resultJson = JSON.stringify(result); + } finally { + await session.disconnect(); + } + + const inference = records.slice(baseline).filter((r) => isInferenceUrl(r.url)); + expect(inference.length, "expected at least one intercepted BYOK inference request").toBeGreaterThan(0); + for (const r of inference) { + expect(r.sessionId, "BYOK inference request must carry the runtime session id").toBe(byokSessionId); + } + + // Session ids are per-session, so the two turns must differ — proves + // we assert against a real, request-specific id, not a constant. + expect(byokSessionId).not.toBe(capiSessionId); + + // Validate the final assistant response arrived (guards against truncated captures) + expect(resultJson).toMatch(/OK from the synthetic/); + }, 90_000); +}); diff --git a/nodejs/test/e2e/llm_inference_stream.e2e.test.ts b/nodejs/test/e2e/llm_inference_stream.e2e.test.ts new file mode 100644 index 000000000..db25cf41f --- /dev/null +++ b/nodejs/test/e2e/llm_inference_stream.e2e.test.ts @@ -0,0 +1,260 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +import { describe, expect, it } from "vitest"; +import { approveAll, LlmRequestHandler, type LlmInferenceRequest } from "../../src/index.js"; +import { createSdkTestContext } from "./harness/sdkTestContext.js"; + +async function drainRequest(req: LlmInferenceRequest): Promise { + const parts: Buffer[] = []; + for await (const chunk of req.requestBody) { + parts.push(Buffer.from(chunk)); + } + return Buffer.concat(parts).toString("utf-8"); +} + +async function respondBuffered( + req: LlmInferenceRequest, + init: { status: number; headers?: Record }, + body: string, +): Promise { + await drainRequest(req); + await req.responseBody.start(init); + if (body.length > 0) { + await req.responseBody.write(body); + } + await req.responseBody.end(); +} + +async function handleNonInferenceModelTraffic(req: LlmInferenceRequest): Promise { + const url = req.url.toLowerCase(); + if (url.endsWith("/models")) { + await respondBuffered( + req, + { status: 200, headers: { "content-type": ["application/json"] } }, + JSON.stringify({ + data: [ + { + id: "claude-sonnet-4.5", + name: "Claude Sonnet 4.5", + object: "model", + vendor: "Anthropic", + version: "1", + preview: false, + model_picker_enabled: true, + capabilities: { + type: "chat", + family: "claude-sonnet-4.5", + tokenizer: "o200k_base", + limits: { max_context_window_tokens: 200000, max_output_tokens: 8192 }, + supports: { streaming: true, tool_calls: true, parallel_tool_calls: true, vision: true }, + }, + }, + ], + }), + ); + return; + } + if (url.includes("/models/session")) { + await respondBuffered(req, { status: 200, headers: {} }, "{}"); + return; + } + if (url.includes("/policy")) { + await respondBuffered(req, { status: 200, headers: {} }, JSON.stringify({ state: "enabled" })); + return; + } + await respondBuffered(req, { status: 200, headers: { "content-type": ["application/json"] } }, "{}"); +} + +/** + * Synthesizes a minimal but well-formed response for the runtime's + * inference request. The runtime calls the buffered code path for + * `/chat/completions` and the streaming code path for `/responses`, but + * the unified callback has no field telling the consumer which — the + * consumer dispatches by URL. + */ +async function handleInference(req: LlmInferenceRequest): Promise { + const bodyText = await drainRequest(req); + const wantsStream = /"stream"\s*:\s*true/.test(bodyText); + const url = req.url.toLowerCase(); + + if (url.includes("/responses")) { + await req.responseBody.start({ + status: 200, + headers: { "content-type": ["text/event-stream"] }, + }); + const id = "resp_stub_1"; + const events: string[] = [ + `event: response.created\ndata: ${JSON.stringify({ + type: "response.created", + response: { id, object: "response", status: "in_progress", output: [] }, + })}\n\n`, + `event: response.output_item.added\ndata: ${JSON.stringify({ + type: "response.output_item.added", + output_index: 0, + item: { id: "msg_1", type: "message", role: "assistant", content: [] }, + })}\n\n`, + `event: response.content_part.added\ndata: ${JSON.stringify({ + type: "response.content_part.added", + output_index: 0, + content_index: 0, + part: { type: "output_text", text: "" }, + })}\n\n`, + `event: response.output_text.delta\ndata: ${JSON.stringify({ + type: "response.output_text.delta", + output_index: 0, + content_index: 0, + delta: "OK from the synthetic stream.", + })}\n\n`, + `event: response.output_text.done\ndata: ${JSON.stringify({ + type: "response.output_text.done", + output_index: 0, + content_index: 0, + text: "OK from the synthetic stream.", + })}\n\n`, + `event: response.completed\ndata: ${JSON.stringify({ + type: "response.completed", + response: { + id, + object: "response", + status: "completed", + output: [ + { + id: "msg_1", + type: "message", + role: "assistant", + content: [{ type: "output_text", text: "OK from the synthetic stream." }], + }, + ], + usage: { input_tokens: 5, output_tokens: 7, total_tokens: 12 }, + }, + })}\n\n`, + ]; + for (const event of events) { + await req.responseBody.write(event); + } + await req.responseBody.end(); + return; + } + + if (url.includes("/chat/completions") && wantsStream) { + await req.responseBody.start({ + status: 200, + headers: { "content-type": ["text/event-stream"] }, + }); + const base = { + id: "chatcmpl-stub-1", + object: "chat.completion.chunk", + created: 1, + model: "claude-sonnet-4.5", + }; + const events: string[] = [ + `data: ${JSON.stringify({ + ...base, + choices: [{ index: 0, delta: { role: "assistant", content: "" }, finish_reason: null }], + })}\n\n`, + `data: ${JSON.stringify({ + ...base, + choices: [ + { + index: 0, + delta: { content: "OK from the synthetic stream." }, + finish_reason: null, + }, + ], + })}\n\n`, + `data: ${JSON.stringify({ + ...base, + choices: [{ index: 0, delta: {}, finish_reason: "stop" }], + usage: { prompt_tokens: 5, completion_tokens: 7, total_tokens: 12 }, + })}\n\n`, + `data: [DONE]\n\n`, + ]; + for (const event of events) { + await req.responseBody.write(event); + } + await req.responseBody.end(); + return; + } + + // /chat/completions non-streaming — buffered JSON. (body already drained above) + await req.responseBody.start({ status: 200, headers: { "content-type": ["application/json"] } }); + await req.responseBody.write( + JSON.stringify({ + id: "chatcmpl-stub-1", + object: "chat.completion", + created: 1, + model: "claude-sonnet-4.5", + choices: [ + { + index: 0, + message: { role: "assistant", content: "OK from the synthetic stream." }, + finish_reason: "stop", + }, + ], + usage: { prompt_tokens: 5, completion_tokens: 7, total_tokens: 12 }, + }), + ); + await req.responseBody.end(); +} + +describe("LLM inference callback — fully mocked streaming", async () => { + const received: LlmInferenceRequest[] = []; + + const { copilotClient: client } = await createSdkTestContext({ + copilotClientOptions: { + llmInference: { + handler: new (class extends LlmRequestHandler { + override async onLlmRequest(req: LlmInferenceRequest): Promise { + received.push(req); + const url = req.url.toLowerCase(); + const isInference = + url.includes("/chat/completions") || + url.endsWith("/responses") || + url.endsWith("/v1/messages") || + url.endsWith("/messages"); + if (isInference) { + await handleInference(req); + } else { + await handleNonInferenceModelTraffic(req); + } + } + })(), + }, + }, + }); + + it( + "completes a full user→assistant turn entirely via the callback (chunked SSE response)", + async () => { + await client.start(); + const session = await client.createSession({ onPermissionRequest: approveAll }); + let resultJson = ""; + try { + const result = await session.sendAndWait({ prompt: "Say OK." }); + resultJson = JSON.stringify(result); + } finally { + await session.disconnect(); + } + + // At least one inference request flowed through the callback. + const inferenceReqs = received.filter((r) => { + const u = r.url.toLowerCase(); + return ( + u.endsWith("/chat/completions") || + u.endsWith("/responses") || + u.endsWith("/v1/messages") || + u.endsWith("/messages") + ); + }); + expect(inferenceReqs.length, "expected at least one inference request via the callback").toBeGreaterThan( + 0, + ); + + // The synthetic content surfaced in the assistant response. + expect(resultJson).toMatch(/OK from the synthetic/); + }, + 90_000, + ); +}); diff --git a/nodejs/test/e2e/llm_inference_websocket.e2e.test.ts b/nodejs/test/e2e/llm_inference_websocket.e2e.test.ts new file mode 100644 index 000000000..440124784 --- /dev/null +++ b/nodejs/test/e2e/llm_inference_websocket.e2e.test.ts @@ -0,0 +1,226 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +import { describe, expect, it } from "vitest"; +import { approveAll, LlmRequestHandler, type LlmInferenceRequest } from "../../src/index.js"; +import { createSdkTestContext } from "./harness/sdkTestContext.js"; + +const WS_TEXT = "OK from the synthetic ws."; + +async function drainRequest(req: LlmInferenceRequest): Promise { + const parts: Buffer[] = []; + for await (const chunk of req.requestBody) { + parts.push(Buffer.from(chunk)); + } + return Buffer.concat(parts).toString("utf-8"); +} + +async function respondBuffered( + req: LlmInferenceRequest, + init: { status: number; headers?: Record }, + body: string, +): Promise { + await drainRequest(req); + await req.responseBody.start(init); + if (body.length > 0) { + await req.responseBody.write(body); + } + await req.responseBody.end(); +} + +/** + * The fake model catalog advertises both `/responses` and `ws:/responses` + * so `pickModelProtocol` selects the Responses wire API and `ai-client.ts` + * is allowed to pick the WebSocket transport (the feature flag is enabled + * via the env var below). No `/v1/messages`, otherwise the model would be + * routed to the Anthropic Messages API instead. + */ +async function handleNonInferenceModelTraffic(req: LlmInferenceRequest): Promise { + const url = req.url.toLowerCase(); + if (url.endsWith("/models")) { + await respondBuffered( + req, + { status: 200, headers: { "content-type": ["application/json"] } }, + JSON.stringify({ + data: [ + { + id: "claude-sonnet-4.5", + name: "Claude Sonnet 4.5", + object: "model", + vendor: "Anthropic", + version: "1", + preview: false, + model_picker_enabled: true, + supported_endpoints: ["/responses", "ws:/responses"], + capabilities: { + type: "chat", + family: "claude-sonnet-4.5", + tokenizer: "o200k_base", + limits: { max_context_window_tokens: 200000, max_output_tokens: 8192 }, + supports: { streaming: true, tool_calls: true, parallel_tool_calls: true, vision: true }, + }, + }, + ], + }), + ); + return; + } + if (url.includes("/models/session")) { + await respondBuffered(req, { status: 200, headers: {} }, "{}"); + return; + } + if (url.includes("/policy")) { + await respondBuffered(req, { status: 200, headers: {} }, JSON.stringify({ state: "enabled" })); + return; + } + await respondBuffered(req, { status: 200, headers: { "content-type": ["application/json"] } }, "{}"); +} + +/** + * Synthesizes the `/responses` SSE event stream for the HTTP code path + * (single-shot inference requests — e.g. title generation — that don't + * pick the WebSocket transport). + */ +async function handleHttpInference(req: LlmInferenceRequest): Promise { + await drainRequest(req); + await req.responseBody.start({ + status: 200, + headers: { "content-type": ["text/event-stream"] }, + }); + for (const event of buildResponsesEvents()) { + await req.responseBody.write(`event: ${event.type}\ndata: ${JSON.stringify(event)}\n\n`); + } + await req.responseBody.end(); +} + +/** + * Builds the ordered `/responses` event objects the reducer expects. + * Used raw (one object = one WS message) for the WebSocket path and + * SSE-framed for the HTTP path. + */ +function buildResponsesEvents(): Array> { + const id = "resp_stub_ws_1"; + return [ + { type: "response.created", response: { id, object: "response", status: "in_progress", output: [] } }, + { + type: "response.output_item.added", + output_index: 0, + item: { id: "msg_1", type: "message", role: "assistant", content: [] }, + }, + { + type: "response.content_part.added", + output_index: 0, + content_index: 0, + part: { type: "output_text", text: "" }, + }, + { type: "response.output_text.delta", output_index: 0, content_index: 0, delta: WS_TEXT }, + { type: "response.output_text.done", output_index: 0, content_index: 0, text: WS_TEXT }, + { + type: "response.completed", + response: { + id, + object: "response", + status: "completed", + output: [ + { + id: "msg_1", + type: "message", + role: "assistant", + content: [{ type: "output_text", text: WS_TEXT }], + }, + ], + usage: { input_tokens: 5, output_tokens: 7, total_tokens: 12 }, + }, + }, + ]; +} + +/** + * Full-duplex WebSocket handler. The runtime opens the channel + * (`transport === "websocket"`), the consumer acks the upgrade, then + * pumps bidirectionally: every inbound `response.create` request the + * runtime sends is answered with the ordered `/responses` event objects, + * one event per outbound WS message (raw JSON, *not* SSE-framed). The + * connection is reused across turns; it stays open until the runtime + * closes it, at which point `req.requestBody` throws and we stop. + */ +async function handleWebSocket(req: LlmInferenceRequest, onRequest: () => void): Promise { + // Ack the upgrade (status 101-equivalent) before any message flows. + await req.responseBody.start({ status: 101, headers: {} }); + try { + for await (const _outbound of req.requestBody) { + onRequest(); + for (const event of buildResponsesEvents()) { + await req.responseBody.write(JSON.stringify(event)); + } + } + } catch { + // Expected: the runtime cancels the request body when it closes the + // socket at session teardown. Nothing more to do. + } +} + +describe("LLM inference callback — full-duplex WebSocket transport", async () => { + const received: LlmInferenceRequest[] = []; + let wsRequestCount = 0; + + const { copilotClient: client, env } = await createSdkTestContext({ + copilotClientOptions: { + llmInference: { + handler: new (class extends LlmRequestHandler { + override async onLlmRequest(req: LlmInferenceRequest): Promise { + received.push(req); + if (req.transport === "websocket") { + await handleWebSocket(req, () => { + wsRequestCount++; + }); + return; + } + const url = req.url.toLowerCase(); + const isInference = + url.includes("/chat/completions") || + url.endsWith("/responses") || + url.endsWith("/v1/messages") || + url.endsWith("/messages"); + if (isInference) { + await handleHttpInference(req); + } else { + await handleNonInferenceModelTraffic(req); + } + } + })(), + }, + }, + }); + + // Enable the WebSocket Responses transport in the spawned runtime. The + // harness env object is the same one passed to the CLI subprocess, so + // mutating it here flips the ExP flag for this test file's client. + env.COPILOT_EXP_COPILOT_CLI_WEBSOCKET_RESPONSES = "true"; + + it( + "completes a user→assistant turn over the WebSocket transport via the callback", + async () => { + await client.start(); + const session = await client.createSession({ onPermissionRequest: approveAll }); + let resultJson = ""; + try { + const result = await session.sendAndWait({ prompt: "Say OK." }); + resultJson = JSON.stringify(result); + } finally { + await session.disconnect(); + } + + // The main agent turn (tools present, not single-shot) selected the + // WebSocket transport and drove it through the callback. + const wsReqs = received.filter((r) => r.transport === "websocket"); + expect(wsReqs.length, "expected at least one websocket request via the callback").toBeGreaterThan(0); + expect(wsRequestCount, "expected the runtime to send at least one ws message").toBeGreaterThan(0); + + // The synthetic content surfaced in the assistant response. + expect(resultJson).toMatch(/OK from the synthetic ws/); + }, + 90_000, + ); +}); diff --git a/nodejs/test/llm_inference_callbacks.test.ts b/nodejs/test/llm_inference_callbacks.test.ts new file mode 100644 index 000000000..061082ca6 --- /dev/null +++ b/nodejs/test/llm_inference_callbacks.test.ts @@ -0,0 +1,294 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +import { describe, expect, it } from "vitest"; +import { + CopilotWebSocketHandler, + LlmRequestHandler, + type LlmInferenceRequest, + type LlmInferenceResponseInit, + type LlmInferenceResponseSink, + type LlmRequestContext, + LlmWebSocketCloseStatus, +} from "../src/index.js"; +import { + createLlmInferenceAdapter, + type LlmInferenceProvider, +} from "../src/llmInferenceProvider.js"; + +/** + * Minimal fake of the server RPC surface the adapter uses to send response + * frames back to the runtime. Records every frame and lets the test decide + * what `accepted` value the runtime returns. + */ +function makeFakeServerRpc(accepted: { start?: boolean; chunk?: boolean } = {}): { + rpc: () => ReturnType; + starts: LlmInferenceResponseInit[]; + chunks: { data: string; binary?: boolean; end?: boolean; error?: unknown }[]; +} { + const starts: LlmInferenceResponseInit[] = []; + const chunks: { data: string; binary?: boolean; end?: boolean; error?: unknown }[] = []; + function createFakeRpc() { + return { + llmInference: { + async httpResponseStart(p: { + status: number; + statusText?: string; + headers: Record; + }) { + starts.push({ status: p.status, statusText: p.statusText, headers: p.headers }); + return { accepted: accepted.start ?? true }; + }, + async httpResponseChunk(p: { + data: string; + binary?: boolean; + end?: boolean; + error?: unknown; + }) { + chunks.push({ data: p.data, binary: p.binary, end: p.end, error: p.error }); + return { accepted: accepted.chunk ?? true }; + }, + }, + }; + } + const single = createFakeRpc(); + return { rpc: () => single, starts, chunks }; +} + +describe("createLlmInferenceAdapter", () => { + it("stages body chunks that arrive before their start frame and replays them in order", async () => { + const received: string[] = []; + let resolveDone: () => void; + const done = new Promise((r) => { + resolveDone = r; + }); + const provider: LlmInferenceProvider = { + async onLlmRequest(req: LlmInferenceRequest) { + const decoder = new TextDecoder(); + for await (const chunk of req.requestBody) { + received.push(decoder.decode(chunk)); + } + await req.responseBody.start({ status: 200, headers: {} }); + await req.responseBody.end(); + resolveDone(); + }, + }; + const fake = makeFakeServerRpc(); + const handler = createLlmInferenceAdapter(provider, () => fake.rpc() as never); + + // Chunks arrive BEFORE the start frame (simulating a reordering the + // runtime should never actually produce). They must be staged and + // delivered once the start frame registers the request. + await handler.httpRequestChunk({ + requestId: "r1", + data: "hello ", + binary: false, + end: false, + }); + await handler.httpRequestChunk({ + requestId: "r1", + data: "world", + binary: false, + end: false, + }); + await handler.httpRequestChunk({ requestId: "r1", data: "", end: true }); + + await handler.httpRequestStart({ + requestId: "r1", + method: "POST", + url: "https://example.test/v1/chat", + headers: {}, + transport: "http", + }); + + await done; + expect(received.join("")).toBe("hello world"); + }); + + it("aborts the provider when the runtime rejects a response frame (accepted=false)", async () => { + let aborted = false; + let writeThrew = false; + let finished: () => void; + const settled = new Promise((r) => { + finished = r; + }); + const provider: LlmInferenceProvider = { + async onLlmRequest(req: LlmInferenceRequest) { + req.signal.addEventListener("abort", () => { + aborted = true; + }); + for await (const _ of req.requestBody) { + // drain + } + await req.responseBody.start({ status: 200, headers: {} }); + try { + await req.responseBody.write("rejected-chunk"); + } catch { + writeThrew = true; + } + finished(); + }, + }; + const fake = makeFakeServerRpc({ start: true, chunk: false }); + const handler = createLlmInferenceAdapter(provider, () => fake.rpc() as never); + + await handler.httpRequestStart({ + requestId: "r2", + method: "POST", + url: "https://example.test/v1/chat", + headers: {}, + transport: "http", + }); + await handler.httpRequestChunk({ requestId: "r2", data: "", end: true }); + + await settled; + expect(writeThrew).toBe(true); + expect(aborted).toBe(true); + }); +}); + +/** + * Controllable fake of a callback-owned WebSocket connection. The test drives + * messages, close, and error explicitly. + */ +class FakeSocketHandler extends CopilotWebSocketHandler { + sent: (string | Uint8Array)[] = []; + + override sendRequestMessage(data: string | Uint8Array): void { + this.sent.push(data); + } + + async emitMessage(data: string | Uint8Array): Promise { + await this.sendResponseMessage(data); + } + + async closeFromUpstream(): Promise { + await this.close(); + } + + async failFromUpstream(error: Error): Promise { + await this.close(new LlmWebSocketCloseStatus(error.message, undefined, error)); + } +} + +interface RecordingSink extends LlmInferenceResponseSink { + starts: LlmInferenceResponseInit[]; + writes: (string | Uint8Array)[]; + ended: boolean; + errored?: { message: string; code?: string }; +} + +function makeRecordingSink(): RecordingSink { + const sink: RecordingSink = { + starts: [], + writes: [], + ended: false, + async start(init) { + sink.starts.push(init); + }, + async write(data) { + sink.writes.push(data); + }, + async end() { + sink.ended = true; + }, + async error(err) { + sink.errored = err; + }, + }; + return sink; +} + +/** Async-iterable request body that yields nothing until the test releases it. */ +function gatedRequestBody(): { body: AsyncIterable; release: () => void } { + let release!: () => void; + const gate = new Promise((r) => { + release = r; + }); + return { + release, + body: { + async *[Symbol.asyncIterator]() { + await gate; + }, + }, + }; +} + +describe("LlmRequestHandler WebSocket dispatch", () => { + it("finalises the response when the upstream closes while the request stream is still open", async () => { + let upstream!: FakeSocketHandler; + class Handler extends LlmRequestHandler { + protected override openWebSocket(ctx: LlmRequestContext): CopilotWebSocketHandler { + upstream = new FakeSocketHandler(ctx); + return upstream; + } + } + const handler = new Handler(); + const sink = makeRecordingSink(); + const gated = gatedRequestBody(); + const abort = new AbortController(); + const req: LlmInferenceRequest = { + requestId: "ws1", + method: "GET", + url: "wss://example.test/responses", + headers: {}, + transport: "websocket", + requestBody: gated.body, + signal: abort.signal, + responseBody: sink, + }; + + const turn = handler.onLlmRequest(req); + + // Let the handler register its listeners and ack the upgrade, then + // deliver an upstream message and close the socket — all while the + // request body is still parked (no runtime → upstream frames yet). + await new Promise((r) => setTimeout(r, 10)); + await upstream.emitMessage("server-event-1"); + await upstream.closeFromUpstream(); + + // The turn must resolve (not hang) because the upstream terminated. + await turn; + + expect(sink.starts).toEqual([{ status: 101, headers: {} }]); + expect(sink.writes).toContain("server-event-1"); + expect(sink.ended).toBe(true); + + gated.release(); + }); + + it("surfaces an upstream error as a thrown failure", async () => { + let upstream!: FakeSocketHandler; + class Handler extends LlmRequestHandler { + protected override openWebSocket(ctx: LlmRequestContext): CopilotWebSocketHandler { + upstream = new FakeSocketHandler(ctx); + return upstream; + } + } + const handler = new Handler(); + const sink = makeRecordingSink(); + const gated = gatedRequestBody(); + const abort = new AbortController(); + const req: LlmInferenceRequest = { + requestId: "ws2", + method: "GET", + url: "wss://example.test/responses", + headers: {}, + transport: "websocket", + requestBody: gated.body, + signal: abort.signal, + responseBody: sink, + }; + + const turn = handler.onLlmRequest(req); + await new Promise((r) => setTimeout(r, 10)); + await upstream.failFromUpstream(new Error("upstream exploded")); + + await expect(turn).rejects.toThrow("upstream exploded"); + expect(sink.ended).toBe(false); + + gated.release(); + }); +}); diff --git a/python/copilot/generated/rpc.py b/python/copilot/generated/rpc.py index ec00eefb5..304cf683c 100644 --- a/python/copilot/generated/rpc.py +++ b/python/copilot/generated/rpc.py @@ -1788,6 +1788,214 @@ def to_dict(self) -> dict: result["projectPaths"] = from_union([lambda x: from_list(from_str, x), from_none], self.project_paths) return result +@dataclass +class LlmInferenceHTTPRequestError: + """Set when the SDK client could not produce a response (transport-level failure). Causes + the runtime to raise an APIConnectionError; status/headers/body are ignored when error is + set. + """ + message: str + """Human-readable failure description.""" + + code: str | None = None + """Optional machine-readable error code.""" + + @staticmethod + def from_dict(obj: Any) -> 'LlmInferenceHTTPRequestError': + assert isinstance(obj, dict) + message = from_str(obj.get("message")) + code = from_union([from_str, from_none], obj.get("code")) + return LlmInferenceHTTPRequestError(message, code) + + def to_dict(self) -> dict: + result: dict = {} + result["message"] = from_str(self.message) + if self.code is not None: + result["code"] = from_union([from_str, from_none], self.code) + return result + +class LlmInferenceRequestMetadataEndpointKind(Enum): + """What kind of model-layer endpoint this is.""" + + EMBEDDINGS = "embeddings" + INFERENCE = "inference" + MODELS_CATALOG = "models-catalog" + MODELS_POLICY = "models-policy" + MODELS_SESSION = "models-session" + OTHER = "other" + +class LlmInferenceRequestMetadataProviderType(Enum): + """Logical model provider this request targets.""" + + ANTHROPIC = "anthropic" + AZURE = "azure" + COPILOT = "copilot" + GOOGLE = "google" + OPENAI = "openai" + OTHER = "other" + +class LlmInferenceRequestMetadataTransport(Enum): + """Transport kind. v1 implements http only.""" + + HTTP = "http" + WEBSOCKET = "websocket" + +class LlmInferenceRequestMetadataWireAPI(Enum): + """Wire API shape, when this is an inference request.""" + + COMPLETIONS = "completions" + MESSAGES = "messages" + RESPONSES = "responses" + +@dataclass +class LlmInferenceHTTPStreamStartError: + """Set when the SDK client could not even begin the stream (transport-level failure). When + error is set the runtime raises an APIConnectionError and ignores status/headers. + """ + message: str + code: str | None = None + + @staticmethod + def from_dict(obj: Any) -> 'LlmInferenceHTTPStreamStartError': + assert isinstance(obj, dict) + message = from_str(obj.get("message")) + code = from_union([from_str, from_none], obj.get("code")) + return LlmInferenceHTTPStreamStartError(message, code) + + def to_dict(self) -> dict: + result: dict = {} + result["message"] = from_str(self.message) + if self.code is not None: + result["code"] = from_union([from_str, from_none], self.code) + return result + +# Experimental: this type is part of an experimental API and may change or be removed. +@dataclass +class LlmInferenceSetProviderRequest: + """No parameters. The calling connection is registered as the runtime's LLM inference + provider; all subsequent model-layer HTTP requests are dispatched back to it via the + llmInference client API. + """ + @staticmethod + def from_dict(obj: Any) -> 'LlmInferenceSetProviderRequest': + assert isinstance(obj, dict) + return LlmInferenceSetProviderRequest() + + def to_dict(self) -> dict: + result: dict = {} + return result + +# Experimental: this type is part of an experimental API and may change or be removed. +@dataclass +class LlmInferenceSetProviderResult: + """Indicates whether the calling client was registered as the LLM inference provider.""" + + success: bool + """Whether the provider was set successfully""" + + @staticmethod + def from_dict(obj: Any) -> 'LlmInferenceSetProviderResult': + assert isinstance(obj, dict) + success = from_bool(obj.get("success")) + return LlmInferenceSetProviderResult(success) + + def to_dict(self) -> dict: + result: dict = {} + result["success"] = from_bool(self.success) + return result + +# Experimental: this type is part of an experimental API and may change or be removed. +@dataclass +class LlmInferenceStreamChunkRequest: + """A streamed response body chunk.""" + + data_base64: str + """One body chunk as base64-encoded bytes. Chunks are appended to the runtime's view of the + response body in the order received. + """ + stream_token: int + """The same streamToken the runtime supplied in the originating llmInference.httpStreamStart + call. + """ + + @staticmethod + def from_dict(obj: Any) -> 'LlmInferenceStreamChunkRequest': + assert isinstance(obj, dict) + data_base64 = from_str(obj.get("dataBase64")) + stream_token = from_int(obj.get("streamToken")) + return LlmInferenceStreamChunkRequest(data_base64, stream_token) + + def to_dict(self) -> dict: + result: dict = {} + result["dataBase64"] = from_str(self.data_base64) + result["streamToken"] = from_int(self.stream_token) + return result + +# Experimental: this type is part of an experimental API and may change or be removed. +@dataclass +class LlmInferenceStreamChunkResult: + """Whether the chunk was accepted.""" + + accepted: bool + """True when the chunk was queued for the stream; false when the stream is unknown.""" + + @staticmethod + def from_dict(obj: Any) -> 'LlmInferenceStreamChunkResult': + assert isinstance(obj, dict) + accepted = from_bool(obj.get("accepted")) + return LlmInferenceStreamChunkResult(accepted) + + def to_dict(self) -> dict: + result: dict = {} + result["accepted"] = from_bool(self.accepted) + return result + +# Experimental: this type is part of an experimental API and may change or be removed. +@dataclass +class LlmInferenceStreamEndRequest: + """End-of-stream signal.""" + + stream_token: int + """The originating streamToken.""" + + error: str | None = None + """When set, marks the stream as ending with a transport-level error of this description. + When absent the stream ends normally. + """ + + @staticmethod + def from_dict(obj: Any) -> 'LlmInferenceStreamEndRequest': + assert isinstance(obj, dict) + stream_token = from_int(obj.get("streamToken")) + error = from_union([from_str, from_none], obj.get("error")) + return LlmInferenceStreamEndRequest(stream_token, error) + + def to_dict(self) -> dict: + result: dict = {} + result["streamToken"] = from_int(self.stream_token) + if self.error is not None: + result["error"] = from_union([from_str, from_none], self.error) + return result + +# Experimental: this type is part of an experimental API and may change or be removed. +@dataclass +class LlmInferenceStreamEndResult: + """Whether the end signal was accepted.""" + + accepted: bool + """True when the stream was found and ended; false when unknown.""" + + @staticmethod + def from_dict(obj: Any) -> 'LlmInferenceStreamEndResult': + assert isinstance(obj, dict) + accepted = from_bool(obj.get("accepted")) + return LlmInferenceStreamEndResult(accepted) + + def to_dict(self) -> dict: + result: dict = {} + result["accepted"] = from_bool(self.accepted) + return result + # Experimental: this type is part of an experimental API and may change or be removed. class HostType(Enum): """Repository host type @@ -2273,6 +2481,13 @@ def to_dict(self) -> dict: result["redirectPort"] = from_union([from_int, from_none], self.redirect_port) return result +class MCPServerConfigDeferTools(Enum): + """Controls if tools provided by this server can be loaded on demand via tool search (auto) + or always included in the initial tool list (never) + """ + AUTO = "auto" + NEVER = "never" + class MCPServerConfigHTTPOauthGrantType(Enum): """OAuth grant type to use when authenticating to the remote MCP server.""" @@ -9848,6 +10063,94 @@ def to_dict(self) -> dict: result["projectPath"] = from_union([from_str, from_none], self.project_path) return result +@dataclass +class LlmInferenceHTTPRequestResult: + """The HTTP response the runtime should treat as if it had issued the request itself.""" + + headers: dict[str, list[str]] + """HTTP headers as a map from lowercased header name to a list of values. Multi-valued + headers (e.g. Set-Cookie) preserve all values. + """ + status: int + """HTTP status code returned to the runtime.""" + + body_base64: str | None = None + """Response body as base64-encoded bytes. Set instead of bodyText for binary responses.""" + + body_text: str | None = None + """Response body as a UTF-8 string. Set when bodyBase64 is absent.""" + + error: LlmInferenceHTTPRequestError | None = None + """Set when the SDK client could not produce a response (transport-level failure). Causes + the runtime to raise an APIConnectionError; status/headers/body are ignored when error is + set. + """ + status_text: str | None = None + """Optional HTTP status text.""" + + @staticmethod + def from_dict(obj: Any) -> 'LlmInferenceHTTPRequestResult': + assert isinstance(obj, dict) + headers = from_dict(lambda x: from_list(from_str, x), obj.get("headers")) + status = from_int(obj.get("status")) + body_base64 = from_union([from_str, from_none], obj.get("bodyBase64")) + body_text = from_union([from_str, from_none], obj.get("bodyText")) + error = from_union([LlmInferenceHTTPRequestError.from_dict, from_none], obj.get("error")) + status_text = from_union([from_str, from_none], obj.get("statusText")) + return LlmInferenceHTTPRequestResult(headers, status, body_base64, body_text, error, status_text) + + def to_dict(self) -> dict: + result: dict = {} + result["headers"] = from_dict(lambda x: from_list(from_str, x), self.headers) + result["status"] = from_int(self.status) + if self.body_base64 is not None: + result["bodyBase64"] = from_union([from_str, from_none], self.body_base64) + if self.body_text is not None: + result["bodyText"] = from_union([from_str, from_none], self.body_text) + if self.error is not None: + result["error"] = from_union([lambda x: to_class(LlmInferenceHTTPRequestError, x), from_none], self.error) + if self.status_text is not None: + result["statusText"] = from_union([from_str, from_none], self.status_text) + return result + +@dataclass +class LlmInferenceHTTPStreamStartResult: + """The response head. After returning, the SDK client pushes body chunks via + llmInference.streamChunk and signals completion (or transport error) via + llmInference.streamEnd. + """ + headers: dict[str, list[str]] + """HTTP headers as a map from lowercased header name to a list of values. Multi-valued + headers (e.g. Set-Cookie) preserve all values. + """ + status: int + """HTTP status code.""" + + error: LlmInferenceHTTPStreamStartError | None = None + """Set when the SDK client could not even begin the stream (transport-level failure). When + error is set the runtime raises an APIConnectionError and ignores status/headers. + """ + status_text: str | None = None + + @staticmethod + def from_dict(obj: Any) -> 'LlmInferenceHTTPStreamStartResult': + assert isinstance(obj, dict) + headers = from_dict(lambda x: from_list(from_str, x), obj.get("headers")) + status = from_int(obj.get("status")) + error = from_union([LlmInferenceHTTPStreamStartError.from_dict, from_none], obj.get("error")) + status_text = from_union([from_str, from_none], obj.get("statusText")) + return LlmInferenceHTTPStreamStartResult(headers, status, error, status_text) + + def to_dict(self) -> dict: + result: dict = {} + result["headers"] = from_dict(lambda x: from_list(from_str, x), self.headers) + result["status"] = from_int(self.status) + if self.error is not None: + result["error"] = from_union([lambda x: to_class(LlmInferenceHTTPStreamStartError, x), from_none], self.error) + if self.status_text is not None: + result["statusText"] = from_union([from_str, from_none], self.status_text) + return result + # Experimental: this type is part of an experimental API and may change or be removed. @dataclass class SessionContext: @@ -10303,6 +10606,10 @@ class MCPServerConfigStdio: cwd: str | None = None """Working directory for the Stdio MCP server process.""" + defer_tools: MCPServerConfigDeferTools | None = None + """Controls if tools provided by this server can be loaded on demand via tool search (auto) + or always included in the initial tool list (never) + """ env: dict[str, str] | None = None """Environment variables to pass to the Stdio MCP server process.""" @@ -10330,13 +10637,14 @@ def from_dict(obj: Any) -> 'MCPServerConfigStdio': args = from_union([lambda x: from_list(from_str, x), from_none], obj.get("args")) auth = from_union([from_bool, MCPServerAuthConfigRedirectPort.from_dict, from_none], obj.get("auth")) cwd = from_union([from_str, from_none], obj.get("cwd")) + defer_tools = from_union([MCPServerConfigDeferTools, from_none], obj.get("deferTools")) env = from_union([lambda x: from_dict(from_str, x), from_none], obj.get("env")) filter_mapping = from_union([lambda x: from_dict(ContentFilterMode, x), ContentFilterMode, from_none], obj.get("filterMapping")) is_default_server = from_union([from_bool, from_none], obj.get("isDefaultServer")) oidc = from_union([from_bool, MCPServerAuthConfigRedirectPort.from_dict, from_none], obj.get("oidc")) timeout = from_union([from_int, from_none], obj.get("timeout")) tools = from_union([lambda x: from_list(from_str, x), from_none], obj.get("tools")) - return MCPServerConfigStdio(command, args, auth, cwd, env, filter_mapping, is_default_server, oidc, timeout, tools) + return MCPServerConfigStdio(command, args, auth, cwd, defer_tools, env, filter_mapping, is_default_server, oidc, timeout, tools) def to_dict(self) -> dict: result: dict = {} @@ -10347,6 +10655,8 @@ def to_dict(self) -> dict: result["auth"] = from_union([from_bool, lambda x: to_class(MCPServerAuthConfigRedirectPort, x), from_none], self.auth) if self.cwd is not None: result["cwd"] = from_union([from_str, from_none], self.cwd) + if self.defer_tools is not None: + result["deferTools"] = from_union([lambda x: to_enum(MCPServerConfigDeferTools, x), from_none], self.defer_tools) if self.env is not None: result["env"] = from_union([lambda x: from_dict(from_str, x), from_none], self.env) if self.filter_mapping is not None: @@ -10381,6 +10691,10 @@ class MCPServerConfig: cwd: str | None = None """Working directory for the Stdio MCP server process.""" + defer_tools: MCPServerConfigDeferTools | None = None + """Controls if tools provided by this server can be loaded on demand via tool search (auto) + or always included in the initial tool list (never) + """ env: dict[str, str] | None = None """Environment variables to pass to the Stdio MCP server process.""" @@ -10426,6 +10740,7 @@ def from_dict(obj: Any) -> 'MCPServerConfig': auth = from_union([from_bool, MCPServerAuthConfigRedirectPort.from_dict, from_none], obj.get("auth")) command = from_union([from_str, from_none], obj.get("command")) cwd = from_union([from_str, from_none], obj.get("cwd")) + defer_tools = from_union([MCPServerConfigDeferTools, from_none], obj.get("deferTools")) env = from_union([lambda x: from_dict(from_str, x), from_none], obj.get("env")) filter_mapping = from_union([lambda x: from_dict(ContentFilterMode, x), ContentFilterMode, from_none], obj.get("filterMapping")) is_default_server = from_union([from_bool, from_none], obj.get("isDefaultServer")) @@ -10438,7 +10753,7 @@ def from_dict(obj: Any) -> 'MCPServerConfig': oauth_public_client = from_union([from_bool, from_none], obj.get("oauthPublicClient")) type = from_union([MCPServerConfigHTTPType, from_none], obj.get("type")) url = from_union([from_str, from_none], obj.get("url")) - return MCPServerConfig(args, auth, command, cwd, env, filter_mapping, is_default_server, oidc, timeout, tools, headers, oauth_client_id, oauth_grant_type, oauth_public_client, type, url) + return MCPServerConfig(args, auth, command, cwd, defer_tools, env, filter_mapping, is_default_server, oidc, timeout, tools, headers, oauth_client_id, oauth_grant_type, oauth_public_client, type, url) def to_dict(self) -> dict: result: dict = {} @@ -10450,6 +10765,8 @@ def to_dict(self) -> dict: result["command"] = from_union([from_str, from_none], self.command) if self.cwd is not None: result["cwd"] = from_union([from_str, from_none], self.cwd) + if self.defer_tools is not None: + result["deferTools"] = from_union([lambda x: to_enum(MCPServerConfigDeferTools, x), from_none], self.defer_tools) if self.env is not None: result["env"] = from_union([lambda x: from_dict(from_str, x), from_none], self.env) if self.filter_mapping is not None: @@ -10486,6 +10803,10 @@ class MCPServerConfigHTTP: auth: bool | MCPServerAuthConfigRedirectPort | None = None """Set to `true` to use defaults, or provide an object with additional auth or OIDC settings.""" + defer_tools: MCPServerConfigDeferTools | None = None + """Controls if tools provided by this server can be loaded on demand via tool search (auto) + or always included in the initial tool list (never) + """ filter_mapping: dict[str, ContentFilterMode] | ContentFilterMode | None = None """Content filtering mode to apply to all tools, or a map of tool name to content filtering mode. @@ -10523,6 +10844,7 @@ def from_dict(obj: Any) -> 'MCPServerConfigHTTP': assert isinstance(obj, dict) url = from_str(obj.get("url")) auth = from_union([from_bool, MCPServerAuthConfigRedirectPort.from_dict, from_none], obj.get("auth")) + defer_tools = from_union([MCPServerConfigDeferTools, from_none], obj.get("deferTools")) filter_mapping = from_union([lambda x: from_dict(ContentFilterMode, x), ContentFilterMode, from_none], obj.get("filterMapping")) headers = from_union([lambda x: from_dict(from_str, x), from_none], obj.get("headers")) is_default_server = from_union([from_bool, from_none], obj.get("isDefaultServer")) @@ -10533,13 +10855,15 @@ def from_dict(obj: Any) -> 'MCPServerConfigHTTP': timeout = from_union([from_int, from_none], obj.get("timeout")) tools = from_union([lambda x: from_list(from_str, x), from_none], obj.get("tools")) type = from_union([MCPServerConfigHTTPType, from_none], obj.get("type")) - return MCPServerConfigHTTP(url, auth, filter_mapping, headers, is_default_server, oauth_client_id, oauth_grant_type, oauth_public_client, oidc, timeout, tools, type) + return MCPServerConfigHTTP(url, auth, defer_tools, filter_mapping, headers, is_default_server, oauth_client_id, oauth_grant_type, oauth_public_client, oidc, timeout, tools, type) def to_dict(self) -> dict: result: dict = {} result["url"] = from_str(self.url) if self.auth is not None: result["auth"] = from_union([from_bool, lambda x: to_class(MCPServerAuthConfigRedirectPort, x), from_none], self.auth) + if self.defer_tools is not None: + result["deferTools"] = from_union([lambda x: to_enum(MCPServerConfigDeferTools, x), from_none], self.defer_tools) if self.filter_mapping is not None: result["filterMapping"] = from_union([lambda x: from_dict(lambda x: to_enum(ContentFilterMode, x), x), lambda x: to_enum(ContentFilterMode, x), from_none], self.filter_mapping) if self.headers is not None: @@ -18940,6 +19264,166 @@ def to_dict(self) -> dict: result["namespacedName"] = from_union([from_str, from_none], self.namespaced_name) return result +@dataclass +class LlmInferenceRequestMetadata: + """Metadata describing an intercepted LLM HTTP request.""" + + endpoint_kind: LlmInferenceRequestMetadataEndpointKind + """What kind of model-layer endpoint this is.""" + + provider_type: LlmInferenceRequestMetadataProviderType + """Logical model provider this request targets.""" + + transport: LlmInferenceRequestMetadataTransport + """Transport kind. v1 implements http only.""" + + model_id: str | None = None + """Model identifier, when known.""" + + wire_api: LlmInferenceRequestMetadataWireAPI | None = None + """Wire API shape, when this is an inference request.""" + + @staticmethod + def from_dict(obj: Any) -> 'LlmInferenceRequestMetadata': + assert isinstance(obj, dict) + endpoint_kind = LlmInferenceRequestMetadataEndpointKind(obj.get("endpointKind")) + provider_type = LlmInferenceRequestMetadataProviderType(obj.get("providerType")) + transport = LlmInferenceRequestMetadataTransport(obj.get("transport")) + model_id = from_union([from_str, from_none], obj.get("modelId")) + wire_api = from_union([LlmInferenceRequestMetadataWireAPI, from_none], obj.get("wireApi")) + return LlmInferenceRequestMetadata(endpoint_kind, provider_type, transport, model_id, wire_api) + + def to_dict(self) -> dict: + result: dict = {} + result["endpointKind"] = to_enum(LlmInferenceRequestMetadataEndpointKind, self.endpoint_kind) + result["providerType"] = to_enum(LlmInferenceRequestMetadataProviderType, self.provider_type) + result["transport"] = to_enum(LlmInferenceRequestMetadataTransport, self.transport) + if self.model_id is not None: + result["modelId"] = from_union([from_str, from_none], self.model_id) + if self.wire_api is not None: + result["wireApi"] = from_union([lambda x: to_enum(LlmInferenceRequestMetadataWireAPI, x), from_none], self.wire_api) + return result + +@dataclass +class LlmInferenceHTTPRequestRequest: + """An outbound model-layer HTTP request the runtime would otherwise have issued itself.""" + + headers: dict[str, list[str]] + """HTTP headers as a map from lowercased header name to a list of values. Multi-valued + headers (e.g. Set-Cookie) preserve all values. + """ + metadata: LlmInferenceRequestMetadata + """Metadata describing an intercepted LLM HTTP request.""" + + method: str + """HTTP method, e.g. GET, POST.""" + + request_id: str + """Opaque runtime-minted id, unique per request. Useful for client-side logging.""" + + url: str + """Absolute request URL.""" + + body_base64: str | None = None + """Request body as base64-encoded bytes. Set instead of bodyText when the body is binary.""" + + body_text: str | None = None + """Request body as a UTF-8 string. Set when binaryBody is absent or false.""" + + session_id: str | None = None + """Id of the runtime session that triggered this request, when one is in scope. Absent for + requests issued outside any session (e.g. startup model-catalog or capability + resolution). This is a payload field — not a dispatch key — because the client-global API + is registered process-wide rather than per session. + """ + + @staticmethod + def from_dict(obj: Any) -> 'LlmInferenceHTTPRequestRequest': + assert isinstance(obj, dict) + headers = from_dict(lambda x: from_list(from_str, x), obj.get("headers")) + metadata = LlmInferenceRequestMetadata.from_dict(obj.get("metadata")) + method = from_str(obj.get("method")) + request_id = from_str(obj.get("requestId")) + url = from_str(obj.get("url")) + body_base64 = from_union([from_str, from_none], obj.get("bodyBase64")) + body_text = from_union([from_str, from_none], obj.get("bodyText")) + session_id = from_union([from_str, from_none], obj.get("sessionId")) + return LlmInferenceHTTPRequestRequest(headers, metadata, method, request_id, url, body_base64, body_text, session_id) + + def to_dict(self) -> dict: + result: dict = {} + result["headers"] = from_dict(lambda x: from_list(from_str, x), self.headers) + result["metadata"] = to_class(LlmInferenceRequestMetadata, self.metadata) + result["method"] = from_str(self.method) + result["requestId"] = from_str(self.request_id) + result["url"] = from_str(self.url) + if self.body_base64 is not None: + result["bodyBase64"] = from_union([from_str, from_none], self.body_base64) + if self.body_text is not None: + result["bodyText"] = from_union([from_str, from_none], self.body_text) + if self.session_id is not None: + result["sessionId"] = from_union([from_str, from_none], self.session_id) + return result + +@dataclass +class LlmInferenceHTTPStreamStartRequest: + """An outbound streaming model-layer HTTP request.""" + + headers: dict[str, list[str]] + """HTTP headers as a map from lowercased header name to a list of values. Multi-valued + headers (e.g. Set-Cookie) preserve all values. + """ + metadata: LlmInferenceRequestMetadata + """Metadata describing an intercepted LLM HTTP request.""" + + method: str + """HTTP method.""" + + request_id: str + """Opaque runtime-minted id, unique per request.""" + + stream_token: int + """Stream identifier. The SDK client passes this exact value back on every + llmInference.streamChunk / streamEnd call to correlate pushed chunks with this request. + """ + url: str + """Absolute request URL.""" + + body_base64: str | None = None + body_text: str | None = None + session_id: str | None = None + """Originating session id, when known.""" + + @staticmethod + def from_dict(obj: Any) -> 'LlmInferenceHTTPStreamStartRequest': + assert isinstance(obj, dict) + headers = from_dict(lambda x: from_list(from_str, x), obj.get("headers")) + metadata = LlmInferenceRequestMetadata.from_dict(obj.get("metadata")) + method = from_str(obj.get("method")) + request_id = from_str(obj.get("requestId")) + stream_token = from_int(obj.get("streamToken")) + url = from_str(obj.get("url")) + body_base64 = from_union([from_str, from_none], obj.get("bodyBase64")) + body_text = from_union([from_str, from_none], obj.get("bodyText")) + session_id = from_union([from_str, from_none], obj.get("sessionId")) + return LlmInferenceHTTPStreamStartRequest(headers, metadata, method, request_id, stream_token, url, body_base64, body_text, session_id) + + def to_dict(self) -> dict: + result: dict = {} + result["headers"] = from_dict(lambda x: from_list(from_str, x), self.headers) + result["metadata"] = to_class(LlmInferenceRequestMetadata, self.metadata) + result["method"] = from_str(self.method) + result["requestId"] = from_str(self.request_id) + result["streamToken"] = from_int(self.stream_token) + result["url"] = from_str(self.url) + if self.body_base64 is not None: + result["bodyBase64"] = from_union([from_str, from_none], self.body_base64) + if self.body_text is not None: + result["bodyText"] = from_union([from_str, from_none], self.body_text) + if self.session_id is not None: + result["sessionId"] = from_union([from_str, from_none], self.session_id) + return result + # Experimental: this type is part of an experimental API and may change or be removed. @dataclass class MCPExecuteSamplingParams: @@ -19738,6 +20222,24 @@ class RPC: instruction_source: InstructionSource instruction_source_location: InstructionSourceLocation instruction_source_type: InstructionSourceType + llm_inference_headers: dict[str, list[str]] + llm_inference_http_request_error: LlmInferenceHTTPRequestError + llm_inference_http_request_request: LlmInferenceHTTPRequestRequest + llm_inference_http_request_result: LlmInferenceHTTPRequestResult + llm_inference_http_stream_start_error: LlmInferenceHTTPStreamStartError + llm_inference_http_stream_start_request: LlmInferenceHTTPStreamStartRequest + llm_inference_http_stream_start_result: LlmInferenceHTTPStreamStartResult + llm_inference_request_metadata: LlmInferenceRequestMetadata + llm_inference_request_metadata_endpoint_kind: LlmInferenceRequestMetadataEndpointKind + llm_inference_request_metadata_provider_type: LlmInferenceRequestMetadataProviderType + llm_inference_request_metadata_transport: LlmInferenceRequestMetadataTransport + llm_inference_request_metadata_wire_api: LlmInferenceRequestMetadataWireAPI + llm_inference_set_provider_request: LlmInferenceSetProviderRequest + llm_inference_set_provider_result: LlmInferenceSetProviderResult + llm_inference_stream_chunk_request: LlmInferenceStreamChunkRequest + llm_inference_stream_chunk_result: LlmInferenceStreamChunkResult + llm_inference_stream_end_request: LlmInferenceStreamEndRequest + llm_inference_stream_end_result: LlmInferenceStreamEndResult local_session_metadata_value: LocalSessionMetadataValue log_request: LogRequest log_result: LogResult @@ -19810,6 +20312,7 @@ class RPC: mcp_server_auth_config: bool | MCPServerAuthConfigRedirectPort mcp_server_auth_config_redirect_port: MCPServerAuthConfigRedirectPort mcp_server_config: MCPServerConfig + mcp_server_config_defer_tools: MCPServerConfigDeferTools mcp_server_config_http: MCPServerConfigHTTP mcp_server_config_http_oauth_grant_type: MCPServerConfigHTTPOauthGrantType mcp_server_config_http_type: MCPServerConfigHTTPType @@ -20464,6 +20967,24 @@ def from_dict(obj: Any) -> 'RPC': instruction_source = InstructionSource.from_dict(obj.get("InstructionSource")) instruction_source_location = InstructionSourceLocation(obj.get("InstructionSourceLocation")) instruction_source_type = InstructionSourceType(obj.get("InstructionSourceType")) + llm_inference_headers = from_dict(lambda x: from_list(from_str, x), obj.get("LlmInferenceHeaders")) + llm_inference_http_request_error = LlmInferenceHTTPRequestError.from_dict(obj.get("LlmInferenceHttpRequestError")) + llm_inference_http_request_request = LlmInferenceHTTPRequestRequest.from_dict(obj.get("LlmInferenceHttpRequestRequest")) + llm_inference_http_request_result = LlmInferenceHTTPRequestResult.from_dict(obj.get("LlmInferenceHttpRequestResult")) + llm_inference_http_stream_start_error = LlmInferenceHTTPStreamStartError.from_dict(obj.get("LlmInferenceHttpStreamStartError")) + llm_inference_http_stream_start_request = LlmInferenceHTTPStreamStartRequest.from_dict(obj.get("LlmInferenceHttpStreamStartRequest")) + llm_inference_http_stream_start_result = LlmInferenceHTTPStreamStartResult.from_dict(obj.get("LlmInferenceHttpStreamStartResult")) + llm_inference_request_metadata = LlmInferenceRequestMetadata.from_dict(obj.get("LlmInferenceRequestMetadata")) + llm_inference_request_metadata_endpoint_kind = LlmInferenceRequestMetadataEndpointKind(obj.get("LlmInferenceRequestMetadataEndpointKind")) + llm_inference_request_metadata_provider_type = LlmInferenceRequestMetadataProviderType(obj.get("LlmInferenceRequestMetadataProviderType")) + llm_inference_request_metadata_transport = LlmInferenceRequestMetadataTransport(obj.get("LlmInferenceRequestMetadataTransport")) + llm_inference_request_metadata_wire_api = LlmInferenceRequestMetadataWireAPI(obj.get("LlmInferenceRequestMetadataWireApi")) + llm_inference_set_provider_request = LlmInferenceSetProviderRequest.from_dict(obj.get("LlmInferenceSetProviderRequest")) + llm_inference_set_provider_result = LlmInferenceSetProviderResult.from_dict(obj.get("LlmInferenceSetProviderResult")) + llm_inference_stream_chunk_request = LlmInferenceStreamChunkRequest.from_dict(obj.get("LlmInferenceStreamChunkRequest")) + llm_inference_stream_chunk_result = LlmInferenceStreamChunkResult.from_dict(obj.get("LlmInferenceStreamChunkResult")) + llm_inference_stream_end_request = LlmInferenceStreamEndRequest.from_dict(obj.get("LlmInferenceStreamEndRequest")) + llm_inference_stream_end_result = LlmInferenceStreamEndResult.from_dict(obj.get("LlmInferenceStreamEndResult")) local_session_metadata_value = LocalSessionMetadataValue.from_dict(obj.get("LocalSessionMetadataValue")) log_request = LogRequest.from_dict(obj.get("LogRequest")) log_result = LogResult.from_dict(obj.get("LogResult")) @@ -20536,6 +21057,7 @@ def from_dict(obj: Any) -> 'RPC': mcp_server_auth_config = from_union([from_bool, MCPServerAuthConfigRedirectPort.from_dict], obj.get("McpServerAuthConfig")) mcp_server_auth_config_redirect_port = MCPServerAuthConfigRedirectPort.from_dict(obj.get("McpServerAuthConfigRedirectPort")) mcp_server_config = MCPServerConfig.from_dict(obj.get("McpServerConfig")) + mcp_server_config_defer_tools = MCPServerConfigDeferTools(obj.get("McpServerConfigDeferTools")) mcp_server_config_http = MCPServerConfigHTTP.from_dict(obj.get("McpServerConfigHttp")) mcp_server_config_http_oauth_grant_type = MCPServerConfigHTTPOauthGrantType(obj.get("McpServerConfigHttpOauthGrantType")) mcp_server_config_http_type = MCPServerConfigHTTPType(obj.get("McpServerConfigHttpType")) @@ -21046,7 +21568,7 @@ def from_dict(obj: Any) -> 'RPC': subagent_settings = from_union([SubagentSettings.from_dict, from_none], obj.get("SubagentSettings")) task_progress = from_union([TaskProgress.from_dict, from_none], obj.get("TaskProgress")) workspace_summary = from_union([WorkspaceSummary.from_dict, from_none], obj.get("WorkspaceSummary")) - return RPC(abort_request, abort_result, account_get_quota_request, account_get_quota_result, account_quota_snapshot, agent_get_current_result, agent_info, agent_info_source, agent_list, agent_registry_live_target_entry, agent_registry_live_target_entry_attention_kind, agent_registry_live_target_entry_kind, agent_registry_live_target_entry_last_terminal_event, agent_registry_live_target_entry_status, agent_registry_log_capture, agent_registry_log_capture_open_error_reason, agent_registry_spawn_error, agent_registry_spawn_permission_mode, agent_registry_spawn_registry_timeout, agent_registry_spawn_request, agent_registry_spawn_result, agent_registry_spawn_spawned, agent_registry_spawn_validation_error, agent_registry_spawn_validation_error_field, agent_registry_spawn_validation_error_reason, agent_reload_result, agents_discover_request, agent_select_request, agent_select_result, allow_all_permission_set_result, allow_all_permission_state, api_key_auth_info, auth_info, auth_info_type, cancel_user_requested_shell_command_result, canvas_action, canvas_action_invoke_request, canvas_action_invoke_result, canvas_close_request, canvas_host_context, canvas_host_context_capabilities, canvas_instance_availability, canvas_json_schema, canvas_list, canvas_list_open_result, canvas_open_request, canvas_provider_close_request, canvas_provider_invoke_action_request, canvas_provider_open_request, canvas_provider_open_result, canvas_session_context, command_list, commands_handle_pending_command_request, commands_handle_pending_command_result, commands_invoke_request, commands_list_request, commands_respond_to_queued_command_request, commands_respond_to_queued_command_result, configure_session_extensions_params, connected_remote_session_metadata, connected_remote_session_metadata_kind, connected_remote_session_metadata_repository, connect_remote_session_params, connect_request, connect_result, content_filter_mode, copilot_api_token_auth_info, copilot_user_response, copilot_user_response_endpoints, copilot_user_response_quota_snapshots, copilot_user_response_quota_snapshots_chat, copilot_user_response_quota_snapshots_completions, copilot_user_response_quota_snapshots_premium_interactions, current_model, current_tool_metadata, discovered_canvas, discovered_mcp_server, discovered_mcp_server_type, enqueue_command_params, enqueue_command_result, env_auth_info, event_log_read_request, event_log_release_interest_result, event_log_tail_result, event_log_types, events_agent_scope, events_cursor_status, events_read_result, execute_command_params, execute_command_result, extension, extension_context_push_input, extension_list, extensions_disable_request, extensions_enable_request, extension_source, extension_status, external_tool_result, external_tool_text_result_for_llm, external_tool_text_result_for_llm_binary_results_for_llm, external_tool_text_result_for_llm_binary_results_for_llm_type, external_tool_text_result_for_llm_content, external_tool_text_result_for_llm_content_audio, external_tool_text_result_for_llm_content_image, external_tool_text_result_for_llm_content_resource, external_tool_text_result_for_llm_content_resource_details, external_tool_text_result_for_llm_content_resource_link, external_tool_text_result_for_llm_content_resource_link_icon, external_tool_text_result_for_llm_content_resource_link_icon_theme, external_tool_text_result_for_llm_content_terminal, external_tool_text_result_for_llm_content_text, filter_mapping, fleet_start_request, fleet_start_result, folder_trust_add_params, folder_trust_check_params, folder_trust_check_result, gh_cli_auth_info, handle_pending_tool_call_request, handle_pending_tool_call_result, history_abort_manual_compaction_result, history_cancel_background_compaction_result, history_compact_context_window, history_compact_request, history_compact_result, history_summarize_for_handoff_result, history_truncate_request, history_truncate_result, hmac_auth_info, installed_plugin, installed_plugin_info, installed_plugin_source, installed_plugin_source_git_hub, installed_plugin_source_local, installed_plugin_source_url, instructions_discover_request, instructions_get_sources_result, instruction_source, instruction_source_location, instruction_source_type, local_session_metadata_value, log_request, log_result, lsp_initialize_request, marketplace_add_result, marketplace_browse_result, marketplace_info, marketplace_list_result, marketplace_plugin_info, marketplace_refresh_entry, marketplace_refresh_result, marketplace_remove_result, mcp_allowed_server, mcp_apps_call_tool_request, mcp_apps_diagnose_capability, mcp_apps_diagnose_request, mcp_apps_diagnose_result, mcp_apps_diagnose_server, mcp_apps_host_context, mcp_apps_host_context_details, mcp_apps_host_context_details_available_display_mode, mcp_apps_host_context_details_display_mode, mcp_apps_host_context_details_platform, mcp_apps_host_context_details_theme, mcp_apps_list_tools_request, mcp_apps_list_tools_result, mcp_apps_read_resource_request, mcp_apps_read_resource_result, mcp_apps_resource_content, mcp_apps_set_host_context_details, mcp_apps_set_host_context_details_available_display_mode, mcp_apps_set_host_context_details_display_mode, mcp_apps_set_host_context_details_platform, mcp_apps_set_host_context_details_theme, mcp_apps_set_host_context_request, mcp_cancel_sampling_execution_params, mcp_cancel_sampling_execution_result, mcp_config_add_request, mcp_config_disable_request, mcp_config_enable_request, mcp_config_list, mcp_config_remove_request, mcp_config_update_request, mcp_configure_git_hub_request, mcp_configure_git_hub_result, mcp_disable_request, mcp_discover_request, mcp_discover_result, mcp_enable_request, mcp_execute_sampling_params, mcp_execute_sampling_request, mcp_execute_sampling_result, mcp_filtered_server, mcp_host_state, mcp_is_server_running_request, mcp_is_server_running_result, mcp_list_tools_request, mcp_list_tools_result, mcp_oauth_login_request, mcp_oauth_login_result, mcp_oauth_respond_request, mcp_oauth_respond_result, mcp_register_external_client_request, mcp_reload_with_config_request, mcp_remove_git_hub_result, mcp_restart_server_request, mcp_sampling_execution_action, mcp_sampling_execution_result, mcp_server, mcp_server_auth_config, mcp_server_auth_config_redirect_port, mcp_server_config, mcp_server_config_http, mcp_server_config_http_oauth_grant_type, mcp_server_config_http_type, mcp_server_config_stdio, mcp_server_failure_info, mcp_server_list, mcp_server_needs_auth_info, mcp_set_env_value_mode_details, mcp_set_env_value_mode_params, mcp_set_env_value_mode_result, mcp_start_server_request, mcp_start_servers_result, mcp_stop_server_request, mcp_tools, mcp_unregister_external_client_request, memory_configuration, metadata_context_info_request, metadata_context_info_result, metadata_is_processing_result, metadata_recompute_context_tokens_request, metadata_recompute_context_tokens_result, metadata_record_context_change_request, metadata_record_context_change_result, metadata_set_working_directory_request, metadata_set_working_directory_result, metadata_snapshot_current_mode, metadata_snapshot_remote_metadata, metadata_snapshot_remote_metadata_repository, metadata_snapshot_remote_metadata_task_type, model, model_billing, model_billing_token_prices, model_billing_token_prices_long_context, model_capabilities, model_capabilities_limits, model_capabilities_limits_vision, model_capabilities_override, model_capabilities_override_limits, model_capabilities_override_limits_vision, model_capabilities_override_supports, model_capabilities_supports, model_list, model_list_request, model_picker_category, model_picker_price_category, model_policy, model_policy_state, model_set_reasoning_effort_request, model_set_reasoning_effort_result, models_list_request, model_switch_to_request, model_switch_to_result, mode_set_request, name_get_result, name_set_auto_request, name_set_auto_result, name_set_request, open_canvas_instance, options_update_additional_content_exclusion_policy, options_update_additional_content_exclusion_policy_rule, options_update_additional_content_exclusion_policy_rule_source, options_update_additional_content_exclusion_policy_scope, options_update_context_tier, options_update_env_value_mode, options_update_reasoning_summary, options_update_tool_filter_precedence, pending_permission_request, pending_permission_request_list, permission_decision, permission_decision_approved, permission_decision_approved_for_location, permission_decision_approved_for_session, permission_decision_approve_for_location, permission_decision_approve_for_location_approval, permission_decision_approve_for_location_approval_commands, permission_decision_approve_for_location_approval_custom_tool, permission_decision_approve_for_location_approval_extension_management, permission_decision_approve_for_location_approval_extension_permission_access, permission_decision_approve_for_location_approval_mcp, permission_decision_approve_for_location_approval_mcp_sampling, permission_decision_approve_for_location_approval_memory, permission_decision_approve_for_location_approval_read, permission_decision_approve_for_location_approval_write, permission_decision_approve_for_session, permission_decision_approve_for_session_approval, permission_decision_approve_for_session_approval_commands, permission_decision_approve_for_session_approval_custom_tool, permission_decision_approve_for_session_approval_extension_management, permission_decision_approve_for_session_approval_extension_permission_access, permission_decision_approve_for_session_approval_mcp, permission_decision_approve_for_session_approval_mcp_sampling, permission_decision_approve_for_session_approval_memory, permission_decision_approve_for_session_approval_read, permission_decision_approve_for_session_approval_write, permission_decision_approve_once, permission_decision_approve_permanently, permission_decision_cancelled, permission_decision_denied_by_content_exclusion_policy, permission_decision_denied_by_permission_request_hook, permission_decision_denied_by_rules, permission_decision_denied_interactively_by_user, permission_decision_denied_no_approval_rule_and_could_not_request_from_user, permission_decision_reject, permission_decision_request, permission_decision_user_not_available, permission_location_add_tool_approval_params, permission_location_apply_params, permission_location_apply_result, permission_location_resolve_params, permission_location_resolve_result, permission_location_type, permission_paths_add_params, permission_paths_allowed_check_params, permission_paths_allowed_check_result, permission_paths_config, permission_paths_list, permission_paths_update_primary_params, permission_paths_workspace_check_params, permission_paths_workspace_check_result, permission_prompt_shown_notification, permission_request_result, permission_rules_set, permissions_configure_additional_content_exclusion_policy, permissions_configure_additional_content_exclusion_policy_rule, permissions_configure_additional_content_exclusion_policy_rule_source, permissions_configure_additional_content_exclusion_policy_scope, permissions_configure_params, permissions_configure_result, permissions_folder_trust_add_trusted_result, permissions_get_allow_all_request, permissions_locations_add_tool_approval_details, permissions_locations_add_tool_approval_details_commands, permissions_locations_add_tool_approval_details_custom_tool, permissions_locations_add_tool_approval_details_extension_management, permissions_locations_add_tool_approval_details_extension_permission_access, permissions_locations_add_tool_approval_details_mcp, permissions_locations_add_tool_approval_details_mcp_sampling, permissions_locations_add_tool_approval_details_memory, permissions_locations_add_tool_approval_details_read, permissions_locations_add_tool_approval_details_write, permissions_locations_add_tool_approval_result, permissions_modify_rules_params, permissions_modify_rules_result, permissions_modify_rules_scope, permissions_notify_prompt_shown_result, permissions_paths_add_result, permissions_paths_list_request, permissions_paths_update_primary_result, permissions_pending_requests_request, permissions_reset_session_approvals_request, permissions_reset_session_approvals_result, permissions_set_allow_all_request, permissions_set_allow_all_source, permissions_set_approve_all_request, permissions_set_approve_all_result, permissions_set_approve_all_source, permissions_set_required_request, permissions_set_required_result, permissions_urls_set_unrestricted_mode_result, permission_urls_config, permission_urls_set_unrestricted_mode_params, ping_request, ping_result, plan_read_result, plan_read_sql_todos_result, plan_read_sql_todos_with_dependencies_result, plan_sql_todo_dependency, plan_sql_todos_row, plan_update_request, plugin, plugin_install_result, plugin_list, plugin_list_result, plugins_disable_request, plugins_enable_request, plugins_install_request, plugins_marketplaces_add_request, plugins_marketplaces_browse_request, plugins_marketplaces_refresh_request, plugins_marketplaces_remove_request, plugins_reload_request, plugins_uninstall_request, plugins_update_request, plugin_update_all_entry, plugin_update_all_result, plugin_update_result, poll_spawned_sessions_result, provider_config, provider_config_azure, provider_config_type, provider_config_wire_api, provider_endpoint, provider_endpoint_type, provider_endpoint_wire_api, provider_get_endpoint_request, provider_session_token, push_attachment, push_attachment_blob, push_attachment_directory, push_attachment_file, push_attachment_file_line_range, push_attachment_git_hub_reference, push_attachment_git_hub_reference_type, push_attachment_selection, push_attachment_selection_details, push_attachment_selection_details_end, push_attachment_selection_details_start, queued_command_handled, queued_command_not_handled, queued_command_result, queue_pending_items, queue_pending_items_kind, queue_pending_items_result, queue_remove_most_recent_result, register_event_interest_params, register_event_interest_result, register_extension_tools_params, register_extension_tools_result, release_event_interest_params, remote_control_config, remote_control_config_existing_mc_session, remote_control_status, remote_control_status_active, remote_control_status_connecting, remote_control_status_error, remote_control_status_off, remote_control_status_result, remote_control_stop_result, remote_control_transfer_result, remote_enable_request, remote_enable_result, remote_notify_steerable_changed_request, remote_notify_steerable_changed_result, remote_session_connection_result, remote_session_metadata_repository, remote_session_metadata_task_type, remote_session_metadata_value, remote_session_mode, remote_session_repository, sandbox_config, sandbox_config_user_policy, sandbox_config_user_policy_experimental, sandbox_config_user_policy_experimental_seatbelt, sandbox_config_user_policy_filesystem, sandbox_config_user_policy_network, schedule_entry, schedule_list, schedule_stop_request, schedule_stop_result, secrets_add_filter_values_request, secrets_add_filter_values_result, send_agent_mode, send_attachments_to_message_params, send_mode, send_request, send_result, server_agent_list, server_instruction_source_list, server_skill, server_skill_list, session_activity, session_auth_status, session_bulk_delete_result, session_capability, session_context, session_context_host_type, session_enrich_metadata_result, session_fs_append_file_request, session_fs_error, session_fs_error_code, session_fs_exists_request, session_fs_exists_result, session_fs_mkdir_request, session_fs_readdir_request, session_fs_readdir_result, session_fs_readdir_with_types_entry, session_fs_readdir_with_types_entry_type, session_fs_readdir_with_types_request, session_fs_readdir_with_types_result, session_fs_read_file_request, session_fs_read_file_result, session_fs_rename_request, session_fs_rm_request, session_fs_set_provider_capabilities, session_fs_set_provider_conventions, session_fs_set_provider_request, session_fs_set_provider_result, session_fs_sqlite_exists_request, session_fs_sqlite_exists_result, session_fs_sqlite_query_request, session_fs_sqlite_query_result, session_fs_sqlite_query_type, session_fs_stat_request, session_fs_stat_result, session_fs_write_file_request, session_installed_plugin, session_installed_plugin_source, session_installed_plugin_source_git_hub, session_installed_plugin_source_local, session_installed_plugin_source_url, session_list, session_list_entry, session_list_filter, session_load_deferred_repo_hooks_result, session_log_level, session_mcp_apps_call_tool_result, session_metadata_snapshot, session_mode, session_model_list, session_open_options, session_open_options_additional_content_exclusion_policy, session_open_options_additional_content_exclusion_policy_rule, session_open_options_additional_content_exclusion_policy_rule_source, session_open_options_additional_content_exclusion_policy_scope, session_open_options_env_value_mode, session_open_options_reasoning_summary, session_open_params, session_open_result, session_prune_result, sessions_bulk_delete_request, sessions_check_in_use_request, sessions_check_in_use_result, sessions_close_request, sessions_close_result, sessions_enrich_metadata_request, session_set_credentials_params, session_set_credentials_result, sessions_find_by_prefix_request, sessions_find_by_prefix_result, sessions_find_by_task_id_request, sessions_find_by_task_id_result, sessions_fork_request, sessions_fork_result, sessions_get_board_entry_count_request, sessions_get_board_entry_count_result, sessions_get_event_file_path_request, sessions_get_event_file_path_result, sessions_get_last_for_context_request, sessions_get_last_for_context_result, sessions_get_persisted_remote_steerable_request, sessions_get_persisted_remote_steerable_result, session_sizes, sessions_list_request, sessions_load_deferred_repo_hooks_request, sessions_open_attach, sessions_open_cloud, sessions_open_create, sessions_open_handoff, sessions_open_handoff_task_type, sessions_open_progress, sessions_open_progress_status, sessions_open_progress_step, sessions_open_remote, sessions_open_resume, sessions_open_resume_last, sessions_open_status, session_source, sessions_poll_spawned_sessions_event, sessions_poll_spawned_sessions_request, sessions_prune_old_request, sessions_register_extension_tools_on_session_options, sessions_release_lock_request, sessions_release_lock_result, sessions_reload_plugin_hooks_request, sessions_reload_plugin_hooks_result, sessions_save_request, sessions_save_result, sessions_set_additional_plugins_request, sessions_set_additional_plugins_result, sessions_set_remote_control_steering_request, sessions_start_remote_control_request, sessions_stop_remote_control_request, sessions_transfer_remote_control_request, session_telemetry_engagement, session_update_options_params, session_update_options_result, session_working_directory_context, session_working_directory_context_host_type, shell_cancel_user_requested_request, shell_exec_request, shell_exec_result, shell_execute_user_requested_request, shell_kill_request, shell_kill_result, shell_kill_signal, shutdown_request, skill, skill_list, skills_config_set_disabled_skills_request, skills_disable_request, skills_discover_request, skills_enable_request, skills_get_invoked_result, skills_invoked_skill, skills_load_diagnostics, slash_command_agent_prompt_result, slash_command_completed_result, slash_command_info, slash_command_input, slash_command_input_completion, slash_command_invocation_result, slash_command_kind, slash_command_select_subcommand_option, slash_command_select_subcommand_result, slash_command_text_result, subagent_settings_entry, subagent_settings_entry_context_tier, task_agent_info, task_agent_progress, task_execution_mode, task_info, task_list, task_progress_line, tasks_cancel_request, tasks_cancel_result, tasks_get_current_promotable_result, tasks_get_progress_request, tasks_get_progress_result, task_shell_info, task_shell_info_attachment_mode, task_shell_progress, tasks_promote_current_to_background_result, tasks_promote_to_background_request, tasks_promote_to_background_result, tasks_refresh_result, tasks_remove_request, tasks_remove_result, tasks_send_message_request, tasks_send_message_result, tasks_start_agent_request, tasks_start_agent_result, task_status, tasks_wait_for_pending_result, telemetry_set_feature_overrides_request, token_auth_info, tool, tool_list, tools_get_current_metadata_result, tools_initialize_and_validate_result, tools_list_request, tools_update_subagent_settings_result, ui_auto_mode_switch_response, ui_elicitation_array_any_of_field, ui_elicitation_array_any_of_field_items, ui_elicitation_array_any_of_field_items_any_of, ui_elicitation_array_enum_field, ui_elicitation_array_enum_field_items, ui_elicitation_field_value, ui_elicitation_request, ui_elicitation_response, ui_elicitation_response_action, ui_elicitation_response_content, ui_elicitation_result, ui_elicitation_schema, ui_elicitation_schema_property, ui_elicitation_schema_property_boolean, ui_elicitation_schema_property_number, ui_elicitation_schema_property_number_type, ui_elicitation_schema_property_string, ui_elicitation_schema_property_string_format, ui_elicitation_string_enum_field, ui_elicitation_string_one_of_field, ui_elicitation_string_one_of_field_one_of, ui_ephemeral_query_request, ui_ephemeral_query_result, ui_exit_plan_mode_action, ui_exit_plan_mode_response, ui_handle_pending_auto_mode_switch_request, ui_handle_pending_elicitation_request, ui_handle_pending_exit_plan_mode_request, ui_handle_pending_result, ui_handle_pending_sampling_request, ui_handle_pending_sampling_response, ui_handle_pending_user_input_request, ui_register_direct_auto_mode_switch_handler_result, ui_unregister_direct_auto_mode_switch_handler_request, ui_unregister_direct_auto_mode_switch_handler_result, ui_user_input_response, update_subagent_settings_request, usage_get_metrics_result, usage_metrics_code_changes, usage_metrics_model_metric, usage_metrics_model_metric_requests, usage_metrics_model_metric_token_detail, usage_metrics_model_metric_usage, usage_metrics_token_detail, user_auth_info, user_requested_shell_command_result, workspace_diff_file_change, workspace_diff_file_change_type, workspace_diff_mode, workspace_diff_result, workspaces_checkpoints, workspaces_create_file_request, workspaces_diff_request, workspaces_get_workspace_result, workspaces_list_checkpoints_result, workspaces_list_files_result, workspaces_read_checkpoint_request, workspaces_read_checkpoint_result, workspaces_read_file_request, workspaces_read_file_result, workspaces_save_large_paste_request, workspaces_save_large_paste_result, workspace_summary_host_type, workspaces_workspace_details_host_type, session_context_info, subagent_settings, task_progress, workspace_summary) + return RPC(abort_request, abort_result, account_get_quota_request, account_get_quota_result, account_quota_snapshot, agent_get_current_result, agent_info, agent_info_source, agent_list, agent_registry_live_target_entry, agent_registry_live_target_entry_attention_kind, agent_registry_live_target_entry_kind, agent_registry_live_target_entry_last_terminal_event, agent_registry_live_target_entry_status, agent_registry_log_capture, agent_registry_log_capture_open_error_reason, agent_registry_spawn_error, agent_registry_spawn_permission_mode, agent_registry_spawn_registry_timeout, agent_registry_spawn_request, agent_registry_spawn_result, agent_registry_spawn_spawned, agent_registry_spawn_validation_error, agent_registry_spawn_validation_error_field, agent_registry_spawn_validation_error_reason, agent_reload_result, agents_discover_request, agent_select_request, agent_select_result, allow_all_permission_set_result, allow_all_permission_state, api_key_auth_info, auth_info, auth_info_type, cancel_user_requested_shell_command_result, canvas_action, canvas_action_invoke_request, canvas_action_invoke_result, canvas_close_request, canvas_host_context, canvas_host_context_capabilities, canvas_instance_availability, canvas_json_schema, canvas_list, canvas_list_open_result, canvas_open_request, canvas_provider_close_request, canvas_provider_invoke_action_request, canvas_provider_open_request, canvas_provider_open_result, canvas_session_context, command_list, commands_handle_pending_command_request, commands_handle_pending_command_result, commands_invoke_request, commands_list_request, commands_respond_to_queued_command_request, commands_respond_to_queued_command_result, configure_session_extensions_params, connected_remote_session_metadata, connected_remote_session_metadata_kind, connected_remote_session_metadata_repository, connect_remote_session_params, connect_request, connect_result, content_filter_mode, copilot_api_token_auth_info, copilot_user_response, copilot_user_response_endpoints, copilot_user_response_quota_snapshots, copilot_user_response_quota_snapshots_chat, copilot_user_response_quota_snapshots_completions, copilot_user_response_quota_snapshots_premium_interactions, current_model, current_tool_metadata, discovered_canvas, discovered_mcp_server, discovered_mcp_server_type, enqueue_command_params, enqueue_command_result, env_auth_info, event_log_read_request, event_log_release_interest_result, event_log_tail_result, event_log_types, events_agent_scope, events_cursor_status, events_read_result, execute_command_params, execute_command_result, extension, extension_context_push_input, extension_list, extensions_disable_request, extensions_enable_request, extension_source, extension_status, external_tool_result, external_tool_text_result_for_llm, external_tool_text_result_for_llm_binary_results_for_llm, external_tool_text_result_for_llm_binary_results_for_llm_type, external_tool_text_result_for_llm_content, external_tool_text_result_for_llm_content_audio, external_tool_text_result_for_llm_content_image, external_tool_text_result_for_llm_content_resource, external_tool_text_result_for_llm_content_resource_details, external_tool_text_result_for_llm_content_resource_link, external_tool_text_result_for_llm_content_resource_link_icon, external_tool_text_result_for_llm_content_resource_link_icon_theme, external_tool_text_result_for_llm_content_terminal, external_tool_text_result_for_llm_content_text, filter_mapping, fleet_start_request, fleet_start_result, folder_trust_add_params, folder_trust_check_params, folder_trust_check_result, gh_cli_auth_info, handle_pending_tool_call_request, handle_pending_tool_call_result, history_abort_manual_compaction_result, history_cancel_background_compaction_result, history_compact_context_window, history_compact_request, history_compact_result, history_summarize_for_handoff_result, history_truncate_request, history_truncate_result, hmac_auth_info, installed_plugin, installed_plugin_info, installed_plugin_source, installed_plugin_source_git_hub, installed_plugin_source_local, installed_plugin_source_url, instructions_discover_request, instructions_get_sources_result, instruction_source, instruction_source_location, instruction_source_type, llm_inference_headers, llm_inference_http_request_error, llm_inference_http_request_request, llm_inference_http_request_result, llm_inference_http_stream_start_error, llm_inference_http_stream_start_request, llm_inference_http_stream_start_result, llm_inference_request_metadata, llm_inference_request_metadata_endpoint_kind, llm_inference_request_metadata_provider_type, llm_inference_request_metadata_transport, llm_inference_request_metadata_wire_api, llm_inference_set_provider_request, llm_inference_set_provider_result, llm_inference_stream_chunk_request, llm_inference_stream_chunk_result, llm_inference_stream_end_request, llm_inference_stream_end_result, local_session_metadata_value, log_request, log_result, lsp_initialize_request, marketplace_add_result, marketplace_browse_result, marketplace_info, marketplace_list_result, marketplace_plugin_info, marketplace_refresh_entry, marketplace_refresh_result, marketplace_remove_result, mcp_allowed_server, mcp_apps_call_tool_request, mcp_apps_diagnose_capability, mcp_apps_diagnose_request, mcp_apps_diagnose_result, mcp_apps_diagnose_server, mcp_apps_host_context, mcp_apps_host_context_details, mcp_apps_host_context_details_available_display_mode, mcp_apps_host_context_details_display_mode, mcp_apps_host_context_details_platform, mcp_apps_host_context_details_theme, mcp_apps_list_tools_request, mcp_apps_list_tools_result, mcp_apps_read_resource_request, mcp_apps_read_resource_result, mcp_apps_resource_content, mcp_apps_set_host_context_details, mcp_apps_set_host_context_details_available_display_mode, mcp_apps_set_host_context_details_display_mode, mcp_apps_set_host_context_details_platform, mcp_apps_set_host_context_details_theme, mcp_apps_set_host_context_request, mcp_cancel_sampling_execution_params, mcp_cancel_sampling_execution_result, mcp_config_add_request, mcp_config_disable_request, mcp_config_enable_request, mcp_config_list, mcp_config_remove_request, mcp_config_update_request, mcp_configure_git_hub_request, mcp_configure_git_hub_result, mcp_disable_request, mcp_discover_request, mcp_discover_result, mcp_enable_request, mcp_execute_sampling_params, mcp_execute_sampling_request, mcp_execute_sampling_result, mcp_filtered_server, mcp_host_state, mcp_is_server_running_request, mcp_is_server_running_result, mcp_list_tools_request, mcp_list_tools_result, mcp_oauth_login_request, mcp_oauth_login_result, mcp_oauth_respond_request, mcp_oauth_respond_result, mcp_register_external_client_request, mcp_reload_with_config_request, mcp_remove_git_hub_result, mcp_restart_server_request, mcp_sampling_execution_action, mcp_sampling_execution_result, mcp_server, mcp_server_auth_config, mcp_server_auth_config_redirect_port, mcp_server_config, mcp_server_config_defer_tools, mcp_server_config_http, mcp_server_config_http_oauth_grant_type, mcp_server_config_http_type, mcp_server_config_stdio, mcp_server_failure_info, mcp_server_list, mcp_server_needs_auth_info, mcp_set_env_value_mode_details, mcp_set_env_value_mode_params, mcp_set_env_value_mode_result, mcp_start_server_request, mcp_start_servers_result, mcp_stop_server_request, mcp_tools, mcp_unregister_external_client_request, memory_configuration, metadata_context_info_request, metadata_context_info_result, metadata_is_processing_result, metadata_recompute_context_tokens_request, metadata_recompute_context_tokens_result, metadata_record_context_change_request, metadata_record_context_change_result, metadata_set_working_directory_request, metadata_set_working_directory_result, metadata_snapshot_current_mode, metadata_snapshot_remote_metadata, metadata_snapshot_remote_metadata_repository, metadata_snapshot_remote_metadata_task_type, model, model_billing, model_billing_token_prices, model_billing_token_prices_long_context, model_capabilities, model_capabilities_limits, model_capabilities_limits_vision, model_capabilities_override, model_capabilities_override_limits, model_capabilities_override_limits_vision, model_capabilities_override_supports, model_capabilities_supports, model_list, model_list_request, model_picker_category, model_picker_price_category, model_policy, model_policy_state, model_set_reasoning_effort_request, model_set_reasoning_effort_result, models_list_request, model_switch_to_request, model_switch_to_result, mode_set_request, name_get_result, name_set_auto_request, name_set_auto_result, name_set_request, open_canvas_instance, options_update_additional_content_exclusion_policy, options_update_additional_content_exclusion_policy_rule, options_update_additional_content_exclusion_policy_rule_source, options_update_additional_content_exclusion_policy_scope, options_update_context_tier, options_update_env_value_mode, options_update_reasoning_summary, options_update_tool_filter_precedence, pending_permission_request, pending_permission_request_list, permission_decision, permission_decision_approved, permission_decision_approved_for_location, permission_decision_approved_for_session, permission_decision_approve_for_location, permission_decision_approve_for_location_approval, permission_decision_approve_for_location_approval_commands, permission_decision_approve_for_location_approval_custom_tool, permission_decision_approve_for_location_approval_extension_management, permission_decision_approve_for_location_approval_extension_permission_access, permission_decision_approve_for_location_approval_mcp, permission_decision_approve_for_location_approval_mcp_sampling, permission_decision_approve_for_location_approval_memory, permission_decision_approve_for_location_approval_read, permission_decision_approve_for_location_approval_write, permission_decision_approve_for_session, permission_decision_approve_for_session_approval, permission_decision_approve_for_session_approval_commands, permission_decision_approve_for_session_approval_custom_tool, permission_decision_approve_for_session_approval_extension_management, permission_decision_approve_for_session_approval_extension_permission_access, permission_decision_approve_for_session_approval_mcp, permission_decision_approve_for_session_approval_mcp_sampling, permission_decision_approve_for_session_approval_memory, permission_decision_approve_for_session_approval_read, permission_decision_approve_for_session_approval_write, permission_decision_approve_once, permission_decision_approve_permanently, permission_decision_cancelled, permission_decision_denied_by_content_exclusion_policy, permission_decision_denied_by_permission_request_hook, permission_decision_denied_by_rules, permission_decision_denied_interactively_by_user, permission_decision_denied_no_approval_rule_and_could_not_request_from_user, permission_decision_reject, permission_decision_request, permission_decision_user_not_available, permission_location_add_tool_approval_params, permission_location_apply_params, permission_location_apply_result, permission_location_resolve_params, permission_location_resolve_result, permission_location_type, permission_paths_add_params, permission_paths_allowed_check_params, permission_paths_allowed_check_result, permission_paths_config, permission_paths_list, permission_paths_update_primary_params, permission_paths_workspace_check_params, permission_paths_workspace_check_result, permission_prompt_shown_notification, permission_request_result, permission_rules_set, permissions_configure_additional_content_exclusion_policy, permissions_configure_additional_content_exclusion_policy_rule, permissions_configure_additional_content_exclusion_policy_rule_source, permissions_configure_additional_content_exclusion_policy_scope, permissions_configure_params, permissions_configure_result, permissions_folder_trust_add_trusted_result, permissions_get_allow_all_request, permissions_locations_add_tool_approval_details, permissions_locations_add_tool_approval_details_commands, permissions_locations_add_tool_approval_details_custom_tool, permissions_locations_add_tool_approval_details_extension_management, permissions_locations_add_tool_approval_details_extension_permission_access, permissions_locations_add_tool_approval_details_mcp, permissions_locations_add_tool_approval_details_mcp_sampling, permissions_locations_add_tool_approval_details_memory, permissions_locations_add_tool_approval_details_read, permissions_locations_add_tool_approval_details_write, permissions_locations_add_tool_approval_result, permissions_modify_rules_params, permissions_modify_rules_result, permissions_modify_rules_scope, permissions_notify_prompt_shown_result, permissions_paths_add_result, permissions_paths_list_request, permissions_paths_update_primary_result, permissions_pending_requests_request, permissions_reset_session_approvals_request, permissions_reset_session_approvals_result, permissions_set_allow_all_request, permissions_set_allow_all_source, permissions_set_approve_all_request, permissions_set_approve_all_result, permissions_set_approve_all_source, permissions_set_required_request, permissions_set_required_result, permissions_urls_set_unrestricted_mode_result, permission_urls_config, permission_urls_set_unrestricted_mode_params, ping_request, ping_result, plan_read_result, plan_read_sql_todos_result, plan_read_sql_todos_with_dependencies_result, plan_sql_todo_dependency, plan_sql_todos_row, plan_update_request, plugin, plugin_install_result, plugin_list, plugin_list_result, plugins_disable_request, plugins_enable_request, plugins_install_request, plugins_marketplaces_add_request, plugins_marketplaces_browse_request, plugins_marketplaces_refresh_request, plugins_marketplaces_remove_request, plugins_reload_request, plugins_uninstall_request, plugins_update_request, plugin_update_all_entry, plugin_update_all_result, plugin_update_result, poll_spawned_sessions_result, provider_config, provider_config_azure, provider_config_type, provider_config_wire_api, provider_endpoint, provider_endpoint_type, provider_endpoint_wire_api, provider_get_endpoint_request, provider_session_token, push_attachment, push_attachment_blob, push_attachment_directory, push_attachment_file, push_attachment_file_line_range, push_attachment_git_hub_reference, push_attachment_git_hub_reference_type, push_attachment_selection, push_attachment_selection_details, push_attachment_selection_details_end, push_attachment_selection_details_start, queued_command_handled, queued_command_not_handled, queued_command_result, queue_pending_items, queue_pending_items_kind, queue_pending_items_result, queue_remove_most_recent_result, register_event_interest_params, register_event_interest_result, register_extension_tools_params, register_extension_tools_result, release_event_interest_params, remote_control_config, remote_control_config_existing_mc_session, remote_control_status, remote_control_status_active, remote_control_status_connecting, remote_control_status_error, remote_control_status_off, remote_control_status_result, remote_control_stop_result, remote_control_transfer_result, remote_enable_request, remote_enable_result, remote_notify_steerable_changed_request, remote_notify_steerable_changed_result, remote_session_connection_result, remote_session_metadata_repository, remote_session_metadata_task_type, remote_session_metadata_value, remote_session_mode, remote_session_repository, sandbox_config, sandbox_config_user_policy, sandbox_config_user_policy_experimental, sandbox_config_user_policy_experimental_seatbelt, sandbox_config_user_policy_filesystem, sandbox_config_user_policy_network, schedule_entry, schedule_list, schedule_stop_request, schedule_stop_result, secrets_add_filter_values_request, secrets_add_filter_values_result, send_agent_mode, send_attachments_to_message_params, send_mode, send_request, send_result, server_agent_list, server_instruction_source_list, server_skill, server_skill_list, session_activity, session_auth_status, session_bulk_delete_result, session_capability, session_context, session_context_host_type, session_enrich_metadata_result, session_fs_append_file_request, session_fs_error, session_fs_error_code, session_fs_exists_request, session_fs_exists_result, session_fs_mkdir_request, session_fs_readdir_request, session_fs_readdir_result, session_fs_readdir_with_types_entry, session_fs_readdir_with_types_entry_type, session_fs_readdir_with_types_request, session_fs_readdir_with_types_result, session_fs_read_file_request, session_fs_read_file_result, session_fs_rename_request, session_fs_rm_request, session_fs_set_provider_capabilities, session_fs_set_provider_conventions, session_fs_set_provider_request, session_fs_set_provider_result, session_fs_sqlite_exists_request, session_fs_sqlite_exists_result, session_fs_sqlite_query_request, session_fs_sqlite_query_result, session_fs_sqlite_query_type, session_fs_stat_request, session_fs_stat_result, session_fs_write_file_request, session_installed_plugin, session_installed_plugin_source, session_installed_plugin_source_git_hub, session_installed_plugin_source_local, session_installed_plugin_source_url, session_list, session_list_entry, session_list_filter, session_load_deferred_repo_hooks_result, session_log_level, session_mcp_apps_call_tool_result, session_metadata_snapshot, session_mode, session_model_list, session_open_options, session_open_options_additional_content_exclusion_policy, session_open_options_additional_content_exclusion_policy_rule, session_open_options_additional_content_exclusion_policy_rule_source, session_open_options_additional_content_exclusion_policy_scope, session_open_options_env_value_mode, session_open_options_reasoning_summary, session_open_params, session_open_result, session_prune_result, sessions_bulk_delete_request, sessions_check_in_use_request, sessions_check_in_use_result, sessions_close_request, sessions_close_result, sessions_enrich_metadata_request, session_set_credentials_params, session_set_credentials_result, sessions_find_by_prefix_request, sessions_find_by_prefix_result, sessions_find_by_task_id_request, sessions_find_by_task_id_result, sessions_fork_request, sessions_fork_result, sessions_get_board_entry_count_request, sessions_get_board_entry_count_result, sessions_get_event_file_path_request, sessions_get_event_file_path_result, sessions_get_last_for_context_request, sessions_get_last_for_context_result, sessions_get_persisted_remote_steerable_request, sessions_get_persisted_remote_steerable_result, session_sizes, sessions_list_request, sessions_load_deferred_repo_hooks_request, sessions_open_attach, sessions_open_cloud, sessions_open_create, sessions_open_handoff, sessions_open_handoff_task_type, sessions_open_progress, sessions_open_progress_status, sessions_open_progress_step, sessions_open_remote, sessions_open_resume, sessions_open_resume_last, sessions_open_status, session_source, sessions_poll_spawned_sessions_event, sessions_poll_spawned_sessions_request, sessions_prune_old_request, sessions_register_extension_tools_on_session_options, sessions_release_lock_request, sessions_release_lock_result, sessions_reload_plugin_hooks_request, sessions_reload_plugin_hooks_result, sessions_save_request, sessions_save_result, sessions_set_additional_plugins_request, sessions_set_additional_plugins_result, sessions_set_remote_control_steering_request, sessions_start_remote_control_request, sessions_stop_remote_control_request, sessions_transfer_remote_control_request, session_telemetry_engagement, session_update_options_params, session_update_options_result, session_working_directory_context, session_working_directory_context_host_type, shell_cancel_user_requested_request, shell_exec_request, shell_exec_result, shell_execute_user_requested_request, shell_kill_request, shell_kill_result, shell_kill_signal, shutdown_request, skill, skill_list, skills_config_set_disabled_skills_request, skills_disable_request, skills_discover_request, skills_enable_request, skills_get_invoked_result, skills_invoked_skill, skills_load_diagnostics, slash_command_agent_prompt_result, slash_command_completed_result, slash_command_info, slash_command_input, slash_command_input_completion, slash_command_invocation_result, slash_command_kind, slash_command_select_subcommand_option, slash_command_select_subcommand_result, slash_command_text_result, subagent_settings_entry, subagent_settings_entry_context_tier, task_agent_info, task_agent_progress, task_execution_mode, task_info, task_list, task_progress_line, tasks_cancel_request, tasks_cancel_result, tasks_get_current_promotable_result, tasks_get_progress_request, tasks_get_progress_result, task_shell_info, task_shell_info_attachment_mode, task_shell_progress, tasks_promote_current_to_background_result, tasks_promote_to_background_request, tasks_promote_to_background_result, tasks_refresh_result, tasks_remove_request, tasks_remove_result, tasks_send_message_request, tasks_send_message_result, tasks_start_agent_request, tasks_start_agent_result, task_status, tasks_wait_for_pending_result, telemetry_set_feature_overrides_request, token_auth_info, tool, tool_list, tools_get_current_metadata_result, tools_initialize_and_validate_result, tools_list_request, tools_update_subagent_settings_result, ui_auto_mode_switch_response, ui_elicitation_array_any_of_field, ui_elicitation_array_any_of_field_items, ui_elicitation_array_any_of_field_items_any_of, ui_elicitation_array_enum_field, ui_elicitation_array_enum_field_items, ui_elicitation_field_value, ui_elicitation_request, ui_elicitation_response, ui_elicitation_response_action, ui_elicitation_response_content, ui_elicitation_result, ui_elicitation_schema, ui_elicitation_schema_property, ui_elicitation_schema_property_boolean, ui_elicitation_schema_property_number, ui_elicitation_schema_property_number_type, ui_elicitation_schema_property_string, ui_elicitation_schema_property_string_format, ui_elicitation_string_enum_field, ui_elicitation_string_one_of_field, ui_elicitation_string_one_of_field_one_of, ui_ephemeral_query_request, ui_ephemeral_query_result, ui_exit_plan_mode_action, ui_exit_plan_mode_response, ui_handle_pending_auto_mode_switch_request, ui_handle_pending_elicitation_request, ui_handle_pending_exit_plan_mode_request, ui_handle_pending_result, ui_handle_pending_sampling_request, ui_handle_pending_sampling_response, ui_handle_pending_user_input_request, ui_register_direct_auto_mode_switch_handler_result, ui_unregister_direct_auto_mode_switch_handler_request, ui_unregister_direct_auto_mode_switch_handler_result, ui_user_input_response, update_subagent_settings_request, usage_get_metrics_result, usage_metrics_code_changes, usage_metrics_model_metric, usage_metrics_model_metric_requests, usage_metrics_model_metric_token_detail, usage_metrics_model_metric_usage, usage_metrics_token_detail, user_auth_info, user_requested_shell_command_result, workspace_diff_file_change, workspace_diff_file_change_type, workspace_diff_mode, workspace_diff_result, workspaces_checkpoints, workspaces_create_file_request, workspaces_diff_request, workspaces_get_workspace_result, workspaces_list_checkpoints_result, workspaces_list_files_result, workspaces_read_checkpoint_request, workspaces_read_checkpoint_result, workspaces_read_file_request, workspaces_read_file_result, workspaces_save_large_paste_request, workspaces_save_large_paste_result, workspace_summary_host_type, workspaces_workspace_details_host_type, session_context_info, subagent_settings, task_progress, workspace_summary) def to_dict(self) -> dict: result: dict = {} @@ -21190,6 +21712,24 @@ def to_dict(self) -> dict: result["InstructionSource"] = to_class(InstructionSource, self.instruction_source) result["InstructionSourceLocation"] = to_enum(InstructionSourceLocation, self.instruction_source_location) result["InstructionSourceType"] = to_enum(InstructionSourceType, self.instruction_source_type) + result["LlmInferenceHeaders"] = from_dict(lambda x: from_list(from_str, x), self.llm_inference_headers) + result["LlmInferenceHttpRequestError"] = to_class(LlmInferenceHTTPRequestError, self.llm_inference_http_request_error) + result["LlmInferenceHttpRequestRequest"] = to_class(LlmInferenceHTTPRequestRequest, self.llm_inference_http_request_request) + result["LlmInferenceHttpRequestResult"] = to_class(LlmInferenceHTTPRequestResult, self.llm_inference_http_request_result) + result["LlmInferenceHttpStreamStartError"] = to_class(LlmInferenceHTTPStreamStartError, self.llm_inference_http_stream_start_error) + result["LlmInferenceHttpStreamStartRequest"] = to_class(LlmInferenceHTTPStreamStartRequest, self.llm_inference_http_stream_start_request) + result["LlmInferenceHttpStreamStartResult"] = to_class(LlmInferenceHTTPStreamStartResult, self.llm_inference_http_stream_start_result) + result["LlmInferenceRequestMetadata"] = to_class(LlmInferenceRequestMetadata, self.llm_inference_request_metadata) + result["LlmInferenceRequestMetadataEndpointKind"] = to_enum(LlmInferenceRequestMetadataEndpointKind, self.llm_inference_request_metadata_endpoint_kind) + result["LlmInferenceRequestMetadataProviderType"] = to_enum(LlmInferenceRequestMetadataProviderType, self.llm_inference_request_metadata_provider_type) + result["LlmInferenceRequestMetadataTransport"] = to_enum(LlmInferenceRequestMetadataTransport, self.llm_inference_request_metadata_transport) + result["LlmInferenceRequestMetadataWireApi"] = to_enum(LlmInferenceRequestMetadataWireAPI, self.llm_inference_request_metadata_wire_api) + result["LlmInferenceSetProviderRequest"] = to_class(LlmInferenceSetProviderRequest, self.llm_inference_set_provider_request) + result["LlmInferenceSetProviderResult"] = to_class(LlmInferenceSetProviderResult, self.llm_inference_set_provider_result) + result["LlmInferenceStreamChunkRequest"] = to_class(LlmInferenceStreamChunkRequest, self.llm_inference_stream_chunk_request) + result["LlmInferenceStreamChunkResult"] = to_class(LlmInferenceStreamChunkResult, self.llm_inference_stream_chunk_result) + result["LlmInferenceStreamEndRequest"] = to_class(LlmInferenceStreamEndRequest, self.llm_inference_stream_end_request) + result["LlmInferenceStreamEndResult"] = to_class(LlmInferenceStreamEndResult, self.llm_inference_stream_end_result) result["LocalSessionMetadataValue"] = to_class(LocalSessionMetadataValue, self.local_session_metadata_value) result["LogRequest"] = to_class(LogRequest, self.log_request) result["LogResult"] = to_class(LogResult, self.log_result) @@ -21262,6 +21802,7 @@ def to_dict(self) -> dict: result["McpServerAuthConfig"] = from_union([from_bool, lambda x: to_class(MCPServerAuthConfigRedirectPort, x)], self.mcp_server_auth_config) result["McpServerAuthConfigRedirectPort"] = to_class(MCPServerAuthConfigRedirectPort, self.mcp_server_auth_config_redirect_port) result["McpServerConfig"] = to_class(MCPServerConfig, self.mcp_server_config) + result["McpServerConfigDeferTools"] = to_enum(MCPServerConfigDeferTools, self.mcp_server_config_defer_tools) result["McpServerConfigHttp"] = to_class(MCPServerConfigHTTP, self.mcp_server_config_http) result["McpServerConfigHttpOauthGrantType"] = to_enum(MCPServerConfigHTTPOauthGrantType, self.mcp_server_config_http_oauth_grant_type) result["McpServerConfigHttpType"] = to_enum(MCPServerConfigHTTPType, self.mcp_server_config_http_type) @@ -21998,6 +22539,7 @@ def _load_TaskInfo(obj: Any) -> "TaskInfo": ExternalToolResult = ExternalToolTextResultForLlm ExternalToolTextResultForLlmContentResourceLinkIconTheme = Theme FilterMapping = dict +LlmInferenceHeaders = dict McpAppsHostContextDetailsAvailableDisplayMode = MCPAppsDisplayMode McpAppsHostContextDetailsDisplayMode = MCPAppsDisplayMode McpAppsHostContextDetailsTheme = Theme @@ -22296,6 +22838,26 @@ async def set_provider(self, params: SessionFSSetProviderRequest, *, timeout: fl return SessionFSSetProviderResult.from_dict(await self._client.request("sessionFs.setProvider", params_dict, **_timeout_kwargs(timeout))) +# Experimental: this API group is experimental and may change or be removed. +class ServerLlmInferenceApi: + def __init__(self, client: "JsonRpcClient"): + self._client = client + + async def set_provider(self, *, timeout: float | None = None) -> LlmInferenceSetProviderResult: + "Registers an SDK client as the LLM inference callback provider.\n\nReturns:\n Indicates whether the calling client was registered as the LLM inference provider." + return LlmInferenceSetProviderResult.from_dict(await self._client.request("llmInference.setProvider", {}, **_timeout_kwargs(timeout))) + + async def stream_chunk(self, params: LlmInferenceStreamChunkRequest, *, timeout: float | None = None) -> LlmInferenceStreamChunkResult: + "Pushes a streamed response body chunk back to the runtime, correlated by the streamToken the runtime previously handed out in llmInference.httpStreamStart.\n\nArgs:\n params: A streamed response body chunk.\n\nReturns:\n Whether the chunk was accepted." + params_dict = {k: v for k, v in params.to_dict().items() if v is not None} + return LlmInferenceStreamChunkResult.from_dict(await self._client.request("llmInference.streamChunk", params_dict, **_timeout_kwargs(timeout))) + + async def stream_end(self, params: LlmInferenceStreamEndRequest, *, timeout: float | None = None) -> LlmInferenceStreamEndResult: + "Signals end-of-stream for an inference response stream the SDK client started via llmInference.httpStreamStart.\n\nArgs:\n params: End-of-stream signal.\n\nReturns:\n Whether the end signal was accepted." + params_dict = {k: v for k, v in params.to_dict().items() if v is not None} + return LlmInferenceStreamEndResult.from_dict(await self._client.request("llmInference.streamEnd", params_dict, **_timeout_kwargs(timeout))) + + # Experimental: this API group is experimental and may change or be removed. class ServerSessionsApi: def __init__(self, client: "JsonRpcClient"): @@ -22442,6 +23004,7 @@ def __init__(self, client: "JsonRpcClient"): self.user = ServerUserApi(client) self.runtime = ServerRuntimeApi(client) self.session_fs = ServerSessionFsApi(client) + self.llm_inference = ServerLlmInferenceApi(client) self.sessions = ServerSessionsApi(client) self.agent_registry = ServerAgentRegistryApi(client) @@ -24045,6 +24608,24 @@ async def handle_canvas_action_invoke(params: dict) -> dict | None: "InstructionsDiscoverRequest", "InstructionsGetSourcesResult", "KindEnum", + "LlmInferenceHTTPRequestError", + "LlmInferenceHTTPRequestRequest", + "LlmInferenceHTTPRequestResult", + "LlmInferenceHTTPStreamStartError", + "LlmInferenceHTTPStreamStartRequest", + "LlmInferenceHTTPStreamStartResult", + "LlmInferenceHeaders", + "LlmInferenceRequestMetadata", + "LlmInferenceRequestMetadataEndpointKind", + "LlmInferenceRequestMetadataProviderType", + "LlmInferenceRequestMetadataTransport", + "LlmInferenceRequestMetadataWireAPI", + "LlmInferenceSetProviderRequest", + "LlmInferenceSetProviderResult", + "LlmInferenceStreamChunkRequest", + "LlmInferenceStreamChunkResult", + "LlmInferenceStreamEndRequest", + "LlmInferenceStreamEndResult", "LocalSessionMetadataValue", "LogRequest", "LogResult", @@ -24101,6 +24682,7 @@ async def handle_canvas_action_invoke(params: dict) -> dict | None: "MCPServer", "MCPServerAuthConfigRedirectPort", "MCPServerConfig", + "MCPServerConfigDeferTools", "MCPServerConfigHTTP", "MCPServerConfigHTTPOauthGrantType", "MCPServerConfigHTTPType", @@ -24440,6 +25022,7 @@ async def handle_canvas_action_invoke(params: dict) -> dict | None: "ServerAgentsApi", "ServerInstructionSourceList", "ServerInstructionsApi", + "ServerLlmInferenceApi", "ServerMcpApi", "ServerMcpConfigApi", "ServerModelsApi", diff --git a/python/copilot/generated/session_events.py b/python/copilot/generated/session_events.py index 697968181..829f61469 100644 --- a/python/copilot/generated/session_events.py +++ b/python/copilot/generated/session_events.py @@ -734,11 +734,13 @@ class AssistantUsageData: api_endpoint: AssistantUsageApiEndpoint | None = None cache_read_tokens: int | None = None cache_write_tokens: int | None = None + content_filter_triggered: bool | None = None # Internal: this field is an internal SDK API and is not part of the public surface. _copilot_usage: _AssistantUsageCopilotUsage | None = None # Experimental: this field is part of an experimental API and may change or be removed. cost: float | None = None duration: timedelta | None = None + finish_reason: str | None = None initiator: str | None = None input_tokens: int | None = None inter_token_latency: timedelta | None = None @@ -761,9 +763,11 @@ def from_dict(obj: Any) -> "AssistantUsageData": api_endpoint = from_union([from_none, lambda x: parse_enum(AssistantUsageApiEndpoint, x)], obj.get("apiEndpoint")) cache_read_tokens = from_union([from_none, from_int], obj.get("cacheReadTokens")) cache_write_tokens = from_union([from_none, from_int], obj.get("cacheWriteTokens")) + content_filter_triggered = from_union([from_none, from_bool], obj.get("contentFilterTriggered")) _copilot_usage = from_union([from_none, _AssistantUsageCopilotUsage.from_dict], obj.get("copilotUsage")) cost = from_union([from_none, from_float], obj.get("cost")) duration = from_union([from_none, from_timedelta], obj.get("duration")) + finish_reason = from_union([from_none, from_str], obj.get("finishReason")) initiator = from_union([from_none, from_str], obj.get("initiator")) input_tokens = from_union([from_none, from_int], obj.get("inputTokens")) inter_token_latency = from_union([from_none, from_timedelta], obj.get("interTokenLatencyMs")) @@ -781,9 +785,11 @@ def from_dict(obj: Any) -> "AssistantUsageData": api_endpoint=api_endpoint, cache_read_tokens=cache_read_tokens, cache_write_tokens=cache_write_tokens, + content_filter_triggered=content_filter_triggered, _copilot_usage=_copilot_usage, cost=cost, duration=duration, + finish_reason=finish_reason, initiator=initiator, input_tokens=input_tokens, inter_token_latency=inter_token_latency, @@ -808,12 +814,16 @@ def to_dict(self) -> dict: result["cacheReadTokens"] = from_union([from_none, to_int], self.cache_read_tokens) if self.cache_write_tokens is not None: result["cacheWriteTokens"] = from_union([from_none, to_int], self.cache_write_tokens) + if self.content_filter_triggered is not None: + result["contentFilterTriggered"] = from_union([from_none, from_bool], self.content_filter_triggered) if self._copilot_usage is not None: result["copilotUsage"] = from_union([from_none, lambda x: to_class(_AssistantUsageCopilotUsage, x)], self._copilot_usage) if self.cost is not None: result["cost"] = from_union([from_none, to_float], self.cost) if self.duration is not None: result["duration"] = from_union([from_none, to_timedelta_int], self.duration) + if self.finish_reason is not None: + result["finishReason"] = from_union([from_none, from_str], self.finish_reason) if self.initiator is not None: result["initiator"] = from_union([from_none, from_str], self.initiator) if self.input_tokens is not None: @@ -5946,6 +5956,7 @@ class ToolExecutionCompleteResult: content: str contents: list[ToolExecutionCompleteContent] | None = None detailed_content: str | None = None + structured_content: Any = None ui_resource: ToolExecutionCompleteUIResource | None = None @staticmethod @@ -5954,11 +5965,13 @@ def from_dict(obj: Any) -> "ToolExecutionCompleteResult": content = from_str(obj.get("content")) contents = from_union([from_none, lambda x: from_list(_load_ToolExecutionCompleteContent, x)], obj.get("contents")) detailed_content = from_union([from_none, from_str], obj.get("detailedContent")) + structured_content = obj.get("structuredContent") ui_resource = from_union([from_none, ToolExecutionCompleteUIResource.from_dict], obj.get("uiResource")) return ToolExecutionCompleteResult( content=content, contents=contents, detailed_content=detailed_content, + structured_content=structured_content, ui_resource=ui_resource, ) @@ -5969,6 +5982,8 @@ def to_dict(self) -> dict: result["contents"] = from_union([from_none, lambda x: from_list(lambda x: x.to_dict(), x)], self.contents) if self.detailed_content is not None: result["detailedContent"] = from_union([from_none, from_str], self.detailed_content) + if self.structured_content is not None: + result["structuredContent"] = self.structured_content if self.ui_resource is not None: result["uiResource"] = from_union([from_none, lambda x: to_class(ToolExecutionCompleteUIResource, x)], self.ui_resource) return result diff --git a/rust/src/generated/api_types.rs b/rust/src/generated/api_types.rs index 6dc971717..a52c50ca6 100644 --- a/rust/src/generated/api_types.rs +++ b/rust/src/generated/api_types.rs @@ -83,6 +83,12 @@ pub mod rpc_methods { pub const RUNTIME_SHUTDOWN: &str = "runtime.shutdown"; /// `sessionFs.setProvider` pub const SESSIONFS_SETPROVIDER: &str = "sessionFs.setProvider"; + /// `llmInference.setProvider` + pub const LLMINFERENCE_SETPROVIDER: &str = "llmInference.setProvider"; + /// `llmInference.streamChunk` + pub const LLMINFERENCE_STREAMCHUNK: &str = "llmInference.streamChunk"; + /// `llmInference.streamEnd` + pub const LLMINFERENCE_STREAMEND: &str = "llmInference.streamEnd"; /// `sessions.open` pub const SESSIONS_OPEN: &str = "sessions.open"; /// `sessions.fork` @@ -3165,6 +3171,223 @@ pub struct InstructionsGetSourcesResult { pub sources: Vec, } +/// Set when the SDK client could not produce a response (transport-level failure). Causes the runtime to raise an APIConnectionError; status/headers/body are ignored when error is set. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LlmInferenceHttpRequestError { + /// Optional machine-readable error code. + #[serde(skip_serializing_if = "Option::is_none")] + pub code: Option, + /// Human-readable failure description. + pub message: String, +} + +/// Metadata describing an intercepted LLM HTTP request. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LlmInferenceRequestMetadata { + /// What kind of model-layer endpoint this is. + pub endpoint_kind: LlmInferenceRequestMetadataEndpointKind, + /// Model identifier, when known. + #[serde(skip_serializing_if = "Option::is_none")] + pub model_id: Option, + /// Logical model provider this request targets. + pub provider_type: LlmInferenceRequestMetadataProviderType, + /// Transport kind. v1 implements http only. + pub transport: LlmInferenceRequestMetadataTransport, + /// Wire API shape, when this is an inference request. + #[serde(skip_serializing_if = "Option::is_none")] + pub wire_api: Option, +} + +/// An outbound model-layer HTTP request the runtime would otherwise have issued itself. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LlmInferenceHttpRequestRequest { + /// Request body as base64-encoded bytes. Set instead of bodyText when the body is binary. + #[serde(skip_serializing_if = "Option::is_none")] + pub body_base64: Option, + /// Request body as a UTF-8 string. Set when binaryBody is absent or false. + #[serde(skip_serializing_if = "Option::is_none")] + pub body_text: Option, + /// HTTP headers as a map from lowercased header name to a list of values. Multi-valued headers (e.g. Set-Cookie) preserve all values. + pub headers: HashMap>, + /// Metadata describing an intercepted LLM HTTP request. + pub metadata: LlmInferenceRequestMetadata, + /// HTTP method, e.g. GET, POST. + pub method: String, + /// Opaque runtime-minted id, unique per request. Useful for client-side logging. + pub request_id: RequestId, + /// Id of the runtime session that triggered this request, when one is in scope. Absent for requests issued outside any session (e.g. startup model-catalog or capability resolution). This is a payload field — not a dispatch key — because the client-global API is registered process-wide rather than per session. + #[serde(skip_serializing_if = "Option::is_none")] + pub session_id: Option, + /// Absolute request URL. + pub url: String, +} + +/// The HTTP response the runtime should treat as if it had issued the request itself. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LlmInferenceHttpRequestResult { + /// Response body as base64-encoded bytes. Set instead of bodyText for binary responses. + #[serde(skip_serializing_if = "Option::is_none")] + pub body_base64: Option, + /// Response body as a UTF-8 string. Set when bodyBase64 is absent. + #[serde(skip_serializing_if = "Option::is_none")] + pub body_text: Option, + /// Set when the SDK client could not produce a response (transport-level failure). Causes the runtime to raise an APIConnectionError; status/headers/body are ignored when error is set. + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + /// HTTP headers as a map from lowercased header name to a list of values. Multi-valued headers (e.g. Set-Cookie) preserve all values. + pub headers: HashMap>, + /// HTTP status code returned to the runtime. + pub status: i64, + /// Optional HTTP status text. + #[serde(skip_serializing_if = "Option::is_none")] + pub status_text: Option, +} + +/// Set when the SDK client could not even begin the stream (transport-level failure). When error is set the runtime raises an APIConnectionError and ignores status/headers. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LlmInferenceHttpStreamStartError { + #[serde(skip_serializing_if = "Option::is_none")] + pub code: Option, + pub message: String, +} + +/// An outbound streaming model-layer HTTP request. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LlmInferenceHttpStreamStartRequest { + #[serde(skip_serializing_if = "Option::is_none")] + pub body_base64: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub body_text: Option, + /// HTTP headers as a map from lowercased header name to a list of values. Multi-valued headers (e.g. Set-Cookie) preserve all values. + pub headers: HashMap>, + /// Metadata describing an intercepted LLM HTTP request. + pub metadata: LlmInferenceRequestMetadata, + /// HTTP method. + pub method: String, + /// Opaque runtime-minted id, unique per request. + pub request_id: RequestId, + /// Originating session id, when known. + #[serde(skip_serializing_if = "Option::is_none")] + pub session_id: Option, + /// Stream identifier. The SDK client passes this exact value back on every llmInference.streamChunk / streamEnd call to correlate pushed chunks with this request. + pub stream_token: i64, + /// Absolute request URL. + pub url: String, +} + +/// The response head. After returning, the SDK client pushes body chunks via llmInference.streamChunk and signals completion (or transport error) via llmInference.streamEnd. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LlmInferenceHttpStreamStartResult { + /// Set when the SDK client could not even begin the stream (transport-level failure). When error is set the runtime raises an APIConnectionError and ignores status/headers. + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + /// HTTP headers as a map from lowercased header name to a list of values. Multi-valued headers (e.g. Set-Cookie) preserve all values. + pub headers: HashMap>, + /// HTTP status code. + pub status: i64, + #[serde(skip_serializing_if = "Option::is_none")] + pub status_text: Option, +} + +/// No parameters. The calling connection is registered as the runtime's LLM inference provider; all subsequent model-layer HTTP requests are dispatched back to it via the llmInference client API. +/// +///
+/// +/// **Experimental.** This type is part of an experimental wire-protocol surface +/// and may change or be removed in future SDK or CLI releases. +/// +///
+#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LlmInferenceSetProviderRequest {} + +/// Indicates whether the calling client was registered as the LLM inference provider. +/// +///
+/// +/// **Experimental.** This type is part of an experimental wire-protocol surface +/// and may change or be removed in future SDK or CLI releases. +/// +///
+#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LlmInferenceSetProviderResult { + /// Whether the provider was set successfully + pub success: bool, +} + +/// A streamed response body chunk. +/// +///
+/// +/// **Experimental.** This type is part of an experimental wire-protocol surface +/// and may change or be removed in future SDK or CLI releases. +/// +///
+#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LlmInferenceStreamChunkRequest { + /// One body chunk as base64-encoded bytes. Chunks are appended to the runtime's view of the response body in the order received. + pub data_base64: String, + /// The same streamToken the runtime supplied in the originating llmInference.httpStreamStart call. + pub stream_token: i64, +} + +/// Whether the chunk was accepted. +/// +///
+/// +/// **Experimental.** This type is part of an experimental wire-protocol surface +/// and may change or be removed in future SDK or CLI releases. +/// +///
+#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LlmInferenceStreamChunkResult { + /// True when the chunk was queued for the stream; false when the stream is unknown. + pub accepted: bool, +} + +/// End-of-stream signal. +/// +///
+/// +/// **Experimental.** This type is part of an experimental wire-protocol surface +/// and may change or be removed in future SDK or CLI releases. +/// +///
+#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LlmInferenceStreamEndRequest { + /// When set, marks the stream as ending with a transport-level error of this description. When absent the stream ends normally. + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + /// The originating streamToken. + pub stream_token: i64, +} + +/// Whether the end signal was accepted. +/// +///
+/// +/// **Experimental.** This type is part of an experimental wire-protocol surface +/// and may change or be removed in future SDK or CLI releases. +/// +///
+#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LlmInferenceStreamEndResult { + /// True when the stream was found and ended; false when unknown. + pub accepted: bool, +} + /// Pre-resolved working-directory context for session startup. /// ///
@@ -4292,6 +4515,9 @@ pub struct McpServerConfigHttp { /// Set to `true` to use defaults, or provide an object with additional auth or OIDC settings. #[serde(skip_serializing_if = "Option::is_none")] pub auth: Option, + /// Controls if tools provided by this server can be loaded on demand via tool search (auto) or always included in the initial tool list (never) + #[serde(skip_serializing_if = "Option::is_none")] + pub defer_tools: Option, /// Content filtering mode to apply to all tools, or a map of tool name to content filtering mode. #[serde(skip_serializing_if = "Option::is_none")] pub filter_mapping: Option, @@ -4341,6 +4567,9 @@ pub struct McpServerConfigStdio { /// Working directory for the Stdio MCP server process. #[serde(skip_serializing_if = "Option::is_none")] pub cwd: Option, + /// Controls if tools provided by this server can be loaded on demand via tool search (auto) or always included in the initial tool list (never) + #[serde(skip_serializing_if = "Option::is_none")] + pub defer_tools: Option, /// Environment variables to pass to the Stdio MCP server process. #[serde(skip_serializing_if = "Option::is_none")] pub env: Option>, @@ -16083,6 +16312,9 @@ pub struct CanvasOpenResult { pub url: Option, } +/// HTTP headers as a map from lowercased header name to a list of values. Multi-valued headers (e.g. Set-Cookie) preserve all values. +pub type LlmInferenceHeaders = HashMap>; + /// MCP CreateMessageResult payload (with optional 'tools' extension), present when action='success'. Treated as opaque at the schema layer; consumers should construct/consume it per the MCP CreateMessageResult shape. /// ///
@@ -16985,6 +17217,93 @@ pub enum InstructionSourceType { Unknown, } +/// What kind of model-layer endpoint this is. +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] +pub enum LlmInferenceRequestMetadataEndpointKind { + /// An inference request (chat/completions, responses, messages). + #[serde(rename = "inference")] + Inference, + /// Listing of available models. + #[serde(rename = "models-catalog")] + ModelsCatalog, + /// Per-model session/auth bootstrap. + #[serde(rename = "models-session")] + ModelsSession, + /// Per-model policy lookup. + #[serde(rename = "models-policy")] + ModelsPolicy, + /// An embeddings request. + #[serde(rename = "embeddings")] + Embeddings, + /// Model-layer endpoint not specifically categorized. + #[serde(rename = "other")] + Other, + /// Unknown variant for forward compatibility. + #[default] + #[serde(other)] + Unknown, +} + +/// Logical model provider this request targets. +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] +pub enum LlmInferenceRequestMetadataProviderType { + /// GitHub Copilot CAPI. + #[serde(rename = "copilot")] + Copilot, + /// OpenAI. + #[serde(rename = "openai")] + Openai, + /// Azure OpenAI. + #[serde(rename = "azure")] + Azure, + /// Anthropic. + #[serde(rename = "anthropic")] + Anthropic, + /// Google Gemini / Vertex. + #[serde(rename = "google")] + Google, + /// Provider not recognised by the runtime's URL heuristics. + #[serde(rename = "other")] + Other, + /// Unknown variant for forward compatibility. + #[default] + #[serde(other)] + Unknown, +} + +/// Transport kind. v1 implements http only. +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] +pub enum LlmInferenceRequestMetadataTransport { + /// Plain HTTP request/response, possibly with an SSE-encoded streamed body. + #[serde(rename = "http")] + Http, + /// WebSocket connection. Not implemented in v1 of the callback wire. + #[serde(rename = "websocket")] + Websocket, + /// Unknown variant for forward compatibility. + #[default] + #[serde(other)] + Unknown, +} + +/// Wire API shape, when this is an inference request. +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] +pub enum LlmInferenceRequestMetadataWireApi { + /// OpenAI chat completions API. + #[serde(rename = "completions")] + Completions, + /// OpenAI responses API. + #[serde(rename = "responses")] + Responses, + /// Anthropic messages API. + #[serde(rename = "messages")] + Messages, + /// Unknown variant for forward compatibility. + #[default] + #[serde(other)] + Unknown, +} + /// Repository host type /// ///
@@ -17251,6 +17570,21 @@ pub enum McpSamplingExecutionAction { Unknown, } +/// Controls if tools provided by this server can be loaded on demand via tool search (auto) or always included in the initial tool list (never) +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] +pub enum McpServerConfigDeferTools { + /// Tools may be deferred under certain conditions + #[serde(rename = "auto")] + Auto, + /// Tools are always included in the initial tool list, even when tool search is enabled. + #[serde(rename = "never")] + Never, + /// Unknown variant for forward compatibility. + #[default] + #[serde(other)] + Unknown, +} + /// OAuth grant type to use when authenticating to the remote MCP server. #[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] pub enum McpServerConfigHttpOauthGrantType { diff --git a/rust/src/generated/rpc.rs b/rust/src/generated/rpc.rs index 22fd08904..5cd5f30f4 100644 --- a/rust/src/generated/rpc.rs +++ b/rust/src/generated/rpc.rs @@ -49,6 +49,13 @@ impl<'a> ClientRpc<'a> { } } + /// `llmInference.*` sub-namespace. + pub fn llm_inference(&self) -> ClientRpcLlmInference<'a> { + ClientRpcLlmInference { + client: self.client, + } + } + /// `mcp.*` sub-namespace. pub fn mcp(&self) -> ClientRpcMcp<'a> { ClientRpcMcp { @@ -321,6 +328,100 @@ impl<'a> ClientRpcInstructions<'a> { } } +/// `llmInference.*` RPCs. +#[derive(Clone, Copy)] +pub struct ClientRpcLlmInference<'a> { + pub(crate) client: &'a Client, +} + +impl<'a> ClientRpcLlmInference<'a> { + /// Registers an SDK client as the LLM inference callback provider. + /// + /// Wire method: `llmInference.setProvider`. + /// + /// # Returns + /// + /// Indicates whether the calling client was registered as the LLM inference provider. + /// + ///
+ /// + /// **Experimental.** This API is part of an experimental wire-protocol surface + /// and may change or be removed in future SDK or CLI releases. Pin both the + /// SDK and CLI versions if your code depends on it. + /// + ///
+ pub async fn set_provider(&self) -> Result { + let wire_params = serde_json::json!({}); + let _value = self + .client + .call(rpc_methods::LLMINFERENCE_SETPROVIDER, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Pushes a streamed response body chunk back to the runtime, correlated by the streamToken the runtime previously handed out in llmInference.httpStreamStart. + /// + /// Wire method: `llmInference.streamChunk`. + /// + /// # Parameters + /// + /// * `params` - A streamed response body chunk. + /// + /// # Returns + /// + /// Whether the chunk was accepted. + /// + ///
+ /// + /// **Experimental.** This API is part of an experimental wire-protocol surface + /// and may change or be removed in future SDK or CLI releases. Pin both the + /// SDK and CLI versions if your code depends on it. + /// + ///
+ pub async fn stream_chunk( + &self, + params: LlmInferenceStreamChunkRequest, + ) -> Result { + let wire_params = serde_json::to_value(params)?; + let _value = self + .client + .call(rpc_methods::LLMINFERENCE_STREAMCHUNK, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Signals end-of-stream for an inference response stream the SDK client started via llmInference.httpStreamStart. + /// + /// Wire method: `llmInference.streamEnd`. + /// + /// # Parameters + /// + /// * `params` - End-of-stream signal. + /// + /// # Returns + /// + /// Whether the end signal was accepted. + /// + ///
+ /// + /// **Experimental.** This API is part of an experimental wire-protocol surface + /// and may change or be removed in future SDK or CLI releases. Pin both the + /// SDK and CLI versions if your code depends on it. + /// + ///
+ pub async fn stream_end( + &self, + params: LlmInferenceStreamEndRequest, + ) -> Result { + let wire_params = serde_json::to_value(params)?; + let _value = self + .client + .call(rpc_methods::LLMINFERENCE_STREAMEND, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + /// `mcp.*` RPCs. #[derive(Clone, Copy)] pub struct ClientRpcMcp<'a> { diff --git a/rust/src/generated/session_events.rs b/rust/src/generated/session_events.rs index 92d7fa133..040859462 100644 --- a/rust/src/generated/session_events.rs +++ b/rust/src/generated/session_events.rs @@ -1428,6 +1428,9 @@ pub struct AssistantUsageData { /// Number of tokens written to prompt cache #[serde(skip_serializing_if = "Option::is_none")] pub cache_write_tokens: Option, + /// Whether the model response was blocked or truncated by content filtering (finish_reason === 'content_filter'). For Anthropic models this corresponds to a 'refusal' stop reason. + #[serde(skip_serializing_if = "Option::is_none")] + pub content_filter_triggered: Option, /// Per-request cost and usage data from the CAPI copilot_usage response field #[doc(hidden)] #[serde(skip_serializing_if = "Option::is_none")] @@ -1445,6 +1448,9 @@ pub struct AssistantUsageData { /// Duration of the API call in milliseconds #[serde(skip_serializing_if = "Option::is_none")] pub duration: Option, + /// Finish reason reported by the model for this API call (e.g. "stop", "length", "tool_calls", "content_filter"). Normalized to OpenAI vocabulary; for Anthropic models a "refusal" stop reason maps to "content_filter". + #[serde(skip_serializing_if = "Option::is_none")] + pub finish_reason: Option, /// What initiated this API call (e.g., "sub-agent", "mcp-sampling"); absent for user-initiated calls #[serde(skip_serializing_if = "Option::is_none")] pub initiator: Option, @@ -1878,6 +1884,9 @@ pub struct ToolExecutionCompleteResult { /// Full detailed tool result for UI/timeline display, preserving complete content such as diffs. Falls back to content when absent. #[serde(skip_serializing_if = "Option::is_none")] pub detailed_content: Option, + /// Structured content (arbitrary JSON) returned verbatim by the MCP tool + #[serde(skip_serializing_if = "Option::is_none")] + pub structured_content: Option, /// MCP Apps UI resource content for rendering in a sandboxed iframe #[serde(skip_serializing_if = "Option::is_none")] pub ui_resource: Option, diff --git a/scripts/codegen/csharp.ts b/scripts/codegen/csharp.ts index 7b0f64a77..e1ceea5b1 100644 --- a/scripts/codegen/csharp.ts +++ b/scripts/codegen/csharp.ts @@ -2297,6 +2297,142 @@ function emitClientSessionApiRegistration(clientSchema: Record, return lines; } +/** + * Emit C# handler interfaces + a process-wide registration for client + * *global* API groups. + * + * Unlike client-session APIs, these methods carry no implicit `sessionId` + * dispatch key. The SDK consumer registers a single process-wide handler set + * via `RegisterClientGlobalApiHandlers`; the runtime dispatcher routes each + * incoming call to the registered handler regardless of which (if any) + * runtime session triggered it. + */ +function emitClientGlobalApiRegistration(clientSchema: Record, classes: string[]): string[] { + const lines: string[] = []; + const groups = collectClientGroups(clientSchema); + + for (const { methods } of groups) { + for (const method of methods) { + const resultSchema = getMethodResultSchema(method); + if (!isVoidSchema(resultSchema) && !isOpaqueJson(resultSchema)) { + emitRpcResultType(resultTypeName(method), resultSchema!, "public", classes); + } + + const effectiveParams = resolveMethodParamsSchema(method); + if (effectiveParams?.properties && Object.keys(effectiveParams.properties).length > 0) { + const paramsClass = emitRpcClass(paramsTypeName(method), effectiveParams, "public", classes); + if (paramsClass) classes.push(paramsClass); + } + } + } + + for (const { groupName, groupNode, methods } of groups) { + const interfaceName = clientHandlerInterfaceName(groupName); + const groupExperimental = isNodeFullyExperimental(groupNode); + const groupDeprecated = isNodeFullyDeprecated(groupNode); + lines.push(`/// Handles \`${groupName}\` client global API methods.`); + if (groupExperimental) { + pushExperimentalAttribute(lines); + } + if (groupDeprecated) { + pushObsoleteAttributes(lines); + } + lines.push(`public interface ${interfaceName}`); + lines.push(`{`); + for (const method of methods) { + const effectiveParams = resolveMethodParamsSchema(method); + const hasParams = !!effectiveParams?.properties && Object.keys(effectiveParams.properties).length > 0; + const resultSchema = getMethodResultSchema(method); + const taskType = resultTaskType(method); + pushRpcMethodXmlDocs( + lines, + method, + " ", + [ + ...(hasParams ? [{ name: "request", description: rpcParamsDescription(method, effectiveParams) }] : []), + { name: "cancellationToken", description: CANCELLATION_TOKEN_DESCRIPTION, escapeDescription: false }, + ], + resultSchema, + `Handles "${method.rpcMethod}".` + ); + if (method.stability === "experimental" && !groupExperimental) { + pushExperimentalAttribute(lines, " "); + } + if (method.deprecated && !groupDeprecated) { + pushObsoleteAttributes(lines, " "); + } + if (hasParams) { + lines.push(` ${taskType} ${clientHandlerMethodName(method.rpcMethod)}(${paramsTypeName(method)} request, CancellationToken cancellationToken = default);`); + } else { + lines.push(` ${taskType} ${clientHandlerMethodName(method.rpcMethod)}(CancellationToken cancellationToken = default);`); + } + } + lines.push(`}`); + lines.push(""); + } + + lines.push(`/// Provides all client global API handler groups for a connection.`); + lines.push(`public sealed class ClientGlobalApiHandlers`); + lines.push(`{`); + for (const { groupName } of groups) { + lines.push(` /// Optional handler for ${toPascalCase(groupName)} client global API methods.`); + lines.push(` public ${clientHandlerInterfaceName(groupName)}? ${toPascalCase(groupName)} { get; set; }`); + lines.push(""); + } + if (lines[lines.length - 1] === "") lines.pop(); + lines.push(`}`); + lines.push(""); + + lines.push(`/// Registers client global API handlers on a JSON-RPC connection.`); + lines.push(`internal static class ClientGlobalApiRegistration`); + lines.push(`{`); + lines.push(` /// `); + lines.push(` /// Registers handlers for server-to-client global API calls.`); + lines.push(` /// Unlike client session APIs, these methods carry no implicit`); + lines.push(` /// sessionId dispatch key — a single set of handlers serves the`); + lines.push(` /// entire connection.`); + lines.push(` /// `); + lines.push(` public static void RegisterClientGlobalApiHandlers(JsonRpc rpc, ClientGlobalApiHandlers handlers)`); + lines.push(` {`); + for (const { groupName, methods } of groups) { + for (const method of methods) { + const handlerProperty = toPascalCase(groupName); + const handlerMethod = clientHandlerMethodName(method.rpcMethod); + const effectiveParams = resolveMethodParamsSchema(method); + const hasParams = !!effectiveParams?.properties && Object.keys(effectiveParams.properties).length > 0; + const resultSchema = getMethodResultSchema(method); + const paramsClass = paramsTypeName(method); + const taskType = handlerTaskType(method); + + if (hasParams) { + lines.push(` rpc.SetLocalRpcMethod("${method.rpcMethod}", (Func<${paramsClass}, CancellationToken, ${taskType}>)(async (request, cancellationToken) =>`); + lines.push(` {`); + lines.push(` var handler = handlers.${handlerProperty} ?? throw new InvalidOperationException("No ${groupName} client-global handler registered");`); + if (!isVoidSchema(resultSchema)) { + lines.push(` return await handler.${handlerMethod}(request, cancellationToken);`); + } else { + lines.push(` await handler.${handlerMethod}(request, cancellationToken);`); + } + lines.push(` }), singleObjectParam: true);`); + } else { + lines.push(` rpc.SetLocalRpcMethod("${method.rpcMethod}", (Func)(async cancellationToken =>`); + lines.push(` {`); + lines.push(` var handler = handlers.${handlerProperty} ?? throw new InvalidOperationException("No ${groupName} client-global handler registered");`); + if (!isVoidSchema(resultSchema)) { + lines.push(` return await handler.${handlerMethod}(cancellationToken);`); + } else { + lines.push(` await handler.${handlerMethod}(cancellationToken);`); + } + lines.push(` }));`); + } + } + } + lines.push(` }`); + lines.push(`}`); + + return lines; +} + function generateRpcCode( schema: ApiSchema, externalJsonSerializableRefs: Map> = new Map(), @@ -2315,6 +2451,7 @@ function generateRpcCode( ...collectRpcMethods(schema.server || {}), ...collectRpcMethods(schema.session || {}), ...collectRpcMethods(schema.clientSession || {}), + ...collectRpcMethods(schema.clientGlobal || {}), ]; for (const name of collectRpcMethodReferencedDefinitionNames( allMethods.filter((method) => method.stability !== "experimental"), @@ -2343,6 +2480,9 @@ function generateRpcCode( let clientSessionParts: string[] = []; if (schema.clientSession) clientSessionParts = emitClientSessionApiRegistration(schema.clientSession, classes); + let clientGlobalParts: string[] = []; + if (schema.clientGlobal) clientGlobalParts = emitClientGlobalApiRegistration(schema.clientGlobal, classes); + const lines: string[] = []; lines.push(`${COPYRIGHT} @@ -2368,6 +2508,7 @@ namespace GitHub.Copilot.Rpc; for (const part of serverRpcParts) lines.push(part, ""); for (const part of sessionRpcParts) lines.push(part, ""); if (clientSessionParts.length > 0) lines.push(...clientSessionParts, ""); + if (clientGlobalParts.length > 0) lines.push(...clientGlobalParts, ""); // Add JsonSerializerContext for AOT/trimming support const typeNames = [...emittedRpcClassSchemas.keys(), ...emittedRpcEnumResultTypes].sort(); diff --git a/scripts/codegen/typescript.ts b/scripts/codegen/typescript.ts index bba360b47..1303a4979 100644 --- a/scripts/codegen/typescript.ts +++ b/scripts/codegen/typescript.ts @@ -516,7 +516,8 @@ import type { MessageConnection } from "vscode-jsonrpc/node.js"; const allMethods = [...collectRpcMethods(schema.server || {}), ...collectRpcMethods(schema.session || {})]; const clientSessionMethods = collectRpcMethods(schema.clientSession || {}); - const rpcMethods = [...allMethods, ...clientSessionMethods]; + const clientGlobalMethods = collectRpcMethods(schema.clientGlobal || {}); + const rpcMethods = [...allMethods, ...clientSessionMethods, ...clientGlobalMethods]; const seenBlocks = new Map(); // Build a single combined schema with shared definitions and all method types. @@ -717,6 +718,13 @@ function hasInternalMethods(node: Record): boolean { lines.push(...emitClientSessionApiRegistration(schema.clientSession)); } + // Generate client *global* API handler interfaces and registration function. + // Unlike client-session APIs, these methods do not carry a `sessionId` dispatch + // key — the SDK consumer registers a single process-wide handler per group. + if (schema.clientGlobal) { + lines.push(...emitClientGlobalApiRegistration(schema.clientGlobal)); + } + const outPath = await writeGeneratedFile("nodejs/src/generated/rpc.ts", lines.join("\n")); console.log(` ✓ ${outPath}`); } @@ -926,6 +934,105 @@ function emitClientSessionApiRegistration(clientSchema: Record) return lines; } +/** + * Generate handler interfaces and a registration function for client *global* + * API groups. + * + * Unlike client-session APIs, these methods carry no implicit `sessionId` + * dispatch key. The SDK consumer registers a single process-wide handler set + * via `registerClientGlobalApiHandlers`; the runtime dispatcher routes each + * incoming call to the registered handler regardless of which (if any) + * runtime session triggered it. + */ +function emitClientGlobalApiRegistration(clientSchema: Record): string[] { + const lines: string[] = []; + const groups = collectClientGroups(clientSchema); + + for (const [groupName, methods] of groups) { + const interfaceName = toPascalCase(groupName) + "Handler"; + const groupDeprecated = isNodeFullyDeprecated(clientSchema[groupName] as Record); + const groupExperimental = isNodeFullyExperimental(clientSchema[groupName] as Record); + if (groupDeprecated) { + lines.push(`/** @deprecated Handler for \`${groupName}\` client global API methods. */`); + } else if (groupExperimental) { + lines.push(`/** Handler for \`${groupName}\` client global API methods. */`); + lines.push(TS_EXPERIMENTAL_JSDOC); + } else { + lines.push(`/** Handler for \`${groupName}\` client global API methods. */`); + } + lines.push(`export interface ${interfaceName} {`); + for (const method of methods) { + const name = handlerMethodName(method.rpcMethod); + const hasParams = hasSchemaPayload(getMethodParamsSchema(method)); + const pType = hasParams ? paramsTypeName(method) : ""; + const rType = tsResultType(method); + + pushTsRpcMethodJsDoc(lines, " ", method, { + summaryFallback: `Handles \`${method.rpcMethod}\`.`, + paramsName: hasParams ? "params" : undefined, + paramsDescription: rpcParamsDescription(method, getMethodParamsSchema(method)), + includeDeprecated: method.deprecated && !groupDeprecated, + includeExperimental: method.stability === "experimental" && !groupExperimental, + }); + if (hasParams) { + lines.push(` ${name}(params: ${pType}): Promise<${rType}>;`); + } else { + lines.push(` ${name}(): Promise<${rType}>;`); + } + } + lines.push(`}`); + lines.push(""); + } + + lines.push(`/** All client global API handler groups. */`); + lines.push(`export interface ClientGlobalApiHandlers {`); + for (const [groupName] of groups) { + const interfaceName = toPascalCase(groupName) + "Handler"; + lines.push(` ${groupName}?: ${interfaceName};`); + } + lines.push(`}`); + lines.push(""); + + lines.push(`/**`); + lines.push(` * Register client global API handlers on a JSON-RPC connection.`); + lines.push(` * The server calls these methods to delegate work to the client.`); + lines.push(` * Unlike session-scoped client APIs, these methods carry no implicit`); + lines.push(` * \`sessionId\` dispatch key — a single set of handlers serves the entire`); + lines.push(` * connection.`); + lines.push(` */`); + lines.push(`export function registerClientGlobalApiHandlers(`); + lines.push(` connection: MessageConnection,`); + lines.push(` handlers: ClientGlobalApiHandlers,`); + lines.push(`): void {`); + + for (const [groupName, methods] of groups) { + for (const method of methods) { + const name = handlerMethodName(method.rpcMethod); + const pType = paramsTypeName(method); + const hasParams = hasSchemaPayload(getMethodParamsSchema(method)); + + if (hasParams) { + lines.push(` connection.onRequest("${method.rpcMethod}", async (params: ${pType}) => {`); + lines.push(` const handler = handlers.${groupName};`); + lines.push(` if (!handler) throw new Error("No ${groupName} client-global handler registered");`); + lines.push(` return handler.${name}(params);`); + lines.push(` });`); + } else { + lines.push(` connection.onRequest("${method.rpcMethod}", async () => {`); + lines.push(` const handler = handlers.${groupName};`); + lines.push(` if (!handler) throw new Error("No ${groupName} client-global handler registered");`); + lines.push(` return handler.${name}();`); + lines.push(` });`); + } + } + } + + lines.push(`}`); + lines.push(""); + + return lines; +} + // ── Main ──────────────────────────────────────────────────────────────────── async function generate(sessionSchemaPath?: string, apiSchemaPath?: string): Promise { diff --git a/scripts/codegen/utils.ts b/scripts/codegen/utils.ts index 3917dad44..726485c93 100644 --- a/scripts/codegen/utils.ts +++ b/scripts/codegen/utils.ts @@ -470,6 +470,7 @@ export interface ApiSchema { server?: Record; session?: Record; clientSession?: Record; + clientGlobal?: Record; } export function isRpcMethod(node: unknown): node is RpcMethod { @@ -519,6 +520,7 @@ export function fixNullableRequiredRefsInApiSchema(schema: ApiSchema): ApiSchema server: walkApiNode(schema.server), session: walkApiNode(schema.session), clientSession: walkApiNode(schema.clientSession), + clientGlobal: walkApiNode(schema.clientGlobal), }; }