Coverage for src / taipanstack / utils / circuit_breaker.py: 100%
151 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-23 14:54 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-23 14:54 +0000
1"""
2Circuit Breaker pattern implementation.
4Provides protection against cascading failures by temporarily
5blocking calls to a failing service. Compatible with any
6Python framework (sync and async).
7"""
9import functools
10import inspect
11import logging
12import threading
13import time
14from collections.abc import Callable, Coroutine
15from dataclasses import dataclass, field
16from enum import Enum
17from typing import Any, ParamSpec, Protocol, TypeVar, cast, overload
19P = ParamSpec("P")
20R = TypeVar("R")
23class CircuitBreakerDecorator(Protocol):
24 """Protocol for the circuit breaker decorator."""
26 @overload
27 def __call__(self, func: Callable[P, R]) -> Callable[P, R]: ... # pragma: no cover
29 @overload
30 def __call__(
31 self, func: Callable[P, Coroutine[Any, Any, R]]
32 ) -> Callable[P, Coroutine[Any, Any, R]]: ... # pragma: no cover
35logger = logging.getLogger("taipanstack.utils.circuit_breaker")
37try:
38 import structlog as _structlog
40 _structlog_logger = _structlog.get_logger("taipanstack.utils.circuit_breaker")
41 _HAS_STRUCTLOG = True
42except ImportError: # pragma: no cover — structlog is optional
43 _structlog_logger = None
44 _HAS_STRUCTLOG = False
47class CircuitState(Enum):
48 """States of the circuit breaker."""
50 CLOSED = "closed" # Normal operation, requests flow through
51 OPEN = "open" # Circuit is tripped, requests are blocked
52 HALF_OPEN = "half_open" # Testing if service has recovered
55class CircuitBreakerError(Exception):
56 """Raised when circuit breaker is open."""
58 def __init__(self, message: str, state: CircuitState) -> None:
59 """Initialize CircuitBreakerError.
61 Args:
62 message: Error description.
63 state: Current circuit state.
65 """
66 self.state = state
67 super().__init__(message)
70@dataclass
71class CircuitBreakerConfig:
72 """Configuration for circuit breaker behavior.
74 Attributes:
75 failure_threshold: Number of failures before opening circuit.
76 success_threshold: Successes needed in half-open to close.
77 timeout: Seconds before trying half-open after open.
78 excluded_exceptions: Exceptions that don't count as failures.
79 failure_exceptions: Exceptions that count as failures.
81 """
83 failure_threshold: int = 5
84 success_threshold: int = 2
85 timeout: float = 30.0
86 excluded_exceptions: tuple[type[Exception], ...] = ()
87 failure_exceptions: tuple[type[Exception], ...] = (Exception,)
90@dataclass
91class CircuitBreakerState:
92 """Internal state tracking for circuit breaker."""
94 state: CircuitState = CircuitState.CLOSED
95 failure_count: int = 0
96 success_count: int = 0
97 half_open_attempts: int = 0
98 last_failure_time: float = 0.0
99 lock: threading.Lock = field(default_factory=threading.Lock)
102class CircuitBreaker:
103 """Circuit breaker implementation.
105 Monitors function calls and opens the circuit when too many
106 failures occur, preventing further calls until the service
107 recovers. Supports both sync and async functions.
109 Example:
110 >>> breaker = CircuitBreaker(failure_threshold=3)
111 >>> @breaker
112 ... def call_external_api():
113 ... return requests.get("https://api.example.com", timeout=10)
115 """
117 def __init__(
118 self,
119 *,
120 failure_threshold: int = 5,
121 success_threshold: int = 2,
122 timeout: float = 30.0,
123 excluded_exceptions: tuple[type[Exception], ...] = (),
124 failure_exceptions: tuple[type[Exception], ...] = (Exception,),
125 name: str = "default",
126 on_state_change: Callable[[CircuitState, CircuitState], None] | None = None,
127 ) -> None:
128 """Initialize CircuitBreaker.
130 Args:
131 failure_threshold: Failures before opening circuit.
132 success_threshold: Successes to close from half-open.
133 timeout: Seconds before attempting half-open.
134 excluded_exceptions: Exceptions that don't trip circuit.
135 failure_exceptions: Exceptions that count as failures.
136 name: Name for logging/identification.
137 on_state_change: Optional callback invoked on state transitions
138 with (old_state, new_state). Useful for custom monitoring.
140 """
141 self.config = CircuitBreakerConfig(
142 failure_threshold=failure_threshold,
143 success_threshold=success_threshold,
144 timeout=timeout,
145 excluded_exceptions=excluded_exceptions,
146 failure_exceptions=failure_exceptions,
147 )
148 self.name = name
149 self._state = CircuitBreakerState()
150 self._on_state_change = on_state_change
152 @property
153 def state(self) -> CircuitState:
154 """Get current circuit state."""
155 return self._state.state
157 @property
158 def failure_count(self) -> int:
159 """Get current failure count."""
160 return self._state.failure_count
162 def _notify_state_change(
163 self,
164 old_state: CircuitState,
165 new_state: CircuitState,
166 ) -> None:
167 """Notify callback of state transition if registered.
169 Emit a structured log via structlog when no callback is provided
170 and structlog is available.
171 """
172 if self._on_state_change is not None:
173 self._on_state_change(old_state, new_state)
174 elif _HAS_STRUCTLOG and _structlog_logger is not None: # pragma: no branch
175 _structlog_logger.warning(
176 "circuit_state_changed",
177 circuit=self.name,
178 old_state=old_state.value,
179 new_state=new_state.value,
180 failure_count=self._state.failure_count,
181 )
183 def _should_attempt(self) -> bool:
184 """Check if a call should be attempted."""
185 with self._state.lock:
186 match self._state.state:
187 case CircuitState.CLOSED:
188 return True
190 case CircuitState.OPEN:
191 # Check if timeout has passed
192 elapsed = time.monotonic() - self._state.last_failure_time
193 if elapsed >= self.config.timeout:
194 # Before transitioning, verify if we can make an attempt
195 # This happens in a lock, so it's thread-safe. However, once
196 # the state changes to HALF_OPEN, subsequent threads in the
197 # same lock block will hit the HALF_OPEN case.
198 self._state.state = CircuitState.HALF_OPEN
199 self._state.success_count = 0
200 # Initialize half_open_attempts to 1 because this first call
201 # that transitions the state is also an attempt.
202 self._state.half_open_attempts = 1
203 logger.info(
204 "Circuit %s entering half-open state "
205 "(was open for %.1fs, failures=%d)",
206 self.name,
207 elapsed,
208 self._state.failure_count,
209 )
210 self._notify_state_change(
211 CircuitState.OPEN,
212 CircuitState.HALF_OPEN,
213 )
214 return True
215 return False
217 case CircuitState.HALF_OPEN:
218 # Allow limited attempts to prevent thundering herd
219 if self._state.half_open_attempts < self.config.success_threshold:
220 self._state.half_open_attempts += 1
221 return True
222 return False
224 return False # pragma: no cover — unreachable, satisfies type checker
226 def _record_success(self) -> None:
227 """Record a successful call."""
228 with self._state.lock:
229 match self._state.state:
230 case CircuitState.HALF_OPEN:
231 self._state.success_count += 1
232 if self._state.success_count >= self.config.success_threshold:
233 self._state.state = CircuitState.CLOSED
234 self._state.failure_count = 0
235 self._state.half_open_attempts = 0
236 logger.info(
237 "Circuit %s closed after recovery "
238 "(%d consecutive successes)",
239 self.name,
240 self._state.success_count,
241 )
242 self._notify_state_change(
243 CircuitState.HALF_OPEN,
244 CircuitState.CLOSED,
245 )
247 case CircuitState.CLOSED:
248 # Reset failure count on success
249 self._state.failure_count = 0
251 case CircuitState.OPEN: # pragma: no branch
252 pass # Should not happen, but handle gracefully
254 def _record_failure(self, exc: Exception) -> None:
255 """Record a failed call."""
256 # Check if exception should be excluded
257 if isinstance(exc, self.config.excluded_exceptions):
258 return
260 with self._state.lock:
261 self._state.failure_count += 1
262 self._state.last_failure_time = time.monotonic()
264 match self._state.state:
265 case CircuitState.HALF_OPEN:
266 # Any failure in half-open reopens circuit
267 self._state.state = CircuitState.OPEN
268 self._state.half_open_attempts = 0
269 logger.warning(
270 "Circuit %s reopened after failure in half-open "
271 "(total failures=%d)",
272 self.name,
273 self._state.failure_count,
274 )
275 self._notify_state_change(
276 CircuitState.HALF_OPEN,
277 CircuitState.OPEN,
278 )
280 case CircuitState.CLOSED:
281 if self._state.failure_count >= self.config.failure_threshold:
282 self._state.state = CircuitState.OPEN
283 logger.warning(
284 "Circuit %s opened after %d failures (threshold=%d)",
285 self.name,
286 self._state.failure_count,
287 self.config.failure_threshold,
288 )
289 self._notify_state_change(
290 CircuitState.CLOSED,
291 CircuitState.OPEN,
292 )
294 case CircuitState.OPEN: # pragma: no branch
295 pass # Already open, nothing to do
297 def reset(self) -> None:
298 """Reset circuit breaker to closed state."""
299 with self._state.lock:
300 self._state.state = CircuitState.CLOSED
301 self._state.failure_count = 0
302 self._state.success_count = 0
303 self._state.half_open_attempts = 0
304 logger.info("Circuit %s manually reset", self.name)
306 def __call__(
307 self, func: Callable[P, R] | Callable[P, Coroutine[Any, Any, R]]
308 ) -> Callable[P, R] | Callable[P, Coroutine[Any, Any, R]]:
309 """Decorate a sync or async function with circuit breaker protection."""
310 if inspect.iscoroutinefunction(func):
311 func_coro = cast(Callable[P, Coroutine[Any, Any, R]], func)
313 @functools.wraps(func_coro)
314 async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
315 if not self._should_attempt():
316 raise CircuitBreakerError(
317 f"Circuit {self.name} is open",
318 state=self._state.state,
319 )
321 try:
322 result = await func_coro(*args, **kwargs)
323 self._record_success()
324 return result
325 except self.config.failure_exceptions as e:
326 self._record_failure(e)
327 raise
329 return async_wrapper
331 func_sync = cast(Callable[P, R], func)
333 @functools.wraps(func_sync)
334 def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
335 if not self._should_attempt():
336 raise CircuitBreakerError(
337 f"Circuit {self.name} is open",
338 state=self._state.state,
339 )
341 try:
342 result = func_sync(*args, **kwargs)
343 self._record_success()
344 return result
345 except self.config.failure_exceptions as e:
346 self._record_failure(e)
347 raise
349 return wrapper
352def circuit_breaker(
353 *,
354 failure_threshold: int = 5,
355 success_threshold: int = 2,
356 timeout: float = 30.0,
357 excluded_exceptions: tuple[type[Exception], ...] = (),
358 failure_exceptions: tuple[type[Exception], ...] = (Exception,),
359 name: str | None = None,
360 on_state_change: Callable[[CircuitState, CircuitState], None] | None = None,
361) -> CircuitBreakerDecorator:
362 """Decorate a sync or async function with circuit breaker pattern.
364 Args:
365 failure_threshold: Failures before opening circuit.
366 success_threshold: Successes to close from half-open.
367 timeout: Seconds before attempting half-open.
368 excluded_exceptions: Exceptions that don't trip circuit.
369 failure_exceptions: Exceptions that count as failures.
370 name: Optional name for the circuit.
371 on_state_change: Optional callback invoked on state transitions
372 with (old_state, new_state).
374 Returns:
375 Decorated function with circuit breaker protection.
377 Example:
378 >>> @circuit_breaker(failure_threshold=3, timeout=60)
379 ... def call_api(endpoint: str) -> dict:
380 ... return requests.get(endpoint, timeout=10).json()
382 >>> @circuit_breaker(
383 ... failure_threshold=3,
384 ... on_state_change=lambda old, new: print(f"{old} -> {new}"),
385 ... )
386 ... def monitored_call() -> str:
387 ... return service.call()
389 """
391 def decorator(
392 func: Callable[P, R] | Callable[P, Coroutine[Any, Any, R]],
393 ) -> Callable[P, R] | Callable[P, Coroutine[Any, Any, R]]:
394 breaker = CircuitBreaker(
395 failure_threshold=failure_threshold,
396 success_threshold=success_threshold,
397 timeout=timeout,
398 excluded_exceptions=excluded_exceptions,
399 failure_exceptions=failure_exceptions,
400 name=name or func.__name__,
401 on_state_change=on_state_change,
402 )
403 return breaker(func)
405 return decorator # type: ignore[return-value]