Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions aws_lambda_powertools/shared/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,6 @@

# Idempotency constants
IDEMPOTENCY_DISABLED_ENV: str = "POWERTOOLS_IDEMPOTENCY_DISABLED"

# Circuit breaker constants
CIRCUIT_BREAKER_DISABLED_ENV: str = "POWERTOOLS_CIRCUIT_BREAKER_DISABLED"
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""
Circuit Breaker utility for protecting unhealthy downstream dependencies.

!!! warning "Alpha / experimental"
This utility is published under the `_alpha` namespace while we collect
feedback. The public API may change in a backwards-incompatible way before it
is promoted to GA. Pin your version and follow the tracking discussion before
relying on it in production.
"""

from aws_lambda_powertools.utilities.circuit_breaker_alpha.circuit_breaker import circuit_breaker
from aws_lambda_powertools.utilities.circuit_breaker_alpha.config import CircuitBreakerConfig
from aws_lambda_powertools.utilities.circuit_breaker_alpha.exceptions import (
CircuitBreakerConfigError,
CircuitBreakerError,
CircuitBreakerOpenError,
CircuitBreakerPersistenceError,
)
from aws_lambda_powertools.utilities.circuit_breaker_alpha.states import (
CircuitInfo,
CircuitState,
CircuitTransition,
)

__all__ = (
"circuit_breaker",
"CircuitBreakerConfig",
"CircuitInfo",
"CircuitState",
"CircuitTransition",
"CircuitBreakerError",
"CircuitBreakerOpenError",
"CircuitBreakerConfigError",
"CircuitBreakerPersistenceError",
)
248 changes: 248 additions & 0 deletions aws_lambda_powertools/utilities/circuit_breaker_alpha/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
"""
Orchestrator for the Circuit Breaker utility.

:class:`CircuitBreakerHandler` owns the state machine and the per-environment failure
counter; the persistence layer owns the shared truth. This split keeps the healthy
path write-free: failures are counted locally and only persisted on a state transition.
"""

from __future__ import annotations

import datetime
import logging
import uuid
from typing import TYPE_CHECKING, Any

from aws_lambda_powertools.utilities.circuit_breaker_alpha.exceptions import CircuitBreakerOpenError
from aws_lambda_powertools.utilities.circuit_breaker_alpha.states import CircuitState, CircuitTransition

if TYPE_CHECKING:
from collections.abc import Callable

from aws_lambda_powertools.utilities.circuit_breaker_alpha.config import CircuitBreakerConfig
from aws_lambda_powertools.utilities.circuit_breaker_alpha.persistence.base import (
CircuitBreakerPersistenceLayer,
)
from aws_lambda_powertools.utilities.circuit_breaker_alpha.states import CircuitInfo

logger = logging.getLogger(__name__)

# Per-environment, per-circuit consecutive counters. Module-level so they survive across
# invocations within the same execution environment, the same way idempotency caches do.
_LOCAL_FAILURES: dict[str, int] = {}
_LOCAL_SUCCESSES: dict[str, int] = {}

# Tracks the last state this environment observed from the store, per circuit. Used to
# detect transitions back to CLOSED that happened externally (another env tripped and
# recovered), so stale local failure streaks can be invalidated.
_LAST_OBSERVED_STATE: dict[str, CircuitState] = {}

# Stable per-environment identifier used to claim the half-open probe lock.
_ENVIRONMENT_ID = uuid.uuid4().hex


class CircuitBreakerHandler:
"""
Drive a single protected call through the circuit breaker state machine.

A new handler is created per invocation by the decorator. It reads the shared state,
routes the call (run, short-circuit, or probe), and records the outcome.

Parameters
----------
function : Callable
The protected function.
name : str
Circuit name.
config : CircuitBreakerConfig
Circuit configuration.
persistence_store : CircuitBreakerPersistenceLayer
Shared state store.
on_circuit_open : Callable | None
Callback invoked with the protected call's own ``*args``/``**kwargs`` plus a
trailing ``circuit`` keyword argument when the circuit is open. If ``None``, an
open circuit raises :class:`CircuitBreakerOpenError`.
function_args : tuple
Positional arguments the protected function was called with.
function_kwargs : dict
Keyword arguments the protected function was called with.
"""

def __init__(
self,
function: Callable,
name: str,
config: CircuitBreakerConfig,
persistence_store: CircuitBreakerPersistenceLayer,
on_circuit_open: Callable | None = None,
on_transition: Callable | None = None,
function_args: tuple | None = None,
function_kwargs: dict | None = None,
):
self.function = function
self.name = name
self.config = config
self.on_circuit_open = on_circuit_open
self.on_transition = on_transition
self.fn_args = function_args or ()
self.fn_kwargs = function_kwargs or {}

persistence_store.configure(config=config, circuit_name=name)
self.persistence_store = persistence_store

def handle(self) -> Any:
"""
Evaluate the circuit and route the call.

Returns
-------
Any
The protected function's result when the call runs, or the
``on_circuit_open`` callback's return value when the circuit is open.

