From b8494ff84cff5fc8b0229460835f97f5409f6934 Mon Sep 17 00:00:00 2001 From: Samikshya Chand <148681192+samikshya-db@users.noreply.github.com> Date: Fri, 21 Nov 2025 14:05:04 +0530 Subject: [PATCH 1/6] Change default use_hybrid_disposition to False (#714) This changes the default value of use_hybrid_disposition from True to False in the SEA backend, disabling hybrid disposition by default. --- src/databricks/sql/backend/sea/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 75d2c665c..1427226d2 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -157,7 +157,7 @@ def __init__( "_use_arrow_native_complex_types", True ) - self.use_hybrid_disposition = kwargs.get("use_hybrid_disposition", True) + self.use_hybrid_disposition = kwargs.get("use_hybrid_disposition", False) self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True) # Extract warehouse ID from http_path From 73580fe298168f553d99f99b8628b0771a91abf5 Mon Sep 17 00:00:00 2001 From: nikhilsuri-db Date: Wed, 26 Nov 2025 18:33:12 +0530 Subject: [PATCH 2/6] Circuit breaker changes using pybreaker (#705) * Added driver connection params Signed-off-by: Nikhil Suri * Added model fields for chunk/result latency Signed-off-by: Nikhil Suri * fixed linting issues Signed-off-by: Nikhil Suri * lint issue fixing Signed-off-by: Nikhil Suri * circuit breaker changes using pybreaker Signed-off-by: Nikhil Suri * Added interface layer top of http client to use circuit rbeaker Signed-off-by: Nikhil Suri * Added test cases to validate ciruit breaker Signed-off-by: Nikhil Suri * fixing broken tests Signed-off-by: Nikhil Suri * fixed linting issues Signed-off-by: Nikhil Suri * fixed failing test cases Signed-off-by: Nikhil Suri * fixed urllib3 issue Signed-off-by: Nikhil Suri * added more test cases for telemetry Signed-off-by: Nikhil Suri * simplified CB config Signed-off-by: Nikhil Suri * poetry lock Signed-off-by: Nikhil Suri * fix minor issues & improvement Signed-off-by: Nikhil Suri * improved circuit breaker for handling only 429/503 Signed-off-by: Nikhil Suri * linting issue fixed Signed-off-by: Nikhil Suri * raise CB only for 429/503 Signed-off-by: Nikhil Suri * fix broken test cases Signed-off-by: Nikhil Suri * fixed untyped references Signed-off-by: Nikhil Suri * added more test to verify the changes Signed-off-by: Nikhil Suri * description changed Signed-off-by: Nikhil Suri * remove cb congig class to constants Signed-off-by: Nikhil Suri * removed mocked reponse and use a new exlucded exception in CB Signed-off-by: Nikhil Suri * fixed broken test Signed-off-by: Nikhil Suri * added e2e test to verify circuit breaker Signed-off-by: Nikhil Suri * lower log level for telemetry Signed-off-by: Nikhil Suri * fixed broken test, removed tests on log assertions Signed-off-by: Nikhil Suri * modified unit to reduce the noise and follow dry principle Signed-off-by: Nikhil Suri --------- Signed-off-by: Nikhil Suri --- poetry.lock | 36 ++- pyproject.toml | 1 + src/databricks/sql/auth/common.py | 2 + .../sql/common/unified_http_client.py | 47 +++- src/databricks/sql/exc.py | 21 ++ .../sql/telemetry/circuit_breaker_manager.py | 112 +++++++++ .../sql/telemetry/telemetry_client.py | 32 ++- .../sql/telemetry/telemetry_push_client.py | 201 +++++++++++++++ src/databricks/sql/utils.py | 3 + tests/e2e/test_circuit_breaker.py | 232 ++++++++++++++++++ .../unit/test_circuit_breaker_http_client.py | 208 ++++++++++++++++ tests/unit/test_circuit_breaker_manager.py | 160 ++++++++++++ tests/unit/test_telemetry.py | 32 ++- tests/unit/test_telemetry_push_client.py | 213 ++++++++++++++++ .../test_telemetry_request_error_handling.py | 96 ++++++++ tests/unit/test_unified_http_client.py | 136 ++++++++++ 16 files changed, 1512 insertions(+), 20 deletions(-) create mode 100644 src/databricks/sql/telemetry/circuit_breaker_manager.py create mode 100644 src/databricks/sql/telemetry/telemetry_push_client.py create mode 100644 tests/e2e/test_circuit_breaker.py create mode 100644 tests/unit/test_circuit_breaker_http_client.py create mode 100644 tests/unit/test_circuit_breaker_manager.py create mode 100644 tests/unit/test_telemetry_push_client.py create mode 100644 tests/unit/test_telemetry_request_error_handling.py create mode 100644 tests/unit/test_unified_http_client.py diff --git a/poetry.lock b/poetry.lock index 1a8074c2a..193efa109 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. [[package]] name = "astroid" @@ -1348,6 +1348,38 @@ files = [ [package.extras] test = ["cffi", "hypothesis", "pandas", "pytest", "pytz"] +[[package]] +name = "pybreaker" +version = "1.2.0" +description = "Python implementation of the Circuit Breaker pattern" +optional = false +python-versions = ">=3.8" +groups = ["main"] +markers = "python_version < \"3.10\"" +files = [ + {file = "pybreaker-1.2.0-py3-none-any.whl", hash = "sha256:c3e7683e29ecb3d4421265aaea55504f1186a2fdc1f17b6b091d80d1e1eb5ede"}, + {file = "pybreaker-1.2.0.tar.gz", hash = "sha256:18707776316f93a30c1be0e4fec1f8aa5ed19d7e395a218eb2f050c8524fb2dc"}, +] + +[package.extras] +test = ["fakeredis", "mock", "pytest", "redis", "tornado", "types-mock", "types-redis"] + +[[package]] +name = "pybreaker" +version = "1.4.1" +description = "Python implementation of the Circuit Breaker pattern" +optional = false +python-versions = ">=3.9" +groups = ["main"] +markers = "python_version >= \"3.10\"" +files = [ + {file = "pybreaker-1.4.1-py3-none-any.whl", hash = "sha256:b4dab4a05195b7f2a64a6c1a6c4ba7a96534ef56ea7210e6bcb59f28897160e0"}, + {file = "pybreaker-1.4.1.tar.gz", hash = "sha256:8df2d245c73ba40c8242c56ffb4f12138fbadc23e296224740c2028ea9dc1178"}, +] + +[package.extras] +test = ["fakeredis", "mock", "pytest", "redis", "tornado", "types-mock", "types-redis"] + [[package]] name = "pycparser" version = "2.22" @@ -1858,4 +1890,4 @@ pyarrow = ["pyarrow", "pyarrow"] [metadata] lock-version = "2.1" python-versions = "^3.8.0" -content-hash = "0a3f611ef8747376f018c1df0a1ea7873368851873cc4bd3a4d51bba0bba847c" +content-hash = "56b62e3543644c91cc316b11d89025423a66daba5f36609c45bcb3eeb3ce3f54" diff --git a/pyproject.toml b/pyproject.toml index d26a71667..61c248e98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ pyarrow = [ { version = ">=18.0.0", python = ">=3.13", optional=true } ] pyjwt = "^2.0.0" +pybreaker = "^1.0.0" requests-kerberos = {version = "^0.15.0", optional = true} diff --git a/src/databricks/sql/auth/common.py b/src/databricks/sql/auth/common.py index 3e0be0d2b..a764b036d 100644 --- a/src/databricks/sql/auth/common.py +++ b/src/databricks/sql/auth/common.py @@ -51,6 +51,7 @@ def __init__( pool_connections: Optional[int] = None, pool_maxsize: Optional[int] = None, user_agent: Optional[str] = None, + telemetry_circuit_breaker_enabled: Optional[bool] = None, ): self.hostname = hostname self.access_token = access_token @@ -83,6 +84,7 @@ def __init__( self.pool_connections = pool_connections or 10 self.pool_maxsize = pool_maxsize or 20 self.user_agent = user_agent + self.telemetry_circuit_breaker_enabled = bool(telemetry_circuit_breaker_enabled) def get_effective_azure_login_app_id(hostname) -> str: diff --git a/src/databricks/sql/common/unified_http_client.py b/src/databricks/sql/common/unified_http_client.py index 96fb9cbb9..d5f7d3c8d 100644 --- a/src/databricks/sql/common/unified_http_client.py +++ b/src/databricks/sql/common/unified_http_client.py @@ -28,6 +28,42 @@ logger = logging.getLogger(__name__) +def _extract_http_status_from_max_retry_error(e: MaxRetryError) -> Optional[int]: + """ + Extract HTTP status code from MaxRetryError if available. + + urllib3 structures MaxRetryError in different ways depending on the failure scenario: + - e.reason.response.status: Most common case when retries are exhausted + - e.response.status: Alternate structure in some scenarios + + Args: + e: MaxRetryError exception from urllib3 + + Returns: + HTTP status code as int if found, None otherwise + """ + # Try primary structure: e.reason.response.status + if ( + hasattr(e, "reason") + and e.reason is not None + and hasattr(e.reason, "response") + and e.reason.response is not None + ): + http_code = getattr(e.reason.response, "status", None) + if http_code is not None: + return http_code + + # Try alternate structure: e.response.status + if ( + hasattr(e, "response") + and e.response is not None + and hasattr(e.response, "status") + ): + return e.response.status + + return None + + class UnifiedHttpClient: """ Unified HTTP client for all Databricks SQL connector HTTP operations. @@ -264,7 +300,16 @@ def request_context( yield response except MaxRetryError as e: logger.error("HTTP request failed after retries: %s", e) - raise RequestError(f"HTTP request failed: {e}") + + # Extract HTTP status code from MaxRetryError if available + http_code = _extract_http_status_from_max_retry_error(e) + + context = {} + if http_code is not None: + context["http-code"] = http_code + logger.error("HTTP request failed with status code: %d", http_code) + + raise RequestError(f"HTTP request failed: {e}", context=context) except Exception as e: logger.error("HTTP request error: %s", e) raise RequestError(f"HTTP request error: {e}") diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index 3a3a6b3c5..24844d573 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -143,3 +143,24 @@ class SessionAlreadyClosedError(RequestError): class CursorAlreadyClosedError(RequestError): """Thrown if CancelOperation receives a code 404. ThriftBackend should gracefully proceed as this is expected.""" + + +class TelemetryRateLimitError(Exception): + """Raised when telemetry endpoint returns 429 or 503, indicating rate limiting or service unavailable. + This exception is used exclusively by the circuit breaker to track telemetry rate limiting events.""" + + +class TelemetryNonRateLimitError(Exception): + """Wrapper for telemetry errors that should NOT trigger circuit breaker. + + This exception wraps non-rate-limiting errors (network errors, timeouts, server errors, etc.) + and is excluded from circuit breaker failure counting. Only TelemetryRateLimitError should + open the circuit breaker. + + Attributes: + original_exception: The actual exception that occurred + """ + + def __init__(self, original_exception: Exception): + self.original_exception = original_exception + super().__init__(f"Non-rate-limit telemetry error: {original_exception}") diff --git a/src/databricks/sql/telemetry/circuit_breaker_manager.py b/src/databricks/sql/telemetry/circuit_breaker_manager.py new file mode 100644 index 000000000..852f0d916 --- /dev/null +++ b/src/databricks/sql/telemetry/circuit_breaker_manager.py @@ -0,0 +1,112 @@ +""" +Circuit breaker implementation for telemetry requests. + +This module provides circuit breaker functionality to prevent telemetry failures +from impacting the main SQL operations. It uses pybreaker library to implement +the circuit breaker pattern. +""" + +import logging +import threading +from typing import Dict + +import pybreaker +from pybreaker import CircuitBreaker, CircuitBreakerError, CircuitBreakerListener + +from databricks.sql.exc import TelemetryNonRateLimitError + +logger = logging.getLogger(__name__) + +# Circuit Breaker Constants +MINIMUM_CALLS = 20 # Number of failures before circuit opens +RESET_TIMEOUT = 30 # Seconds to wait before trying to close circuit +NAME_PREFIX = "telemetry-circuit-breaker" + +# Circuit Breaker State Constants (used in logging) +CIRCUIT_BREAKER_STATE_OPEN = "open" +CIRCUIT_BREAKER_STATE_CLOSED = "closed" +CIRCUIT_BREAKER_STATE_HALF_OPEN = "half-open" + +# Logging Message Constants +LOG_CIRCUIT_BREAKER_STATE_CHANGED = "Circuit breaker state changed from %s to %s for %s" +LOG_CIRCUIT_BREAKER_OPENED = ( + "Circuit breaker opened for %s - telemetry requests will be blocked" +) +LOG_CIRCUIT_BREAKER_CLOSED = ( + "Circuit breaker closed for %s - telemetry requests will be allowed" +) +LOG_CIRCUIT_BREAKER_HALF_OPEN = ( + "Circuit breaker half-open for %s - testing telemetry requests" +) + + +class CircuitBreakerStateListener(CircuitBreakerListener): + """Listener for circuit breaker state changes.""" + + def before_call(self, cb: CircuitBreaker, func, *args, **kwargs) -> None: + """Called before the circuit breaker calls a function.""" + pass + + def failure(self, cb: CircuitBreaker, exc: BaseException) -> None: + """Called when a function called by the circuit breaker fails.""" + pass + + def success(self, cb: CircuitBreaker) -> None: + """Called when a function called by the circuit breaker succeeds.""" + pass + + def state_change(self, cb: CircuitBreaker, old_state, new_state) -> None: + """Called when the circuit breaker state changes.""" + old_state_name = old_state.name if old_state else "None" + new_state_name = new_state.name if new_state else "None" + + logger.info( + LOG_CIRCUIT_BREAKER_STATE_CHANGED, old_state_name, new_state_name, cb.name + ) + + if new_state_name == CIRCUIT_BREAKER_STATE_OPEN: + logger.warning(LOG_CIRCUIT_BREAKER_OPENED, cb.name) + elif new_state_name == CIRCUIT_BREAKER_STATE_CLOSED: + logger.info(LOG_CIRCUIT_BREAKER_CLOSED, cb.name) + elif new_state_name == CIRCUIT_BREAKER_STATE_HALF_OPEN: + logger.info(LOG_CIRCUIT_BREAKER_HALF_OPEN, cb.name) + + +class CircuitBreakerManager: + """ + Manages circuit breaker instances for telemetry requests. + + Creates and caches circuit breaker instances per host to ensure telemetry + failures don't impact main SQL operations. + """ + + _instances: Dict[str, CircuitBreaker] = {} + _lock = threading.RLock() + + @classmethod + def get_circuit_breaker(cls, host: str) -> CircuitBreaker: + """ + Get or create a circuit breaker instance for the specified host. + + Args: + host: The hostname for which to get the circuit breaker + + Returns: + CircuitBreaker instance for the host + """ + with cls._lock: + if host not in cls._instances: + breaker = CircuitBreaker( + fail_max=MINIMUM_CALLS, + reset_timeout=RESET_TIMEOUT, + name=f"{NAME_PREFIX}-{host}", + exclude=[ + TelemetryNonRateLimitError + ], # Don't count these as failures + ) + # Add state change listener for logging + breaker.add_listener(CircuitBreakerStateListener()) + cls._instances[host] = breaker + logger.debug("Created circuit breaker for host: %s", host) + + return cls._instances[host] diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 134757fe5..177d5445c 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -41,6 +41,11 @@ from databricks.sql.common.feature_flag import FeatureFlagsContextFactory from databricks.sql.common.unified_http_client import UnifiedHttpClient from databricks.sql.common.http import HttpMethod +from databricks.sql.telemetry.telemetry_push_client import ( + ITelemetryPushClient, + TelemetryPushClient, + CircuitBreakerTelemetryPushClient, +) if TYPE_CHECKING: from databricks.sql.client import Connection @@ -166,21 +171,21 @@ class TelemetryClient(BaseTelemetryClient): def __init__( self, - telemetry_enabled, - session_id_hex, + telemetry_enabled: bool, + session_id_hex: str, auth_provider, - host_url, + host_url: str, executor, - batch_size, + batch_size: int, client_context, - ): + ) -> None: logger.debug("Initializing TelemetryClient for connection: %s", session_id_hex) self._telemetry_enabled = telemetry_enabled self._batch_size = batch_size self._session_id_hex = session_id_hex self._auth_provider = auth_provider self._user_agent = None - self._events_batch = [] + self._events_batch: list = [] self._lock = threading.RLock() self._driver_connection_params = None self._host_url = host_url @@ -189,6 +194,19 @@ def __init__( # Create own HTTP client from client context self._http_client = UnifiedHttpClient(client_context) + # Create telemetry push client based on circuit breaker enabled flag + if client_context.telemetry_circuit_breaker_enabled: + # Create circuit breaker telemetry push client (circuit breakers created on-demand) + self._telemetry_push_client: ITelemetryPushClient = ( + CircuitBreakerTelemetryPushClient( + TelemetryPushClient(self._http_client), + host_url, + ) + ) + else: + # Circuit breaker disabled - use direct telemetry push client + self._telemetry_push_client = TelemetryPushClient(self._http_client) + def _export_event(self, event): """Add an event to the batch queue and flush if batch is full""" logger.debug("Exporting event for connection %s", self._session_id_hex) @@ -254,7 +272,7 @@ def _send_telemetry(self, events): def _send_with_unified_client(self, url, data, headers, timeout=900): """Helper method to send telemetry using the unified HTTP client.""" try: - response = self._http_client.request( + response = self._telemetry_push_client.request( HttpMethod.POST, url, body=data, headers=headers, timeout=timeout ) return response diff --git a/src/databricks/sql/telemetry/telemetry_push_client.py b/src/databricks/sql/telemetry/telemetry_push_client.py new file mode 100644 index 000000000..461a57738 --- /dev/null +++ b/src/databricks/sql/telemetry/telemetry_push_client.py @@ -0,0 +1,201 @@ +""" +Telemetry push client interface and implementations. + +This module provides an interface for telemetry push clients with two implementations: +1. TelemetryPushClient - Direct HTTP client implementation +2. CircuitBreakerTelemetryPushClient - Circuit breaker wrapper implementation +""" + +import logging +from abc import ABC, abstractmethod +from typing import Dict, Any, Optional + +try: + from urllib3 import BaseHTTPResponse +except ImportError: + from urllib3 import HTTPResponse as BaseHTTPResponse +from pybreaker import CircuitBreakerError + +from databricks.sql.common.unified_http_client import UnifiedHttpClient +from databricks.sql.common.http import HttpMethod +from databricks.sql.exc import ( + TelemetryRateLimitError, + TelemetryNonRateLimitError, + RequestError, +) +from databricks.sql.telemetry.circuit_breaker_manager import CircuitBreakerManager + +logger = logging.getLogger(__name__) + + +class ITelemetryPushClient(ABC): + """Interface for telemetry push clients.""" + + @abstractmethod + def request( + self, + method: HttpMethod, + url: str, + headers: Optional[Dict[str, str]] = None, + **kwargs, + ) -> BaseHTTPResponse: + """Make an HTTP request.""" + pass + + +class TelemetryPushClient(ITelemetryPushClient): + """Direct HTTP client implementation for telemetry requests.""" + + def __init__(self, http_client: UnifiedHttpClient): + """ + Initialize the telemetry push client. + + Args: + http_client: The underlying HTTP client + """ + self._http_client = http_client + logger.debug("TelemetryPushClient initialized") + + def request( + self, + method: HttpMethod, + url: str, + headers: Optional[Dict[str, str]] = None, + **kwargs, + ) -> BaseHTTPResponse: + """Make an HTTP request using the underlying HTTP client.""" + return self._http_client.request(method, url, headers, **kwargs) + + +class CircuitBreakerTelemetryPushClient(ITelemetryPushClient): + """Circuit breaker wrapper implementation for telemetry requests.""" + + def __init__(self, delegate: ITelemetryPushClient, host: str): + """ + Initialize the circuit breaker telemetry push client. + + Args: + delegate: The underlying telemetry push client to wrap + host: The hostname for circuit breaker identification + """ + self._delegate = delegate + self._host = host + + # Get circuit breaker for this host (creates if doesn't exist) + self._circuit_breaker = CircuitBreakerManager.get_circuit_breaker(host) + + logger.debug( + "CircuitBreakerTelemetryPushClient initialized for host %s", + host, + ) + + def _make_request_and_check_status( + self, + method: HttpMethod, + url: str, + headers: Optional[Dict[str, str]], + **kwargs, + ) -> BaseHTTPResponse: + """ + Make the request and check response status. + + Raises TelemetryRateLimitError for 429/503 (circuit breaker counts these). + Wraps other errors in TelemetryNonRateLimitError (circuit breaker excludes these). + + Args: + method: HTTP method + url: Request URL + headers: Request headers + **kwargs: Additional request parameters + + Returns: + HTTP response + + Raises: + TelemetryRateLimitError: For 429/503 status codes (circuit breaker counts) + TelemetryNonRateLimitError: For other errors (circuit breaker excludes) + """ + try: + response = self._delegate.request(method, url, headers, **kwargs) + + # Check for rate limiting or service unavailable + if response.status in [429, 503]: + logger.warning( + "Telemetry endpoint returned %d for host %s, triggering circuit breaker", + response.status, + self._host, + ) + raise TelemetryRateLimitError( + f"Telemetry endpoint rate limited or unavailable: {response.status}" + ) + + return response + + except Exception as e: + # Don't catch TelemetryRateLimitError - let it propagate to circuit breaker + if isinstance(e, TelemetryRateLimitError): + raise + + # Check if it's a RequestError with rate limiting status code (exhausted retries) + if isinstance(e, RequestError): + http_code = ( + e.context.get("http-code") + if hasattr(e, "context") and e.context + else None + ) + + if http_code in [429, 503]: + logger.debug( + "Telemetry retries exhausted with status %d for host %s, triggering circuit breaker", + http_code, + self._host, + ) + raise TelemetryRateLimitError( + f"Telemetry rate limited after retries: {http_code}" + ) + + # NOT rate limiting (500 errors, network errors, timeouts, etc.) + # Wrap in TelemetryNonRateLimitError so circuit breaker excludes it + logger.debug( + "Non-rate-limit telemetry error for host %s: %s, wrapping to exclude from circuit breaker", + self._host, + e, + ) + raise TelemetryNonRateLimitError(e) from e + + def request( + self, + method: HttpMethod, + url: str, + headers: Optional[Dict[str, str]] = None, + **kwargs, + ) -> BaseHTTPResponse: + """ + Make an HTTP request with circuit breaker protection. + + Circuit breaker only opens for TelemetryRateLimitError (429/503 responses). + Other errors are wrapped in TelemetryNonRateLimitError and excluded from circuit breaker. + All exceptions propagate to caller (TelemetryClient callback handles them). + """ + try: + # Use circuit breaker to protect the request + # TelemetryRateLimitError will trigger circuit breaker + # TelemetryNonRateLimitError is excluded from circuit breaker + return self._circuit_breaker.call( + self._make_request_and_check_status, + method, + url, + headers, + **kwargs, + ) + + except TelemetryNonRateLimitError as e: + # Unwrap and re-raise original exception + # Circuit breaker didn't count this, but caller should handle it + logger.debug( + "Non-rate-limit telemetry error for host %s, re-raising original: %s", + self._host, + e.original_exception, + ) + raise e.original_exception from e + # All other exceptions (TelemetryRateLimitError, CircuitBreakerError) propagate as-is diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 9f96e8743..b46784b10 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -922,4 +922,7 @@ def build_client_context(server_hostname: str, version: str, **kwargs): proxy_auth_method=kwargs.get("_proxy_auth_method"), pool_connections=kwargs.get("_pool_connections"), pool_maxsize=kwargs.get("_pool_maxsize"), + telemetry_circuit_breaker_enabled=kwargs.get( + "_telemetry_circuit_breaker_enabled" + ), ) diff --git a/tests/e2e/test_circuit_breaker.py b/tests/e2e/test_circuit_breaker.py new file mode 100644 index 000000000..45c494d19 --- /dev/null +++ b/tests/e2e/test_circuit_breaker.py @@ -0,0 +1,232 @@ +""" +E2E tests for circuit breaker functionality in telemetry. + +This test suite verifies: +1. Circuit breaker opens after rate limit failures (429/503) +2. Circuit breaker blocks subsequent calls while open +3. Circuit breaker does not trigger for non-rate-limit errors +4. Circuit breaker can be disabled via configuration flag +5. Circuit breaker closes after reset timeout + +Run with: + pytest tests/e2e/test_circuit_breaker.py -v -s +""" + +import time +from unittest.mock import patch, MagicMock + +import pytest +from pybreaker import STATE_OPEN, STATE_CLOSED, STATE_HALF_OPEN +from urllib3 import HTTPResponse + +import databricks.sql as sql +from databricks.sql.telemetry.circuit_breaker_manager import CircuitBreakerManager + + +@pytest.fixture(autouse=True) +def aggressive_circuit_breaker_config(): + """ + Configure circuit breaker to be aggressive for faster testing. + Opens after 2 failures instead of 20, with 5 second timeout. + """ + from databricks.sql.telemetry import circuit_breaker_manager + + original_minimum_calls = circuit_breaker_manager.MINIMUM_CALLS + original_reset_timeout = circuit_breaker_manager.RESET_TIMEOUT + + circuit_breaker_manager.MINIMUM_CALLS = 2 + circuit_breaker_manager.RESET_TIMEOUT = 5 + + CircuitBreakerManager._instances.clear() + + yield + + circuit_breaker_manager.MINIMUM_CALLS = original_minimum_calls + circuit_breaker_manager.RESET_TIMEOUT = original_reset_timeout + CircuitBreakerManager._instances.clear() + + +class TestCircuitBreakerTelemetry: + """Tests for circuit breaker functionality with telemetry""" + + @pytest.fixture(autouse=True) + def get_details(self, connection_details): + """Get connection details from pytest fixture""" + self.arguments = connection_details.copy() + + def create_mock_response(self, status_code): + """Helper to create mock HTTP response.""" + response = MagicMock(spec=HTTPResponse) + response.status = status_code + response.data = { + 429: b"Too Many Requests", + 503: b"Service Unavailable", + 500: b"Internal Server Error", + }.get(status_code, b"Response") + return response + + @pytest.mark.parametrize("status_code,should_trigger", [ + (429, True), + (503, True), + (500, False), + ]) + def test_circuit_breaker_triggers_for_rate_limit_codes(self, status_code, should_trigger): + """ + Verify circuit breaker opens for rate-limit codes (429/503) but not others (500). + """ + request_count = {"count": 0} + + def mock_request(*args, **kwargs): + request_count["count"] += 1 + return self.create_mock_response(status_code) + + with patch( + "databricks.sql.telemetry.telemetry_push_client.TelemetryPushClient.request", + side_effect=mock_request, + ): + with sql.connect( + server_hostname=self.arguments["host"], + http_path=self.arguments["http_path"], + access_token=self.arguments.get("access_token"), + force_enable_telemetry=True, + telemetry_batch_size=1, + _telemetry_circuit_breaker_enabled=True, + ) as conn: + circuit_breaker = CircuitBreakerManager.get_circuit_breaker( + self.arguments["host"] + ) + + assert circuit_breaker.current_state == STATE_CLOSED + + cursor = conn.cursor() + + # Execute queries to trigger telemetry + for i in range(1, 6): + cursor.execute(f"SELECT {i}") + cursor.fetchone() + time.sleep(0.5) + + if should_trigger: + # Circuit should be OPEN after 2 rate-limit failures + assert circuit_breaker.current_state == STATE_OPEN + assert circuit_breaker.fail_counter == 2 + + # Track requests before another query + requests_before = request_count["count"] + cursor.execute("SELECT 99") + cursor.fetchone() + time.sleep(1) + + # No new telemetry requests (circuit is open) + assert request_count["count"] == requests_before + else: + # Circuit should remain CLOSED for non-rate-limit errors + assert circuit_breaker.current_state == STATE_CLOSED + assert circuit_breaker.fail_counter == 0 + assert request_count["count"] >= 5 + + def test_circuit_breaker_disabled_allows_all_calls(self): + """ + Verify that when circuit breaker is disabled, all calls go through + even with rate limit errors. + """ + request_count = {"count": 0} + + def mock_rate_limited_request(*args, **kwargs): + request_count["count"] += 1 + return self.create_mock_response(429) + + with patch( + "databricks.sql.telemetry.telemetry_push_client.TelemetryPushClient.request", + side_effect=mock_rate_limited_request, + ): + with sql.connect( + server_hostname=self.arguments["host"], + http_path=self.arguments["http_path"], + access_token=self.arguments.get("access_token"), + force_enable_telemetry=True, + telemetry_batch_size=1, + _telemetry_circuit_breaker_enabled=False, # Disabled + ) as conn: + cursor = conn.cursor() + + for i in range(5): + cursor.execute(f"SELECT {i}") + cursor.fetchone() + time.sleep(0.3) + + assert request_count["count"] >= 5 + + def test_circuit_breaker_recovers_after_reset_timeout(self): + """ + Verify circuit breaker transitions to HALF_OPEN after reset timeout + and eventually CLOSES if requests succeed. + """ + request_count = {"count": 0} + fail_requests = {"enabled": True} + + def mock_conditional_request(*args, **kwargs): + request_count["count"] += 1 + status = 429 if fail_requests["enabled"] else 200 + return self.create_mock_response(status) + + with patch( + "databricks.sql.telemetry.telemetry_push_client.TelemetryPushClient.request", + side_effect=mock_conditional_request, + ): + with sql.connect( + server_hostname=self.arguments["host"], + http_path=self.arguments["http_path"], + access_token=self.arguments.get("access_token"), + force_enable_telemetry=True, + telemetry_batch_size=1, + _telemetry_circuit_breaker_enabled=True, + ) as conn: + circuit_breaker = CircuitBreakerManager.get_circuit_breaker( + self.arguments["host"] + ) + + cursor = conn.cursor() + + # Trigger failures to open circuit + cursor.execute("SELECT 1") + cursor.fetchone() + time.sleep(1) + + cursor.execute("SELECT 2") + cursor.fetchone() + time.sleep(2) + + assert circuit_breaker.current_state == STATE_OPEN + + # Wait for reset timeout (5 seconds in test) + time.sleep(6) + + # Now make requests succeed + fail_requests["enabled"] = False + + # Execute query to trigger HALF_OPEN state + cursor.execute("SELECT 3") + cursor.fetchone() + time.sleep(1) + + # Circuit should be recovering + assert circuit_breaker.current_state in [ + STATE_HALF_OPEN, + STATE_CLOSED, + ], f"Circuit should be recovering, but is {circuit_breaker.current_state}" + + # Execute more queries to fully recover + cursor.execute("SELECT 4") + cursor.fetchone() + time.sleep(1) + + current_state = circuit_breaker.current_state + assert current_state in [ + STATE_CLOSED, + STATE_HALF_OPEN, + ], f"Circuit should recover to CLOSED or HALF_OPEN, got {current_state}" + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/unit/test_circuit_breaker_http_client.py b/tests/unit/test_circuit_breaker_http_client.py new file mode 100644 index 000000000..432ca1be3 --- /dev/null +++ b/tests/unit/test_circuit_breaker_http_client.py @@ -0,0 +1,208 @@ +""" +Unit tests for telemetry push client functionality. +""" + +import pytest +from unittest.mock import Mock, patch, MagicMock + +from databricks.sql.telemetry.telemetry_push_client import ( + ITelemetryPushClient, + TelemetryPushClient, + CircuitBreakerTelemetryPushClient, +) +from databricks.sql.common.http import HttpMethod +from pybreaker import CircuitBreakerError + + +class TestTelemetryPushClient: + """Test cases for TelemetryPushClient.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_http_client = Mock() + self.client = TelemetryPushClient(self.mock_http_client) + + def test_initialization(self): + """Test client initialization.""" + assert self.client._http_client == self.mock_http_client + + def test_request_delegates_to_http_client(self): + """Test that request delegates to underlying HTTP client.""" + mock_response = Mock() + self.mock_http_client.request.return_value = mock_response + + response = self.client.request(HttpMethod.POST, "https://test.com", {}) + + assert response == mock_response + self.mock_http_client.request.assert_called_once() + + def test_direct_client_has_no_circuit_breaker(self): + """Test that direct client does not have circuit breaker functionality.""" + # Direct client should work without circuit breaker + assert isinstance(self.client, TelemetryPushClient) + + +class TestCircuitBreakerTelemetryPushClient: + """Test cases for CircuitBreakerTelemetryPushClient.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_delegate = Mock(spec=ITelemetryPushClient) + self.host = "test-host.example.com" + self.client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host) + + def test_initialization(self): + """Test client initialization.""" + assert self.client._delegate == self.mock_delegate + assert self.client._host == self.host + assert self.client._circuit_breaker is not None + + def test_request_enabled_success(self): + """Test successful request when circuit breaker is enabled.""" + mock_response = Mock() + self.mock_delegate.request.return_value = mock_response + + response = self.client.request(HttpMethod.POST, "https://test.com", {}) + + assert response == mock_response + self.mock_delegate.request.assert_called_once() + + def test_request_enabled_circuit_breaker_error(self): + """Test request when circuit breaker is open - should raise CircuitBreakerError.""" + # Mock circuit breaker to raise CircuitBreakerError + with patch.object( + self.client._circuit_breaker, + "call", + side_effect=CircuitBreakerError("Circuit is open"), + ): + # Circuit breaker open should raise (caller handles it) + with pytest.raises(CircuitBreakerError): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + def test_request_enabled_other_error(self): + """Test request when other error occurs - should raise original exception.""" + # Mock delegate to raise a different error (not rate limiting) + self.mock_delegate.request.side_effect = ValueError("Network error") + + # Non-rate-limit errors are unwrapped and raised + with pytest.raises(ValueError, match="Network error"): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + def test_is_circuit_breaker_enabled(self): + """Test checking if circuit breaker is enabled.""" + assert self.client._circuit_breaker is not None + + def test_circuit_breaker_state_logging(self): + """Test that circuit breaker errors are raised (no longer silent).""" + with patch.object( + self.client._circuit_breaker, + "call", + side_effect=CircuitBreakerError("Circuit is open"), + ): + # Should raise CircuitBreakerError (caller handles it) + with pytest.raises(CircuitBreakerError): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + def test_other_error_logging(self): + """Test that other errors are wrapped, logged, then unwrapped and raised.""" + with patch( + "databricks.sql.telemetry.telemetry_push_client.logger" + ) as mock_logger: + self.mock_delegate.request.side_effect = ValueError("Network error") + + # Should raise the original ValueError + with pytest.raises(ValueError, match="Network error"): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + # Check that debug was logged (for wrapping and/or unwrapping) + assert mock_logger.debug.call_count >= 1 + + +class TestCircuitBreakerTelemetryPushClientIntegration: + """Integration tests for CircuitBreakerTelemetryPushClient.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_delegate = Mock() + self.host = "test-host.example.com" + + def test_circuit_breaker_opens_after_failures(self): + """Test that circuit breaker opens after repeated failures (429/503 errors).""" + from databricks.sql.telemetry.circuit_breaker_manager import ( + CircuitBreakerManager, + MINIMUM_CALLS, + ) + from databricks.sql.exc import TelemetryRateLimitError + + # Clear any existing state + CircuitBreakerManager._instances.clear() + + client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host) + + # Simulate rate limit failures (429) + mock_response = Mock() + mock_response.status = 429 + self.mock_delegate.request.return_value = mock_response + + # All calls should raise TelemetryRateLimitError + # After MINIMUM_CALLS failures, circuit breaker opens + rate_limit_error_count = 0 + circuit_breaker_error_count = 0 + + for i in range(MINIMUM_CALLS + 5): + try: + client.request(HttpMethod.POST, "https://test.com", {}) + except TelemetryRateLimitError: + rate_limit_error_count += 1 + except CircuitBreakerError: + circuit_breaker_error_count += 1 + + # Should have some rate limit errors before circuit opens, then circuit breaker errors + assert rate_limit_error_count >= MINIMUM_CALLS - 1 + assert circuit_breaker_error_count > 0 + + def test_circuit_breaker_recovers_after_success(self): + """Test that circuit breaker recovers after successful calls.""" + from databricks.sql.telemetry.circuit_breaker_manager import ( + CircuitBreakerManager, + MINIMUM_CALLS, + RESET_TIMEOUT, + ) + import time + + # Clear any existing state + CircuitBreakerManager._instances.clear() + + client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host) + + # Simulate rate limit failures first (429) + from databricks.sql.exc import TelemetryRateLimitError + from pybreaker import CircuitBreakerError + + mock_rate_limit_response = Mock() + mock_rate_limit_response.status = 429 + self.mock_delegate.request.return_value = mock_rate_limit_response + + # Trigger enough rate limit failures to open circuit + for i in range(MINIMUM_CALLS + 5): + try: + client.request(HttpMethod.POST, "https://test.com", {}) + except (TelemetryRateLimitError, CircuitBreakerError): + pass # Expected - circuit breaker opens after MINIMUM_CALLS failures + + # Circuit should be open now - raises CircuitBreakerError + with pytest.raises(CircuitBreakerError): + client.request(HttpMethod.POST, "https://test.com", {}) + + # Wait for reset timeout + time.sleep(RESET_TIMEOUT + 1.0) + + # Simulate successful calls (200 response) + mock_success_response = Mock() + mock_success_response.status = 200 + self.mock_delegate.request.return_value = mock_success_response + + # Should work again with actual success response + response = client.request(HttpMethod.POST, "https://test.com", {}) + assert response is not None + assert response.status == 200 diff --git a/tests/unit/test_circuit_breaker_manager.py b/tests/unit/test_circuit_breaker_manager.py new file mode 100644 index 000000000..e8ed4e809 --- /dev/null +++ b/tests/unit/test_circuit_breaker_manager.py @@ -0,0 +1,160 @@ +""" +Unit tests for circuit breaker manager functionality. +""" + +import pytest +import threading +import time +from unittest.mock import Mock, patch + +from databricks.sql.telemetry.circuit_breaker_manager import ( + CircuitBreakerManager, + MINIMUM_CALLS, + RESET_TIMEOUT, + NAME_PREFIX as CIRCUIT_BREAKER_NAME, +) +from pybreaker import CircuitBreakerError + + +class TestCircuitBreakerManager: + """Test cases for CircuitBreakerManager.""" + + def setup_method(self): + """Set up test fixtures.""" + CircuitBreakerManager._instances.clear() + + def teardown_method(self): + """Clean up after tests.""" + CircuitBreakerManager._instances.clear() + + def test_get_circuit_breaker_creates_instance(self): + """Test getting circuit breaker creates instance with correct config.""" + breaker = CircuitBreakerManager.get_circuit_breaker("test-host") + + assert breaker.name == "telemetry-circuit-breaker-test-host" + assert breaker.fail_max == MINIMUM_CALLS + + def test_get_circuit_breaker_same_host_returns_same_instance(self): + """Test that same host returns same circuit breaker instance.""" + breaker1 = CircuitBreakerManager.get_circuit_breaker("test-host") + breaker2 = CircuitBreakerManager.get_circuit_breaker("test-host") + + assert breaker1 is breaker2 + + def test_get_circuit_breaker_different_hosts_return_different_instances(self): + """Test that different hosts return different circuit breaker instances.""" + breaker1 = CircuitBreakerManager.get_circuit_breaker("host1") + breaker2 = CircuitBreakerManager.get_circuit_breaker("host2") + + assert breaker1 is not breaker2 + assert breaker1.name != breaker2.name + + def test_thread_safety(self): + """Test thread safety of circuit breaker manager.""" + results = [] + + def get_breaker(host): + breaker = CircuitBreakerManager.get_circuit_breaker(host) + results.append(breaker) + + threads = [] + for i in range(10): + thread = threading.Thread(target=get_breaker, args=(f"host{i % 3}",)) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + assert len(results) == 10 + + # All breakers for same host should be same instance + host0_breakers = [b for b in results if b.name.endswith("host0")] + assert all(b is host0_breakers[0] for b in host0_breakers) + + +class TestCircuitBreakerIntegration: + """Integration tests for circuit breaker functionality.""" + + def setup_method(self): + """Set up test fixtures.""" + CircuitBreakerManager._instances.clear() + + def teardown_method(self): + """Clean up after tests.""" + CircuitBreakerManager._instances.clear() + + def test_circuit_breaker_state_transitions(self): + """Test circuit breaker state transitions from closed to open.""" + breaker = CircuitBreakerManager.get_circuit_breaker("test-host") + + assert breaker.current_state == "closed" + + def failing_func(): + raise Exception("Simulated failure") + + # Trigger failures up to the threshold (MINIMUM_CALLS = 20) + for _ in range(MINIMUM_CALLS): + with pytest.raises(Exception): + breaker.call(failing_func) + + # Next call should fail with CircuitBreakerError (circuit is now open) + with pytest.raises(CircuitBreakerError): + breaker.call(failing_func) + + assert breaker.current_state == "open" + + def test_circuit_breaker_recovery(self): + """Test circuit breaker recovery after failures.""" + breaker = CircuitBreakerManager.get_circuit_breaker("test-host") + + def failing_func(): + raise Exception("Simulated failure") + + # Trigger failures up to the threshold + for _ in range(MINIMUM_CALLS): + with pytest.raises(Exception): + breaker.call(failing_func) + + assert breaker.current_state == "open" + + # Wait for reset timeout + time.sleep(RESET_TIMEOUT + 1.0) + + # Try successful call to close circuit breaker + def successful_func(): + return "success" + + try: + result = breaker.call(successful_func) + assert result == "success" + except CircuitBreakerError: + pass # Circuit might still be open, acceptable + + assert breaker.current_state in ["closed", "half-open", "open"] + + @pytest.mark.parametrize("old_state,new_state", [ + ("closed", "open"), + ("open", "half-open"), + ("half-open", "closed"), + ("closed", "half-open"), + ]) + def test_circuit_breaker_state_listener_transitions(self, old_state, new_state): + """Test circuit breaker state listener logs all state transitions.""" + from databricks.sql.telemetry.circuit_breaker_manager import ( + CircuitBreakerStateListener, + ) + + listener = CircuitBreakerStateListener() + mock_cb = Mock() + mock_cb.name = "test-breaker" + + mock_old_state = Mock() + mock_old_state.name = old_state + + mock_new_state = Mock() + mock_new_state.name = new_state + + with patch("databricks.sql.telemetry.circuit_breaker_manager.logger") as mock_logger: + listener.state_change(mock_cb, mock_old_state, mock_new_state) + mock_logger.info.assert_called() diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 36141ee2b..6f5a01c7b 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -37,7 +37,9 @@ def mock_telemetry_client(): client_context = MagicMock() # Patch the _setup_pool_manager method to avoid SSL file loading - with patch('databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers'): + with patch( + "databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers" + ): return TelemetryClient( telemetry_enabled=True, session_id_hex=session_id, @@ -95,7 +97,7 @@ def test_network_request_flow(self, mock_http_request, mock_telemetry_client): mock_response.status = 200 mock_response.status_code = 200 mock_http_request.return_value = mock_response - + client = mock_telemetry_client # Create mock events @@ -231,7 +233,9 @@ def test_client_lifecycle_flow(self): client_context = MagicMock() # Initialize enabled client - with patch('databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers'): + with patch( + "databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers" + ): TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, session_id_hex=session_id_hex, @@ -299,7 +303,9 @@ def test_factory_shutdown_flow(self): client_context = MagicMock() # Initialize multiple clients - with patch('databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers'): + with patch( + "databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers" + ): for session in [session1, session2]: TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, @@ -382,8 +388,10 @@ def test_telemetry_enabled_when_flag_is_true(self, mock_http_request, MockSessio mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-true" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") - mock_session_instance.is_open = False # Connection starts closed for test cleanup - + mock_session_instance.is_open = ( + False # Connection starts closed for test cleanup + ) + # Set up mock HTTP client on the session mock_http_client = MagicMock() mock_http_client.request = mock_http_request @@ -410,8 +418,10 @@ def test_telemetry_disabled_when_flag_is_false( mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-false" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") - mock_session_instance.is_open = False # Connection starts closed for test cleanup - + mock_session_instance.is_open = ( + False # Connection starts closed for test cleanup + ) + # Set up mock HTTP client on the session mock_http_client = MagicMock() mock_http_client.request = mock_http_request @@ -438,8 +448,10 @@ def test_telemetry_disabled_when_flag_request_fails( mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-fail" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") - mock_session_instance.is_open = False # Connection starts closed for test cleanup - + mock_session_instance.is_open = ( + False # Connection starts closed for test cleanup + ) + # Set up mock HTTP client on the session mock_http_client = MagicMock() mock_http_client.request = mock_http_request diff --git a/tests/unit/test_telemetry_push_client.py b/tests/unit/test_telemetry_push_client.py new file mode 100644 index 000000000..0e9455e1f --- /dev/null +++ b/tests/unit/test_telemetry_push_client.py @@ -0,0 +1,213 @@ +""" +Unit tests for telemetry push client functionality. +""" + +import pytest +from unittest.mock import Mock, patch + +from databricks.sql.telemetry.telemetry_push_client import ( + ITelemetryPushClient, + TelemetryPushClient, + CircuitBreakerTelemetryPushClient, +) +from databricks.sql.common.http import HttpMethod +from databricks.sql.exc import TelemetryRateLimitError +from pybreaker import CircuitBreakerError + + +class TestTelemetryPushClient: + """Test cases for TelemetryPushClient.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_http_client = Mock() + self.client = TelemetryPushClient(self.mock_http_client) + + def test_initialization(self): + """Test client initialization.""" + assert self.client._http_client == self.mock_http_client + + def test_request_delegates_to_http_client(self): + """Test that request delegates to underlying HTTP client.""" + mock_response = Mock() + self.mock_http_client.request.return_value = mock_response + + response = self.client.request(HttpMethod.POST, "https://test.com", {}) + + assert response == mock_response + self.mock_http_client.request.assert_called_once() + + +class TestCircuitBreakerTelemetryPushClient: + """Test cases for CircuitBreakerTelemetryPushClient.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_delegate = Mock(spec=ITelemetryPushClient) + self.host = "test-host.example.com" + self.client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host) + + def test_initialization(self): + """Test client initialization.""" + assert self.client._delegate == self.mock_delegate + assert self.client._host == self.host + assert self.client._circuit_breaker is not None + + def test_request_success(self): + """Test successful request when circuit breaker is enabled.""" + mock_response = Mock() + self.mock_delegate.request.return_value = mock_response + + response = self.client.request(HttpMethod.POST, "https://test.com", {}) + + assert response == mock_response + self.mock_delegate.request.assert_called_once() + + def test_request_circuit_breaker_open(self): + """Test request when circuit breaker is open raises CircuitBreakerError.""" + with patch.object( + self.client._circuit_breaker, + "call", + side_effect=CircuitBreakerError("Circuit is open"), + ): + with pytest.raises(CircuitBreakerError): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + def test_request_other_error(self): + """Test request when other error occurs raises original exception.""" + self.mock_delegate.request.side_effect = ValueError("Network error") + + with pytest.raises(ValueError, match="Network error"): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + @pytest.mark.parametrize("status_code,expected_error", [ + (429, TelemetryRateLimitError), + (503, TelemetryRateLimitError), + ]) + def test_request_rate_limit_codes(self, status_code, expected_error): + """Test that rate-limit status codes raise TelemetryRateLimitError.""" + mock_response = Mock() + mock_response.status = status_code + self.mock_delegate.request.return_value = mock_response + + with pytest.raises(expected_error): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + def test_request_non_rate_limit_code(self): + """Test that non-rate-limit status codes return response.""" + mock_response = Mock() + mock_response.status = 500 + mock_response.data = b'Server error' + self.mock_delegate.request.return_value = mock_response + + response = self.client.request(HttpMethod.POST, "https://test.com", {}) + assert response is not None + assert response.status == 500 + + def test_rate_limit_error_logging(self): + """Test that rate limit errors are logged with circuit breaker context.""" + with patch("databricks.sql.telemetry.telemetry_push_client.logger") as mock_logger: + mock_response = Mock() + mock_response.status = 429 + self.mock_delegate.request.return_value = mock_response + + with pytest.raises(TelemetryRateLimitError): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + mock_logger.warning.assert_called() + warning_args = mock_logger.warning.call_args[0] + assert "429" in str(warning_args) + assert "circuit breaker" in warning_args[0] + + def test_other_error_logging(self): + """Test that other errors are logged during wrapping/unwrapping.""" + with patch("databricks.sql.telemetry.telemetry_push_client.logger") as mock_logger: + self.mock_delegate.request.side_effect = ValueError("Network error") + + with pytest.raises(ValueError, match="Network error"): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + assert mock_logger.debug.call_count >= 1 + + +class TestCircuitBreakerTelemetryPushClientIntegration: + """Integration tests for CircuitBreakerTelemetryPushClient.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_delegate = Mock() + self.host = "test-host.example.com" + from databricks.sql.telemetry.circuit_breaker_manager import CircuitBreakerManager + CircuitBreakerManager._instances.clear() + + def test_circuit_breaker_opens_after_failures(self): + """Test that circuit breaker opens after repeated failures (429/503 errors).""" + from databricks.sql.telemetry.circuit_breaker_manager import ( + CircuitBreakerManager, + MINIMUM_CALLS, + ) + + CircuitBreakerManager._instances.clear() + client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host) + + mock_response = Mock() + mock_response.status = 429 + self.mock_delegate.request.return_value = mock_response + + rate_limit_error_count = 0 + circuit_breaker_error_count = 0 + + for _ in range(MINIMUM_CALLS + 5): + try: + client.request(HttpMethod.POST, "https://test.com", {}) + except TelemetryRateLimitError: + rate_limit_error_count += 1 + except CircuitBreakerError: + circuit_breaker_error_count += 1 + + assert rate_limit_error_count >= MINIMUM_CALLS - 1 + assert circuit_breaker_error_count > 0 + + def test_circuit_breaker_recovers_after_success(self): + """Test that circuit breaker recovers after successful calls.""" + import time + from databricks.sql.telemetry.circuit_breaker_manager import ( + CircuitBreakerManager, + MINIMUM_CALLS, + RESET_TIMEOUT, + ) + + CircuitBreakerManager._instances.clear() + client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host) + + # Trigger failures + mock_rate_limit_response = Mock() + mock_rate_limit_response.status = 429 + self.mock_delegate.request.return_value = mock_rate_limit_response + + for _ in range(MINIMUM_CALLS + 5): + try: + client.request(HttpMethod.POST, "https://test.com", {}) + except (TelemetryRateLimitError, CircuitBreakerError): + pass + + # Circuit should be open + with pytest.raises(CircuitBreakerError): + client.request(HttpMethod.POST, "https://test.com", {}) + + # Wait for reset timeout + time.sleep(RESET_TIMEOUT + 1.0) + + # Simulate success + mock_success_response = Mock() + mock_success_response.status = 200 + self.mock_delegate.request.return_value = mock_success_response + + response = client.request(HttpMethod.POST, "https://test.com", {}) + assert response is not None + assert response.status == 200 + + def test_urllib3_import_fallback(self): + """Test that the urllib3 import fallback works correctly.""" + from databricks.sql.telemetry.telemetry_push_client import BaseHTTPResponse + assert BaseHTTPResponse is not None diff --git a/tests/unit/test_telemetry_request_error_handling.py b/tests/unit/test_telemetry_request_error_handling.py new file mode 100644 index 000000000..aa31f6628 --- /dev/null +++ b/tests/unit/test_telemetry_request_error_handling.py @@ -0,0 +1,96 @@ +""" +Unit tests specifically for telemetry_push_client RequestError handling +with http-code context extraction for rate limiting detection. +""" + +import pytest +from unittest.mock import Mock + +from databricks.sql.telemetry.telemetry_push_client import ( + CircuitBreakerTelemetryPushClient, + TelemetryPushClient, +) +from databricks.sql.common.http import HttpMethod +from databricks.sql.exc import RequestError, TelemetryRateLimitError +from databricks.sql.telemetry.circuit_breaker_manager import CircuitBreakerManager + + +class TestTelemetryPushClientRequestErrorHandling: + """Test RequestError handling and http-code context extraction.""" + + @pytest.fixture + def setup_circuit_breaker(self): + """Setup circuit breaker for testing.""" + CircuitBreakerManager._instances.clear() + yield + CircuitBreakerManager._instances.clear() + + @pytest.fixture + def mock_delegate(self): + """Create mock delegate client.""" + return Mock(spec=TelemetryPushClient) + + @pytest.fixture + def client(self, mock_delegate, setup_circuit_breaker): + """Create CircuitBreakerTelemetryPushClient instance.""" + return CircuitBreakerTelemetryPushClient(mock_delegate, "test-host.example.com") + + @pytest.mark.parametrize("status_code", [429, 503]) + def test_request_error_with_rate_limit_codes(self, client, mock_delegate, status_code): + """Test that RequestError with rate-limit codes raises TelemetryRateLimitError.""" + request_error = RequestError("HTTP request failed", context={"http-code": status_code}) + mock_delegate.request.side_effect = request_error + + with pytest.raises(TelemetryRateLimitError): + client.request(HttpMethod.POST, "https://test.com", {}) + + @pytest.mark.parametrize("status_code", [500, 400, 404]) + def test_request_error_with_non_rate_limit_codes(self, client, mock_delegate, status_code): + """Test that RequestError with non-rate-limit codes raises original RequestError.""" + request_error = RequestError("HTTP request failed", context={"http-code": status_code}) + mock_delegate.request.side_effect = request_error + + with pytest.raises(RequestError, match="HTTP request failed"): + client.request(HttpMethod.POST, "https://test.com", {}) + + @pytest.mark.parametrize("context", [{}, None, "429"]) + def test_request_error_with_invalid_context(self, client, mock_delegate, context): + """Test RequestError with invalid/missing context raises original error.""" + request_error = RequestError("HTTP request failed") + if context == "429": + # Edge case: http-code as string instead of int + request_error.context = {"http-code": context} + else: + request_error.context = context + mock_delegate.request.side_effect = request_error + + with pytest.raises(RequestError, match="HTTP request failed"): + client.request(HttpMethod.POST, "https://test.com", {}) + + def test_request_error_missing_context_attribute(self, client, mock_delegate): + """Test RequestError without context attribute raises original error.""" + request_error = RequestError("HTTP request failed") + if hasattr(request_error, "context"): + delattr(request_error, "context") + mock_delegate.request.side_effect = request_error + + with pytest.raises(RequestError, match="HTTP request failed"): + client.request(HttpMethod.POST, "https://test.com", {}) + + def test_http_code_extraction_prioritization(self, client, mock_delegate): + """Test that http-code from RequestError context is correctly extracted.""" + request_error = RequestError( + "HTTP request failed after retries", context={"http-code": 503} + ) + mock_delegate.request.side_effect = request_error + + with pytest.raises(TelemetryRateLimitError): + client.request(HttpMethod.POST, "https://test.com", {}) + + def test_non_request_error_exceptions_raised(self, client, mock_delegate): + """Test that non-RequestError exceptions are wrapped then unwrapped.""" + generic_error = ValueError("Network timeout") + mock_delegate.request.side_effect = generic_error + + with pytest.raises(ValueError, match="Network timeout"): + client.request(HttpMethod.POST, "https://test.com", {}) diff --git a/tests/unit/test_unified_http_client.py b/tests/unit/test_unified_http_client.py new file mode 100644 index 000000000..4e9ce1bbf --- /dev/null +++ b/tests/unit/test_unified_http_client.py @@ -0,0 +1,136 @@ +""" +Unit tests for UnifiedHttpClient, specifically testing MaxRetryError handling +and HTTP status code extraction. +""" + +import pytest +from unittest.mock import Mock, patch +from urllib3.exceptions import MaxRetryError + +from databricks.sql.common.unified_http_client import UnifiedHttpClient +from databricks.sql.common.http import HttpMethod +from databricks.sql.exc import RequestError +from databricks.sql.auth.common import ClientContext +from databricks.sql.types import SSLOptions + + +class TestUnifiedHttpClientMaxRetryError: + """Test MaxRetryError handling and HTTP status code extraction.""" + + @pytest.fixture + def client_context(self): + """Create a minimal ClientContext for testing.""" + context = Mock(spec=ClientContext) + context.hostname = "https://test.databricks.com" + context.ssl_options = SSLOptions( + tls_verify=True, + tls_verify_hostname=True, + tls_trusted_ca_file=None, + tls_client_cert_file=None, + tls_client_cert_key_file=None, + tls_client_cert_key_password=None, + ) + context.socket_timeout = 30 + context.retry_stop_after_attempts_count = 3 + context.retry_delay_min = 1.0 + context.retry_delay_max = 10.0 + context.retry_stop_after_attempts_duration = 300.0 + context.retry_delay_default = 5.0 + context.retry_dangerous_codes = [] + context.proxy_auth_method = None + context.pool_connections = 10 + context.pool_maxsize = 20 + context.user_agent = "test-agent" + return context + + @pytest.fixture + def http_client(self, client_context): + """Create UnifiedHttpClient instance.""" + return UnifiedHttpClient(client_context) + + @pytest.mark.parametrize("status_code,path", [ + (429, "reason.response"), + (503, "reason.response"), + (500, "direct_response"), + ]) + def test_max_retry_error_with_status_codes(self, http_client, status_code, path): + """Test MaxRetryError with various status codes and response paths.""" + mock_pool = Mock() + max_retry_error = MaxRetryError(pool=mock_pool, url="http://test.com") + + if path == "reason.response": + max_retry_error.reason = Mock() + max_retry_error.reason.response = Mock() + max_retry_error.reason.response.status = status_code + else: # direct_response + max_retry_error.response = Mock() + max_retry_error.response.status = status_code + + with patch.object( + http_client._direct_pool_manager, "request", side_effect=max_retry_error + ): + with pytest.raises(RequestError) as exc_info: + http_client.request( + HttpMethod.POST, "http://test.com", headers={"test": "header"} + ) + + error = exc_info.value + assert hasattr(error, "context") + assert "http-code" in error.context + assert error.context["http-code"] == status_code + + @pytest.mark.parametrize("setup_func", [ + lambda e: None, # No setup - error with no attributes + lambda e: setattr(e, "reason", None), # reason=None + lambda e: (setattr(e, "reason", Mock()), setattr(e.reason, "response", None)), # reason.response=None + lambda e: (setattr(e, "reason", Mock()), setattr(e.reason, "response", Mock(spec=[]))), # No status attr + ]) + def test_max_retry_error_missing_status(self, http_client, setup_func): + """Test MaxRetryError without status code (no crash, empty context).""" + mock_pool = Mock() + max_retry_error = MaxRetryError(pool=mock_pool, url="http://test.com") + setup_func(max_retry_error) + + with patch.object( + http_client._direct_pool_manager, "request", side_effect=max_retry_error + ): + with pytest.raises(RequestError) as exc_info: + http_client.request(HttpMethod.GET, "http://test.com") + + error = exc_info.value + assert error.context == {} + + def test_max_retry_error_prefers_reason_response(self, http_client): + """Test that e.reason.response.status is preferred over e.response.status.""" + mock_pool = Mock() + max_retry_error = MaxRetryError(pool=mock_pool, url="http://test.com") + + # Set both structures with different status codes + max_retry_error.reason = Mock() + max_retry_error.reason.response = Mock() + max_retry_error.reason.response.status = 429 # Should use this + + max_retry_error.response = Mock() + max_retry_error.response.status = 500 # Should be ignored + + with patch.object( + http_client._direct_pool_manager, "request", side_effect=max_retry_error + ): + with pytest.raises(RequestError) as exc_info: + http_client.request(HttpMethod.GET, "http://test.com") + + error = exc_info.value + assert error.context["http-code"] == 429 + + def test_generic_exception_no_crash(self, http_client): + """Test that generic exceptions don't crash when checking for status code.""" + generic_error = Exception("Network error") + + with patch.object( + http_client._direct_pool_manager, "request", side_effect=generic_error + ): + with pytest.raises(RequestError) as exc_info: + http_client.request(HttpMethod.POST, "http://test.com") + + error = exc_info.value + assert "HTTP request error" in str(error) From 8d5e1552531092f7892f424b1efcfb313e7db8ae Mon Sep 17 00:00:00 2001 From: Samikshya Chand <148681192+samikshya-db@users.noreply.github.com> Date: Thu, 27 Nov 2025 12:41:07 +0530 Subject: [PATCH 3/6] perf: Optimize telemetry latency logging to reduce overhead (#715) perf: Optimize telemetry latency logging to reduce overhead Optimizations implemented: 1. Eliminated extractor pattern - replaced wrapper classes with direct attribute access functions, removing object creation overhead 2. Added feature flag early exit - checks cached telemetry_enabled flag to skip heavy work when telemetry is disabled 3. Simplified code structure with early returns for better readability Signed-off-by: Samikshya Chand --- src/databricks/sql/common/feature_flag.py | 8 +- .../sql/telemetry/latency_logger.py | 289 +++++++++--------- .../sql/telemetry/telemetry_client.py | 62 +++- tests/unit/test_telemetry.py | 70 ++++- 4 files changed, 264 insertions(+), 165 deletions(-) diff --git a/src/databricks/sql/common/feature_flag.py b/src/databricks/sql/common/feature_flag.py index 8a1cf5bd5..032701f63 100644 --- a/src/databricks/sql/common/feature_flag.py +++ b/src/databricks/sql/common/feature_flag.py @@ -165,8 +165,9 @@ def get_instance(cls, connection: "Connection") -> FeatureFlagsContext: cls._initialize() assert cls._executor is not None - # Use the unique session ID as the key - key = connection.get_session_id_hex() + # Cache at HOST level - share feature flags across connections to same host + # Feature flags are per-host, not per-session + key = connection.session.host if key not in cls._context_map: cls._context_map[key] = FeatureFlagsContext( connection, cls._executor, connection.session.http_client @@ -177,7 +178,8 @@ def get_instance(cls, connection: "Connection") -> FeatureFlagsContext: def remove_instance(cls, connection: "Connection"): """Removes the context for a given connection and shuts down the executor if no clients remain.""" with cls._lock: - key = connection.get_session_id_hex() + # Use host as key to match get_instance + key = connection.session.host if key in cls._context_map: cls._context_map.pop(key, None) diff --git a/src/databricks/sql/telemetry/latency_logger.py b/src/databricks/sql/telemetry/latency_logger.py index 12cacd851..36ebee2b8 100644 --- a/src/databricks/sql/telemetry/latency_logger.py +++ b/src/databricks/sql/telemetry/latency_logger.py @@ -1,6 +1,6 @@ import time import functools -from typing import Optional +from typing import Optional, Dict, Any import logging from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory from databricks.sql.telemetry.models.event import ( @@ -11,127 +11,141 @@ logger = logging.getLogger(__name__) -class TelemetryExtractor: +def _extract_cursor_data(cursor) -> Dict[str, Any]: """ - Base class for extracting telemetry information from various object types. + Extract telemetry data directly from a Cursor object. - This class serves as a proxy that delegates attribute access to the wrapped object - while providing a common interface for extracting telemetry-related data. - """ - - def __init__(self, obj): - self._obj = obj - - def __getattr__(self, name): - return getattr(self._obj, name) - - def get_session_id_hex(self): - pass - - def get_statement_id(self): - pass - - def get_is_compressed(self): - pass - - def get_execution_result_format(self): - pass - - def get_retry_count(self): - pass - - def get_chunk_id(self): - pass + OPTIMIZATION: Uses direct attribute access instead of wrapper objects. + This eliminates object creation overhead and method call indirection. + Args: + cursor: The Cursor object to extract data from -class CursorExtractor(TelemetryExtractor): + Returns: + Dict with telemetry data (values may be None if extraction fails) """ - Telemetry extractor specialized for Cursor objects. - - Extracts telemetry information from database cursor objects, including - statement IDs, session information, compression settings, and result formats. + data = {} + + # Extract statement_id (query_id) - direct attribute access + try: + data["statement_id"] = cursor.query_id + except (AttributeError, Exception): + data["statement_id"] = None + + # Extract session_id_hex - direct method call + try: + data["session_id_hex"] = cursor.connection.get_session_id_hex() + except (AttributeError, Exception): + data["session_id_hex"] = None + + # Extract is_compressed - direct attribute access + try: + data["is_compressed"] = cursor.connection.lz4_compression + except (AttributeError, Exception): + data["is_compressed"] = False + + # Extract execution_result_format - inline logic + try: + if cursor.active_result_set is None: + data["execution_result"] = ExecutionResultFormat.FORMAT_UNSPECIFIED + else: + from databricks.sql.utils import ColumnQueue, CloudFetchQueue, ArrowQueue + + results = cursor.active_result_set.results + if isinstance(results, ColumnQueue): + data["execution_result"] = ExecutionResultFormat.COLUMNAR_INLINE + elif isinstance(results, CloudFetchQueue): + data["execution_result"] = ExecutionResultFormat.EXTERNAL_LINKS + elif isinstance(results, ArrowQueue): + data["execution_result"] = ExecutionResultFormat.INLINE_ARROW + else: + data["execution_result"] = ExecutionResultFormat.FORMAT_UNSPECIFIED + except (AttributeError, Exception): + data["execution_result"] = ExecutionResultFormat.FORMAT_UNSPECIFIED + + # Extract retry_count - direct attribute access + try: + if hasattr(cursor.backend, "retry_policy") and cursor.backend.retry_policy: + data["retry_count"] = len(cursor.backend.retry_policy.history) + else: + data["retry_count"] = 0 + except (AttributeError, Exception): + data["retry_count"] = 0 + + # chunk_id is always None for Cursor + data["chunk_id"] = None + + return data + + +def _extract_result_set_handler_data(handler) -> Dict[str, Any]: """ + Extract telemetry data directly from a ResultSetDownloadHandler object. - def get_statement_id(self) -> Optional[str]: - return self.query_id - - def get_session_id_hex(self) -> Optional[str]: - return self.connection.get_session_id_hex() - - def get_is_compressed(self) -> bool: - return self.connection.lz4_compression - - def get_execution_result_format(self) -> ExecutionResultFormat: - if self.active_result_set is None: - return ExecutionResultFormat.FORMAT_UNSPECIFIED - - from databricks.sql.utils import ColumnQueue, CloudFetchQueue, ArrowQueue - - if isinstance(self.active_result_set.results, ColumnQueue): - return ExecutionResultFormat.COLUMNAR_INLINE - elif isinstance(self.active_result_set.results, CloudFetchQueue): - return ExecutionResultFormat.EXTERNAL_LINKS - elif isinstance(self.active_result_set.results, ArrowQueue): - return ExecutionResultFormat.INLINE_ARROW - return ExecutionResultFormat.FORMAT_UNSPECIFIED - - def get_retry_count(self) -> int: - if hasattr(self.backend, "retry_policy") and self.backend.retry_policy: - return len(self.backend.retry_policy.history) - return 0 - - def get_chunk_id(self): - return None + OPTIMIZATION: Uses direct attribute access instead of wrapper objects. + Args: + handler: The ResultSetDownloadHandler object to extract data from -class ResultSetDownloadHandlerExtractor(TelemetryExtractor): - """ - Telemetry extractor specialized for ResultSetDownloadHandler objects. + Returns: + Dict with telemetry data (values may be None if extraction fails) """ + data = {} - def get_session_id_hex(self) -> Optional[str]: - return self._obj.session_id_hex + # Extract session_id_hex - direct attribute access + try: + data["session_id_hex"] = handler.session_id_hex + except (AttributeError, Exception): + data["session_id_hex"] = None - def get_statement_id(self) -> Optional[str]: - return self._obj.statement_id + # Extract statement_id - direct attribute access + try: + data["statement_id"] = handler.statement_id + except (AttributeError, Exception): + data["statement_id"] = None - def get_is_compressed(self) -> bool: - return self._obj.settings.is_lz4_compressed + # Extract is_compressed - direct attribute access + try: + data["is_compressed"] = handler.settings.is_lz4_compressed + except (AttributeError, Exception): + data["is_compressed"] = False - def get_execution_result_format(self) -> ExecutionResultFormat: - return ExecutionResultFormat.EXTERNAL_LINKS + # execution_result is always EXTERNAL_LINKS for result set handlers + data["execution_result"] = ExecutionResultFormat.EXTERNAL_LINKS - def get_retry_count(self) -> Optional[int]: - # standard requests and urllib3 libraries don't expose retry count - return None + # retry_count is not available for result set handlers + data["retry_count"] = None + + # Extract chunk_id - direct attribute access + try: + data["chunk_id"] = handler.chunk_id + except (AttributeError, Exception): + data["chunk_id"] = None - def get_chunk_id(self) -> Optional[int]: - return self._obj.chunk_id + return data -def get_extractor(obj): +def _extract_telemetry_data(obj) -> Optional[Dict[str, Any]]: """ - Factory function to create the appropriate telemetry extractor for an object. + Extract telemetry data from an object based on its type. - Determines the object type and returns the corresponding specialized extractor - that can extract telemetry information from that object type. + OPTIMIZATION: Returns a simple dict instead of creating wrapper objects. + This dict will be used to create the SqlExecutionEvent in the background thread. Args: - obj: The object to create an extractor for. Can be a Cursor, - ResultSetDownloadHandler, or any other object. + obj: The object to extract data from (Cursor, ResultSetDownloadHandler, etc.) Returns: - TelemetryExtractor: A specialized extractor instance: - - CursorExtractor for Cursor objects - - ResultSetDownloadHandlerExtractor for ResultSetDownloadHandler objects - - None for all other objects + Dict with telemetry data, or None if object type is not supported """ - if obj.__class__.__name__ == "Cursor": - return CursorExtractor(obj) - elif obj.__class__.__name__ == "ResultSetDownloadHandler": - return ResultSetDownloadHandlerExtractor(obj) + obj_type = obj.__class__.__name__ + + if obj_type == "Cursor": + return _extract_cursor_data(obj) + elif obj_type == "ResultSetDownloadHandler": + return _extract_result_set_handler_data(obj) else: - logger.debug("No extractor found for %s", obj.__class__.__name__) + logger.debug("No telemetry extraction available for %s", obj_type) return None @@ -143,12 +157,6 @@ def log_latency(statement_type: StatementType = StatementType.NONE): data about the operation, including latency, statement information, and execution context. - The decorator automatically: - - Measures execution time using high-precision performance counters - - Extracts telemetry information from the method's object (self) - - Creates a SqlExecutionEvent with execution details - - Sends the telemetry data asynchronously via TelemetryClient - Args: statement_type (StatementType): The type of SQL statement being executed. @@ -162,54 +170,49 @@ def execute(self, query): function: A decorator that wraps methods to add latency logging. Note: - The wrapped method's object (self) must be compatible with the - telemetry extractor system (e.g., Cursor or ResultSet objects). + The wrapped method's object (self) must be a Cursor or + ResultSetDownloadHandler for telemetry data extraction. """ def decorator(func): @functools.wraps(func) def wrapper(self, *args, **kwargs): - start_time = time.perf_counter() - result = None + start_time = time.monotonic() try: - result = func(self, *args, **kwargs) - return result + return func(self, *args, **kwargs) finally: - - def _safe_call(func_to_call): - """Calls a function and returns a default value on any exception.""" - try: - return func_to_call() - except Exception: - return None - - end_time = time.perf_counter() - duration_ms = int((end_time - start_time) * 1000) - - extractor = get_extractor(self) - - if extractor is not None: - session_id_hex = _safe_call(extractor.get_session_id_hex) - statement_id = _safe_call(extractor.get_statement_id) - - sql_exec_event = SqlExecutionEvent( - statement_type=statement_type, - is_compressed=_safe_call(extractor.get_is_compressed), - execution_result=_safe_call( - extractor.get_execution_result_format - ), - retry_count=_safe_call(extractor.get_retry_count), - chunk_id=_safe_call(extractor.get_chunk_id), - ) - - telemetry_client = TelemetryClientFactory.get_telemetry_client( - session_id_hex - ) - telemetry_client.export_latency_log( - latency_ms=duration_ms, - sql_execution_event=sql_exec_event, - sql_statement_id=statement_id, - ) + duration_ms = int((time.monotonic() - start_time) * 1000) + + # Always log for debugging + logger.debug("%s completed in %dms", func.__name__, duration_ms) + + # Fast check: use cached telemetry_enabled flag from connection + # Avoids dictionary lookup + instance check on every operation + connection = getattr(self, "connection", None) + if connection and getattr(connection, "telemetry_enabled", False): + session_id_hex = connection.get_session_id_hex() + if session_id_hex: + # Telemetry enabled - extract and send + telemetry_data = _extract_telemetry_data(self) + if telemetry_data: + sql_exec_event = SqlExecutionEvent( + statement_type=statement_type, + is_compressed=telemetry_data.get("is_compressed"), + execution_result=telemetry_data.get("execution_result"), + retry_count=telemetry_data.get("retry_count"), + chunk_id=telemetry_data.get("chunk_id"), + ) + + telemetry_client = ( + TelemetryClientFactory.get_telemetry_client( + session_id_hex + ) + ) + telemetry_client.export_latency_log( + latency_ms=duration_ms, + sql_execution_event=sql_exec_event, + sql_statement_id=telemetry_data.get("statement_id"), + ) return wrapper diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 177d5445c..d5f5b575c 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -2,6 +2,7 @@ import time import logging import json +from queue import Queue, Full from concurrent.futures import ThreadPoolExecutor from concurrent.futures import Future from datetime import datetime, timezone @@ -114,18 +115,21 @@ def get_auth_flow(auth_provider): @staticmethod def is_telemetry_enabled(connection: "Connection") -> bool: + # Fast path: force enabled - skip feature flag fetch entirely if connection.force_enable_telemetry: return True - if connection.enable_telemetry: - context = FeatureFlagsContextFactory.get_instance(connection) - flag_value = context.get_flag_value( - TelemetryHelper.TELEMETRY_FEATURE_FLAG_NAME, default_value=False - ) - return str(flag_value).lower() == "true" - else: + # Fast path: disabled - no need to check feature flag + if not connection.enable_telemetry: return False + # Only fetch feature flags when enable_telemetry=True and not forced + context = FeatureFlagsContextFactory.get_instance(connection) + flag_value = context.get_flag_value( + TelemetryHelper.TELEMETRY_FEATURE_FLAG_NAME, default_value=False + ) + return str(flag_value).lower() == "true" + class NoopTelemetryClient(BaseTelemetryClient): """ @@ -185,8 +189,11 @@ def __init__( self._session_id_hex = session_id_hex self._auth_provider = auth_provider self._user_agent = None - self._events_batch: list = [] - self._lock = threading.RLock() + + # OPTIMIZATION: Use lock-free Queue instead of list + lock + # Queue is thread-safe internally and has better performance under concurrency + self._events_queue: Queue[TelemetryFrontendLog] = Queue(maxsize=batch_size * 2) + self._driver_connection_params = None self._host_url = host_url self._executor = executor @@ -196,7 +203,8 @@ def __init__( # Create telemetry push client based on circuit breaker enabled flag if client_context.telemetry_circuit_breaker_enabled: - # Create circuit breaker telemetry push client (circuit breakers created on-demand) + # Create circuit breaker telemetry push client + # (circuit breakers created on-demand) self._telemetry_push_client: ITelemetryPushClient = ( CircuitBreakerTelemetryPushClient( TelemetryPushClient(self._http_client), @@ -210,9 +218,24 @@ def __init__( def _export_event(self, event): """Add an event to the batch queue and flush if batch is full""" logger.debug("Exporting event for connection %s", self._session_id_hex) - with self._lock: - self._events_batch.append(event) - if len(self._events_batch) >= self._batch_size: + + # OPTIMIZATION: Use non-blocking put with queue + # No explicit lock needed - Queue is thread-safe internally + try: + self._events_queue.put_nowait(event) + except Full: + # Queue is full, trigger immediate flush + logger.debug("Event queue full, triggering flush") + self._flush() + # Try again after flush + try: + self._events_queue.put_nowait(event) + except Full: + # Still full, drop event (acceptable for telemetry) + logger.debug("Dropped telemetry event - queue still full") + + # Check if we should flush based on queue size + if self._events_queue.qsize() >= self._batch_size: logger.debug( "Batch size limit reached (%s), flushing events", self._batch_size ) @@ -220,9 +243,16 @@ def _export_event(self, event): def _flush(self): """Flush the current batch of events to the server""" - with self._lock: - events_to_flush = self._events_batch.copy() - self._events_batch = [] + # OPTIMIZATION: Drain queue without locks + # Collect all events currently in the queue + events_to_flush = [] + while not self._events_queue.empty(): + try: + event = self._events_queue.get_nowait() + events_to_flush.append(event) + except: + # Queue is empty + break if events_to_flush: logger.debug("Flushing %s telemetry events to server", len(events_to_flush)) diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 6f5a01c7b..96a2f87d8 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -10,6 +10,10 @@ TelemetryClientFactory, TelemetryHelper, ) +from databricks.sql.common.feature_flag import ( + FeatureFlagsContextFactory, + FeatureFlagsContext, +) from databricks.sql.telemetry.models.enums import AuthMech, AuthFlow, DatabricksClientType from databricks.sql.telemetry.models.event import ( TelemetryEvent, @@ -82,12 +86,12 @@ def test_event_batching_and_flushing_flow(self, mock_telemetry_client): client._export_event("event1") client._export_event("event2") mock_send.assert_not_called() - assert len(client._events_batch) == 2 + assert client._events_queue.qsize() == 2 # Third event should trigger flush client._export_event("event3") mock_send.assert_called_once() - assert len(client._events_batch) == 0 # Batch cleared after flush + assert client._events_queue.qsize() == 0 # Queue cleared after flush @patch("databricks.sql.common.unified_http_client.UnifiedHttpClient.request") def test_network_request_flow(self, mock_http_request, mock_telemetry_client): @@ -817,7 +821,67 @@ def test_cf_proxy_fields_default_to_false_none(self, mock_setup_pools, mock_sess mock_export.assert_called_once() driver_params = mock_export.call_args.kwargs.get("driver_connection_params") - + # CF proxy not yet supported - should be False/None assert driver_params.use_cf_proxy is False assert driver_params.cf_proxy_host_info is None + + +class TestFeatureFlagsContextFactory: + """Tests for FeatureFlagsContextFactory host-level caching.""" + + @pytest.fixture(autouse=True) + def reset_factory(self): + """Reset factory state before/after each test.""" + FeatureFlagsContextFactory._context_map.clear() + if FeatureFlagsContextFactory._executor: + FeatureFlagsContextFactory._executor.shutdown(wait=False) + FeatureFlagsContextFactory._executor = None + yield + FeatureFlagsContextFactory._context_map.clear() + if FeatureFlagsContextFactory._executor: + FeatureFlagsContextFactory._executor.shutdown(wait=False) + FeatureFlagsContextFactory._executor = None + + @pytest.mark.parametrize( + "hosts,expected_contexts", + [ + (["host1.com", "host1.com"], 1), # Same host shares context + (["host1.com", "host2.com"], 2), # Different hosts get separate contexts + (["host1.com", "host1.com", "host2.com"], 2), # Mixed scenario + ], + ) + def test_host_level_caching(self, hosts, expected_contexts): + """Test that contexts are cached by host correctly.""" + contexts = [] + for host in hosts: + conn = MagicMock() + conn.session.host = host + conn.session.http_client = MagicMock() + contexts.append(FeatureFlagsContextFactory.get_instance(conn)) + + assert len(FeatureFlagsContextFactory._context_map) == expected_contexts + if expected_contexts == 1: + assert all(ctx is contexts[0] for ctx in contexts) + + def test_remove_instance_and_executor_cleanup(self): + """Test removal uses host key and cleans up executor when empty.""" + conn1 = MagicMock() + conn1.session.host = "host1.com" + conn1.session.http_client = MagicMock() + + conn2 = MagicMock() + conn2.session.host = "host2.com" + conn2.session.http_client = MagicMock() + + FeatureFlagsContextFactory.get_instance(conn1) + FeatureFlagsContextFactory.get_instance(conn2) + assert FeatureFlagsContextFactory._executor is not None + + FeatureFlagsContextFactory.remove_instance(conn1) + assert len(FeatureFlagsContextFactory._context_map) == 1 + assert FeatureFlagsContextFactory._executor is not None + + FeatureFlagsContextFactory.remove_instance(conn2) + assert len(FeatureFlagsContextFactory._context_map) == 0 + assert FeatureFlagsContextFactory._executor is None From d524f0e4bd730ec7397f3c0088c6e80761ebf8c9 Mon Sep 17 00:00:00 2001 From: nikhilsuri-db Date: Fri, 28 Nov 2025 13:33:02 +0530 Subject: [PATCH 4/6] basic e2e test for force telemetry verification (#708) * basic e2e test for force telemetry verification Signed-off-by: Nikhil Suri * Added more integration test scenarios Signed-off-by: Nikhil Suri * default on telemetry + logs to investigate failing test Signed-off-by: Nikhil Suri * fixed linting issue Signed-off-by: Nikhil Suri * added more logs to identify server side flag evaluation Signed-off-by: Nikhil Suri * remove unused logs Signed-off-by: Nikhil Suri * fix broken test case for default enable telemetry Signed-off-by: Nikhil Suri * redcude test length and made more reusable code Signed-off-by: Nikhil Suri * removed telemetry e2e to daily single run Signed-off-by: Nikhil Suri --------- Signed-off-by: Nikhil Suri --- .github/workflows/daily-telemetry-e2e.yml | 87 ++++++ .github/workflows/integration.yml | 8 +- tests/e2e/test_telemetry_e2e.py | 343 ++++++++++++++++++++++ 3 files changed, 436 insertions(+), 2 deletions(-) create mode 100644 .github/workflows/daily-telemetry-e2e.yml create mode 100644 tests/e2e/test_telemetry_e2e.py diff --git a/.github/workflows/daily-telemetry-e2e.yml b/.github/workflows/daily-telemetry-e2e.yml new file mode 100644 index 000000000..3d61cf177 --- /dev/null +++ b/.github/workflows/daily-telemetry-e2e.yml @@ -0,0 +1,87 @@ +name: Daily Telemetry E2E Tests + +on: + schedule: + - cron: '0 0 * * 0' # Run every Sunday at midnight UTC + + workflow_dispatch: # Allow manual triggering + inputs: + test_pattern: + description: 'Test pattern to run (default: tests/e2e/test_telemetry_e2e.py)' + required: false + default: 'tests/e2e/test_telemetry_e2e.py' + type: string + +jobs: + telemetry-e2e-tests: + runs-on: ubuntu-latest + environment: azure-prod + + env: + DATABRICKS_SERVER_HOSTNAME: ${{ secrets.DATABRICKS_HOST }} + DATABRICKS_HTTP_PATH: ${{ secrets.TEST_PECO_WAREHOUSE_HTTP_PATH }} + DATABRICKS_TOKEN: ${{ secrets.DATABRICKS_TOKEN }} + DATABRICKS_CATALOG: peco + DATABRICKS_USER: ${{ secrets.TEST_PECO_SP_ID }} + + steps: + #---------------------------------------------- + # check-out repo and set-up python + #---------------------------------------------- + - name: Check out repository + uses: actions/checkout@v4 + + - name: Set up python + id: setup-python + uses: actions/setup-python@v5 + with: + python-version: "3.10" + + #---------------------------------------------- + # ----- install & configure poetry ----- + #---------------------------------------------- + - name: Install Poetry + uses: snok/install-poetry@v1 + with: + virtualenvs-create: true + virtualenvs-in-project: true + installer-parallel: true + + #---------------------------------------------- + # load cached venv if cache exists + #---------------------------------------------- + - name: Load cached venv + id: cached-poetry-dependencies + uses: actions/cache@v4 + with: + path: .venv + key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} + + #---------------------------------------------- + # install dependencies if cache does not exist + #---------------------------------------------- + - name: Install dependencies + run: poetry install --no-interaction --all-extras + + #---------------------------------------------- + # run telemetry E2E tests + #---------------------------------------------- + - name: Run telemetry E2E tests + run: | + TEST_PATTERN="${{ github.event.inputs.test_pattern || 'tests/e2e/test_telemetry_e2e.py' }}" + echo "Running tests: $TEST_PATTERN" + poetry run python -m pytest $TEST_PATTERN -v -s + + #---------------------------------------------- + # upload test results on failure + #---------------------------------------------- + - name: Upload test results on failure + if: failure() + uses: actions/upload-artifact@v4 + with: + name: telemetry-test-results + path: | + .pytest_cache/ + tests-unsafe.log + retention-days: 7 + diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 9c9e30a24..ad5369997 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -54,5 +54,9 @@ jobs: #---------------------------------------------- # run test suite #---------------------------------------------- - - name: Run e2e tests - run: poetry run python -m pytest tests/e2e -n auto \ No newline at end of file + - name: Run e2e tests (excluding daily-only tests) + run: | + # Exclude telemetry E2E tests from PR runs (run daily instead) + poetry run python -m pytest tests/e2e \ + --ignore=tests/e2e/test_telemetry_e2e.py \ + -n auto \ No newline at end of file diff --git a/tests/e2e/test_telemetry_e2e.py b/tests/e2e/test_telemetry_e2e.py new file mode 100644 index 000000000..917c8e5eb --- /dev/null +++ b/tests/e2e/test_telemetry_e2e.py @@ -0,0 +1,343 @@ +""" +E2E test for telemetry - verifies telemetry behavior with different scenarios +""" +import time +import threading +import logging +from contextlib import contextmanager +from unittest.mock import patch +import pytest +from concurrent.futures import wait + +import databricks.sql as sql +from databricks.sql.telemetry.telemetry_client import ( + TelemetryClient, + TelemetryClientFactory, +) + +log = logging.getLogger(__name__) + + +class TelemetryTestBase: + """Simplified test base class for telemetry e2e tests""" + + @pytest.fixture(autouse=True) + def get_details(self, connection_details): + self.arguments = connection_details.copy() + + def connection_params(self): + return { + "server_hostname": self.arguments["host"], + "http_path": self.arguments["http_path"], + "access_token": self.arguments.get("access_token"), + } + + @contextmanager + def connection(self, extra_params=()): + connection_params = dict(self.connection_params(), **dict(extra_params)) + log.info("Connecting with args: {}".format(connection_params)) + conn = sql.connect(**connection_params) + try: + yield conn + finally: + conn.close() + + +class TestTelemetryE2E(TelemetryTestBase): + """E2E tests for telemetry scenarios""" + + @pytest.fixture(autouse=True) + def telemetry_setup_teardown(self): + """Clean up telemetry client state before and after each test""" + try: + yield + finally: + if TelemetryClientFactory._executor: + TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryClientFactory._executor = None + TelemetryClientFactory._stop_flush_thread() + TelemetryClientFactory._initialized = False + + @pytest.fixture + def telemetry_interceptors(self): + """Setup reusable telemetry interceptors as a fixture""" + capture_lock = threading.Lock() + captured_events = [] + captured_futures = [] + + original_export = TelemetryClient._export_event + original_callback = TelemetryClient._telemetry_request_callback + + def export_wrapper(self_client, event): + with capture_lock: + captured_events.append(event) + return original_export(self_client, event) + + def callback_wrapper(self_client, future, sent_count): + with capture_lock: + captured_futures.append(future) + original_callback(self_client, future, sent_count) + + return captured_events, captured_futures, export_wrapper, callback_wrapper + + # ==================== ASSERTION HELPERS ==================== + + def assert_system_config(self, event): + """Assert system configuration fields""" + sys_config = event.entry.sql_driver_log.system_configuration + assert sys_config is not None + + # Check all required fields are non-empty + for field in ['driver_name', 'driver_version', 'os_name', 'os_version', + 'os_arch', 'runtime_name', 'runtime_version', 'runtime_vendor', + 'locale_name', 'char_set_encoding']: + value = getattr(sys_config, field) + assert value and len(value) > 0, f"{field} should not be None or empty" + + assert sys_config.driver_name == "Databricks SQL Python Connector" + + def assert_connection_params(self, event, expected_http_path=None): + """Assert connection parameters""" + conn_params = event.entry.sql_driver_log.driver_connection_params + assert conn_params is not None + assert conn_params.http_path + assert conn_params.host_info is not None + assert conn_params.auth_mech is not None + + if expected_http_path: + assert conn_params.http_path == expected_http_path + + if conn_params.socket_timeout is not None: + assert conn_params.socket_timeout > 0 + + def assert_statement_execution(self, event): + """Assert statement execution details""" + sql_op = event.entry.sql_driver_log.sql_operation + assert sql_op is not None + assert sql_op.statement_type is not None + assert sql_op.execution_result is not None + assert hasattr(sql_op, "retry_count") + + if sql_op.retry_count is not None: + assert sql_op.retry_count >= 0 + + latency = event.entry.sql_driver_log.operation_latency_ms + assert latency is not None and latency >= 0 + + def assert_error_info(self, event, expected_error_name=None): + """Assert error information""" + error_info = event.entry.sql_driver_log.error_info + assert error_info is not None + assert error_info.error_name and len(error_info.error_name) > 0 + assert error_info.stack_trace and len(error_info.stack_trace) > 0 + + if expected_error_name: + assert error_info.error_name == expected_error_name + + def verify_events(self, captured_events, captured_futures, expected_count): + """Common verification for event count and HTTP responses""" + if expected_count == 0: + assert len(captured_events) == 0, f"Expected 0 events, got {len(captured_events)}" + assert len(captured_futures) == 0, f"Expected 0 responses, got {len(captured_futures)}" + else: + assert len(captured_events) == expected_count, \ + f"Expected {expected_count} events, got {len(captured_events)}" + + time.sleep(2) + done, _ = wait(captured_futures, timeout=10) + assert len(done) == expected_count, \ + f"Expected {expected_count} responses, got {len(done)}" + + for future in done: + response = future.result() + assert 200 <= response.status < 300 + + # Assert common fields for all events + for event in captured_events: + self.assert_system_config(event) + self.assert_connection_params(event, self.arguments["http_path"]) + + # ==================== PARAMETERIZED TESTS ==================== + + @pytest.mark.parametrize("enable_telemetry,force_enable,expected_count,test_id", [ + (True, False, 2, "enable_on_force_off"), + (False, True, 2, "enable_off_force_on"), + (False, False, 0, "both_off"), + (None, None, 0, "default_behavior"), + ]) + def test_telemetry_flags(self, telemetry_interceptors, enable_telemetry, + force_enable, expected_count, test_id): + """Test telemetry behavior with different flag combinations""" + captured_events, captured_futures, export_wrapper, callback_wrapper = \ + telemetry_interceptors + + with patch.object(TelemetryClient, "_export_event", export_wrapper), \ + patch.object(TelemetryClient, "_telemetry_request_callback", callback_wrapper): + + extra_params = {"telemetry_batch_size": 1} + if enable_telemetry is not None: + extra_params["enable_telemetry"] = enable_telemetry + if force_enable is not None: + extra_params["force_enable_telemetry"] = force_enable + + with self.connection(extra_params=extra_params) as conn: + with conn.cursor() as cursor: + cursor.execute("SELECT 1") + cursor.fetchone() + + self.verify_events(captured_events, captured_futures, expected_count) + + # Assert statement execution on latency event (if events exist) + if expected_count > 0: + self.assert_statement_execution(captured_events[-1]) + + @pytest.mark.parametrize("query,expected_error", [ + ("SELECT * FROM WHERE INVALID SYNTAX 12345", "ServerOperationError"), + ("SELECT * FROM non_existent_table_xyz_12345", None), + ]) + def test_sql_errors(self, telemetry_interceptors, query, expected_error): + """Test telemetry captures error information for different SQL errors""" + captured_events, captured_futures, export_wrapper, callback_wrapper = \ + telemetry_interceptors + + with patch.object(TelemetryClient, "_export_event", export_wrapper), \ + patch.object(TelemetryClient, "_telemetry_request_callback", callback_wrapper): + + with self.connection(extra_params={ + "force_enable_telemetry": True, + "telemetry_batch_size": 1, + }) as conn: + with conn.cursor() as cursor: + with pytest.raises(Exception): + cursor.execute(query) + cursor.fetchone() + + time.sleep(2) + wait(captured_futures, timeout=10) + + assert len(captured_events) >= 1 + + # Find event with error_info + error_event = next((e for e in captured_events + if e.entry.sql_driver_log.error_info), None) + assert error_event is not None + + self.assert_system_config(error_event) + self.assert_connection_params(error_event, self.arguments["http_path"]) + self.assert_error_info(error_event, expected_error) + + def test_metadata_operation(self, telemetry_interceptors): + """Test telemetry for metadata operations (getCatalogs)""" + captured_events, captured_futures, export_wrapper, callback_wrapper = \ + telemetry_interceptors + + with patch.object(TelemetryClient, "_export_event", export_wrapper), \ + patch.object(TelemetryClient, "_telemetry_request_callback", callback_wrapper): + + with self.connection(extra_params={ + "force_enable_telemetry": True, + "telemetry_batch_size": 1, + }) as conn: + with conn.cursor() as cursor: + catalogs = cursor.catalogs() + catalogs.fetchall() + + time.sleep(2) + wait(captured_futures, timeout=10) + + assert len(captured_events) >= 1 + for event in captured_events: + self.assert_system_config(event) + self.assert_connection_params(event, self.arguments["http_path"]) + + def test_direct_results(self, telemetry_interceptors): + """Test telemetry with direct results (use_cloud_fetch=False)""" + captured_events, captured_futures, export_wrapper, callback_wrapper = \ + telemetry_interceptors + + with patch.object(TelemetryClient, "_export_event", export_wrapper), \ + patch.object(TelemetryClient, "_telemetry_request_callback", callback_wrapper): + + with self.connection(extra_params={ + "force_enable_telemetry": True, + "telemetry_batch_size": 1, + "use_cloud_fetch": False, + }) as conn: + with conn.cursor() as cursor: + cursor.execute("SELECT 100") + result = cursor.fetchall() + assert len(result) == 1 and result[0][0] == 100 + + time.sleep(2) + wait(captured_futures, timeout=10) + + assert len(captured_events) >= 2 + for event in captured_events: + self.assert_system_config(event) + self.assert_connection_params(event, self.arguments["http_path"]) + + self.assert_statement_execution(captured_events[-1]) + + @pytest.mark.parametrize("close_type", [ + "context_manager", + "explicit_cursor", + "explicit_connection", + "implicit_fetchall", + ]) + def test_cloudfetch_with_different_close_patterns(self, telemetry_interceptors, + close_type): + """Test telemetry with cloud fetch using different resource closing patterns""" + captured_events, captured_futures, export_wrapper, callback_wrapper = \ + telemetry_interceptors + + with patch.object(TelemetryClient, "_export_event", export_wrapper), \ + patch.object(TelemetryClient, "_telemetry_request_callback", callback_wrapper): + + if close_type == "explicit_connection": + # Test explicit connection close + conn = sql.connect( + **self.connection_params(), + force_enable_telemetry=True, + telemetry_batch_size=1, + use_cloud_fetch=True, + ) + cursor = conn.cursor() + cursor.execute("SELECT * FROM range(1000)") + result = cursor.fetchall() + assert len(result) == 1000 + conn.close() + else: + # Other patterns use connection context manager + with self.connection(extra_params={ + "force_enable_telemetry": True, + "telemetry_batch_size": 1, + "use_cloud_fetch": True, + }) as conn: + if close_type == "context_manager": + with conn.cursor() as cursor: + cursor.execute("SELECT * FROM range(1000)") + result = cursor.fetchall() + assert len(result) == 1000 + + elif close_type == "explicit_cursor": + cursor = conn.cursor() + cursor.execute("SELECT * FROM range(1000)") + result = cursor.fetchall() + assert len(result) == 1000 + cursor.close() + + elif close_type == "implicit_fetchall": + cursor = conn.cursor() + cursor.execute("SELECT * FROM range(1000)") + result = cursor.fetchall() + assert len(result) == 1000 + + time.sleep(2) + wait(captured_futures, timeout=10) + + assert len(captured_events) >= 2 + for event in captured_events: + self.assert_system_config(event) + self.assert_connection_params(event, self.arguments["http_path"]) + + self.assert_statement_execution(captured_events[-1]) From ebe4b07a83499bb725338a10cf18444b31e3ec44 Mon Sep 17 00:00:00 2001 From: Samikshya Chand <148681192+samikshya-db@users.noreply.github.com> Date: Wed, 3 Dec 2025 20:22:39 +0530 Subject: [PATCH 5/6] feat: Implement host-level telemetry batching to reduce rate limiting (#718) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: Implement host-level telemetry batching to reduce rate limiting Changes telemetry client architecture from per-session to per-host batching, matching the JDBC driver implementation. This reduces the number of HTTP requests to the telemetry endpoint and prevents rate limiting in test environments. Key changes: - Add _TelemetryClientHolder with reference counting for shared clients - Change TelemetryClientFactory to key clients by host_url instead of session_id - Add getHostUrlSafely() helper for defensive null handling - Update all callers (client.py, exc.py, latency_logger.py) to pass host_url Before: 100 connections to same host = 100 separate TelemetryClients After: 100 connections to same host = 1 shared TelemetryClient (refcount=100) This fixes rate limiting issues seen in e2e tests where 300+ parallel connections were overwhelming the telemetry endpoint with 429 errors. * chore: Change all telemetry logging to DEBUG level Reduces log noise by changing all telemetry-related log statements (info, warning, error) to debug level. Telemetry operations are background tasks and should not clutter logs with operational messages. Changes: - Circuit breaker state changes: info/warning -> debug - Telemetry send failures: error -> debug - All telemetry operations now consistently use debug level * chore: Fix remaining telemetry warning log to debug Changes remaining logger.warning in telemetry_push_client.py to debug level for consistency with other telemetry logging. * fix: Update tests to use host_url instead of session_id_hex - Update circuit breaker test to check logger.debug instead of logger.info - Replace all session_id_hex test parameters with host_url - Apply Black formatting to exc.py and telemetry_client.py This fixes test failures caused by the signature change from session_id_hex to host_url in the Error class and TelemetryClientFactory. * fix: Revert session_id_hex in tests for functions that still use it Only Error classes changed from session_id_hex to host_url. Other classes (TelemetryClient, ResultSetDownloadHandler, etc.) still use session_id_hex. Reverted: - test_telemetry.py: TelemetryClient and initialize_telemetry_client - test_downloader.py: ResultSetDownloadHandler - test_download_manager.py: ResultFileDownloadManager Kept as host_url: - test_client.py: Error class instantiation * fix: Update all Error raises and test calls to use host_url Changes: 1. client.py: Changed all error raises from session_id_hex to host_url - Connection class: session_id_hex=self.get_session_id_hex() -> host_url=self.session.host - Cursor class: session_id_hex=self.connection.get_session_id_hex() -> host_url=self.connection.session.host 2. test_telemetry.py: Updated get_telemetry_client() and close() calls - get_telemetry_client(session_id) -> get_telemetry_client(host_url) - close(session_id) -> close(host_url=host_url) 3. test_telemetry_push_client.py: Changed logger.warning to logger.debug - Updated test assertion to match debug logging level These changes complete the migration from session-level to host-level telemetry client management. * fix: Update thrift_backend.py to use host_url instead of session_id_hex Changes: 1. Added self._host attribute to store server_hostname 2. Updated all error raises to use host_url=self._host 3. Changed method signatures from session_id_hex to host_url: - _check_response_for_error - _hive_schema_to_arrow_schema - _col_to_description - _hive_schema_to_description - _check_direct_results_for_error 4. Updated all method calls to pass self._host instead of self._session_id_hex This completes the migration from session-level to host-level error reporting. * Fix Black formatting by adjusting fmt directive placement Moved the `# fmt: on` directive to the except block level instead of inside the if statement to resolve Black parsing confusion. * Fix telemetry feature flag tests to set mock session host The tests were failing because they called get_telemetry_client("test") but the mock session didn't have .host set, so the telemetry client was registered under a different key (likely None or MagicMock). This caused the factory to return NoopTelemetryClient instead of the expected client. Fixed by setting mock_session_instance.host = "test" in all three tests. * Add teardown_method to clear telemetry factory state between tests Without this cleanup, tests were sharing telemetry clients because they all used the same host key ("test"), causing test pollution. The first test would create an enabled client, and subsequent tests would reuse it even when they expected a disabled client. * Clear feature flag context cache in teardown to fix test pollution The FeatureFlagsContextFactory caches feature flag contexts per session, causing tests to share the same feature flag state. This resulted in the first test creating a context with telemetry enabled, and subsequent tests incorrectly reusing that enabled state even when they expected disabled. * fix: Access actual client from holder in flush worker The flush worker was calling _flush() on _TelemetryClientHolder objects instead of the actual TelemetryClient. Fixed by accessing holder.client before calling _flush(). Fixes AttributeError in e2e tests: '_TelemetryClientHolder' object has no attribute '_flush' * Clear telemetry client cache in e2e test teardown Added _clients.clear() to the teardown fixture to prevent telemetry clients from persisting across e2e tests, which was causing session ID pollution in test_concurrent_queries_sends_telemetry. * Pass session_id parameter to telemetry export methods With host-level telemetry batching, multiple connections share one TelemetryClient. Each client stores session_id_hex from the first connection that created it. This caused all subsequent connections' telemetry events to use the wrong session ID. Changes: - Modified telemetry export method signatures to accept optional session_id - Updated Connection.export_initial_telemetry_log() to pass session_id - Updated latency_logger.py export_latency_log() to pass session_id - Updated Error.__init__() to accept optional session_id_hex and pass it - Updated all error raises in Connection and Cursor to pass session_id_hex 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude * Fix Black formatting in telemetry_client.py * Use 'test-host' instead of 'test' for mock host in telemetry tests * Replace test-session-id with test-host in test_client.py * Fix telemetry client lookup to use test-host in tests * Make session_id_hex keyword-only parameter in Error.__init__ --------- Co-authored-by: Claude --- src/databricks/sql/backend/thrift_backend.py | 72 ++++--- src/databricks/sql/client.py | 36 +++- src/databricks/sql/exc.py | 16 +- .../sql/telemetry/circuit_breaker_manager.py | 8 +- .../sql/telemetry/latency_logger.py | 3 +- .../sql/telemetry/telemetry_client.py | 186 ++++++++++++++---- .../sql/telemetry/telemetry_push_client.py | 2 +- tests/e2e/test_concurrent_telemetry.py | 1 + tests/unit/test_circuit_breaker_manager.py | 2 +- tests/unit/test_client.py | 8 +- tests/unit/test_telemetry.py | 28 ++- tests/unit/test_telemetry_push_client.py | 8 +- 12 files changed, 262 insertions(+), 108 deletions(-) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index d2b10e718..edee02bfa 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -163,6 +163,7 @@ def __init__( else: raise ValueError("No valid connection settings.") + self._host = server_hostname self._initialize_retry_args(kwargs) self._use_arrow_native_complex_types = kwargs.get( "_use_arrow_native_complex_types", True @@ -279,14 +280,14 @@ def _initialize_retry_args(self, kwargs): ) @staticmethod - def _check_response_for_error(response, session_id_hex=None): + def _check_response_for_error(response, host_url=None): if response.status and response.status.statusCode in [ ttypes.TStatusCode.ERROR_STATUS, ttypes.TStatusCode.INVALID_HANDLE_STATUS, ]: raise DatabaseError( response.status.errorMessage, - session_id_hex=session_id_hex, + host_url=host_url, ) @staticmethod @@ -340,7 +341,7 @@ def _handle_request_error(self, error_info, attempt, elapsed): network_request_error = RequestError( user_friendly_error_message, full_error_info_context, - self._session_id_hex, + self._host, error_info.error, ) logger.info(network_request_error.message_with_context()) @@ -461,13 +462,12 @@ def attempt_request(attempt): errno.ECONNRESET, # | 104 | 54 | errno.ETIMEDOUT, # | 110 | 60 | ] + # fmt: on gos_name = TCLIServiceClient.GetOperationStatus.__name__ # retry on timeout. Happens a lot in Azure and it is safe as data has not been sent to server yet if method.__name__ == gos_name or err.errno == errno.ETIMEDOUT: retry_delay = bound_retry_delay(attempt, self._retry_delay_default) - - # fmt: on log_string = f"{gos_name} failed with code {err.errno} and will attempt to retry" if err.errno in info_errs: logger.info(log_string) @@ -516,9 +516,7 @@ def attempt_request(attempt): if not isinstance(response_or_error_info, RequestErrorInfo): # log nothing here, presume that main request logging covers response = response_or_error_info - ThriftDatabricksClient._check_response_for_error( - response, self._session_id_hex - ) + ThriftDatabricksClient._check_response_for_error(response, self._host) return response error_info = response_or_error_info @@ -533,7 +531,7 @@ def _check_protocol_version(self, t_open_session_resp): "Error: expected server to use a protocol version >= " "SPARK_CLI_SERVICE_PROTOCOL_V2, " "instead got: {}".format(protocol_version), - session_id_hex=self._session_id_hex, + host_url=self._host, ) def _check_initial_namespace(self, catalog, schema, response): @@ -547,7 +545,7 @@ def _check_initial_namespace(self, catalog, schema, response): raise InvalidServerResponseError( "Setting initial namespace not supported by the DBR version, " "Please use a Databricks SQL endpoint or a cluster with DBR >= 9.0.", - session_id_hex=self._session_id_hex, + host_url=self._host, ) if catalog: @@ -555,7 +553,7 @@ def _check_initial_namespace(self, catalog, schema, response): raise InvalidServerResponseError( "Unexpected response from server: Trying to set initial catalog to {}, " + "but server does not support multiple catalogs.".format(catalog), # type: ignore - session_id_hex=self._session_id_hex, + host_url=self._host, ) def _check_session_configuration(self, session_configuration): @@ -570,7 +568,7 @@ def _check_session_configuration(self, session_configuration): TIMESTAMP_AS_STRING_CONFIG, session_configuration[TIMESTAMP_AS_STRING_CONFIG], ), - session_id_hex=self._session_id_hex, + host_url=self._host, ) def open_session(self, session_configuration, catalog, schema) -> SessionId: @@ -639,7 +637,7 @@ def _check_command_not_in_error_or_closed_state( and guid_to_hex_id(op_handle.operationId.guid), "diagnostic-info": get_operations_resp.diagnosticInfo, }, - session_id_hex=self._session_id_hex, + host_url=self._host, ) else: raise ServerOperationError( @@ -649,7 +647,7 @@ def _check_command_not_in_error_or_closed_state( and guid_to_hex_id(op_handle.operationId.guid), "diagnostic-info": None, }, - session_id_hex=self._session_id_hex, + host_url=self._host, ) elif get_operations_resp.operationState == ttypes.TOperationState.CLOSED_STATE: raise DatabaseError( @@ -660,7 +658,7 @@ def _check_command_not_in_error_or_closed_state( "operation-id": op_handle and guid_to_hex_id(op_handle.operationId.guid) }, - session_id_hex=self._session_id_hex, + host_url=self._host, ) def _poll_for_status(self, op_handle): @@ -683,7 +681,7 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti else: raise OperationalError( "Unsupported TRowSet instance {}".format(t_row_set), - session_id_hex=self._session_id_hex, + host_url=self._host, ) return convert_decimals_in_arrow_table(arrow_table, description), num_rows @@ -692,7 +690,7 @@ def _get_metadata_resp(self, op_handle): return self.make_request(self._client.GetResultSetMetadata, req) @staticmethod - def _hive_schema_to_arrow_schema(t_table_schema, session_id_hex=None): + def _hive_schema_to_arrow_schema(t_table_schema, host_url=None): def map_type(t_type_entry): if t_type_entry.primitiveEntry: return { @@ -724,7 +722,7 @@ def map_type(t_type_entry): # even for complex types raise OperationalError( "Thrift protocol error: t_type_entry not a primitiveEntry", - session_id_hex=session_id_hex, + host_url=host_url, ) def convert_col(t_column_desc): @@ -735,7 +733,7 @@ def convert_col(t_column_desc): return pyarrow.schema([convert_col(col) for col in t_table_schema.columns]) @staticmethod - def _col_to_description(col, field=None, session_id_hex=None): + def _col_to_description(col, field=None, host_url=None): type_entry = col.typeDesc.types[0] if type_entry.primitiveEntry: @@ -745,7 +743,7 @@ def _col_to_description(col, field=None, session_id_hex=None): else: raise OperationalError( "Thrift protocol error: t_type_entry not a primitiveEntry", - session_id_hex=session_id_hex, + host_url=host_url, ) if type_entry.primitiveEntry.type == ttypes.TTypeId.DECIMAL_TYPE: @@ -759,7 +757,7 @@ def _col_to_description(col, field=None, session_id_hex=None): raise OperationalError( "Decimal type did not provide typeQualifier precision, scale in " "primitiveEntry {}".format(type_entry.primitiveEntry), - session_id_hex=session_id_hex, + host_url=host_url, ) else: precision, scale = None, None @@ -778,9 +776,7 @@ def _col_to_description(col, field=None, session_id_hex=None): return col.columnName, cleaned_type, None, None, precision, scale, None @staticmethod - def _hive_schema_to_description( - t_table_schema, schema_bytes=None, session_id_hex=None - ): + def _hive_schema_to_description(t_table_schema, schema_bytes=None, host_url=None): field_dict = {} if pyarrow and schema_bytes: try: @@ -795,7 +791,7 @@ def _hive_schema_to_description( ThriftDatabricksClient._col_to_description( col, field_dict.get(col.columnName) if field_dict else None, - session_id_hex, + host_url, ) for col in t_table_schema.columns ] @@ -818,7 +814,7 @@ def _results_message_to_execute_response(self, resp, operation_state): t_result_set_metadata_resp.resultFormat ] ), - session_id_hex=self._session_id_hex, + host_url=self._host, ) direct_results = resp.directResults has_been_closed_server_side = direct_results and direct_results.closeOperation @@ -833,7 +829,7 @@ def _results_message_to_execute_response(self, resp, operation_state): schema_bytes = ( t_result_set_metadata_resp.arrowSchema or self._hive_schema_to_arrow_schema( - t_result_set_metadata_resp.schema, self._session_id_hex + t_result_set_metadata_resp.schema, self._host ) .serialize() .to_pybytes() @@ -844,7 +840,7 @@ def _results_message_to_execute_response(self, resp, operation_state): description = self._hive_schema_to_description( t_result_set_metadata_resp.schema, schema_bytes, - self._session_id_hex, + self._host, ) lz4_compressed = t_result_set_metadata_resp.lz4Compressed @@ -895,7 +891,7 @@ def get_execution_result( schema_bytes = ( t_result_set_metadata_resp.arrowSchema or self._hive_schema_to_arrow_schema( - t_result_set_metadata_resp.schema, self._session_id_hex + t_result_set_metadata_resp.schema, self._host ) .serialize() .to_pybytes() @@ -906,7 +902,7 @@ def get_execution_result( description = self._hive_schema_to_description( t_result_set_metadata_resp.schema, schema_bytes, - self._session_id_hex, + self._host, ) lz4_compressed = t_result_set_metadata_resp.lz4Compressed @@ -971,27 +967,27 @@ def get_query_state(self, command_id: CommandId) -> CommandState: return state @staticmethod - def _check_direct_results_for_error(t_spark_direct_results, session_id_hex=None): + def _check_direct_results_for_error(t_spark_direct_results, host_url=None): if t_spark_direct_results: if t_spark_direct_results.operationStatus: ThriftDatabricksClient._check_response_for_error( t_spark_direct_results.operationStatus, - session_id_hex, + host_url, ) if t_spark_direct_results.resultSetMetadata: ThriftDatabricksClient._check_response_for_error( t_spark_direct_results.resultSetMetadata, - session_id_hex, + host_url, ) if t_spark_direct_results.resultSet: ThriftDatabricksClient._check_response_for_error( t_spark_direct_results.resultSet, - session_id_hex, + host_url, ) if t_spark_direct_results.closeOperation: ThriftDatabricksClient._check_response_for_error( t_spark_direct_results.closeOperation, - session_id_hex, + host_url, ) def execute_command( @@ -1260,7 +1256,7 @@ def _handle_execute_response(self, resp, cursor): raise ValueError(f"Invalid Thrift handle: {resp.operationHandle}") cursor.active_command_id = command_id - self._check_direct_results_for_error(resp.directResults, self._session_id_hex) + self._check_direct_results_for_error(resp.directResults, self._host) final_operation_state = self._wait_until_command_done( resp.operationHandle, @@ -1275,7 +1271,7 @@ def _handle_execute_response_async(self, resp, cursor): raise ValueError(f"Invalid Thrift handle: {resp.operationHandle}") cursor.active_command_id = command_id - self._check_direct_results_for_error(resp.directResults, self._session_id_hex) + self._check_direct_results_for_error(resp.directResults, self._host) def fetch_results( self, @@ -1313,7 +1309,7 @@ def fetch_results( "fetch_results failed due to inconsistency in the state between the client and the server. Expected results to start from {} but they instead start at {}, some result batches must have been skipped".format( expected_row_start_offset, resp.results.startRowOffset ), - session_id_hex=self._session_id_hex, + host_url=self._host, ) queue = ThriftResultSetQueueFactory.build_queue( diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index a7f802dcd..c873700bc 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -341,7 +341,7 @@ def read(self) -> Optional[OAuthToken]: ) self._telemetry_client = TelemetryClientFactory.get_telemetry_client( - session_id_hex=self.get_session_id_hex() + host_url=self.session.host ) # Determine proxy usage @@ -391,6 +391,7 @@ def read(self) -> Optional[OAuthToken]: self._telemetry_client.export_initial_telemetry_log( driver_connection_params=driver_connection_params, user_agent=self.session.useragent_header, + session_id=self.get_session_id_hex(), ) def _set_use_inline_params_with_warning(self, value: Union[bool, str]): @@ -494,6 +495,7 @@ def cursor( if not self.open: raise InterfaceError( "Cannot create cursor from closed connection", + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) @@ -521,7 +523,7 @@ def _close(self, close_cursors=True) -> None: except Exception as e: logger.error(f"Attempt to close session raised a local exception: {e}") - TelemetryClientFactory.close(self.get_session_id_hex()) + TelemetryClientFactory.close(host_url=self.session.host) # Close HTTP client that was created by this connection if self.http_client: @@ -546,6 +548,7 @@ def autocommit(self) -> bool: if not self.open: raise InterfaceError( "Cannot get autocommit on closed connection", + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) @@ -578,6 +581,7 @@ def autocommit(self, value: bool) -> None: if not self.open: raise InterfaceError( "Cannot set autocommit on closed connection", + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) @@ -600,6 +604,7 @@ def autocommit(self, value: bool) -> None: "operation": "set_autocommit", "autocommit_value": value, }, + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) from e finally: @@ -627,6 +632,7 @@ def _fetch_autocommit_state_from_server(self) -> bool: raise TransactionError( "No result returned from SET AUTOCOMMIT query", context={"operation": "fetch_autocommit"}, + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) @@ -647,6 +653,7 @@ def _fetch_autocommit_state_from_server(self) -> bool: raise TransactionError( f"Failed to fetch autocommit state from server: {e.message}", context={**e.context, "operation": "fetch_autocommit"}, + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) from e finally: @@ -680,6 +687,7 @@ def commit(self) -> None: if not self.open: raise InterfaceError( "Cannot commit on closed connection", + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) @@ -692,6 +700,7 @@ def commit(self) -> None: raise TransactionError( f"Failed to commit transaction: {e.message}", context={**e.context, "operation": "commit"}, + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) from e finally: @@ -725,12 +734,14 @@ def rollback(self) -> None: if self.ignore_transactions: raise NotSupportedError( "Transactions are not supported on Databricks", + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) if not self.open: raise InterfaceError( "Cannot rollback on closed connection", + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) @@ -743,6 +754,7 @@ def rollback(self) -> None: raise TransactionError( f"Failed to rollback transaction: {e.message}", context={**e.context, "operation": "rollback"}, + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) from e finally: @@ -767,6 +779,7 @@ def get_transaction_isolation(self) -> str: if not self.open: raise InterfaceError( "Cannot get transaction isolation on closed connection", + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) @@ -793,6 +806,7 @@ def set_transaction_isolation(self, level: str) -> None: if not self.open: raise InterfaceError( "Cannot set transaction isolation on closed connection", + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) @@ -805,6 +819,7 @@ def set_transaction_isolation(self, level: str) -> None: raise NotSupportedError( f"Setting transaction isolation level '{level}' is not supported. " f"Only {TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ} is supported.", + host_url=self.session.host, session_id_hex=self.get_session_id_hex(), ) @@ -857,6 +872,7 @@ def __iter__(self): else: raise ProgrammingError( "There is no active result set", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -997,6 +1013,7 @@ def _check_not_closed(self): if not self.open: raise InterfaceError( "Attempting operation on closed cursor", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1041,6 +1058,7 @@ def _handle_staging_operation( else: raise ProgrammingError( "You must provide at least one staging_allowed_local_path when initialising a connection to perform ingestion commands", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1067,6 +1085,7 @@ def _handle_staging_operation( if not allow_operation: raise ProgrammingError( "Local file operations are restricted to paths within the configured staging_allowed_local_path", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1095,6 +1114,7 @@ def _handle_staging_operation( raise ProgrammingError( f"Operation {row.operation} is not supported. " + "Supported operations are GET, PUT, and REMOVE", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1110,6 +1130,7 @@ def _handle_staging_put( if local_file is None: raise ProgrammingError( "Cannot perform PUT without specifying a local_file", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1135,6 +1156,7 @@ def _handle_staging_http_response(self, r): error_text = r.data.decode() if r.data else "" raise OperationalError( f"Staging operation over HTTP was unsuccessful: {r.status}-{error_text}", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1166,6 +1188,7 @@ def _handle_staging_put_stream( if not stream: raise ProgrammingError( "No input stream provided for streaming operation", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1187,6 +1210,7 @@ def _handle_staging_get( if local_file is None: raise ProgrammingError( "Cannot perform GET without specifying a local_file", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1201,6 +1225,7 @@ def _handle_staging_get( error_text = r.data.decode() if r.data else "" raise OperationalError( f"Staging operation over HTTP was unsuccessful: {r.status}-{error_text}", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1222,6 +1247,7 @@ def _handle_staging_remove( error_text = r.data.decode() if r.data else "" raise OperationalError( f"Staging operation over HTTP was unsuccessful: {r.status}-{error_text}", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1413,6 +1439,7 @@ def get_async_execution_result(self): else: raise OperationalError( f"get_execution_result failed with Operation status {operation_state}", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1541,6 +1568,7 @@ def fetchall(self) -> List[Row]: else: raise ProgrammingError( "There is no active result set", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1558,6 +1586,7 @@ def fetchone(self) -> Optional[Row]: else: raise ProgrammingError( "There is no active result set", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1583,6 +1612,7 @@ def fetchmany(self, size: int) -> List[Row]: else: raise ProgrammingError( "There is no active result set", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1593,6 +1623,7 @@ def fetchall_arrow(self) -> "pyarrow.Table": else: raise ProgrammingError( "There is no active result set", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) @@ -1603,6 +1634,7 @@ def fetchmany_arrow(self, size) -> "pyarrow.Table": else: raise ProgrammingError( "There is no active result set", + host_url=self.connection.session.host, session_id_hex=self.connection.get_session_id_hex(), ) diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index 24844d573..f4770f3c4 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -12,20 +12,28 @@ class Error(Exception): """ def __init__( - self, message=None, context=None, session_id_hex=None, *args, **kwargs + self, + message=None, + context=None, + host_url=None, + *args, + session_id_hex=None, + **kwargs, ): super().__init__(message, *args, **kwargs) self.message = message self.context = context or {} error_name = self.__class__.__name__ - if session_id_hex: + if host_url: from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory telemetry_client = TelemetryClientFactory.get_telemetry_client( - session_id_hex + host_url=host_url + ) + telemetry_client.export_failure_log( + error_name, self.message, session_id=session_id_hex ) - telemetry_client.export_failure_log(error_name, self.message) def __str__(self): return self.message diff --git a/src/databricks/sql/telemetry/circuit_breaker_manager.py b/src/databricks/sql/telemetry/circuit_breaker_manager.py index 852f0d916..a5df7371e 100644 --- a/src/databricks/sql/telemetry/circuit_breaker_manager.py +++ b/src/databricks/sql/telemetry/circuit_breaker_manager.py @@ -60,16 +60,16 @@ def state_change(self, cb: CircuitBreaker, old_state, new_state) -> None: old_state_name = old_state.name if old_state else "None" new_state_name = new_state.name if new_state else "None" - logger.info( + logger.debug( LOG_CIRCUIT_BREAKER_STATE_CHANGED, old_state_name, new_state_name, cb.name ) if new_state_name == CIRCUIT_BREAKER_STATE_OPEN: - logger.warning(LOG_CIRCUIT_BREAKER_OPENED, cb.name) + logger.debug(LOG_CIRCUIT_BREAKER_OPENED, cb.name) elif new_state_name == CIRCUIT_BREAKER_STATE_CLOSED: - logger.info(LOG_CIRCUIT_BREAKER_CLOSED, cb.name) + logger.debug(LOG_CIRCUIT_BREAKER_CLOSED, cb.name) elif new_state_name == CIRCUIT_BREAKER_STATE_HALF_OPEN: - logger.info(LOG_CIRCUIT_BREAKER_HALF_OPEN, cb.name) + logger.debug(LOG_CIRCUIT_BREAKER_HALF_OPEN, cb.name) class CircuitBreakerManager: diff --git a/src/databricks/sql/telemetry/latency_logger.py b/src/databricks/sql/telemetry/latency_logger.py index 36ebee2b8..2445c25c2 100644 --- a/src/databricks/sql/telemetry/latency_logger.py +++ b/src/databricks/sql/telemetry/latency_logger.py @@ -205,13 +205,14 @@ def wrapper(self, *args, **kwargs): telemetry_client = ( TelemetryClientFactory.get_telemetry_client( - session_id_hex + host_url=connection.session.host ) ) telemetry_client.export_latency_log( latency_ms=duration_ms, sql_execution_event=sql_exec_event, sql_statement_id=telemetry_data.get("statement_id"), + session_id=session_id_hex, ) return wrapper diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index d5f5b575c..77d1a2f9c 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -147,13 +147,17 @@ def __new__(cls): cls._instance = super(NoopTelemetryClient, cls).__new__(cls) return cls._instance - def export_initial_telemetry_log(self, driver_connection_params, user_agent): + def export_initial_telemetry_log( + self, driver_connection_params, user_agent, session_id=None + ): pass - def export_failure_log(self, error_name, error_message): + def export_failure_log(self, error_name, error_message, session_id=None): pass - def export_latency_log(self, latency_ms, sql_execution_event, sql_statement_id): + def export_latency_log( + self, latency_ms, sql_execution_event, sql_statement_id, session_id=None + ): pass def close(self): @@ -307,7 +311,7 @@ def _send_with_unified_client(self, url, data, headers, timeout=900): ) return response except Exception as e: - logger.error("Failed to send telemetry with unified client: %s", e) + logger.debug("Failed to send telemetry with unified client: %s", e) raise def _telemetry_request_callback(self, future, sent_count: int): @@ -352,19 +356,22 @@ def _telemetry_request_callback(self, future, sent_count: int): except Exception as e: logger.debug("Telemetry request failed with exception: %s", e) - def _export_telemetry_log(self, **telemetry_event_kwargs): + def _export_telemetry_log(self, session_id=None, **telemetry_event_kwargs): """ Common helper method for exporting telemetry logs. Args: + session_id: Optional session ID for this event. If not provided, uses the client's session ID. **telemetry_event_kwargs: Keyword arguments to pass to TelemetryEvent constructor """ - logger.debug("Exporting telemetry log for connection %s", self._session_id_hex) + # Use provided session_id or fall back to client's session_id + actual_session_id = session_id or self._session_id_hex + logger.debug("Exporting telemetry log for connection %s", actual_session_id) try: # Set common fields for all telemetry events event_kwargs = { - "session_id": self._session_id_hex, + "session_id": actual_session_id, "system_configuration": TelemetryHelper.get_driver_system_configuration(), "driver_connection_params": self._driver_connection_params, } @@ -387,17 +394,22 @@ def _export_telemetry_log(self, **telemetry_event_kwargs): except Exception as e: logger.debug("Failed to export telemetry log: %s", e) - def export_initial_telemetry_log(self, driver_connection_params, user_agent): + def export_initial_telemetry_log( + self, driver_connection_params, user_agent, session_id=None + ): self._driver_connection_params = driver_connection_params self._user_agent = user_agent - self._export_telemetry_log() + self._export_telemetry_log(session_id=session_id) - def export_failure_log(self, error_name, error_message): + def export_failure_log(self, error_name, error_message, session_id=None): error_info = DriverErrorInfo(error_name=error_name, stack_trace=error_message) - self._export_telemetry_log(error_info=error_info) + self._export_telemetry_log(session_id=session_id, error_info=error_info) - def export_latency_log(self, latency_ms, sql_execution_event, sql_statement_id): + def export_latency_log( + self, latency_ms, sql_execution_event, sql_statement_id, session_id=None + ): self._export_telemetry_log( + session_id=session_id, sql_statement_id=sql_statement_id, sql_operation=sql_execution_event, operation_latency_ms=latency_ms, @@ -409,15 +421,39 @@ def close(self): self._flush() +class _TelemetryClientHolder: + """ + Holds a telemetry client with reference counting. + Multiple connections to the same host share one client. + """ + + def __init__(self, client: BaseTelemetryClient): + self.client = client + self.refcount = 1 + + def increment(self): + """Increment reference count when a new connection uses this client""" + self.refcount += 1 + + def decrement(self): + """Decrement reference count when a connection closes""" + self.refcount -= 1 + return self.refcount + + class TelemetryClientFactory: """ Static factory class for creating and managing telemetry clients. It uses a thread pool to handle asynchronous operations and a single flush thread for all clients. + + Clients are shared at the HOST level - multiple connections to the same host + share a single TelemetryClient to enable efficient batching and reduce load + on the telemetry endpoint. """ _clients: Dict[ - str, BaseTelemetryClient - ] = {} # Map of session_id_hex -> BaseTelemetryClient + str, _TelemetryClientHolder + ] = {} # Map of host_url -> TelemetryClientHolder _executor: Optional[ThreadPoolExecutor] = None _initialized: bool = False _lock = threading.RLock() # Thread safety for factory operations @@ -431,6 +467,22 @@ class TelemetryClientFactory: _flush_interval_seconds = 300 # 5 minutes DEFAULT_BATCH_SIZE = 100 + UNKNOWN_HOST = "unknown-host" + + @staticmethod + def getHostUrlSafely(host_url): + """ + Safely get host URL with fallback to UNKNOWN_HOST. + + Args: + host_url: The host URL to validate + + Returns: + The host_url if valid, otherwise UNKNOWN_HOST + """ + if not host_url or not isinstance(host_url, str) or not host_url.strip(): + return TelemetryClientFactory.UNKNOWN_HOST + return host_url @classmethod def _initialize(cls): @@ -464,8 +516,8 @@ def _flush_worker(cls): with cls._lock: clients_to_flush = list(cls._clients.values()) - for client in clients_to_flush: - client._flush() + for holder in clients_to_flush: + holder.client._flush() @classmethod def _stop_flush_thread(cls): @@ -506,21 +558,38 @@ def initialize_telemetry_client( batch_size, client_context, ): - """Initialize a telemetry client for a specific connection if telemetry is enabled""" + """ + Initialize a telemetry client for a specific connection if telemetry is enabled. + + Clients are shared at the HOST level - multiple connections to the same host + will share a single TelemetryClient with reference counting. + """ try: + # Safely get host_url with fallback to UNKNOWN_HOST + host_url = TelemetryClientFactory.getHostUrlSafely(host_url) with TelemetryClientFactory._lock: TelemetryClientFactory._initialize() - if session_id_hex not in TelemetryClientFactory._clients: + if host_url in TelemetryClientFactory._clients: + # Reuse existing client for this host + holder = TelemetryClientFactory._clients[host_url] + holder.increment() logger.debug( - "Creating new TelemetryClient for connection %s", + "Reusing TelemetryClient for host %s (session %s, refcount=%d)", + host_url, + session_id_hex, + holder.refcount, + ) + else: + # Create new client for this host + logger.debug( + "Creating new TelemetryClient for host %s (session %s)", + host_url, session_id_hex, ) if telemetry_enabled: - TelemetryClientFactory._clients[ - session_id_hex - ] = TelemetryClient( + client = TelemetryClient( telemetry_enabled=telemetry_enabled, session_id_hex=session_id_hex, auth_provider=auth_provider, @@ -529,36 +598,73 @@ def initialize_telemetry_client( batch_size=batch_size, client_context=client_context, ) + TelemetryClientFactory._clients[ + host_url + ] = _TelemetryClientHolder(client) else: TelemetryClientFactory._clients[ - session_id_hex - ] = NoopTelemetryClient() + host_url + ] = _TelemetryClientHolder(NoopTelemetryClient()) except Exception as e: logger.debug("Failed to initialize telemetry client: %s", e) # Fallback to NoopTelemetryClient to ensure connection doesn't fail - TelemetryClientFactory._clients[session_id_hex] = NoopTelemetryClient() + TelemetryClientFactory._clients[host_url] = _TelemetryClientHolder( + NoopTelemetryClient() + ) @staticmethod - def get_telemetry_client(session_id_hex): - """Get the telemetry client for a specific connection""" - return TelemetryClientFactory._clients.get( - session_id_hex, NoopTelemetryClient() - ) + def get_telemetry_client(host_url): + """ + Get the shared telemetry client for a specific host. + + Args: + host_url: The host URL to look up the client. If None/empty, uses UNKNOWN_HOST. + + Returns: + The shared TelemetryClient for this host, or NoopTelemetryClient if not found + """ + host_url = TelemetryClientFactory.getHostUrlSafely(host_url) + + if host_url in TelemetryClientFactory._clients: + return TelemetryClientFactory._clients[host_url].client + return NoopTelemetryClient() @staticmethod - def close(session_id_hex): - """Close and remove the telemetry client for a specific connection""" + def close(host_url): + """ + Close the telemetry client for a specific host. + + Decrements the reference count for the host's client. Only actually closes + the client when the reference count reaches zero (all connections to this host closed). + + Args: + host_url: The host URL whose client to close. If None/empty, uses UNKNOWN_HOST. + """ + host_url = TelemetryClientFactory.getHostUrlSafely(host_url) with TelemetryClientFactory._lock: - if ( - telemetry_client := TelemetryClientFactory._clients.pop( - session_id_hex, None - ) - ) is not None: + # Get the holder for this host + holder = TelemetryClientFactory._clients.get(host_url) + if holder is None: + logger.debug("No telemetry client found for host %s", host_url) + return + + # Decrement refcount + remaining_refs = holder.decrement() + logger.debug( + "Decremented refcount for host %s (refcount=%d)", + host_url, + remaining_refs, + ) + + # Only close if no more references + if remaining_refs <= 0: logger.debug( - "Removing telemetry client for connection %s", session_id_hex + "Closing telemetry client for host %s (no more references)", + host_url, ) - telemetry_client.close() + TelemetryClientFactory._clients.pop(host_url, None) + holder.client.close() # Shutdown executor if no more clients if not TelemetryClientFactory._clients and TelemetryClientFactory._executor: @@ -597,7 +703,7 @@ def connection_failure_log( ) telemetry_client = TelemetryClientFactory.get_telemetry_client( - UNAUTH_DUMMY_SESSION_ID + host_url=host_url ) telemetry_client._driver_connection_params = DriverConnectionParameters( http_path=http_path, diff --git a/src/databricks/sql/telemetry/telemetry_push_client.py b/src/databricks/sql/telemetry/telemetry_push_client.py index 461a57738..e77910007 100644 --- a/src/databricks/sql/telemetry/telemetry_push_client.py +++ b/src/databricks/sql/telemetry/telemetry_push_client.py @@ -120,7 +120,7 @@ def _make_request_and_check_status( # Check for rate limiting or service unavailable if response.status in [429, 503]: - logger.warning( + logger.debug( "Telemetry endpoint returned %d for host %s, triggering circuit breaker", response.status, self._host, diff --git a/tests/e2e/test_concurrent_telemetry.py b/tests/e2e/test_concurrent_telemetry.py index d2ac4227d..546a2b8b2 100644 --- a/tests/e2e/test_concurrent_telemetry.py +++ b/tests/e2e/test_concurrent_telemetry.py @@ -41,6 +41,7 @@ def telemetry_setup_teardown(self): TelemetryClientFactory._executor.shutdown(wait=True) TelemetryClientFactory._executor = None TelemetryClientFactory._stop_flush_thread() + TelemetryClientFactory._clients.clear() TelemetryClientFactory._initialized = False def test_concurrent_queries_sends_telemetry(self): diff --git a/tests/unit/test_circuit_breaker_manager.py b/tests/unit/test_circuit_breaker_manager.py index e8ed4e809..1e02556d9 100644 --- a/tests/unit/test_circuit_breaker_manager.py +++ b/tests/unit/test_circuit_breaker_manager.py @@ -157,4 +157,4 @@ def test_circuit_breaker_state_listener_transitions(self, old_state, new_state): with patch("databricks.sql.telemetry.circuit_breaker_manager.logger") as mock_logger: listener.state_change(mock_cb, mock_old_state, mock_new_state) - mock_logger.info.assert_called() + mock_logger.debug.assert_called() diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index b515756e8..8f8a97eae 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -714,7 +714,7 @@ def test_autocommit_setter_wraps_database_error(self, mock_session_class): server_error = DatabaseError( "AUTOCOMMIT_SET_DURING_ACTIVE_TRANSACTION", context={"sql_state": "25000"}, - session_id_hex="test-session-id", + host_url="test-host", ) mock_cursor.execute.side_effect = server_error @@ -737,7 +737,7 @@ def test_autocommit_setter_preserves_exception_chain(self, mock_session_class): mock_cursor = Mock() original_error = DatabaseError( - "Original error", session_id_hex="test-session-id" + "Original error", host_url="test-host" ) mock_cursor.execute.side_effect = original_error @@ -772,7 +772,7 @@ def test_commit_wraps_database_error(self, mock_session_class): server_error = DatabaseError( "MULTI_STATEMENT_TRANSACTION_NO_ACTIVE_TRANSACTION", context={"sql_state": "25000"}, - session_id_hex="test-session-id", + host_url="test-host", ) mock_cursor.execute.side_effect = server_error @@ -822,7 +822,7 @@ def test_rollback_wraps_database_error(self, mock_session_class): server_error = DatabaseError( "Unexpected rollback error", context={"sql_state": "HY000"}, - session_id_hex="test-session-id", + host_url="test-host", ) mock_cursor.execute.side_effect = server_error diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 96a2f87d8..e9fa16649 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -249,13 +249,13 @@ def test_client_lifecycle_flow(self): client_context=client_context, ) - client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + client = TelemetryClientFactory.get_telemetry_client("test-host.com") assert isinstance(client, TelemetryClient) assert client._session_id_hex == session_id_hex # Close client with patch.object(client, "close") as mock_close: - TelemetryClientFactory.close(session_id_hex) + TelemetryClientFactory.close(host_url="test-host.com") mock_close.assert_called_once() # Should get NoopTelemetryClient after close @@ -274,7 +274,7 @@ def test_disabled_telemetry_creates_noop_client(self): client_context=client_context, ) - client = TelemetryClientFactory.get_telemetry_client(session_id_hex) + client = TelemetryClientFactory.get_telemetry_client("test-host.com") assert isinstance(client, NoopTelemetryClient) def test_factory_error_handling(self): @@ -297,7 +297,7 @@ def test_factory_error_handling(self): ) # Should fall back to NoopTelemetryClient - client = TelemetryClientFactory.get_telemetry_client(session_id) + client = TelemetryClientFactory.get_telemetry_client("test-host.com") assert isinstance(client, NoopTelemetryClient) def test_factory_shutdown_flow(self): @@ -325,11 +325,11 @@ def test_factory_shutdown_flow(self): assert TelemetryClientFactory._executor is not None # Close first client - factory should stay initialized - TelemetryClientFactory.close(session1) + TelemetryClientFactory.close(host_url="test-host.com") assert TelemetryClientFactory._initialized is True # Close second client - factory should shut down - TelemetryClientFactory.close(session2) + TelemetryClientFactory.close(host_url="test-host.com") assert TelemetryClientFactory._initialized is False assert TelemetryClientFactory._executor is None @@ -367,6 +367,13 @@ def test_connection_failure_sends_correct_telemetry_payload( class TestTelemetryFeatureFlag: """Tests the interaction between the telemetry feature flag and connection parameters.""" + def teardown_method(self): + """Clean up telemetry factory state after each test to prevent test pollution.""" + from databricks.sql.common.feature_flag import FeatureFlagsContextFactory + + TelemetryClientFactory._clients.clear() + FeatureFlagsContextFactory._context_map.clear() + def _mock_ff_response(self, mock_http_request, enabled: bool): """Helper method to mock feature flag response for unified HTTP client.""" mock_response = MagicMock() @@ -391,6 +398,7 @@ def test_telemetry_enabled_when_flag_is_true(self, mock_http_request, MockSessio self._mock_ff_response(mock_http_request, enabled=True) mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-true" + mock_session_instance.host = "test-host" # Set host for telemetry client lookup mock_session_instance.auth_provider = AccessTokenAuthProvider("token") mock_session_instance.is_open = ( False # Connection starts closed for test cleanup @@ -410,7 +418,7 @@ def test_telemetry_enabled_when_flag_is_true(self, mock_http_request, MockSessio assert conn.telemetry_enabled is True mock_http_request.assert_called_once() - client = TelemetryClientFactory.get_telemetry_client("test-session-ff-true") + client = TelemetryClientFactory.get_telemetry_client("test-host") assert isinstance(client, TelemetryClient) @patch("databricks.sql.common.unified_http_client.UnifiedHttpClient.request") @@ -421,6 +429,7 @@ def test_telemetry_disabled_when_flag_is_false( self._mock_ff_response(mock_http_request, enabled=False) mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-false" + mock_session_instance.host = "test-host" # Set host for telemetry client lookup mock_session_instance.auth_provider = AccessTokenAuthProvider("token") mock_session_instance.is_open = ( False # Connection starts closed for test cleanup @@ -440,7 +449,7 @@ def test_telemetry_disabled_when_flag_is_false( assert conn.telemetry_enabled is False mock_http_request.assert_called_once() - client = TelemetryClientFactory.get_telemetry_client("test-session-ff-false") + client = TelemetryClientFactory.get_telemetry_client("test-host") assert isinstance(client, NoopTelemetryClient) @patch("databricks.sql.common.unified_http_client.UnifiedHttpClient.request") @@ -451,6 +460,7 @@ def test_telemetry_disabled_when_flag_request_fails( mock_http_request.side_effect = Exception("Network is down") mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-fail" + mock_session_instance.host = "test-host" # Set host for telemetry client lookup mock_session_instance.auth_provider = AccessTokenAuthProvider("token") mock_session_instance.is_open = ( False # Connection starts closed for test cleanup @@ -470,7 +480,7 @@ def test_telemetry_disabled_when_flag_request_fails( assert conn.telemetry_enabled is False mock_http_request.assert_called_once() - client = TelemetryClientFactory.get_telemetry_client("test-session-ff-fail") + client = TelemetryClientFactory.get_telemetry_client("test-host") assert isinstance(client, NoopTelemetryClient) diff --git a/tests/unit/test_telemetry_push_client.py b/tests/unit/test_telemetry_push_client.py index 0e9455e1f..6555f1d02 100644 --- a/tests/unit/test_telemetry_push_client.py +++ b/tests/unit/test_telemetry_push_client.py @@ -114,10 +114,10 @@ def test_rate_limit_error_logging(self): with pytest.raises(TelemetryRateLimitError): self.client.request(HttpMethod.POST, "https://test.com", {}) - mock_logger.warning.assert_called() - warning_args = mock_logger.warning.call_args[0] - assert "429" in str(warning_args) - assert "circuit breaker" in warning_args[0] + mock_logger.debug.assert_called() + debug_args = mock_logger.debug.call_args[0] + assert "429" in str(debug_args) + assert "circuit breaker" in debug_args[0] def test_other_error_logging(self): """Test that other errors are logged during wrapping/unwrapping.""" From d2ae1e8fe6e1f4ad5dba13751cd0f91763aea9bf Mon Sep 17 00:00:00 2001 From: Samikshya Chand <148681192+samikshya-db@users.noreply.github.com> Date: Thu, 4 Dec 2025 16:21:20 +0530 Subject: [PATCH 6/6] Prepare for a release with telemetry on by default (#717) * Prepare for a release with telemetry on by default Signed-off-by: samikshya-chand_data * Make edits Signed-off-by: samikshya-chand_data * Update version Signed-off-by: samikshya-chand_data * Fix CHANGELOG formatting to match previous style Signed-off-by: samikshya-chand_data * Fix telemetry e2e tests for default-enabled behavior - Update test expectations to reflect telemetry being enabled by default - Add feature flags cache cleanup in teardown to prevent state leakage between tests - This ensures each test runs with fresh feature flag state * Add wait after connection close for async telemetry submission * Remove debug logging from telemetry tests * Mark telemetry e2e tests as serial - must not run in parallel Root cause: Telemetry tests share host-level client across pytest-xdist workers, causing test isolation issues with patches. Tests pass serially but fail with -n auto. Solution: Add @pytest.mark.serial marker. CI needs to run these separately without -n auto. * Split test execution to run serial tests separately Telemetry e2e tests must run serially due to shared host-level telemetry client across pytest-xdist workers. Running with -n auto causes test isolation issues where futures aren't properly captured. Changes: - Run parallel tests with -m 'not serial' -n auto - Run serial tests with -m 'serial' without parallelization - Use --cov-append for serial tests to combine coverage - Mark telemetry e2e tests with @pytest.mark.serial - Update test expectations for default telemetry behavior - Add feature flags cache cleanup in test teardown * Mark telemetry e2e tests as serial - must not run in parallel The concurrent telemetry e2e test globally patches telemetry methods to capture events. When run in parallel with other tests via pytest-xdist, it captures telemetry events from other concurrent tests, causing assertion failures (expected 60 events, got 88). All telemetry e2e tests must run serially to avoid cross-test interference with the shared host-level telemetry client. --------- Signed-off-by: samikshya-chand_data --- .github/workflows/code-coverage.yml | 19 +++++++++++++++++-- CHANGELOG.md | 7 +++++++ README.md | 2 +- pyproject.toml | 7 +++++-- src/databricks/sql/__init__.py | 2 +- src/databricks/sql/auth/common.py | 2 +- src/databricks/sql/client.py | 2 +- tests/e2e/test_concurrent_telemetry.py | 1 + tests/e2e/test_telemetry_e2e.py | 17 ++++++++++++++--- 9 files changed, 48 insertions(+), 11 deletions(-) diff --git a/.github/workflows/code-coverage.yml b/.github/workflows/code-coverage.yml index d9954d051..3c76be728 100644 --- a/.github/workflows/code-coverage.yml +++ b/.github/workflows/code-coverage.yml @@ -61,17 +61,32 @@ jobs: - name: Install library run: poetry install --no-interaction --all-extras #---------------------------------------------- - # run all tests with coverage + # run parallel tests with coverage #---------------------------------------------- - - name: Run all tests with coverage + - name: Run parallel tests with coverage continue-on-error: false run: | poetry run pytest tests/unit tests/e2e \ + -m "not serial" \ -n auto \ --cov=src \ --cov-report=xml \ --cov-report=term \ -v + + #---------------------------------------------- + # run serial tests with coverage + #---------------------------------------------- + - name: Run serial tests with coverage + continue-on-error: false + run: | + poetry run pytest tests/e2e \ + -m "serial" \ + --cov=src \ + --cov-append \ + --cov-report=xml \ + --cov-report=term \ + -v #---------------------------------------------- # check for coverage override diff --git a/CHANGELOG.md b/CHANGELOG.md index 5b902e976..6be2dacaa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ # Release History +# 4.2.2 (2025-12-01) +- Change default use_hybrid_disposition to False (databricks/databricks-sql-python#714 by @samikshya-db) +- Circuit breaker changes using pybreaker (databricks/databricks-sql-python#705 by @nikhilsuri-db) +- perf: Optimize telemetry latency logging to reduce overhead (databricks/databricks-sql-python#715 by @samikshya-db) +- basic e2e test for force telemetry verification (databricks/databricks-sql-python#708 by @nikhilsuri-db) +- Telemetry is ON by default to track connection stats. (Note : This strictly excludes PII, query text, and results) (databricks/databricks-sql-python#717 by @samikshya-db) + # 4.2.1 (2025-11-20) - Ignore transactions by default (databricks/databricks-sql-python#711 by @jayantsing-db) diff --git a/README.md b/README.md index ec82a3637..047515ba4 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ You are welcome to file an issue here for general use cases. You can also contac ## Requirements -Python 3.8 or above is required. +Python 3.9 or above is required. ## Documentation diff --git a/pyproject.toml b/pyproject.toml index 61c248e98..d2739c7d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "databricks-sql-connector" -version = "4.2.1" +version = "4.2.2" description = "Databricks SQL Connector for Python" authors = ["Databricks "] license = "Apache-2.0" @@ -62,7 +62,10 @@ exclude = ['ttypes\.py$', 'TCLIService\.py$'] exclude = '/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist|thrift_api)/' [tool.pytest.ini_options] -markers = {"reviewed" = "Test case has been reviewed by Databricks"} +markers = [ + "reviewed: Test case has been reviewed by Databricks", + "serial: Tests that must run serially (not parallelized)" +] minversion = "6.0" log_cli = "false" log_cli_level = "INFO" diff --git a/src/databricks/sql/__init__.py b/src/databricks/sql/__init__.py index cd37e6ce1..7cf631e83 100644 --- a/src/databricks/sql/__init__.py +++ b/src/databricks/sql/__init__.py @@ -71,7 +71,7 @@ def __repr__(self): DATE = DBAPITypeObject("date") ROWID = DBAPITypeObject() -__version__ = "4.2.1" +__version__ = "4.2.2" USER_AGENT_NAME = "PyDatabricksSqlConnector" # These two functions are pyhive legacy diff --git a/src/databricks/sql/auth/common.py b/src/databricks/sql/auth/common.py index a764b036d..0e3a01918 100644 --- a/src/databricks/sql/auth/common.py +++ b/src/databricks/sql/auth/common.py @@ -51,7 +51,7 @@ def __init__( pool_connections: Optional[int] = None, pool_maxsize: Optional[int] = None, user_agent: Optional[str] = None, - telemetry_circuit_breaker_enabled: Optional[bool] = None, + telemetry_circuit_breaker_enabled: Optional[bool] = True, ): self.hostname = hostname self.access_token = access_token diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index c873700bc..1f17d54f2 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -328,7 +328,7 @@ def read(self) -> Optional[OAuthToken]: self.ignore_transactions = ignore_transactions self.force_enable_telemetry = kwargs.get("force_enable_telemetry", False) - self.enable_telemetry = kwargs.get("enable_telemetry", False) + self.enable_telemetry = kwargs.get("enable_telemetry", True) self.telemetry_enabled = TelemetryHelper.is_telemetry_enabled(self) TelemetryClientFactory.initialize_telemetry_client( diff --git a/tests/e2e/test_concurrent_telemetry.py b/tests/e2e/test_concurrent_telemetry.py index 546a2b8b2..bed348c2c 100644 --- a/tests/e2e/test_concurrent_telemetry.py +++ b/tests/e2e/test_concurrent_telemetry.py @@ -26,6 +26,7 @@ def run_in_threads(target, num_threads, pass_index=False): t.join() +@pytest.mark.serial class TestE2ETelemetry(PySQLPytestTestCase): @pytest.fixture(autouse=True) def telemetry_setup_teardown(self): diff --git a/tests/e2e/test_telemetry_e2e.py b/tests/e2e/test_telemetry_e2e.py index 917c8e5eb..0a57edd3c 100644 --- a/tests/e2e/test_telemetry_e2e.py +++ b/tests/e2e/test_telemetry_e2e.py @@ -43,8 +43,9 @@ def connection(self, extra_params=()): conn.close() +@pytest.mark.serial class TestTelemetryE2E(TelemetryTestBase): - """E2E tests for telemetry scenarios""" + """E2E tests for telemetry scenarios - must run serially due to shared host-level telemetry client""" @pytest.fixture(autouse=True) def telemetry_setup_teardown(self): @@ -58,6 +59,14 @@ def telemetry_setup_teardown(self): TelemetryClientFactory._stop_flush_thread() TelemetryClientFactory._initialized = False + # Clear feature flags cache to prevent state leakage between tests + from databricks.sql.common.feature_flag import FeatureFlagsContextFactory + with FeatureFlagsContextFactory._lock: + FeatureFlagsContextFactory._context_map.clear() + if FeatureFlagsContextFactory._executor: + FeatureFlagsContextFactory._executor.shutdown(wait=False) + FeatureFlagsContextFactory._executor = None + @pytest.fixture def telemetry_interceptors(self): """Setup reusable telemetry interceptors as a fixture""" @@ -142,7 +151,7 @@ def verify_events(self, captured_events, captured_futures, expected_count): else: assert len(captured_events) == expected_count, \ f"Expected {expected_count} events, got {len(captured_events)}" - + time.sleep(2) done, _ = wait(captured_futures, timeout=10) assert len(done) == expected_count, \ @@ -163,7 +172,7 @@ def verify_events(self, captured_events, captured_futures, expected_count): (True, False, 2, "enable_on_force_off"), (False, True, 2, "enable_off_force_on"), (False, False, 0, "both_off"), - (None, None, 0, "default_behavior"), + (None, None, 2, "default_behavior"), ]) def test_telemetry_flags(self, telemetry_interceptors, enable_telemetry, force_enable, expected_count, test_id): @@ -185,6 +194,8 @@ def test_telemetry_flags(self, telemetry_interceptors, enable_telemetry, cursor.execute("SELECT 1") cursor.fetchone() + # Give time for async telemetry submission after connection closes + time.sleep(0.5) self.verify_events(captured_events, captured_futures, expected_count) # Assert statement execution on latency event (if events exist)