"""Circuit breaker pattern for fault tolerance."""
from __future__ import annotations
import functools
import inspect
import threading
import time
from dataclasses import dataclass
from enum import Enum, auto
from typing import TYPE_CHECKING, ParamSpec, TypeVar, overload
from kstlib.limits import (
HARD_MAX_CIRCUIT_FAILURES,
HARD_MAX_CIRCUIT_RESET_TIMEOUT,
HARD_MAX_HALF_OPEN_CALLS,
HARD_MIN_CIRCUIT_FAILURES,
HARD_MIN_CIRCUIT_RESET_TIMEOUT,
HARD_MIN_HALF_OPEN_CALLS,
clamp_with_limits,
get_resilience_limits,
)
from kstlib.resilience.exceptions import CircuitOpenError
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
P = ParamSpec("P")
R = TypeVar("R")
[docs]
class CircuitState(Enum):
"""State of the circuit breaker.
States:
CLOSED: Normal operation, requests pass through.
OPEN: Circuit tripped, requests fail immediately.
HALF_OPEN: Testing if service recovered.
"""
CLOSED = auto()
OPEN = auto()
HALF_OPEN = auto()
[docs]
@dataclass
class CircuitStats:
"""Statistics for circuit breaker monitoring.
Attributes:
total_calls: Total number of calls attempted.
successful_calls: Number of successful calls.
failed_calls: Number of failed calls.
rejected_calls: Number of calls rejected due to open circuit.
state_changes: Number of state transitions.
Examples:
>>> stats = CircuitStats()
>>> stats.record_success()
>>> stats.record_failure()
>>> stats.record_rejection()
>>> (stats.successful_calls, stats.failed_calls, stats.rejected_calls)
(1, 1, 1)
>>> stats.total_calls
3
"""
total_calls: int = 0
successful_calls: int = 0
failed_calls: int = 0
rejected_calls: int = 0
state_changes: int = 0
[docs]
def record_success(self) -> None:
"""Record a successful call."""
self.total_calls += 1
self.successful_calls += 1
[docs]
def record_failure(self) -> None:
"""Record a failed call."""
self.total_calls += 1
self.failed_calls += 1
[docs]
def record_rejection(self) -> None:
"""Record a rejected call (circuit open)."""
self.total_calls += 1
self.rejected_calls += 1
[docs]
def record_state_change(self) -> None:
"""Record a state transition."""
self.state_changes += 1
[docs]
class CircuitBreaker:
"""Circuit breaker for protecting against cascading failures.
Implements the circuit breaker pattern to prevent repeated calls
to a failing service and allow recovery time.
Args:
max_failures: Failures before opening circuit (default from config).
reset_timeout: Seconds before attempting recovery (default from config).
half_open_max_calls: Calls allowed in half-open state (default from config).
excluded_exceptions: Exceptions that don't count as failures.
name: Optional name for the circuit breaker.
Examples:
As a decorator:
>>> @circuit_breaker
... def call_api(): # doctest: +SKIP
... return requests.get("http://api.example.com")
With custom settings:
>>> @circuit_breaker(max_failures=3, reset_timeout=30)
... def risky_call(): # doctest: +SKIP
... pass
Direct instantiation:
>>> cb = CircuitBreaker(max_failures=5)
>>> cb.state
<CircuitState.CLOSED: 1>
"""
[docs]
def __init__(
self,
*,
max_failures: int | None = None,
reset_timeout: float | None = None,
half_open_max_calls: int | None = None,
excluded_exceptions: tuple[type[Exception], ...] = (),
name: str | None = None,
) -> None:
"""Initialize circuit breaker.
Args:
max_failures: Failures before opening circuit. Uses config if None.
reset_timeout: Seconds before attempting recovery. Uses config if None.
half_open_max_calls: Calls allowed in half-open state. Uses config if None.
excluded_exceptions: Exceptions that don't count as failures.
name: Optional name for the circuit breaker.
"""
limits = get_resilience_limits()
# Max failures (use config default or clamp provided value)
self._max_failures = (
limits.circuit_max_failures
if max_failures is None
else int(clamp_with_limits(max_failures, HARD_MIN_CIRCUIT_FAILURES, HARD_MAX_CIRCUIT_FAILURES))
)
# Reset timeout (use config default or clamp provided value)
self._reset_timeout = (
limits.circuit_reset_timeout
if reset_timeout is None
else clamp_with_limits(reset_timeout, HARD_MIN_CIRCUIT_RESET_TIMEOUT, HARD_MAX_CIRCUIT_RESET_TIMEOUT)
)
# Half-open max calls (use config default or clamp provided value)
self._half_open_max_calls = (
limits.circuit_half_open_calls
if half_open_max_calls is None
else int(clamp_with_limits(half_open_max_calls, HARD_MIN_HALF_OPEN_CALLS, HARD_MAX_HALF_OPEN_CALLS))
)
self._excluded_exceptions = excluded_exceptions
self._name = name
# State
self._state = CircuitState.CLOSED
self._failure_count = 0
self._last_failure_time: float | None = None
self._half_open_calls = 0
# Thread safety
self._lock = threading.Lock()
# Statistics
self._stats = CircuitStats()
@property
def state(self) -> CircuitState:
"""Return the current circuit state."""
with self._lock:
self._check_state_transition()
return self._state
@property
def name(self) -> str | None:
"""Return the circuit breaker name."""
return self._name
@property
def stats(self) -> CircuitStats:
"""Return circuit breaker statistics."""
return self._stats
@property
def failure_count(self) -> int:
"""Return current failure count."""
return self._failure_count
def _check_state_transition(self) -> None:
"""Check and perform state transition if needed."""
if self._state == CircuitState.OPEN and self._last_failure_time is not None:
elapsed = time.monotonic() - self._last_failure_time
if elapsed >= self._reset_timeout:
self._state = CircuitState.HALF_OPEN
self._half_open_calls = 0
self._stats.record_state_change()
def _record_success(self) -> None:
"""Record a successful call and update state."""
with self._lock:
self._stats.record_success()
if self._state == CircuitState.HALF_OPEN:
self._half_open_calls += 1
if self._half_open_calls >= self._half_open_max_calls:
# Recovery successful, close circuit
self._state = CircuitState.CLOSED
self._failure_count = 0
self._last_failure_time = None
self._stats.record_state_change()
elif self._state == CircuitState.CLOSED:
# Reset failure count on success
self._failure_count = 0
def _record_failure(self, exc: Exception) -> None:
"""Record a failed call and update state."""
# Check if exception is excluded
if isinstance(exc, self._excluded_exceptions):
return
with self._lock:
self._stats.record_failure()
self._failure_count += 1
self._last_failure_time = time.monotonic()
if self._state == CircuitState.HALF_OPEN:
# Failed during recovery, reopen circuit
self._state = CircuitState.OPEN
self._stats.record_state_change()
elif self._state == CircuitState.CLOSED:
if self._failure_count >= self._max_failures:
self._state = CircuitState.OPEN
self._stats.record_state_change()
def _check_open(self) -> None:
"""Check if circuit is open and raise if so."""
with self._lock:
self._check_state_transition()
if self._state == CircuitState.OPEN:
remaining = 0.0
if self._last_failure_time is not None:
elapsed = time.monotonic() - self._last_failure_time
remaining = max(0.0, self._reset_timeout - elapsed)
self._stats.record_rejection()
raise CircuitOpenError(
f"Circuit breaker '{self._name or 'unnamed'}' is open",
remaining_seconds=remaining,
)
[docs]
def call(self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
"""Execute a function through the circuit breaker.
Args:
func: Function to execute.
*args: Positional arguments for the function.
**kwargs: Keyword arguments for the function.
Returns:
Function result.
Raises:
CircuitOpenError: If circuit is open.
Examples:
>>> cb = CircuitBreaker()
>>> result = cb.call(lambda x: x * 2, 5)
>>> result
10
"""
self._check_open()
try:
result = func(*args, **kwargs)
self._record_success()
return result
except Exception as exc:
self._record_failure(exc)
raise
[docs]
async def acall(
self,
func: Callable[P, Awaitable[R]],
*args: P.args,
**kwargs: P.kwargs,
) -> R:
"""Execute an async function through the circuit breaker.
Args:
func: Async function to execute.
*args: Positional arguments for the function.
**kwargs: Keyword arguments for the function.
Returns:
Function result.
Raises:
CircuitOpenError: If circuit is open.
Examples:
>>> import asyncio
>>> cb = CircuitBreaker()
>>> async def double(x): return x * 2
>>> asyncio.run(cb.acall(double, 5))
10
"""
self._check_open()
try:
result = await func(*args, **kwargs)
self._record_success()
return result
except Exception as exc:
self._record_failure(exc)
raise
[docs]
def reset(self) -> None:
"""Manually reset the circuit breaker to closed state.
Examples:
>>> cb = CircuitBreaker(max_failures=1)
>>> try:
... cb.call(lambda: 1/0)
... except ZeroDivisionError:
... pass
>>> cb.state.name
'OPEN'
>>> cb.reset()
>>> cb.state.name
'CLOSED'
"""
with self._lock:
self._state = CircuitState.CLOSED
self._failure_count = 0
self._last_failure_time = None
self._half_open_calls = 0
[docs]
def __call__(self, func: Callable[P, R]) -> Callable[P, R] | Callable[P, Awaitable[R]]:
"""Use circuit breaker as a decorator.
Args:
func: Function to wrap.
Returns:
Wrapped function with circuit breaker protection.
"""
if inspect.iscoroutinefunction(func):
@functools.wraps(func)
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
return await self.acall(func, *args, **kwargs)
return async_wrapper
@functools.wraps(func)
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
return self.call(func, *args, **kwargs)
return sync_wrapper
# Decorator factory
@overload
def circuit_breaker(func: Callable[P, R]) -> Callable[P, R]: ...
@overload
def circuit_breaker(
*,
max_failures: int | None = None,
reset_timeout: float | None = None,
half_open_max_calls: int | None = None,
excluded_exceptions: tuple[type[Exception], ...] = (),
name: str | None = None,
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
[docs]
def circuit_breaker(
func: Callable[P, R] | None = None,
*,
max_failures: int | None = None,
reset_timeout: float | None = None,
half_open_max_calls: int | None = None,
excluded_exceptions: tuple[type[Exception], ...] = (),
name: str | None = None,
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
"""Circuit breaker decorator for functions.
Can be used with or without arguments:
Examples:
Without arguments (uses config defaults):
>>> @circuit_breaker
... def api_call(): # doctest: +SKIP
... pass
With arguments:
>>> @circuit_breaker(max_failures=3, reset_timeout=30)
... def api_call(): # doctest: +SKIP
... pass
Exclude specific exceptions:
>>> @circuit_breaker(excluded_exceptions=(ValueError,))
... def validate(): # doctest: +SKIP
... pass
"""
cb = CircuitBreaker(
max_failures=max_failures,
reset_timeout=reset_timeout,
half_open_max_calls=half_open_max_calls,
excluded_exceptions=excluded_exceptions,
name=name,
)
if func is not None:
# @circuit_breaker without parentheses
return cb(func) # type: ignore[return-value]
# @circuit_breaker(...) with arguments
return cb # type: ignore[return-value]
__all__ = ["CircuitBreaker", "CircuitState", "CircuitStats", "circuit_breaker"]