Raises
------
CircuitBreakerOpenError
If the circuit is open and no callback is registered.
"""
record = self.persistence_store.get_state(self.name)

if record.state == CircuitState.CLOSED:
# If we previously observed a non-CLOSED state and the circuit is now back to
# CLOSED, another environment completed the recovery cycle. Reset local counters
# so a stale partial failure streak doesn't immediately re-trip the circuit.
prev = _LAST_OBSERVED_STATE.get(self.name)
if prev is not None and prev != CircuitState.CLOSED:
_LOCAL_FAILURES[self.name] = 0
_LAST_OBSERVED_STATE[self.name] = CircuitState.CLOSED
return self._call_closed()

if record.state == CircuitState.OPEN:
_LAST_OBSERVED_STATE[self.name] = CircuitState.OPEN
# ``opened_at`` may legitimately be 0 (epoch); treat only None as missing.
opened_at = record.opened_at if record.opened_at is not None else self._now()
if self._now() >= opened_at + self.config.recovery_timeout:
# Recovery window elapsed: try to become the single prober.
if self.persistence_store.try_acquire_half_open(self.name, _ENVIRONMENT_ID, opened_at):
self._notify(CircuitState.OPEN, CircuitState.HALF_OPEN, opened_at=opened_at)
return self._call_probe()
return self._open_response(record.to_circuit_info())

# HALF_OPEN: only the environment that owns the probe lock runs.
_LAST_OBSERVED_STATE[self.name] = CircuitState.HALF_OPEN
if record.half_open_owner == _ENVIRONMENT_ID:
return self._call_probe()

# If the probe lease has expired (owner recycled mid-probe), take over.
if record.probe_lease_expiry is not None and self._now() >= record.probe_lease_expiry:
logger.debug("Circuit '%s' probe lease expired; attempting takeover.", self.name)
if self.persistence_store.try_acquire_half_open(self.name, _ENVIRONMENT_ID, record.opened_at or 0):
return self._call_probe()

return self._open_response(record.to_circuit_info())

def _call_closed(self) -> Any:
"""Run the protected call while the circuit is closed, tracking failures."""
try:
result = self.function(*self.fn_args, **self.fn_kwargs)
except Exception as exc:
if not self.config.counts_as_failure(exc):
raise
failures = _LOCAL_FAILURES.get(self.name, 0) + 1
_LOCAL_FAILURES[self.name] = failures
if failures >= self.config.failure_threshold:
logger.debug("Circuit '%s' tripping CLOSED to OPEN after %d failures.", self.name, failures)
opened_at = self._now()
self._safe_persist(
self.persistence_store.save_open,
self.name,
failure_count=failures,
opened_at=opened_at,
)
_LOCAL_FAILURES[self.name] = 0
self._notify(CircuitState.CLOSED, CircuitState.OPEN, opened_at=opened_at)
raise
else:
_LOCAL_FAILURES[self.name] = 0
return result

def _call_probe(self) -> Any:
"""Run a probe during half-open, closing or reopening based on the outcome."""
try:
result = self.function(*self.fn_args, **self.fn_kwargs)
except Exception as exc:
if not self.config.counts_as_failure(exc):
raise
logger.debug("Circuit '%s' probe failed; reopening.", self.name)
opened_at = self._now()
self._safe_persist(self.persistence_store.save_reopen, self.name, opened_at=opened_at)
_LOCAL_SUCCESSES[self.name] = 0
self._notify(CircuitState.HALF_OPEN, CircuitState.OPEN, opened_at=opened_at)
raise
else:
successes = _LOCAL_SUCCESSES.get(self.name, 0) + 1
_LOCAL_SUCCESSES[self.name] = successes
if successes >= self.config.success_threshold:
logger.debug("Circuit '%s' closing after %d probe successes.", self.name, successes)
self._safe_persist(self.persistence_store.save_closed, self.name)
_LOCAL_SUCCESSES[self.name] = 0
_LOCAL_FAILURES[self.name] = 0
self._notify(CircuitState.HALF_OPEN, CircuitState.CLOSED)
return result

def _safe_persist(self, fn: Callable, *args: Any, **kwargs: Any) -> None:
"""
Call a persistence write, swallowing and logging failures.

State-transition writes must never mask the downstream's real result or replace
the downstream's real exception. This mirrors the fail-open read policy in the
persistence layer.
"""
try:
fn(*args, **kwargs)
except Exception:
logger.warning(
"Circuit '%s': persistence write (%s) failed; the transition may be delayed but the "
"downstream result is preserved.",
self.name,
getattr(fn, "__name__", repr(fn)),
exc_info=True,
)

def _open_response(self, circuit: CircuitInfo) -> Any:
"""Produce the response for an open circuit: callback result or raise."""
if self.on_circuit_open is not None:
# Forward the protected call's arguments unchanged: positional stay positional,
# keyword stay keyword. The circuit snapshot is passed as a keyword argument so
# it never collides with positionalized kwargs nor depends on dict ordering.
return self.on_circuit_open(*self.fn_args, **self.fn_kwargs, circuit=circuit)
raise CircuitBreakerOpenError(
f"Circuit '{self.name}' is open.",
circuit=circuit,
)

def _notify(self, from_state: CircuitState, to_state: CircuitState, opened_at: int | None = None) -> None:
"""
Fire the ``on_transition`` hook for a state change.

Called only on real transitions, never on the hot path. Any exception the hook
raises is swallowed and logged: observability must never break the protected call.
"""
if self.on_transition is None:
return
try:
self.on_transition(
CircuitTransition(
circuit_name=self.name,
from_state=from_state,
to_state=to_state,
opened_at=opened_at,
),
)
except Exception:
logger.warning("on_transition hook for circuit '%s' raised; ignoring.", self.name, exc_info=True)

@staticmethod
def _now() -> int:
"""Current unix timestamp in seconds."""
return int(datetime.datetime.now().timestamp())
Loading