Coverage for src / taipanstack / resilience / circuit_breaker.py: 100%
254 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-12 21:18 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-12 21:18 +0000
1"""
2Circuit Breaker pattern implementation.
4Provides protection against cascading failures by temporarily
5blocking calls to a failing service. Compatible with any
6Python framework (sync and async).
7"""
9import functools
10import inspect
11import logging
12import math
13import threading
14import time
15from collections.abc import Awaitable, Callable
16from dataclasses import dataclass, field
17from enum import Enum
18from typing import ParamSpec, Protocol, TypeVar, cast, overload
20from taipanstack.core.result import Err
22P = ParamSpec("P")
23R = TypeVar("R")
26class CircuitBreakerDecorator(Protocol):
27 """Protocol for the circuit breaker decorator."""
29 @overload
30 def __call__(self, func: Callable[P, R]) -> Callable[P, R]: ...
32 @overload
33 def __call__(
34 self, func: Callable[P, Awaitable[R]]
35 ) -> Callable[P, Awaitable[R]]: ...
38logger = logging.getLogger("taipanstack.resilience.circuit_breaker")
40try:
41 import structlog as _structlog
43 _structlog_logger = _structlog.get_logger("taipanstack.resilience.circuit_breaker")
44 _HAS_STRUCTLOG = True
45except ImportError:
46 _structlog_logger = None
47 _HAS_STRUCTLOG = False
50class CircuitState(Enum):
51 """States of the circuit breaker."""
53 CLOSED = "closed" # Normal operation, requests flow through
54 OPEN = "open" # Circuit is tripped, requests are blocked
55 HALF_OPEN = "half_open" # Testing if service has recovered
58class CircuitBreakerError(Exception):
59 """Raised when circuit breaker is open."""
61 def __init__(self, message: str, state: CircuitState) -> None:
62 """Initialize CircuitBreakerError.
64 Args:
65 message: Error description.
66 state: Current circuit state.
68 """
69 self.state = state
70 super().__init__(message)
73@dataclass
74class CircuitBreakerConfig:
75 """Configuration for circuit breaker behavior.
77 Attributes:
78 failure_threshold: Number of failures before opening circuit.
79 success_threshold: Successes needed in half-open to close.
80 timeout: Seconds before trying half-open after open.
81 excluded_exceptions: Exceptions that don't count as failures.
82 failure_exceptions: Exceptions that count as failures.
84 """
86 failure_threshold: int = 5
87 success_threshold: int = 2
88 timeout: float = 30.0
89 excluded_exceptions: tuple[type[Exception], ...] = ()
90 failure_exceptions: tuple[type[Exception], ...] = (Exception,)
92 def __post_init__(self) -> None:
93 """Validate configuration values."""
94 if not math.isfinite(self.failure_threshold):
95 raise ValueError("failure_threshold must be finite")
96 if not math.isfinite(self.success_threshold):
97 raise ValueError("success_threshold must be finite")
98 if not math.isfinite(self.timeout):
99 raise ValueError("timeout must be finite")
102@dataclass
103class CircuitBreakerState:
104 """Internal state tracking for circuit breaker."""
106 state: CircuitState = CircuitState.CLOSED
107 failure_count: int = 0
108 success_count: int = 0
109 half_open_attempts: int = 0
110 last_failure_time: float = 0.0
111 lock: threading.Lock = field(default_factory=threading.Lock)
114class CircuitBreaker:
115 """Circuit breaker implementation.
117 Monitors function calls and opens the circuit when too many
118 failures occur, preventing further calls until the service
119 recovers. Supports both sync and async functions.
121 Example:
122 >>> breaker = CircuitBreaker(failure_threshold=3)
123 >>> @breaker
124 ... def call_external_api():
125 ... return requests.get("https://api.example.com", timeout=10)
127 """
129 @staticmethod
130 def _check_finite_val(value: float, min_val: float, err_msg: str) -> None:
131 if not math.isfinite(value) or value < min_val:
132 raise ValueError(err_msg)
134 @staticmethod
135 def _validate_thresholds(
136 timeout: float, failure_threshold: int, success_threshold: int
137 ) -> None:
138 CircuitBreaker._check_finite_val(
139 timeout, 0, "timeout must be a finite non-negative number"
140 )
141 CircuitBreaker._check_finite_val(
142 failure_threshold, 1, "failure_threshold must be a finite number >= 1"
143 )
144 CircuitBreaker._check_finite_val(
145 success_threshold, 1, "success_threshold must be a finite number >= 1"
146 )
148 def __init__(
149 self,
150 *,
151 failure_threshold: int = 5,
152 success_threshold: int = 2,
153 timeout: float = 30.0,
154 excluded_exceptions: tuple[type[Exception], ...] = (),
155 failure_exceptions: tuple[type[Exception], ...] = (Exception,),
156 name: str = "default",
157 on_state_change: Callable[[CircuitState, CircuitState], None] | None = None,
158 ) -> None:
159 """Initialize CircuitBreaker.
161 Args:
162 failure_threshold: Failures before opening circuit.
163 success_threshold: Successes to close from half-open.
164 timeout: Seconds before attempting half-open.
165 excluded_exceptions: Exceptions that don't trip circuit.
166 failure_exceptions: Exceptions that count as failures.
167 name: Name for logging/identification.
168 on_state_change: Optional callback invoked on state transitions
169 with (old_state, new_state). Useful for custom monitoring.
171 """
172 CircuitBreaker._validate_thresholds(
173 timeout, failure_threshold, success_threshold
174 )
176 self.config = CircuitBreakerConfig(
177 failure_threshold=failure_threshold,
178 success_threshold=success_threshold,
179 timeout=timeout,
180 excluded_exceptions=excluded_exceptions,
181 failure_exceptions=failure_exceptions,
182 )
183 self.name = name
184 self._state = CircuitBreakerState()
185 self._on_state_change = on_state_change
187 @property
188 def state(self) -> CircuitState:
189 """Get current circuit state."""
190 return self._state.state
192 @property
193 def failure_count(self) -> int:
194 """Get current failure count."""
195 return self._state.failure_count
197 def _log_callback_failure(
198 self,
199 old_state: CircuitState,
200 new_state: CircuitState,
201 e: Exception,
202 ) -> None:
203 if _HAS_STRUCTLOG and _structlog_logger is not None:
204 _structlog_logger.error(
205 "circuit_state_change_callback_failed",
206 circuit=self.name,
207 old_state=old_state.value,
208 new_state=new_state.value,
209 error=str(e),
210 )
211 else:
212 logger.error(
213 "Circuit %s state change callback failed: %s",
214 self.name,
215 str(e),
216 )
218 def _notify_state_change(
219 self,
220 old_state: CircuitState,
221 new_state: CircuitState,
222 ) -> None:
223 """Notify callback of state transition if registered.
225 Emit a structured log via structlog when no callback is provided
226 and structlog is available.
227 """
228 if self._on_state_change is not None:
229 try:
230 self._on_state_change(old_state, new_state)
231 except Exception as e:
232 self._log_callback_failure(old_state, new_state, e)
233 elif _HAS_STRUCTLOG and _structlog_logger is not None: # pragma: no branch
234 _structlog_logger.warning(
235 "circuit_state_changed",
236 circuit=self.name,
237 old_state=old_state.value,
238 new_state=new_state.value,
239 failure_count=self._state.failure_count,
240 )
242 def _handle_open_state(
243 self,
244 ) -> tuple[bool, tuple[CircuitState, CircuitState] | None]:
245 """Handle logic for OPEN state in _should_attempt."""
246 now = time.monotonic()
247 try:
248 elapsed = now - self._state.last_failure_time
249 except TypeError:
250 # Type corruption detected (e.g. last_failure_time is string)
251 return False, None
253 # Safe check against NaN and Inf time corruption
254 # If elapsed < 0, a backward clock jump occurred. We should
255 # allow a transition to prevent permanent lockout.
256 if elapsed < 0:
257 elapsed = self.config.timeout
259 if math.isfinite(now) and elapsed >= self.config.timeout:
260 # Before transitioning, verify if we can make an attempt
261 # This happens in a lock, so it's thread-safe. However, once
262 # the state changes to HALF_OPEN, subsequent threads in the
263 # same lock block will hit the HALF_OPEN case.
264 self._state.state = CircuitState.HALF_OPEN
265 self._state.success_count = 0
266 # Initialize half_open_attempts to 1 because this first call
267 # that transitions the state is also an attempt.
268 self._state.half_open_attempts = 1
269 logger.info(
270 "Circuit %s entering half-open state (was open for %.1fs, failures=%d)",
271 self.name,
272 elapsed,
273 self._state.failure_count,
274 )
275 return True, (CircuitState.OPEN, CircuitState.HALF_OPEN)
276 return False, None
278 def _handle_attempt_half_open(self) -> bool:
279 try:
280 if not math.isfinite(self._state.half_open_attempts):
281 return False
282 except TypeError:
283 # Type corruption detected, deny attempt to be safe
284 return False
286 if self._state.half_open_attempts < self.config.success_threshold:
287 self._state.half_open_attempts += 1
288 return True
289 return False
291 def _should_attempt(self) -> bool:
292 """Check if a call should be attempted."""
293 state_change: tuple[CircuitState, CircuitState] | None = None
294 should_attempt = False
296 with self._state.lock:
297 match self._state.state:
298 case CircuitState.CLOSED:
299 should_attempt = True
300 case CircuitState.OPEN:
301 should_attempt, state_change = self._handle_open_state()
302 case CircuitState.HALF_OPEN:
303 should_attempt = self._handle_attempt_half_open()
305 if state_change:
306 self._notify_state_change(*state_change)
308 return should_attempt
310 def _handle_success_half_open(self) -> tuple[CircuitState, CircuitState] | None:
311 try:
312 if not math.isfinite(self._state.success_count):
313 self._state.success_count = 0
314 self._state.success_count += 1
315 except TypeError:
316 # Type corruption detected, reset and increment
317 self._state.success_count = 1
319 if self._state.success_count >= self.config.success_threshold:
320 self._state.state = CircuitState.CLOSED
321 self._state.failure_count = 0
322 self._state.half_open_attempts = 0
323 logger.info(
324 "Circuit %s closed after recovery (%d consecutive successes)",
325 self.name,
326 self._state.success_count,
327 )
328 return (CircuitState.HALF_OPEN, CircuitState.CLOSED)
329 return None
331 def _record_success(self) -> None:
332 """Record a successful call."""
333 state_change: tuple[CircuitState, CircuitState] | None = None
335 with self._state.lock:
336 match self._state.state:
337 case CircuitState.HALF_OPEN:
338 state_change = self._handle_success_half_open()
339 case CircuitState.CLOSED:
340 # Reset failure count on success
341 self._state.failure_count = 0
342 case CircuitState.OPEN: # pragma: no branch
343 pass # Should not happen, but handle gracefully
345 if state_change:
346 self._notify_state_change(*state_change)
348 def _handle_failure_half_open(self) -> tuple[CircuitState, CircuitState] | None:
349 """Handle failure when in HALF_OPEN state."""
350 self._state.state = CircuitState.OPEN
351 self._state.half_open_attempts = 0
352 logger.warning(
353 "Circuit %s reopened after failure in half-open",
354 self.name,
355 )
356 return (CircuitState.HALF_OPEN, CircuitState.OPEN)
358 def _handle_failure_closed(self) -> tuple[CircuitState, CircuitState] | None:
359 """Handle failure when in CLOSED state."""
360 # Check against corrupted NaN/Inf failure_count
361 try:
362 if not math.isfinite(self._state.failure_count):
363 self._state.state = CircuitState.OPEN
364 logger.warning(
365 "Circuit %s opened due to state corruption (NaN/Inf failures)",
366 self.name,
367 )
368 return (CircuitState.CLOSED, CircuitState.OPEN)
369 except TypeError:
370 self._state.state = CircuitState.OPEN
371 logger.warning(
372 "Circuit %s opened due to type corruption in failure_count",
373 self.name,
374 )
375 return (CircuitState.CLOSED, CircuitState.OPEN)
377 if self._state.failure_count >= self.config.failure_threshold:
378 self._state.state = CircuitState.OPEN
379 logger.warning(
380 "Circuit %s opened after %d failures (threshold=%d)",
381 self.name,
382 self._state.failure_count,
383 self.config.failure_threshold,
384 )
385 return (CircuitState.CLOSED, CircuitState.OPEN)
387 return None
389 def _update_failure_metrics(self) -> None:
390 try:
391 if math.isfinite(self._state.failure_count):
392 self._state.failure_count += 1
393 except TypeError:
394 # Handle type mutation (e.g. failure_count became string)
395 # Safe degradation: reset to max so it opens immediately
396 self._state.failure_count = self.config.failure_threshold
398 now = time.monotonic()
399 if math.isfinite(now):
400 self._state.last_failure_time = now
402 def _record_failure(self, exc: Exception) -> None:
403 """Record a failed call."""
404 # Check if exception should be excluded
405 if isinstance(exc, self.config.excluded_exceptions):
406 return
408 state_change: tuple[CircuitState, CircuitState] | None = None
410 with self._state.lock:
411 self._update_failure_metrics()
413 match self._state.state:
414 case CircuitState.HALF_OPEN:
415 state_change = self._handle_failure_half_open()
416 case CircuitState.CLOSED:
417 state_change = self._handle_failure_closed()
418 case CircuitState.OPEN: # pragma: no branch
419 pass # Already open, nothing to do
421 if state_change:
422 self._notify_state_change(*state_change)
424 def reset(self) -> None:
425 """Reset circuit breaker to closed state."""
426 with self._state.lock:
427 self._state.state = CircuitState.CLOSED
428 self._state.failure_count = 0
429 self._state.success_count = 0
430 self._state.half_open_attempts = 0
431 logger.info("Circuit %s manually reset", self.name)
433 def _process_result(self, result: R) -> R:
434 """Process Result outcome and record success/failure.
436 Args:
437 result: The result to process.
439 Returns:
440 The original result.
442 """
443 if isinstance(result, Err):
444 err_val = result.unwrap_err()
445 if isinstance(err_val, self.config.failure_exceptions):
446 self._record_failure(err_val)
447 return result
448 # Ignored exception in Result monad
449 return result
450 self._record_success()
451 return result
453 def _decrement_half_open(self, is_half_open: bool) -> None:
454 """Decrement half-open attempt count if applicable.
456 Args:
457 is_half_open: Whether the circuit was half-open before attempt.
459 """
460 if is_half_open:
461 with self._state.lock:
462 try:
463 if (
464 self._state.state == CircuitState.HALF_OPEN
465 and math.isfinite(self._state.half_open_attempts)
466 and self._state.half_open_attempts > 0
467 ):
468 self._state.half_open_attempts -= 1
469 except TypeError:
470 # Reset if state is corrupted to prevent crash
471 self._state.half_open_attempts = 0
473 def __call__(
474 self, func: Callable[P, R] | Callable[P, Awaitable[R]]
475 ) -> Callable[P, R] | Callable[P, Awaitable[R]]:
476 """Decorate a sync or async function with circuit breaker protection."""
477 if inspect.iscoroutinefunction(func):
478 func_coro = cast(Callable[P, Awaitable[R]], func)
480 @functools.wraps(func_coro)
481 async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
482 if not self._should_attempt():
483 raise CircuitBreakerError(
484 f"Circuit {self.name} is open",
485 state=self._state.state,
486 )
488 is_half_open = self._state.state == CircuitState.HALF_OPEN
490 try:
491 result = await func_coro(*args, **kwargs)
492 return self._process_result(result)
493 except self.config.failure_exceptions as e:
494 self._record_failure(e)
495 raise
496 finally:
497 self._decrement_half_open(is_half_open)
499 return async_wrapper
501 func_sync = cast(Callable[P, R], func)
503 @functools.wraps(func_sync)
504 def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
505 if not self._should_attempt():
506 raise CircuitBreakerError(
507 f"Circuit {self.name} is open",
508 state=self._state.state,
509 )
511 is_half_open = self._state.state == CircuitState.HALF_OPEN
513 try:
514 result = func_sync(*args, **kwargs)
515 return self._process_result(result)
516 except self.config.failure_exceptions as e:
517 self._record_failure(e)
518 raise
519 finally:
520 self._decrement_half_open(is_half_open)
522 return wrapper
525def circuit_breaker(
526 *,
527 failure_threshold: int = 5,
528 success_threshold: int = 2,
529 timeout: float = 30.0,
530 excluded_exceptions: tuple[type[Exception], ...] = (),
531 failure_exceptions: tuple[type[Exception], ...] = (Exception,),
532 name: str | None = None,
533 on_state_change: Callable[[CircuitState, CircuitState], None] | None = None,
534) -> CircuitBreakerDecorator:
535 """Decorate a sync or async function with circuit breaker pattern.
537 Args:
538 failure_threshold: Failures before opening circuit.
539 success_threshold: Successes to close from half-open.
540 timeout: Seconds before attempting half-open.
541 excluded_exceptions: Exceptions that don't trip circuit.
542 failure_exceptions: Exceptions that count as failures.
543 name: Optional name for the circuit.
544 on_state_change: Optional callback invoked on state transitions
545 with (old_state, new_state).
547 Returns:
548 Decorated function with circuit breaker protection.
550 Example:
551 >>> @circuit_breaker(failure_threshold=3, timeout=60)
552 ... def call_api(endpoint: str) -> dict:
553 ... return requests.get(endpoint, timeout=10).json()
555 >>> @circuit_breaker(
556 ... failure_threshold=3,
557 ... on_state_change=lambda old, new: print(f"{old} -> {new}"),
558 ... )
559 ... def monitored_call() -> str:
560 ... return service.call()
562 """
564 def decorator(
565 func: Callable[P, R] | Callable[P, Awaitable[R]],
566 ) -> Callable[P, R] | Callable[P, Awaitable[R]]:
567 breaker = CircuitBreaker(
568 failure_threshold=failure_threshold,
569 success_threshold=success_threshold,
570 timeout=timeout,
571 excluded_exceptions=excluded_exceptions,
572 failure_exceptions=failure_exceptions,
573 name=name or func.__name__,
574 on_state_change=on_state_change,
575 )
576 return breaker(func)
578 return cast(CircuitBreakerDecorator, decorator)