From 4f2fc4ca6d0b7a0b13b236dafb0e4e3148c2ed58 Mon Sep 17 00:00:00 2001 From: Rob von Behren Date: Thu, 7 May 2026 20:48:28 -0700 Subject: [PATCH 1/2] feat: share HTTP connection pool across SDK instances; refactor polling (#797) Co-authored-by: Claude Opus 4.6 --- src/runloop_api_client/_base_client.py | 240 +++++++++++-- src/runloop_api_client/_client.py | 30 +- src/runloop_api_client/_constants.py | 2 +- src/runloop_api_client/lib/wait_for_status.py | 99 ++++++ .../resources/devboxes/devboxes.py | 134 +++----- .../resources/devboxes/executions.py | 73 ++-- tests/test_client.py | 6 +- tests/test_shared_pool.py | 319 ++++++++++++++++++ uv.lock | 2 +- 9 files changed, 736 insertions(+), 169 deletions(-) create mode 100644 src/runloop_api_client/lib/wait_for_status.py create mode 100644 tests/test_shared_pool.py diff --git a/src/runloop_api_client/_base_client.py b/src/runloop_api_client/_base_client.py index 410e78aab..88e0bbb3b 100644 --- a/src/runloop_api_client/_base_client.py +++ b/src/runloop_api_client/_base_client.py @@ -8,8 +8,10 @@ import asyncio import inspect import logging +import weakref import platform import warnings +import threading import email.utils from types import TracebackType from random import random @@ -90,6 +92,88 @@ log: logging.Logger = logging.getLogger(__name__) +# Shared HTTP transport state. We share transports (connection pools) rather +# than full httpx clients so each SDK instance keeps its own cookie jar and +# mutable client state. Refcounted wrappers close the real transport only +# when the last user releases it. +# The async transport is keyed by event loop because connections bind to the +# loop that created them and cannot be reused across asyncio.run() calls. +_pool_lock = threading.Lock() + + +class _SharedTransport(httpx.BaseTransport): + """Refcounted wrapper: delegates to a real transport, closes it when refcount hits 0.""" + + def __init__(self, transport: httpx.BaseTransport) -> None: + self._transport = transport + self._refcount = 1 + self._lock = threading.Lock() + + @property + def refcount(self) -> int: + return self._refcount + + def acquire(self) -> bool: + with self._lock: + if self._refcount <= 0: + return False + self._refcount += 1 + return True + + @override + def handle_request(self, request: httpx.Request) -> httpx.Response: + return self._transport.handle_request(request) + + @override + def close(self) -> None: + should_close = False + with self._lock: + self._refcount -= 1 + if self._refcount <= 0: + should_close = True + if should_close: + self._transport.close() + + +class _SharedAsyncTransport(httpx.AsyncBaseTransport): + """Async refcounted wrapper: delegates to a real async transport.""" + + def __init__(self, transport: httpx.AsyncBaseTransport) -> None: + self._transport = transport + self._refcount = 1 + self._lock = threading.Lock() + + @property + def refcount(self) -> int: + return self._refcount + + def acquire(self) -> bool: + with self._lock: + if self._refcount <= 0: + return False + self._refcount += 1 + return True + + @override + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: + return await self._transport.handle_async_request(request) + + @override + async def aclose(self) -> None: + should_close = False + with self._lock: + self._refcount -= 1 + if self._refcount <= 0: + should_close = True + if should_close: + await self._transport.aclose() + + +_shared_sync_transport: _SharedTransport | None = None +_shared_async_transports: weakref.WeakKeyDictionary[asyncio.AbstractEventLoop, _SharedAsyncTransport] = ( + weakref.WeakKeyDictionary() +) + # TODO: make base page type vars covariant SyncPageT = TypeVar("SyncPageT", bound="BaseSyncPage[Any]") AsyncPageT = TypeVar("AsyncPageT", bound="BaseAsyncPage[Any]") @@ -816,6 +900,7 @@ def __init__(self, **kwargs: Any) -> None: kwargs.setdefault("timeout", DEFAULT_TIMEOUT) kwargs.setdefault("limits", DEFAULT_CONNECTION_LIMITS) kwargs.setdefault("follow_redirects", True) + kwargs.setdefault("http2", True) super().__init__(**kwargs) @@ -845,6 +930,8 @@ def __del__(self) -> None: class SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]): _client: httpx.Client _default_stream_cls: type[Stream[Any]] | None = None + _uses_shared_pool: bool + _closed: bool def __init__( self, @@ -857,6 +944,7 @@ def __init__( custom_headers: Mapping[str, str] | None = None, custom_query: Mapping[str, object] | None = None, _strict_response_validation: bool, + shared_http_pool: bool = True, ) -> None: if not is_given(timeout): # if the user passed in a custom http client with a non-default @@ -886,24 +974,46 @@ def __init__( custom_headers=custom_headers, _strict_response_validation=_strict_response_validation, ) - self._client = http_client or SyncHttpxClientWrapper( - base_url=base_url, - # cast to a valid type because mypy doesn't understand our type narrowing - timeout=cast(Timeout, timeout), - ) + + self._closed = False + + if http_client is not None: + self._client = http_client + self._uses_shared_pool = False + elif shared_http_pool: + global _shared_sync_transport + with _pool_lock: + if _shared_sync_transport is None or not _shared_sync_transport.acquire(): + _shared_sync_transport = _SharedTransport( + httpx.HTTPTransport(limits=DEFAULT_CONNECTION_LIMITS, http2=True), + ) + self._client = SyncHttpxClientWrapper( + base_url=base_url, + timeout=cast(Timeout, timeout), + transport=_shared_sync_transport, + ) + self._uses_shared_pool = True + else: + self._client = SyncHttpxClientWrapper( + base_url=base_url, + timeout=cast(Timeout, timeout), + ) + self._uses_shared_pool = False def is_closed(self) -> bool: - return self._client.is_closed + return self._closed or self._client.is_closed def close(self) -> None: """Close the underlying HTTPX client. The client will *not* be usable after this. """ - # If an error is thrown while constructing a client, self._client - # may not be present - if hasattr(self, "_client"): - self._client.close() + if not hasattr(self, "_client"): + return + if self._closed: + return + self._closed = True + self._client.close() def __enter__(self: _T) -> _T: return self @@ -1018,6 +1128,7 @@ def request( max_retries=max_retries, options=input_options, response=None, + error=err, ) continue @@ -1032,6 +1143,7 @@ def request( max_retries=max_retries, options=input_options, response=None, + error=err, ) continue @@ -1083,7 +1195,13 @@ def request( ) def _sleep_for_retry( - self, *, retries_taken: int, max_retries: int, options: FinalRequestOptions, response: httpx.Response | None + self, + *, + retries_taken: int, + max_retries: int, + options: FinalRequestOptions, + response: httpx.Response | None, + error: BaseException | None = None, ) -> None: remaining_retries = max_retries - retries_taken if remaining_retries == 1: @@ -1092,7 +1210,23 @@ def _sleep_for_retry( log.debug("%i retries left", remaining_retries) timeout = self._calculate_retry_timeout(remaining_retries, options, response.headers if response else None) - log.info("Retrying request to %s in %f seconds", options.url, timeout) + if response is not None: + log.info( + "Retrying request to %s in %f seconds (status %d)", + options.url, + timeout, + response.status_code, + ) + elif error is not None: + log.info( + "Retrying request to %s in %f seconds (%s: %s)", + options.url, + timeout, + type(error).__name__, + error, + ) + else: + log.info("Retrying request to %s in %f seconds", options.url, timeout) time.sleep(timeout) @@ -1428,6 +1562,8 @@ def __del__(self) -> None: class AsyncAPIClient(BaseClient[httpx.AsyncClient, AsyncStream[Any]]): _client: httpx.AsyncClient _default_stream_cls: type[AsyncStream[Any]] | None = None + _uses_shared_pool: bool + _closed: bool def __init__( self, @@ -1440,6 +1576,7 @@ def __init__( http_client: httpx.AsyncClient | None = None, custom_headers: Mapping[str, str] | None = None, custom_query: Mapping[str, object] | None = None, + shared_http_pool: bool = True, ) -> None: if not is_given(timeout): # if the user passed in a custom http client with a non-default @@ -1469,20 +1606,59 @@ def __init__( custom_headers=custom_headers, _strict_response_validation=_strict_response_validation, ) - self._client = http_client or AsyncHttpxClientWrapper( - base_url=base_url, - # cast to a valid type because mypy doesn't understand our type narrowing - timeout=cast(Timeout, timeout), - ) + + self._closed = False + + if http_client is not None: + self._client = http_client + self._uses_shared_pool = False + elif shared_http_pool: + try: + loop: asyncio.AbstractEventLoop | None = asyncio.get_running_loop() + except RuntimeError: + loop = None + if loop is not None: + with _pool_lock: + existing = _shared_async_transports.get(loop) + if existing is not None and existing.acquire(): + transport: _SharedAsyncTransport = existing + else: + transport = _SharedAsyncTransport( + httpx.AsyncHTTPTransport(limits=DEFAULT_CONNECTION_LIMITS, http2=True), + ) + _shared_async_transports[loop] = transport + self._client = AsyncHttpxClientWrapper( + base_url=base_url, + timeout=cast(Timeout, timeout), + transport=transport, + ) + self._uses_shared_pool = True + else: + self._client = AsyncHttpxClientWrapper( + base_url=base_url, + timeout=cast(Timeout, timeout), + ) + self._uses_shared_pool = False + else: + self._client = AsyncHttpxClientWrapper( + base_url=base_url, + timeout=cast(Timeout, timeout), + ) + self._uses_shared_pool = False def is_closed(self) -> bool: - return self._client.is_closed + return self._closed or self._client.is_closed async def close(self) -> None: """Close the underlying HTTPX client. The client will *not* be usable after this. """ + if not hasattr(self, "_client"): + return + if self._closed: + return + self._closed = True await self._client.aclose() async def __aenter__(self: _T) -> _T: @@ -1603,6 +1779,7 @@ async def request( max_retries=max_retries, options=input_options, response=None, + error=err, ) continue @@ -1617,6 +1794,7 @@ async def request( max_retries=max_retries, options=input_options, response=None, + error=err, ) continue @@ -1668,7 +1846,13 @@ async def request( ) async def _sleep_for_retry( - self, *, retries_taken: int, max_retries: int, options: FinalRequestOptions, response: httpx.Response | None + self, + *, + retries_taken: int, + max_retries: int, + options: FinalRequestOptions, + response: httpx.Response | None, + error: BaseException | None = None, ) -> None: remaining_retries = max_retries - retries_taken if remaining_retries == 1: @@ -1677,7 +1861,23 @@ async def _sleep_for_retry( log.debug("%i retries left", remaining_retries) timeout = self._calculate_retry_timeout(remaining_retries, options, response.headers if response else None) - log.info("Retrying request to %s in %f seconds", options.url, timeout) + if response is not None: + log.info( + "Retrying request to %s in %f seconds (status %d)", + options.url, + timeout, + response.status_code, + ) + elif error is not None: + log.info( + "Retrying request to %s in %f seconds (%s: %s)", + options.url, + timeout, + type(error).__name__, + error, + ) + else: + log.info("Retrying request to %s in %f seconds", options.url, timeout) await anyio.sleep(timeout) diff --git a/src/runloop_api_client/_client.py b/src/runloop_api_client/_client.py index 61db3a474..1867f9997 100644 --- a/src/runloop_api_client/_client.py +++ b/src/runloop_api_client/_client.py @@ -84,6 +84,10 @@ def __init__( # We provide a `DefaultHttpxClient` class that you can pass to retain the default values we use for `limits`, `timeout` & `follow_redirects`. # See the [httpx documentation](https://www.python-httpx.org/api/#client) for more details. http_client: httpx.Client | None = None, + # Share a single httpx connection pool across all Runloop client instances. + # Enables HTTP/2 multiplexing and avoids ConnectTimeout storms under high concurrency. + # Set to False to create a private connection pool (old behavior). + shared_http_pool: bool = True, # Enable or disable schema validation for data returned by the API. # When enabled an error APIResponseValidationError is raised # if the API responds with invalid data for the expected schema. @@ -120,6 +124,7 @@ def __init__( custom_headers=default_headers, custom_query=default_query, _strict_response_validation=_strict_response_validation, + shared_http_pool=shared_http_pool, ) self._idempotency_header = "x-request-id" @@ -249,6 +254,7 @@ def copy( base_url: str | httpx.URL | None = None, timeout: float | Timeout | None | NotGiven = not_given, http_client: httpx.Client | None = None, + shared_http_pool: bool | None = None, max_retries: int | NotGiven = not_given, default_headers: Mapping[str, str] | None = None, set_default_headers: Mapping[str, str] | None = None, @@ -277,12 +283,19 @@ def copy( elif set_default_query is not None: params = set_default_query - http_client = http_client or self._client + if http_client is not None: + resolved_shared = False + elif shared_http_pool is not None: + resolved_shared = shared_http_pool + else: + resolved_shared = self._uses_shared_pool + return self.__class__( bearer_token=bearer_token or self.bearer_token, base_url=base_url or self.base_url, timeout=self.timeout if isinstance(timeout, NotGiven) else timeout, http_client=http_client, + shared_http_pool=resolved_shared, max_retries=max_retries if is_given(max_retries) else self.max_retries, default_headers=headers, default_query=params, @@ -344,6 +357,10 @@ def __init__( # We provide a `DefaultAsyncHttpxClient` class that you can pass to retain the default values we use for `limits`, `timeout` & `follow_redirects`. # See the [httpx documentation](https://www.python-httpx.org/api/#asyncclient) for more details. http_client: httpx.AsyncClient | None = None, + # Share a single httpx connection pool across all AsyncRunloop client instances. + # Enables HTTP/2 multiplexing and avoids ConnectTimeout storms under high concurrency. + # Set to False to create a private connection pool (old behavior). + shared_http_pool: bool = True, # Enable or disable schema validation for data returned by the API. # When enabled an error APIResponseValidationError is raised # if the API responds with invalid data for the expected schema. @@ -380,6 +397,7 @@ def __init__( custom_headers=default_headers, custom_query=default_query, _strict_response_validation=_strict_response_validation, + shared_http_pool=shared_http_pool, ) self._idempotency_header = "x-request-id" @@ -509,6 +527,7 @@ def copy( base_url: str | httpx.URL | None = None, timeout: float | Timeout | None | NotGiven = not_given, http_client: httpx.AsyncClient | None = None, + shared_http_pool: bool | None = None, max_retries: int | NotGiven = not_given, default_headers: Mapping[str, str] | None = None, set_default_headers: Mapping[str, str] | None = None, @@ -537,12 +556,19 @@ def copy( elif set_default_query is not None: params = set_default_query - http_client = http_client or self._client + if http_client is not None: + resolved_shared = False + elif shared_http_pool is not None: + resolved_shared = shared_http_pool + else: + resolved_shared = self._uses_shared_pool + return self.__class__( bearer_token=bearer_token or self.bearer_token, base_url=base_url or self.base_url, timeout=self.timeout if isinstance(timeout, NotGiven) else timeout, http_client=http_client, + shared_http_pool=resolved_shared, max_retries=max_retries if is_given(max_retries) else self.max_retries, default_headers=headers, default_query=params, diff --git a/src/runloop_api_client/_constants.py b/src/runloop_api_client/_constants.py index d6361c8ad..88f944ce2 100644 --- a/src/runloop_api_client/_constants.py +++ b/src/runloop_api_client/_constants.py @@ -8,7 +8,7 @@ # default timeout is 30 seconds DEFAULT_TIMEOUT = httpx.Timeout(timeout=30, connect=5.0) DEFAULT_MAX_RETRIES = 5 -DEFAULT_CONNECTION_LIMITS = httpx.Limits(max_connections=100, max_keepalive_connections=20) +DEFAULT_CONNECTION_LIMITS = httpx.Limits(max_connections=20, max_keepalive_connections=10) INITIAL_RETRY_DELAY = 1.0 MAX_RETRY_DELAY = 60.0 diff --git a/src/runloop_api_client/lib/wait_for_status.py b/src/runloop_api_client/lib/wait_for_status.py new file mode 100644 index 000000000..73df2bf95 --- /dev/null +++ b/src/runloop_api_client/lib/wait_for_status.py @@ -0,0 +1,99 @@ +"""Helpers for polling wait_for_status long-poll endpoints. + +Each function wraps a server-side long-poll POST with a client-side retry +loop. On each iteration the remaining timeout is forwarded to the server +so the server can long-poll for up to that duration. 408 responses and +client-side timeouts are converted to a caller-supplied placeholder so the +loop can continue. No client-side sleep between iterations — the +server-side long-poll *is* the wait. +""" + +from __future__ import annotations + +import time +from typing import List, Type, TypeVar, Callable, Optional, Awaitable + +from .polling import PollingConfig, PollingTimeout +from .._exceptions import APIStatusError, APITimeoutError + +T = TypeVar("T") + + +def wait_for_status( + post_fn: Callable[..., T], + path: str, + statuses: List[str], + cast_to: Type[T], + placeholder: Callable[[], T], + is_terminal: Callable[[T], bool], + polling_config: Optional[PollingConfig] = None, +) -> T: + """Sync long-poll for a status change, retrying until *is_terminal* or timeout.""" + config = polling_config or PollingConfig() + timeout = config.interval_seconds * config.max_attempts + if config.timeout_seconds is not None and config.timeout_seconds > 0: + timeout = min(config.timeout_seconds, timeout) + + start_time = time.time() + last_result: T | None = None + + while True: + remaining = timeout - (time.time() - start_time) + if remaining <= 0: + raise PollingTimeout(f"Exceeded timeout of {timeout} seconds", last_result) + + try: + last_result = post_fn( + path, + body={"statuses": statuses, "timeout_seconds": remaining}, + cast_to=cast_to, + options={"max_retries": 0}, + ) + except (APITimeoutError, APIStatusError) as error: + if isinstance(error, APITimeoutError) or error.response.status_code == 408: + last_result = placeholder() + else: + raise + + if is_terminal(last_result): + return last_result + + +async def async_wait_for_status( + post_fn: Callable[..., Awaitable[T]], + path: str, + statuses: List[str], + cast_to: Type[T], + placeholder: Callable[[], T], + is_terminal: Callable[[T], bool], + polling_config: Optional[PollingConfig] = None, +) -> T: + """Async long-poll for a status change, retrying until *is_terminal* or timeout.""" + config = polling_config or PollingConfig() + timeout = config.interval_seconds * config.max_attempts + if config.timeout_seconds is not None and config.timeout_seconds > 0: + timeout = min(config.timeout_seconds, timeout) + + start_time = time.time() + last_result: T | None = None + + while True: + remaining = timeout - (time.time() - start_time) + if remaining <= 0: + raise PollingTimeout(f"Exceeded timeout of {timeout} seconds", last_result) + + try: + last_result = await post_fn( + path, + body={"statuses": statuses, "timeout_seconds": remaining}, + cast_to=cast_to, + options={"max_retries": 0}, + ) + except (APITimeoutError, APIStatusError) as error: + if isinstance(error, APITimeoutError) or error.response.status_code == 408: + last_result = placeholder() + else: + raise + + if is_terminal(last_result): + return last_result diff --git a/src/runloop_api_client/resources/devboxes/devboxes.py b/src/runloop_api_client/resources/devboxes/devboxes.py index 83459959b..888369e98 100644 --- a/src/runloop_api_client/resources/devboxes/devboxes.py +++ b/src/runloop_api_client/resources/devboxes/devboxes.py @@ -72,7 +72,7 @@ AsyncDiskSnapshotsCursorIDPage, ) from ..._exceptions import RunloopError, APIStatusError, APITimeoutError -from ...lib.polling import PollingConfig, poll_until, retry_server_poll_until as sync_retry_server_poll_until +from ...lib.polling import PollingConfig, poll_until from ..._base_client import AsyncPaginator, make_request_options from .disk_snapshots import ( DiskSnapshotsResource, @@ -82,9 +82,10 @@ DiskSnapshotsResourceWithStreamingResponse, AsyncDiskSnapshotsResourceWithStreamingResponse, ) -from ...lib.polling_async import async_poll_until, async_retry_server_poll_until +from ...lib.polling_async import async_poll_until from ...types.devbox_view import DevboxView from ...types.tunnel_view import TunnelView +from ...lib.wait_for_status import wait_for_status, async_wait_for_status from ...types.shared_params.mount import Mount from ...types.devbox_snapshot_view import DevboxSnapshotView from ...types.shared.launch_parameters import LaunchParameters as SharedLaunchParameters @@ -383,11 +384,7 @@ def await_running( Args: id: The ID of the devbox to wait for - config: Optional polling configuration - extra_headers: Send extra headers - extra_query: Add additional query parameters to the request - extra_body: Add additional JSON properties to the request - timeout: Override the client-level default timeout for this request, in seconds + polling_config: Optional polling configuration Returns: The devbox in running state @@ -397,31 +394,18 @@ def await_running( RunloopError: If devbox enters a non-running terminal state """ - def wait_for_devbox_status(remaining_timeout_seconds: float) -> DevboxView: - try: - return self._post( - f"/v1/devboxes/{id}/wait_for_status", - body={"statuses": ["running", "failure", "shutdown"], "timeout_seconds": remaining_timeout_seconds}, - cast_to=DevboxView, - options={"max_retries": 0}, - ) - except (APITimeoutError, APIStatusError) as error: - if isinstance(error, APITimeoutError) or error.response.status_code == 408: - return placeholder_devbox_view(id) - raise - def is_done_booting(devbox: DevboxView) -> bool: return devbox.status not in DEVBOX_BOOTING_STATES - config = polling_config - if not config: - config = PollingConfig() - - timeout = config.interval_seconds * config.max_attempts - if config.timeout_seconds is not None and config.timeout_seconds > 0: - timeout = min(config.timeout_seconds, timeout) - - devbox = sync_retry_server_poll_until(wait_for_devbox_status, is_done_booting, timeout) + devbox = wait_for_status( + self._post, + f"/v1/devboxes/{id}/wait_for_status", + ["running", "failure", "shutdown"], + DevboxView, + lambda: placeholder_devbox_view(id), + is_done_booting, + polling_config, + ) if devbox.status != "running": raise RunloopError(f"Devbox entered non-running terminal state: {devbox.status}") @@ -448,25 +432,18 @@ def await_suspended( RunloopError: If the devbox enters a non-suspended terminal state. """ - def wait_for_devbox_status() -> DevboxView: - return self._post( - f"/v1/devboxes/{id}/wait_for_status", - body={"statuses": list(DEVBOX_TERMINAL_STATES)}, - cast_to=DevboxView, - options={"max_retries": 0}, - ) - - def handle_timeout_error(error: Exception) -> DevboxView: - if isinstance(error, APITimeoutError) or ( - isinstance(error, APIStatusError) and error.response.status_code == 408 - ): - return placeholder_devbox_view(id) - raise error - def is_terminal_state(devbox: DevboxView) -> bool: return devbox.status in DEVBOX_TERMINAL_STATES - devbox = poll_until(wait_for_devbox_status, is_terminal_state, polling_config, handle_timeout_error) + devbox = wait_for_status( + self._post, + f"/v1/devboxes/{id}/wait_for_status", + list(DEVBOX_TERMINAL_STATES), + DevboxView, + lambda: placeholder_devbox_view(id), + is_terminal_state, + polling_config, + ) if devbox.status != "suspended": raise RunloopError(f"Devbox entered non-suspended terminal state: {devbox.status}") @@ -2045,9 +2022,6 @@ async def await_running( Args: id: The ID of the devbox to wait for polling_config: Optional polling configuration - extra_headers: Send extra headers - extra_query: Add additional query parameters to the request - extra_body: Add additional JSON properties to the request Returns: The devbox in running state @@ -2057,41 +2031,18 @@ async def await_running( RunloopError: If devbox enters a non-running terminal state """ - async def wait_for_devbox_status(remaining_timeout_seconds: float) -> DevboxView: - # This wait_for_status endpoint polls the devbox status for 10 seconds until it reaches either running or failure. - # If it's neither, it will throw an error. - try: - return await self._post( - f"/v1/devboxes/{id}/wait_for_status", - body={"statuses": ["running", "failure", "shutdown"], "timeout_seconds": remaining_timeout_seconds}, - cast_to=DevboxView, - options={"max_retries": 0}, - ) - except (APITimeoutError, APIStatusError) as error: - # Handle timeout errors by returning current devbox state to continue polling - if isinstance(error, APITimeoutError) or error.response.status_code == 408: - # Return a placeholder result to continue polling - return placeholder_devbox_view(id) - - # Re-raise other errors to stop polling - raise - def is_done_booting(devbox: DevboxView) -> bool: return devbox.status not in DEVBOX_BOOTING_STATES - # calculate the timeout to use. The PollingConfig doesn't - # match the semantics for server-side polling well, so we - # instead convert interval*attempts to a total time, and take - # the minimum total. - config = polling_config - if not config: - config = PollingConfig() # use defaults - - timeout = config.interval_seconds * config.max_attempts - if config.timeout_seconds is not None and config.timeout_seconds > 0: - timeout = min(config.timeout_seconds, timeout) - - devbox = await async_retry_server_poll_until(wait_for_devbox_status, is_done_booting, timeout) + devbox = await async_wait_for_status( + self._post, + f"/v1/devboxes/{id}/wait_for_status", + ["running", "failure", "shutdown"], + DevboxView, + lambda: placeholder_devbox_view(id), + is_done_booting, + polling_config, + ) if devbox.status != "running": raise RunloopError(f"Devbox entered non-running terminal state: {devbox.status}") @@ -2118,23 +2069,18 @@ async def await_suspended( RunloopError: If the devbox enters a non-suspended terminal state. """ - async def wait_for_devbox_status() -> DevboxView: - try: - return await self._post( - f"/v1/devboxes/{id}/wait_for_status", - body={"statuses": list(DEVBOX_TERMINAL_STATES)}, - cast_to=DevboxView, - options={"max_retries": 0}, - ) - except (APITimeoutError, APIStatusError) as error: - if isinstance(error, APITimeoutError) or error.response.status_code == 408: - return placeholder_devbox_view(id) - raise - def is_terminal_state(devbox: DevboxView) -> bool: return devbox.status in DEVBOX_TERMINAL_STATES - devbox = await async_poll_until(wait_for_devbox_status, is_terminal_state, polling_config) + devbox = await async_wait_for_status( + self._post, + f"/v1/devboxes/{id}/wait_for_status", + list(DEVBOX_TERMINAL_STATES), + DevboxView, + lambda: placeholder_devbox_view(id), + is_terminal_state, + polling_config, + ) if devbox.status != "suspended": raise RunloopError(f"Devbox entered non-suspended terminal state: {devbox.status}") diff --git a/src/runloop_api_client/resources/devboxes/executions.py b/src/runloop_api_client/resources/devboxes/executions.py index ff7638798..e5bfd7ed8 100755 --- a/src/runloop_api_client/resources/devboxes/executions.py +++ b/src/runloop_api_client/resources/devboxes/executions.py @@ -20,8 +20,7 @@ ) from ..._constants import DEFAULT_TIMEOUT, RAW_RESPONSE_HEADER from ..._streaming import Stream, AsyncStream, ReconnectingStream, AsyncReconnectingStream -from ..._exceptions import APIStatusError, APITimeoutError -from ...lib.polling import PollingConfig, poll_until +from ...lib.polling import PollingConfig from ..._base_client import make_request_options from ...types.devboxes import ( execution_kill_params, @@ -32,7 +31,7 @@ execution_stream_stderr_updates_params, execution_stream_stdout_updates_params, ) -from ...lib.polling_async import async_poll_until +from ...lib.wait_for_status import wait_for_status, async_wait_for_status from ...types.devbox_send_std_in_result import DevboxSendStdInResult from ...types.devbox_execution_detail_view import DevboxExecutionDetailView from ...types.devboxes.execution_update_chunk import ExecutionUpdateChunk @@ -129,12 +128,8 @@ def await_completed( Args: execution_id: The ID of the execution to wait for - id: The ID of the devbox - config: Optional polling configuration - extra_headers: Send extra headers - extra_query: Add additional query parameters to the request - extra_body: Add additional JSON properties to the request - timeout: Override the client-level default timeout for this request, in seconds + devbox_id: The ID of the devbox + polling_config: Optional polling configuration Returns: The completed execution @@ -143,29 +138,18 @@ def await_completed( PollingTimeout: If polling times out before execution completes """ - def wait_for_execution_status() -> DevboxAsyncExecutionDetailView: - # This wait_for_status endpoint polls the execution status for 60 seconds until it reaches either completed. - return self._post( - f"/v1/devboxes/{devbox_id}/executions/{execution_id}/wait_for_status", - body={"statuses": ["completed"]}, - cast_to=DevboxAsyncExecutionDetailView, - ) - - def handle_timeout_error(error: Exception) -> DevboxAsyncExecutionDetailView: - # Handle timeout errors by returning current execution state to continue polling - if isinstance(error, APITimeoutError) or ( - isinstance(error, APIStatusError) and error.response.status_code == 408 - ): - # Return a placeholder result to continue polling - return placeholder_execution_detail_view(devbox_id, execution_id) - else: - # Re-raise other errors to stop polling - raise error - def is_done(execution: DevboxAsyncExecutionDetailView) -> bool: return execution.status == "completed" - return poll_until(wait_for_execution_status, is_done, polling_config, handle_timeout_error) + return wait_for_status( + self._post, + f"/v1/devboxes/{devbox_id}/executions/{execution_id}/wait_for_status", + ["completed"], + DevboxAsyncExecutionDetailView, + lambda: placeholder_execution_detail_view(devbox_id, execution_id), + is_done, + polling_config, + ) def execute_async( self, @@ -675,12 +659,8 @@ async def await_completed( Args: execution_id: The ID of the execution to wait for - id: The ID of the devbox + devbox_id: The ID of the devbox polling_config: Optional polling configuration - extra_headers: Send extra headers - extra_query: Add additional query parameters to the request - extra_body: Add additional JSON properties to the request - timeout: Override the client-level default timeout for this request, in seconds Returns: The completed execution @@ -689,25 +669,18 @@ async def await_completed( PollingTimeout: If polling times out before execution completes """ - async def wait_for_execution_status() -> DevboxAsyncExecutionDetailView: - try: - return await self._post( - f"/v1/devboxes/{devbox_id}/executions/{execution_id}/wait_for_status", - body={"statuses": ["completed"]}, - cast_to=DevboxAsyncExecutionDetailView, - ) - except (APITimeoutError, APIStatusError) as error: - # Handle timeout errors by returning placeholder to continue polling - if isinstance(error, APITimeoutError) or error.response.status_code == 408: - return placeholder_execution_detail_view(devbox_id, execution_id) - - # Re-raise other errors to stop polling - raise - def is_done(execution: DevboxAsyncExecutionDetailView) -> bool: return execution.status == "completed" - return await async_poll_until(wait_for_execution_status, is_done, polling_config) + return await async_wait_for_status( + self._post, + f"/v1/devboxes/{devbox_id}/executions/{execution_id}/wait_for_status", + ["completed"], + DevboxAsyncExecutionDetailView, + lambda: placeholder_execution_detail_view(devbox_id, execution_id), + is_done, + polling_config, + ) async def execute_async( self, diff --git a/tests/test_client.py b/tests/test_client.py index 408c7cedd..7728bf5bb 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -32,7 +32,9 @@ DefaultHttpxClient, DefaultAsyncHttpxClient, get_platform, + _SharedTransport, make_request_options, + _SharedAsyncTransport, ) from .utils import update_env @@ -105,7 +107,9 @@ async def _make_async_iterator(iterable: Iterable[T], counter: Optional[Counter] def _get_open_connections(client: Runloop | AsyncRunloop) -> int: transport = client._client._transport - assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport) + if isinstance(transport, (_SharedTransport, _SharedAsyncTransport)): + transport = transport._transport + assert isinstance(transport, (httpx.HTTPTransport, httpx.AsyncHTTPTransport)) pool = transport._pool return len(pool._requests) diff --git a/tests/test_shared_pool.py b/tests/test_shared_pool.py new file mode 100644 index 000000000..4220f8ba9 --- /dev/null +++ b/tests/test_shared_pool.py @@ -0,0 +1,319 @@ +"""Tests for shared HTTP transport pool behavior. + +Verifies that SDK clients share (or don't share) the underlying httpx +transport, and that refcounting correctly manages the transport lifecycle. +""" + +from __future__ import annotations + +import os +import asyncio +from typing import Any, Iterator + +import httpx +import pytest + +import runloop_api_client._base_client as _base_mod +from runloop_api_client import Runloop, AsyncRunloop + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") +bearer_token = "My Bearer Token" + + +@pytest.fixture(autouse=True) +def _reset_shared_pool() -> Iterator[None]: # pyright: ignore[reportUnusedFunction] + _clear_pool_state() + yield + _clear_pool_state() + + +def _clear_pool_state() -> None: + with _base_mod._pool_lock: + old_sync = _base_mod._shared_sync_transport + _base_mod._shared_sync_transport = None + _base_mod._shared_async_transports.clear() + if old_sync is not None: + try: + old_sync._transport.close() + except Exception: + pass + + +def _make_client(**kwargs: Any) -> Runloop: + kwargs.setdefault("base_url", base_url) + kwargs.setdefault("bearer_token", bearer_token) + return Runloop(**kwargs) + + +def _make_async_client(**kwargs: Any) -> AsyncRunloop: + kwargs.setdefault("base_url", base_url) + kwargs.setdefault("bearer_token", bearer_token) + return AsyncRunloop(**kwargs) + + +def _get_transport(client: Runloop | AsyncRunloop) -> Any: + return client._client._transport # type: ignore[union-attr] + + +# --------------------------------------------------------------------------- +# Sync: sharing behavior +# --------------------------------------------------------------------------- + + +class TestSyncSharedPool: + def test_shared_pool_uses_same_transport(self): + c1 = _make_client(shared_http_pool=True) + c2 = _make_client(shared_http_pool=True) + + assert _get_transport(c1) is _get_transport(c2) + assert c1._client is not c2._client + assert c1._uses_shared_pool is True + assert c2._uses_shared_pool is True + + c1.close() + c2.close() + + def test_private_pool_uses_different_transports(self): + c1 = _make_client(shared_http_pool=False) + c2 = _make_client(shared_http_pool=False) + + assert _get_transport(c1) is not _get_transport(c2) + assert c1._uses_shared_pool is False + assert c2._uses_shared_pool is False + + c1.close() + c2.close() + + def test_custom_http_client_bypasses_sharing(self): + custom = httpx.Client() + c1 = _make_client(http_client=custom, shared_http_pool=True) + + assert c1._client is custom + assert c1._uses_shared_pool is False + + c1.close() + custom.close() + + def test_default_is_shared(self): + c1 = _make_client() + assert c1._uses_shared_pool is True + c1.close() + + def test_cookie_isolation(self): + c1 = _make_client(shared_http_pool=True) + c2 = _make_client(shared_http_pool=True) + + c1._client.cookies.set("session", "secret-123") + assert "session" not in c2._client.cookies + + c1.close() + c2.close() + + +class TestSyncRefcounting: + def test_close_one_keeps_transport_alive(self): + c1 = _make_client(shared_http_pool=True) + c2 = _make_client(shared_http_pool=True) + transport = _get_transport(c1) + + assert transport.refcount == 2 + + c1.close() + assert transport.refcount == 1 + assert not c2.is_closed() + + c2.close() + assert transport.refcount == 0 + + def test_double_close_is_safe(self): + c1 = _make_client(shared_http_pool=True) + transport = _get_transport(c1) + + c1.close() + c1.close() # should not raise or double-decrement + assert transport.refcount == 0 + + def test_three_clients_refcount(self): + c1 = _make_client(shared_http_pool=True) + c2 = _make_client(shared_http_pool=True) + c3 = _make_client(shared_http_pool=True) + transport = _get_transport(c1) + + assert transport.refcount == 3 + + c1.close() + assert transport.refcount == 2 + + c2.close() + assert transport.refcount == 1 + + c3.close() + assert transport.refcount == 0 + + def test_transport_recreated_after_full_release(self): + c1 = _make_client(shared_http_pool=True) + t1 = _get_transport(c1) + c1.close() + + c2 = _make_client(shared_http_pool=True) + t2 = _get_transport(c2) + assert t2 is not t1 + assert t2.refcount == 1 + + c2.close() + + +class TestSyncCopy: + def test_copy_inherits_shared_pool(self): + c1 = _make_client(shared_http_pool=True) + c2 = c1.copy() + transport = _get_transport(c1) + + assert c2._uses_shared_pool is True + assert _get_transport(c2) is transport + assert transport.refcount == 2 + + c1.close() + c2.close() + + def test_copy_with_custom_client_disables_sharing(self): + c1 = _make_client(shared_http_pool=True) + custom = httpx.Client() + c2 = c1.copy(http_client=custom) + + assert c2._uses_shared_pool is False + assert c2._client is custom + + c1.close() + c2.close() + custom.close() + + def test_copy_of_non_shared_stays_non_shared(self): + c1 = _make_client(shared_http_pool=False) + c2 = c1.copy() + + assert c2._uses_shared_pool is False + assert _get_transport(c2) is not _get_transport(c1) + + c1.close() + c2.close() + + +# --------------------------------------------------------------------------- +# Async: sharing behavior +# --------------------------------------------------------------------------- + + +class TestAsyncSharedPool: + async def test_shared_pool_uses_same_transport(self): + c1 = _make_async_client(shared_http_pool=True) + c2 = _make_async_client(shared_http_pool=True) + + assert _get_transport(c1) is _get_transport(c2) + assert c1._client is not c2._client + assert c1._uses_shared_pool is True + assert c2._uses_shared_pool is True + + def test_private_pool_uses_different_transports(self): + c1 = _make_async_client(shared_http_pool=False) + c2 = _make_async_client(shared_http_pool=False) + + assert _get_transport(c1) is not _get_transport(c2) + assert c1._uses_shared_pool is False + + def test_custom_http_client_bypasses_sharing(self): + custom = httpx.AsyncClient() + c1 = _make_async_client(http_client=custom, shared_http_pool=True) + + assert c1._client is custom + assert c1._uses_shared_pool is False + + async def test_default_is_shared(self): + c1 = _make_async_client() + assert c1._uses_shared_pool is True + + def test_no_loop_creates_private_client(self): + c1 = _make_async_client(shared_http_pool=True) + assert c1._uses_shared_pool is False + + +class TestAsyncRefcounting: + async def test_close_one_keeps_transport_alive(self): + c1 = _make_async_client(shared_http_pool=True) + c2 = _make_async_client(shared_http_pool=True) + transport = _get_transport(c1) + + assert transport.refcount == 2 + + await c1.close() + assert transport.refcount == 1 + assert not c2.is_closed() + + await c2.close() + assert transport.refcount == 0 + + async def test_double_close_is_safe(self): + c1 = _make_async_client(shared_http_pool=True) + transport = _get_transport(c1) + + await c1.close() + await c1.close() # should not raise or double-decrement + assert transport.refcount == 0 + + def test_no_loop_client_closes_properly(self): + """Client created without a running loop should close without leaking.""" + c1 = _make_async_client(shared_http_pool=True) + assert c1._uses_shared_pool is False + + asyncio.run(c1.close()) + assert c1.is_closed() + + +class TestAsyncCopy: + async def test_copy_inherits_shared_pool(self): + c1 = _make_async_client(shared_http_pool=True) + c2 = c1.copy() + transport = _get_transport(c1) + + assert c2._uses_shared_pool is True + assert _get_transport(c2) is transport + assert transport.refcount == 2 + + async def test_copy_with_custom_client_disables_sharing(self): + c1 = _make_async_client(shared_http_pool=True) + custom = httpx.AsyncClient() + c2 = c1.copy(http_client=custom) + + assert c2._uses_shared_pool is False + assert c2._client is custom + + +class TestAsyncCrossLoop: + def test_separate_loops_get_separate_transports(self): + """Clients created in different asyncio.run() calls must not share a transport.""" + + async def create_client() -> Any: + c = _make_async_client(shared_http_pool=True) + transport = _get_transport(c) + await c.close() + return transport + + t1 = asyncio.run(create_client()) + t2 = asyncio.run(create_client()) + + assert t1 is not t2, "each loop should get its own transport" + + def test_same_loop_shares_transport(self): + """Clients created in the same asyncio.run() must share a transport.""" + + async def create_two() -> tuple[int, int]: + c1 = _make_async_client(shared_http_pool=True) + c2 = _make_async_client(shared_http_pool=True) + id1 = id(_get_transport(c1)) + id2 = id(_get_transport(c2)) + await c1.close() + await c2.close() + return id1, id2 + + id1, id2 = asyncio.run(create_two()) + assert id1 == id2 diff --git a/uv.lock b/uv.lock index 88dc754a1..a35165b2c 100644 --- a/uv.lock +++ b/uv.lock @@ -2422,7 +2422,7 @@ wheels = [ [[package]] name = "runloop-api-client" -version = "1.20.0" +version = "1.20.2" source = { editable = "." } dependencies = [ { name = "anyio" }, From 5ff16ecf82103a5f10a448b27603c4b6f9e2c2b1 Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Thu, 7 May 2026 20:50:56 -0700 Subject: [PATCH 2/2] release: 1.20.3 (#796) Co-authored-by: stainless-app[bot] <142633134+stainless-app[bot]@users.noreply.github.com> --- .release-please-manifest.json | 2 +- .stats.yml | 8 +- CHANGELOG.md | 22 ++++++ api.md | 1 + pyproject.toml | 4 +- scripts/bootstrap | 2 +- src/runloop_api_client/_client.py | 24 +++++- src/runloop_api_client/_qs.py | 8 +- src/runloop_api_client/_types.py | 3 + src/runloop_api_client/_utils/_utils.py | 42 ++++++++-- src/runloop_api_client/_version.py | 2 +- src/runloop_api_client/resources/agents.py | 20 +++-- src/runloop_api_client/resources/secrets.py | 22 +++--- .../types/agent_create_params.py | 11 ++- src/runloop_api_client/types/agent_view.py | 11 ++- tests/api_resources/test_agents.py | 10 +-- tests/api_resources/test_secrets.py | 76 +++++++++++++++++++ tests/test_extract_files.py | 28 +++++-- tests/test_files.py | 2 +- 19 files changed, 236 insertions(+), 62 deletions(-) diff --git a/.release-please-manifest.json b/.release-please-manifest.json index d9a5d555d..3b5fc24d7 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "1.20.2" + ".": "1.20.3" } \ No newline at end of file diff --git a/.stats.yml b/.stats.yml index c5abf7216..20e30a2d9 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1,4 +1,4 @@ -configured_endpoints: 115 -openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/runloop-ai%2Frunloop-563a11030291b5dd44e1b1b917e3e7bb865d7c873bf49c82056bfade22166843.yml -openapi_spec_hash: 20770e5f6ed8370fc14ff0e1351ccffc -config_hash: 12de9459ff629b6a3072a75b236b7b70 +configured_endpoints: 116 +openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/runloop-ai/runloop-6cf4d9a6afac92d72787088b3aefa941f5240ee522d9e98e1160eea2e29f87f4.yml +openapi_spec_hash: e07fc8349cf507b083830b4e2b0caca0 +config_hash: 436c8d4e665915db22b5d98fe58382c1 diff --git a/CHANGELOG.md b/CHANGELOG.md index 4c560e47a..19de81d83 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,27 @@ # Changelog +## 1.20.3 (2026-05-08) + +Full Changelog: [v1.20.2...v1.20.3](https://github.com/runloopai/api-client-python/compare/v1.20.2...v1.20.3) + +### Features + +* make agent version optional in API ([#8858](https://github.com/runloopai/api-client-python/issues/8858)) ([7e11a9d](https://github.com/runloopai/api-client-python/commit/7e11a9db85aff6c28dcc04b8d391979027f38549)) +* share HTTP connection pool across SDK instances; refactor polling ([#797](https://github.com/runloopai/api-client-python/issues/797)) ([4f2fc4c](https://github.com/runloopai/api-client-python/commit/4f2fc4ca6d0b7a0b13b236dafb0e4e3148c2ed58)) +* support setting headers via env ([54ead49](https://github.com/runloopai/api-client-python/commit/54ead49fd28a61f60e18197d727fa57216c785fd)) + + +### Bug Fixes + +* use correct field name format for multipart file arrays ([c564da8](https://github.com/runloopai/api-client-python/commit/c564da85b7dfdbb77edf347f6b25ca4ca57e470e)) + + +### Chores + +* add get secret to stainless ([#7833](https://github.com/runloopai/api-client-python/issues/7833)) ([ce39778](https://github.com/runloopai/api-client-python/commit/ce39778de67907365c90f11ba3b3602cbc7daa2a)) +* **internal:** more robust bootstrap script ([115744e](https://github.com/runloopai/api-client-python/commit/115744e3c181822a1ec172e0526684839e278899)) +* **internal:** reformat pyproject.toml ([89e8401](https://github.com/runloopai/api-client-python/commit/89e8401b518f0ec15cb1e394dde66cb876bf0578)) + ## 1.20.2 (2026-05-01) Full Changelog: [v1.20.1...v1.20.2](https://github.com/runloopai/api-client-python/compare/v1.20.1...v1.20.2) diff --git a/api.md b/api.md index 1d97dc90c..555f0c4f8 100644 --- a/api.md +++ b/api.md @@ -381,6 +381,7 @@ from runloop_api_client.types import ( Methods: - client.secrets.create(\*\*params) -> SecretView +- client.secrets.retrieve(name) -> SecretView - client.secrets.update(name, \*\*params) -> SecretView - client.secrets.list(\*\*params) -> SecretListView - client.secrets.delete(name) -> SecretView diff --git a/pyproject.toml b/pyproject.toml index 705aaa40f..1178d263e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "runloop_api_client" -version = "1.20.2" +version = "1.20.3" description = "The official Python library for the runloop API" dynamic = ["readme"] license = "MIT" @@ -157,7 +157,7 @@ show_error_codes = true # # We also exclude our `tests` as mypy doesn't always infer # types correctly and Pyright will still catch any type errors. -exclude = ['src/runloop_api_client/_files.py', '_dev/.*.py', 'tests/.*'] +exclude = ["src/runloop_api_client/_files.py", "_dev/.*.py", "tests/.*"] strict_equality = true implicit_reexport = true diff --git a/scripts/bootstrap b/scripts/bootstrap index 76185f88c..ec7c87055 100755 --- a/scripts/bootstrap +++ b/scripts/bootstrap @@ -4,7 +4,7 @@ set -e cd "$(dirname "$0")/.." -if [ -f "Brewfile" ] && [ "$(uname -s)" = "Darwin" ] && [ "$SKIP_BREW" != "1" ] && [ -t 0 ]; then +if [ -f "Brewfile" ] && [ "$(uname -s)" = "Darwin" ] && [ "${SKIP_BREW:-}" != "1" ] && [ -t 0 ]; then brew bundle check >/dev/null 2>&1 || { echo -n "==> Install Homebrew dependencies? (y/N): " read -r response diff --git a/src/runloop_api_client/_client.py b/src/runloop_api_client/_client.py index 1867f9997..3d032dfd6 100644 --- a/src/runloop_api_client/_client.py +++ b/src/runloop_api_client/_client.py @@ -19,7 +19,11 @@ RequestOptions, not_given, ) -from ._utils import is_given, get_async_library +from ._utils import ( + is_given, + is_mapping_t, + get_async_library, +) from ._compat import cached_property from ._version import __version__ from ._streaming import Stream as Stream, AsyncStream as AsyncStream @@ -115,6 +119,15 @@ def __init__( if base_url is None: base_url = f"https://api.runloop.ai" + custom_headers_env = os.environ.get("RUNLOOP_CUSTOM_HEADERS") + if custom_headers_env is not None: + parsed: dict[str, str] = {} + for line in custom_headers_env.split("\n"): + colon = line.find(":") + if colon >= 0: + parsed[line[:colon].strip()] = line[colon + 1 :].strip() + default_headers = {**parsed, **(default_headers if is_mapping_t(default_headers) else {})} + super().__init__( version=__version__, base_url=base_url, @@ -388,6 +401,15 @@ def __init__( if base_url is None: base_url = f"https://api.runloop.ai" + custom_headers_env = os.environ.get("RUNLOOP_CUSTOM_HEADERS") + if custom_headers_env is not None: + parsed: dict[str, str] = {} + for line in custom_headers_env.split("\n"): + colon = line.find(":") + if colon >= 0: + parsed[line[:colon].strip()] = line[colon + 1 :].strip() + default_headers = {**parsed, **(default_headers if is_mapping_t(default_headers) else {})} + super().__init__( version=__version__, base_url=base_url, diff --git a/src/runloop_api_client/_qs.py b/src/runloop_api_client/_qs.py index de8c99bc6..4127c19c6 100644 --- a/src/runloop_api_client/_qs.py +++ b/src/runloop_api_client/_qs.py @@ -2,17 +2,13 @@ from typing import Any, List, Tuple, Union, Mapping, TypeVar from urllib.parse import parse_qs, urlencode -from typing_extensions import Literal, get_args +from typing_extensions import get_args -from ._types import NotGiven, not_given +from ._types import NotGiven, ArrayFormat, NestedFormat, not_given from ._utils import flatten _T = TypeVar("_T") - -ArrayFormat = Literal["comma", "repeat", "indices", "brackets"] -NestedFormat = Literal["dots", "brackets"] - PrimitiveData = Union[str, int, float, bool, None] # this should be Data = Union[PrimitiveData, "List[Data]", "Tuple[Data]", "Mapping[str, Data]"] # https://github.com/microsoft/pyright/issues/3555 diff --git a/src/runloop_api_client/_types.py b/src/runloop_api_client/_types.py index 9f7d6eb21..db47875ab 100644 --- a/src/runloop_api_client/_types.py +++ b/src/runloop_api_client/_types.py @@ -47,6 +47,9 @@ ModelT = TypeVar("ModelT", bound=pydantic.BaseModel) _T = TypeVar("_T") +ArrayFormat = Literal["comma", "repeat", "indices", "brackets"] +NestedFormat = Literal["dots", "brackets"] + # Approximates httpx internal ProxiesTypes and RequestFiles types # while adding support for `PathLike` instances diff --git a/src/runloop_api_client/_utils/_utils.py b/src/runloop_api_client/_utils/_utils.py index 771859f5e..199cd231f 100644 --- a/src/runloop_api_client/_utils/_utils.py +++ b/src/runloop_api_client/_utils/_utils.py @@ -17,11 +17,11 @@ ) from pathlib import Path from datetime import date, datetime -from typing_extensions import TypeGuard +from typing_extensions import TypeGuard, get_args import sniffio -from .._types import Omit, NotGiven, FileTypes, HeadersLike +from .._types import Omit, NotGiven, FileTypes, ArrayFormat, HeadersLike _T = TypeVar("_T") _TupleT = TypeVar("_TupleT", bound=Tuple[object, ...]) @@ -40,25 +40,45 @@ def extract_files( query: Mapping[str, object], *, paths: Sequence[Sequence[str]], + array_format: ArrayFormat = "brackets", ) -> list[tuple[str, FileTypes]]: """Recursively extract files from the given dictionary based on specified paths. A path may look like this ['foo', 'files', '', 'data']. + ``array_format`` controls how ```` segments contribute to the emitted + field name. Supported values: ``"brackets"`` (``foo[]``), ``"repeat"`` and + ``"comma"`` (``foo``), ``"indices"`` (``foo[0]``, ``foo[1]``). + Note: this mutates the given dictionary. """ files: list[tuple[str, FileTypes]] = [] for path in paths: - files.extend(_extract_items(query, path, index=0, flattened_key=None)) + files.extend(_extract_items(query, path, index=0, flattened_key=None, array_format=array_format)) return files +def _array_suffix(array_format: ArrayFormat, array_index: int) -> str: + if array_format == "brackets": + return "[]" + if array_format == "indices": + return f"[{array_index}]" + if array_format == "repeat" or array_format == "comma": + # Both repeat the bare field name for each file part; there is no + # meaningful way to comma-join binary parts. + return "" + raise NotImplementedError( + f"Unknown array_format value: {array_format}, choose from {', '.join(get_args(ArrayFormat))}" + ) + + def _extract_items( obj: object, path: Sequence[str], *, index: int, flattened_key: str | None, + array_format: ArrayFormat, ) -> list[tuple[str, FileTypes]]: try: key = path[index] @@ -75,9 +95,11 @@ def _extract_items( if is_list(obj): files: list[tuple[str, FileTypes]] = [] - for entry in obj: - assert_is_file_content(entry, key=flattened_key + "[]" if flattened_key else "") - files.append((flattened_key + "[]", cast(FileTypes, entry))) + for array_index, entry in enumerate(obj): + suffix = _array_suffix(array_format, array_index) + emitted_key = (flattened_key + suffix) if flattened_key else suffix + assert_is_file_content(entry, key=emitted_key) + files.append((emitted_key, cast(FileTypes, entry))) return files assert_is_file_content(obj, key=flattened_key) @@ -106,6 +128,7 @@ def _extract_items( path, index=index, flattened_key=flattened_key, + array_format=array_format, ) elif is_list(obj): if key != "": @@ -117,9 +140,12 @@ def _extract_items( item, path, index=index, - flattened_key=flattened_key + "[]" if flattened_key is not None else "[]", + flattened_key=( + (flattened_key if flattened_key is not None else "") + _array_suffix(array_format, array_index) + ), + array_format=array_format, ) - for item in obj + for array_index, item in enumerate(obj) ] ) diff --git a/src/runloop_api_client/_version.py b/src/runloop_api_client/_version.py index 686bad05c..62ae17d31 100644 --- a/src/runloop_api_client/_version.py +++ b/src/runloop_api_client/_version.py @@ -1,4 +1,4 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. __title__ = "runloop_api_client" -__version__ = "1.20.2" # x-release-please-version +__version__ = "1.20.3" # x-release-please-version diff --git a/src/runloop_api_client/resources/agents.py b/src/runloop_api_client/resources/agents.py index 6febe22f0..6ea0f6b73 100644 --- a/src/runloop_api_client/resources/agents.py +++ b/src/runloop_api_client/resources/agents.py @@ -50,8 +50,8 @@ def create( self, *, name: str, - version: str, source: Optional[AgentSource] | Omit = omit, + version: Optional[str] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -68,10 +68,12 @@ def create( Args: name: The name of the Agent. - version: The version of the Agent. Must be a semver string (e.g., '2.0.65') or a SHA. - source: The source configuration for the Agent. + version: Optional version identifier for the Agent. For npm/pip sources this is typically + a semver string (e.g. '2.0.65'). For git sources it can be a branch or tag. + Semantics are user-defined for object sources. + extra_headers: Send extra headers extra_query: Add additional query parameters to the request @@ -87,8 +89,8 @@ def create( body=maybe_transform( { "name": name, - "version": version, "source": source, + "version": version, }, agent_create_params.AgentCreateParams, ), @@ -357,8 +359,8 @@ async def create( self, *, name: str, - version: str, source: Optional[AgentSource] | Omit = omit, + version: Optional[str] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -375,10 +377,12 @@ async def create( Args: name: The name of the Agent. - version: The version of the Agent. Must be a semver string (e.g., '2.0.65') or a SHA. - source: The source configuration for the Agent. + version: Optional version identifier for the Agent. For npm/pip sources this is typically + a semver string (e.g. '2.0.65'). For git sources it can be a branch or tag. + Semantics are user-defined for object sources. + extra_headers: Send extra headers extra_query: Add additional query parameters to the request @@ -394,8 +398,8 @@ async def create( body=await async_maybe_transform( { "name": name, - "version": version, "source": source, + "version": version, }, agent_create_params.AgentCreateParams, ), diff --git a/src/runloop_api_client/resources/secrets.py b/src/runloop_api_client/resources/secrets.py index 38a9d8fc0..0dab937c4 100644 --- a/src/runloop_api_client/resources/secrets.py +++ b/src/runloop_api_client/resources/secrets.py @@ -100,6 +100,8 @@ def retrieve( self, name: str, *, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, @@ -107,6 +109,8 @@ def retrieve( ) -> SecretView: """Retrieve a Secret by name. + The secret value is not included for security. + Args: extra_headers: Send extra headers @@ -119,12 +123,9 @@ def retrieve( if not name: raise ValueError(f"Expected a non-empty value for `name` but received {name!r}") return self._get( - f"/v1/secrets/{name}", + path_template("/v1/secrets/{name}", name=name), options=make_request_options( - extra_headers=extra_headers, - extra_query=extra_query, - extra_body=extra_body, - timeout=timeout, + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), cast_to=SecretView, ) @@ -336,6 +337,8 @@ async def retrieve( self, name: str, *, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, @@ -343,6 +346,8 @@ async def retrieve( ) -> SecretView: """Retrieve a Secret by name. + The secret value is not included for security. + Args: extra_headers: Send extra headers @@ -355,12 +360,9 @@ async def retrieve( if not name: raise ValueError(f"Expected a non-empty value for `name` but received {name!r}") return await self._get( - f"/v1/secrets/{name}", + path_template("/v1/secrets/{name}", name=name), options=make_request_options( - extra_headers=extra_headers, - extra_query=extra_query, - extra_body=extra_body, - timeout=timeout, + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), cast_to=SecretView, ) diff --git a/src/runloop_api_client/types/agent_create_params.py b/src/runloop_api_client/types/agent_create_params.py index 3c2deff2a..c1b70a046 100644 --- a/src/runloop_api_client/types/agent_create_params.py +++ b/src/runloop_api_client/types/agent_create_params.py @@ -14,8 +14,13 @@ class AgentCreateParams(TypedDict, total=False): name: Required[str] """The name of the Agent.""" - version: Required[str] - """The version of the Agent. Must be a semver string (e.g., '2.0.65') or a SHA.""" - source: Optional[AgentSource] """The source configuration for the Agent.""" + + version: Optional[str] + """Optional version identifier for the Agent. + + For npm/pip sources this is typically a semver string (e.g. '2.0.65'). For git + sources it can be a branch or tag. Semantics are user-defined for object + sources. + """ diff --git a/src/runloop_api_client/types/agent_view.py b/src/runloop_api_client/types/agent_view.py index 23b1f68ff..d77527731 100644 --- a/src/runloop_api_client/types/agent_view.py +++ b/src/runloop_api_client/types/agent_view.py @@ -23,8 +23,13 @@ class AgentView(BaseModel): name: str """The name of the Agent.""" - version: str - """The version of the Agent. A semver string (e.g., '2.0.65') or a SHA.""" - source: Optional[AgentSource] = None """The source configuration for the Agent.""" + + version: Optional[str] = None + """Optional version identifier for the Agent. + + For npm/pip sources this is typically a semver string (e.g. '2.0.65'). For git + sources it can be a branch or tag. Omitted for object sources or when not + provided. + """ diff --git a/tests/api_resources/test_agents.py b/tests/api_resources/test_agents.py index fb602d148..d4a98d26e 100644 --- a/tests/api_resources/test_agents.py +++ b/tests/api_resources/test_agents.py @@ -25,7 +25,6 @@ class TestAgents: def test_method_create(self, client: Runloop) -> None: agent = client.agents.create( name="name", - version="version", ) assert_matches_type(AgentView, agent, path=["response"]) @@ -33,7 +32,6 @@ def test_method_create(self, client: Runloop) -> None: def test_method_create_with_all_params(self, client: Runloop) -> None: agent = client.agents.create( name="name", - version="version", source={ "type": "type", "git": { @@ -56,6 +54,7 @@ def test_method_create_with_all_params(self, client: Runloop) -> None: "registry_url": "registry_url", }, }, + version="version", ) assert_matches_type(AgentView, agent, path=["response"]) @@ -63,7 +62,6 @@ def test_method_create_with_all_params(self, client: Runloop) -> None: def test_raw_response_create(self, client: Runloop) -> None: response = client.agents.with_raw_response.create( name="name", - version="version", ) assert response.is_closed is True @@ -75,7 +73,6 @@ def test_raw_response_create(self, client: Runloop) -> None: def test_streaming_response_create(self, client: Runloop) -> None: with client.agents.with_streaming_response.create( name="name", - version="version", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -271,7 +268,6 @@ class TestAsyncAgents: async def test_method_create(self, async_client: AsyncRunloop) -> None: agent = await async_client.agents.create( name="name", - version="version", ) assert_matches_type(AgentView, agent, path=["response"]) @@ -279,7 +275,6 @@ async def test_method_create(self, async_client: AsyncRunloop) -> None: async def test_method_create_with_all_params(self, async_client: AsyncRunloop) -> None: agent = await async_client.agents.create( name="name", - version="version", source={ "type": "type", "git": { @@ -302,6 +297,7 @@ async def test_method_create_with_all_params(self, async_client: AsyncRunloop) - "registry_url": "registry_url", }, }, + version="version", ) assert_matches_type(AgentView, agent, path=["response"]) @@ -309,7 +305,6 @@ async def test_method_create_with_all_params(self, async_client: AsyncRunloop) - async def test_raw_response_create(self, async_client: AsyncRunloop) -> None: response = await async_client.agents.with_raw_response.create( name="name", - version="version", ) assert response.is_closed is True @@ -321,7 +316,6 @@ async def test_raw_response_create(self, async_client: AsyncRunloop) -> None: async def test_streaming_response_create(self, async_client: AsyncRunloop) -> None: async with async_client.agents.with_streaming_response.create( name="name", - version="version", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" diff --git a/tests/api_resources/test_secrets.py b/tests/api_resources/test_secrets.py index 7f0ff8e21..8e0abff43 100644 --- a/tests/api_resources/test_secrets.py +++ b/tests/api_resources/test_secrets.py @@ -54,6 +54,44 @@ def test_streaming_response_create(self, client: Runloop) -> None: assert cast(Any, response.is_closed) is True + @parametrize + def test_method_retrieve(self, client: Runloop) -> None: + secret = client.secrets.retrieve( + "name", + ) + assert_matches_type(SecretView, secret, path=["response"]) + + @parametrize + def test_raw_response_retrieve(self, client: Runloop) -> None: + response = client.secrets.with_raw_response.retrieve( + "name", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + secret = response.parse() + assert_matches_type(SecretView, secret, path=["response"]) + + @parametrize + def test_streaming_response_retrieve(self, client: Runloop) -> None: + with client.secrets.with_streaming_response.retrieve( + "name", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + secret = response.parse() + assert_matches_type(SecretView, secret, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_retrieve(self, client: Runloop) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `name` but received ''"): + client.secrets.with_raw_response.retrieve( + "", + ) + @parametrize def test_method_update(self, client: Runloop) -> None: secret = client.secrets.update( @@ -206,6 +244,44 @@ async def test_streaming_response_create(self, async_client: AsyncRunloop) -> No assert cast(Any, response.is_closed) is True + @parametrize + async def test_method_retrieve(self, async_client: AsyncRunloop) -> None: + secret = await async_client.secrets.retrieve( + "name", + ) + assert_matches_type(SecretView, secret, path=["response"]) + + @parametrize + async def test_raw_response_retrieve(self, async_client: AsyncRunloop) -> None: + response = await async_client.secrets.with_raw_response.retrieve( + "name", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + secret = await response.parse() + assert_matches_type(SecretView, secret, path=["response"]) + + @parametrize + async def test_streaming_response_retrieve(self, async_client: AsyncRunloop) -> None: + async with async_client.secrets.with_streaming_response.retrieve( + "name", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + secret = await response.parse() + assert_matches_type(SecretView, secret, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_retrieve(self, async_client: AsyncRunloop) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `name` but received ''"): + await async_client.secrets.with_raw_response.retrieve( + "", + ) + @parametrize async def test_method_update(self, async_client: AsyncRunloop) -> None: secret = await async_client.secrets.update( diff --git a/tests/test_extract_files.py b/tests/test_extract_files.py index a76b07d19..2822c7028 100644 --- a/tests/test_extract_files.py +++ b/tests/test_extract_files.py @@ -4,7 +4,7 @@ import pytest -from runloop_api_client._types import FileTypes +from runloop_api_client._types import FileTypes, ArrayFormat from runloop_api_client._utils import extract_files @@ -37,10 +37,7 @@ def test_multiple_files() -> None: def test_top_level_file_array() -> None: query = {"files": [b"file one", b"file two"], "title": "hello"} - assert extract_files(query, paths=[["files", ""]]) == [ - ("files[]", b"file one"), - ("files[]", b"file two"), - ] + assert extract_files(query, paths=[["files", ""]]) == [("files[]", b"file one"), ("files[]", b"file two")] assert query == {"title": "hello"} @@ -71,3 +68,24 @@ def test_ignores_incorrect_paths( expected: list[tuple[str, FileTypes]], ) -> None: assert extract_files(query, paths=paths) == expected + + +@pytest.mark.parametrize( + "array_format,expected_top_level,expected_nested", + [ + ("brackets", [("files[]", b"a"), ("files[]", b"b")], [("items[][file]", b"a"), ("items[][file]", b"b")]), + ("repeat", [("files", b"a"), ("files", b"b")], [("items[file]", b"a"), ("items[file]", b"b")]), + ("comma", [("files", b"a"), ("files", b"b")], [("items[file]", b"a"), ("items[file]", b"b")]), + ("indices", [("files[0]", b"a"), ("files[1]", b"b")], [("items[0][file]", b"a"), ("items[1][file]", b"b")]), + ], +) +def test_array_format_controls_file_field_names( + array_format: ArrayFormat, + expected_top_level: list[tuple[str, FileTypes]], + expected_nested: list[tuple[str, FileTypes]], +) -> None: + top_level = {"files": [b"a", b"b"]} + assert extract_files(top_level, paths=[["files", ""]], array_format=array_format) == expected_top_level + + nested = {"items": [{"file": b"a"}, {"file": b"b"}]} + assert extract_files(nested, paths=[["items", "", "file"]], array_format=array_format) == expected_nested diff --git a/tests/test_files.py b/tests/test_files.py index f78bb843b..b5ee73ae2 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -131,7 +131,7 @@ def test_extract_files_does_not_mutate_original_nested_array_path(self) -> None: copied = deepcopy_with_paths(original, [["items", "", "file"]]) extracted = extract_files(copied, paths=[["items", "", "file"]]) - assert extracted == [("items[][file]", file1), ("items[][file]", file2)] + assert [entry for _, entry in extracted] == [file1, file2] assert original == { "items": [ {"file": file1, "extra": 1},