|
| 1 | +""" |
| 2 | +Circuit breaker implementation for telemetry requests. |
| 3 | +
|
| 4 | +This module provides circuit breaker functionality to prevent telemetry failures |
| 5 | +from impacting the main SQL operations. It uses pybreaker library to implement |
| 6 | +the circuit breaker pattern. |
| 7 | +""" |
| 8 | + |
| 9 | +import logging |
| 10 | +import threading |
| 11 | +from typing import Dict |
| 12 | + |
| 13 | +import pybreaker |
| 14 | +from pybreaker import CircuitBreaker, CircuitBreakerError, CircuitBreakerListener |
| 15 | + |
| 16 | +from databricks.sql.exc import TelemetryNonRateLimitError |
| 17 | + |
| 18 | +logger = logging.getLogger(__name__) |
| 19 | + |
| 20 | +# Circuit Breaker Constants |
| 21 | +MINIMUM_CALLS = 20 # Number of failures before circuit opens |
| 22 | +RESET_TIMEOUT = 30 # Seconds to wait before trying to close circuit |
| 23 | +NAME_PREFIX = "telemetry-circuit-breaker" |
| 24 | + |
| 25 | +# Circuit Breaker State Constants (used in logging) |
| 26 | +CIRCUIT_BREAKER_STATE_OPEN = "open" |
| 27 | +CIRCUIT_BREAKER_STATE_CLOSED = "closed" |
| 28 | +CIRCUIT_BREAKER_STATE_HALF_OPEN = "half-open" |
| 29 | + |
| 30 | +# Logging Message Constants |
| 31 | +LOG_CIRCUIT_BREAKER_STATE_CHANGED = "Circuit breaker state changed from %s to %s for %s" |
| 32 | +LOG_CIRCUIT_BREAKER_OPENED = ( |
| 33 | + "Circuit breaker opened for %s - telemetry requests will be blocked" |
| 34 | +) |
| 35 | +LOG_CIRCUIT_BREAKER_CLOSED = ( |
| 36 | + "Circuit breaker closed for %s - telemetry requests will be allowed" |
| 37 | +) |
| 38 | +LOG_CIRCUIT_BREAKER_HALF_OPEN = ( |
| 39 | + "Circuit breaker half-open for %s - testing telemetry requests" |
| 40 | +) |
| 41 | + |
| 42 | + |
| 43 | +class CircuitBreakerStateListener(CircuitBreakerListener): |
| 44 | + """Listener for circuit breaker state changes.""" |
| 45 | + |
| 46 | + def before_call(self, cb: CircuitBreaker, func, *args, **kwargs) -> None: |
| 47 | + """Called before the circuit breaker calls a function.""" |
| 48 | + pass |
| 49 | + |
| 50 | + def failure(self, cb: CircuitBreaker, exc: BaseException) -> None: |
| 51 | + """Called when a function called by the circuit breaker fails.""" |
| 52 | + pass |
| 53 | + |
| 54 | + def success(self, cb: CircuitBreaker) -> None: |
| 55 | + """Called when a function called by the circuit breaker succeeds.""" |
| 56 | + pass |
| 57 | + |
| 58 | + def state_change(self, cb: CircuitBreaker, old_state, new_state) -> None: |
| 59 | + """Called when the circuit breaker state changes.""" |
| 60 | + old_state_name = old_state.name if old_state else "None" |
| 61 | + new_state_name = new_state.name if new_state else "None" |
| 62 | + |
| 63 | + logger.info( |
| 64 | + LOG_CIRCUIT_BREAKER_STATE_CHANGED, old_state_name, new_state_name, cb.name |
| 65 | + ) |
| 66 | + |
| 67 | + if new_state_name == CIRCUIT_BREAKER_STATE_OPEN: |
| 68 | + logger.warning(LOG_CIRCUIT_BREAKER_OPENED, cb.name) |
| 69 | + elif new_state_name == CIRCUIT_BREAKER_STATE_CLOSED: |
| 70 | + logger.info(LOG_CIRCUIT_BREAKER_CLOSED, cb.name) |
| 71 | + elif new_state_name == CIRCUIT_BREAKER_STATE_HALF_OPEN: |
| 72 | + logger.info(LOG_CIRCUIT_BREAKER_HALF_OPEN, cb.name) |
| 73 | + |
| 74 | + |
| 75 | +class CircuitBreakerManager: |
| 76 | + """ |
| 77 | + Manages circuit breaker instances for telemetry requests. |
| 78 | +
|
| 79 | + Creates and caches circuit breaker instances per host to ensure telemetry |
| 80 | + failures don't impact main SQL operations. |
| 81 | + """ |
| 82 | + |
| 83 | + _instances: Dict[str, CircuitBreaker] = {} |
| 84 | + _lock = threading.RLock() |
| 85 | + |
| 86 | + @classmethod |
| 87 | + def get_circuit_breaker(cls, host: str) -> CircuitBreaker: |
| 88 | + """ |
| 89 | + Get or create a circuit breaker instance for the specified host. |
| 90 | +
|
| 91 | + Args: |
| 92 | + host: The hostname for which to get the circuit breaker |
| 93 | +
|
| 94 | + Returns: |
| 95 | + CircuitBreaker instance for the host |
| 96 | + """ |
| 97 | + with cls._lock: |
| 98 | + if host not in cls._instances: |
| 99 | + breaker = CircuitBreaker( |
| 100 | + fail_max=MINIMUM_CALLS, |
| 101 | + reset_timeout=RESET_TIMEOUT, |
| 102 | + name=f"{NAME_PREFIX}-{host}", |
| 103 | + exclude=[ |
| 104 | + TelemetryNonRateLimitError |
| 105 | + ], # Don't count these as failures |
| 106 | + ) |
| 107 | + # Add state change listener for logging |
| 108 | + breaker.add_listener(CircuitBreakerStateListener()) |
| 109 | + cls._instances[host] = breaker |
| 110 | + logger.debug("Created circuit breaker for host: %s", host) |
| 111 | + |
| 112 | + return cls._instances[host] |
0 commit comments