Coverage for src / taipanstack / resilience / circuit_breaker.py: 100%

254 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-12 21:18 +0000

1""" 

2Circuit Breaker pattern implementation. 

3 

4Provides protection against cascading failures by temporarily 

5blocking calls to a failing service. Compatible with any 

6Python framework (sync and async). 

7""" 

8 

9import functools 

10import inspect 

11import logging 

12import math 

13import threading 

14import time 

15from collections.abc import Awaitable, Callable 

16from dataclasses import dataclass, field 

17from enum import Enum 

18from typing import ParamSpec, Protocol, TypeVar, cast, overload 

19 

20from taipanstack.core.result import Err 

21 

22P = ParamSpec("P") 

23R = TypeVar("R") 

24 

25 

26class CircuitBreakerDecorator(Protocol): 

27 """Protocol for the circuit breaker decorator.""" 

28 

29 @overload 

30 def __call__(self, func: Callable[P, R]) -> Callable[P, R]: ... 

31 

32 @overload 

33 def __call__( 

34 self, func: Callable[P, Awaitable[R]] 

35 ) -> Callable[P, Awaitable[R]]: ... 

36 

37 

38logger = logging.getLogger("taipanstack.resilience.circuit_breaker") 

39 

40try: 

41 import structlog as _structlog 

42 

43 _structlog_logger = _structlog.get_logger("taipanstack.resilience.circuit_breaker") 

44 _HAS_STRUCTLOG = True 

45except ImportError: 

46 _structlog_logger = None 

47 _HAS_STRUCTLOG = False 

48 

49 

50class CircuitState(Enum): 

51 """States of the circuit breaker.""" 

52 

53 CLOSED = "closed" # Normal operation, requests flow through 

54 OPEN = "open" # Circuit is tripped, requests are blocked 

55 HALF_OPEN = "half_open" # Testing if service has recovered 

56 

57 

58class CircuitBreakerError(Exception): 

59 """Raised when circuit breaker is open.""" 

60 

61 def __init__(self, message: str, state: CircuitState) -> None: 

62 """Initialize CircuitBreakerError. 

63 

64 Args: 

65 message: Error description. 

66 state: Current circuit state. 

67 

68 """ 

69 self.state = state 

70 super().__init__(message) 

71 

72 

73@dataclass 

74class CircuitBreakerConfig: 

75 """Configuration for circuit breaker behavior. 

76 

77 Attributes: 

78 failure_threshold: Number of failures before opening circuit. 

79 success_threshold: Successes needed in half-open to close. 

80 timeout: Seconds before trying half-open after open. 

81 excluded_exceptions: Exceptions that don't count as failures. 

82 failure_exceptions: Exceptions that count as failures. 

83 

84 """ 

85 

86 failure_threshold: int = 5 

87 success_threshold: int = 2 

88 timeout: float = 30.0 

89 excluded_exceptions: tuple[type[Exception], ...] = () 

90 failure_exceptions: tuple[type[Exception], ...] = (Exception,) 

91 

92 def __post_init__(self) -> None: 

93 """Validate configuration values.""" 

94 if not math.isfinite(self.failure_threshold): 

95 raise ValueError("failure_threshold must be finite") 

96 if not math.isfinite(self.success_threshold): 

97 raise ValueError("success_threshold must be finite") 

98 if not math.isfinite(self.timeout): 

99 raise ValueError("timeout must be finite") 

100 

101 

102@dataclass 

103class CircuitBreakerState: 

104 """Internal state tracking for circuit breaker.""" 

105 

106 state: CircuitState = CircuitState.CLOSED 

107 failure_count: int = 0 

108 success_count: int = 0 

109 half_open_attempts: int = 0 

110 last_failure_time: float = 0.0 

111 lock: threading.Lock = field(default_factory=threading.Lock) 

112 

113 

114class CircuitBreaker: 

115 """Circuit breaker implementation. 

116 

117 Monitors function calls and opens the circuit when too many 

118 failures occur, preventing further calls until the service 

119 recovers. Supports both sync and async functions. 

120 

121 Example: 

122 >>> breaker = CircuitBreaker(failure_threshold=3) 

123 >>> @breaker 

124 ... def call_external_api(): 

125 ... return requests.get("https://api.example.com", timeout=10) 

126 

127 """ 

128 

129 @staticmethod 

130 def _check_finite_val(value: float, min_val: float, err_msg: str) -> None: 

131 if not math.isfinite(value) or value < min_val: 

132 raise ValueError(err_msg) 

133 

134 @staticmethod 

135 def _validate_thresholds( 

136 timeout: float, failure_threshold: int, success_threshold: int 

137 ) -> None: 

138 CircuitBreaker._check_finite_val( 

139 timeout, 0, "timeout must be a finite non-negative number" 

140 ) 

141 CircuitBreaker._check_finite_val( 

142 failure_threshold, 1, "failure_threshold must be a finite number >= 1" 

143 ) 

144 CircuitBreaker._check_finite_val( 

145 success_threshold, 1, "success_threshold must be a finite number >= 1" 

146 ) 

147 

148 def __init__( 

149 self, 

150 *, 

151 failure_threshold: int = 5, 

152 success_threshold: int = 2, 

153 timeout: float = 30.0, 

154 excluded_exceptions: tuple[type[Exception], ...] = (), 

155 failure_exceptions: tuple[type[Exception], ...] = (Exception,), 

156 name: str = "default", 

157 on_state_change: Callable[[CircuitState, CircuitState], None] | None = None, 

158 ) -> None: 

159 """Initialize CircuitBreaker. 

160 

161 Args: 

162 failure_threshold: Failures before opening circuit. 

163 success_threshold: Successes to close from half-open. 

164 timeout: Seconds before attempting half-open. 

165 excluded_exceptions: Exceptions that don't trip circuit. 

166 failure_exceptions: Exceptions that count as failures. 

167 name: Name for logging/identification. 

168 on_state_change: Optional callback invoked on state transitions 

169 with (old_state, new_state). Useful for custom monitoring. 

170 

171 """ 

172 CircuitBreaker._validate_thresholds( 

173 timeout, failure_threshold, success_threshold 

174 ) 

175 

176 self.config = CircuitBreakerConfig( 

177 failure_threshold=failure_threshold, 

178 success_threshold=success_threshold, 

179 timeout=timeout, 

180 excluded_exceptions=excluded_exceptions, 

181 failure_exceptions=failure_exceptions, 

182 ) 

183 self.name = name 

184 self._state = CircuitBreakerState() 

185 self._on_state_change = on_state_change 

186 

187 @property 

188 def state(self) -> CircuitState: 

189 """Get current circuit state.""" 

190 return self._state.state 

191 

192 @property 

193 def failure_count(self) -> int: 

194 """Get current failure count.""" 

195 return self._state.failure_count 

196 

197 def _log_callback_failure( 

198 self, 

199 old_state: CircuitState, 

200 new_state: CircuitState, 

201 e: Exception, 

202 ) -> None: 

203 if _HAS_STRUCTLOG and _structlog_logger is not None: 

204 _structlog_logger.error( 

205 "circuit_state_change_callback_failed", 

206 circuit=self.name, 

207 old_state=old_state.value, 

208 new_state=new_state.value, 

209 error=str(e), 

210 ) 

211 else: 

212 logger.error( 

213 "Circuit %s state change callback failed: %s", 

214 self.name, 

215 str(e), 

216 ) 

217 

218 def _notify_state_change( 

219 self, 

220 old_state: CircuitState, 

221 new_state: CircuitState, 

222 ) -> None: 

223 """Notify callback of state transition if registered. 

224 

225 Emit a structured log via structlog when no callback is provided 

226 and structlog is available. 

227 """ 

228 if self._on_state_change is not None: 

229 try: 

230 self._on_state_change(old_state, new_state) 

231 except Exception as e: 

232 self._log_callback_failure(old_state, new_state, e) 

233 elif _HAS_STRUCTLOG and _structlog_logger is not None: # pragma: no branch 

234 _structlog_logger.warning( 

235 "circuit_state_changed", 

236 circuit=self.name, 

237 old_state=old_state.value, 

238 new_state=new_state.value, 

239 failure_count=self._state.failure_count, 

240 ) 

241 

242 def _handle_open_state( 

243 self, 

244 ) -> tuple[bool, tuple[CircuitState, CircuitState] | None]: 

245 """Handle logic for OPEN state in _should_attempt.""" 

246 now = time.monotonic() 

247 try: 

248 elapsed = now - self._state.last_failure_time 

249 except TypeError: 

250 # Type corruption detected (e.g. last_failure_time is string) 

251 return False, None 

252 

253 # Safe check against NaN and Inf time corruption 

254 # If elapsed < 0, a backward clock jump occurred. We should 

255 # allow a transition to prevent permanent lockout. 

256 if elapsed < 0: 

257 elapsed = self.config.timeout 

258 

259 if math.isfinite(now) and elapsed >= self.config.timeout: 

260 # Before transitioning, verify if we can make an attempt 

261 # This happens in a lock, so it's thread-safe. However, once 

262 # the state changes to HALF_OPEN, subsequent threads in the 

263 # same lock block will hit the HALF_OPEN case. 

264 self._state.state = CircuitState.HALF_OPEN 

265 self._state.success_count = 0 

266 # Initialize half_open_attempts to 1 because this first call 

267 # that transitions the state is also an attempt. 

268 self._state.half_open_attempts = 1 

269 logger.info( 

270 "Circuit %s entering half-open state (was open for %.1fs, failures=%d)", 

271 self.name, 

272 elapsed, 

273 self._state.failure_count, 

274 ) 

275 return True, (CircuitState.OPEN, CircuitState.HALF_OPEN) 

276 return False, None 

277 

278 def _handle_attempt_half_open(self) -> bool: 

279 try: 

280 if not math.isfinite(self._state.half_open_attempts): 

281 return False 

282 except TypeError: 

283 # Type corruption detected, deny attempt to be safe 

284 return False 

285 

286 if self._state.half_open_attempts < self.config.success_threshold: 

287 self._state.half_open_attempts += 1 

288 return True 

289 return False 

290 

291 def _should_attempt(self) -> bool: 

292 """Check if a call should be attempted.""" 

293 state_change: tuple[CircuitState, CircuitState] | None = None 

294 should_attempt = False 

295 

296 with self._state.lock: 

297 match self._state.state: 

298 case CircuitState.CLOSED: 

299 should_attempt = True 

300 case CircuitState.OPEN: 

301 should_attempt, state_change = self._handle_open_state() 

302 case CircuitState.HALF_OPEN: 

303 should_attempt = self._handle_attempt_half_open() 

304 

305 if state_change: 

306 self._notify_state_change(*state_change) 

307 

308 return should_attempt 

309 

310 def _handle_success_half_open(self) -> tuple[CircuitState, CircuitState] | None: 

311 try: 

312 if not math.isfinite(self._state.success_count): 

313 self._state.success_count = 0 

314 self._state.success_count += 1 

315 except TypeError: 

316 # Type corruption detected, reset and increment 

317 self._state.success_count = 1 

318 

319 if self._state.success_count >= self.config.success_threshold: 

320 self._state.state = CircuitState.CLOSED 

321 self._state.failure_count = 0 

322 self._state.half_open_attempts = 0 

323 logger.info( 

324 "Circuit %s closed after recovery (%d consecutive successes)", 

325 self.name, 

326 self._state.success_count, 

327 ) 

328 return (CircuitState.HALF_OPEN, CircuitState.CLOSED) 

329 return None 

330 

331 def _record_success(self) -> None: 

332 """Record a successful call.""" 

333 state_change: tuple[CircuitState, CircuitState] | None = None 

334 

335 with self._state.lock: 

336 match self._state.state: 

337 case CircuitState.HALF_OPEN: 

338 state_change = self._handle_success_half_open() 

339 case CircuitState.CLOSED: 

340 # Reset failure count on success 

341 self._state.failure_count = 0 

342 case CircuitState.OPEN: # pragma: no branch 

343 pass # Should not happen, but handle gracefully 

344 

345 if state_change: 

346 self._notify_state_change(*state_change) 

347 

348 def _handle_failure_half_open(self) -> tuple[CircuitState, CircuitState] | None: 

349 """Handle failure when in HALF_OPEN state.""" 

350 self._state.state = CircuitState.OPEN 

351 self._state.half_open_attempts = 0 

352 logger.warning( 

353 "Circuit %s reopened after failure in half-open", 

354 self.name, 

355 ) 

356 return (CircuitState.HALF_OPEN, CircuitState.OPEN) 

357 

358 def _handle_failure_closed(self) -> tuple[CircuitState, CircuitState] | None: 

359 """Handle failure when in CLOSED state.""" 

360 # Check against corrupted NaN/Inf failure_count 

361 try: 

362 if not math.isfinite(self._state.failure_count): 

363 self._state.state = CircuitState.OPEN 

364 logger.warning( 

365 "Circuit %s opened due to state corruption (NaN/Inf failures)", 

366 self.name, 

367 ) 

368 return (CircuitState.CLOSED, CircuitState.OPEN) 

369 except TypeError: 

370 self._state.state = CircuitState.OPEN 

371 logger.warning( 

372 "Circuit %s opened due to type corruption in failure_count", 

373 self.name, 

374 ) 

375 return (CircuitState.CLOSED, CircuitState.OPEN) 

376 

377 if self._state.failure_count >= self.config.failure_threshold: 

378 self._state.state = CircuitState.OPEN 

379 logger.warning( 

380 "Circuit %s opened after %d failures (threshold=%d)", 

381 self.name, 

382 self._state.failure_count, 

383 self.config.failure_threshold, 

384 ) 

385 return (CircuitState.CLOSED, CircuitState.OPEN) 

386 

387 return None 

388 

389 def _update_failure_metrics(self) -> None: 

390 try: 

391 if math.isfinite(self._state.failure_count): 

392 self._state.failure_count += 1 

393 except TypeError: 

394 # Handle type mutation (e.g. failure_count became string) 

395 # Safe degradation: reset to max so it opens immediately 

396 self._state.failure_count = self.config.failure_threshold 

397 

398 now = time.monotonic() 

399 if math.isfinite(now): 

400 self._state.last_failure_time = now 

401 

402 def _record_failure(self, exc: Exception) -> None: 

403 """Record a failed call.""" 

404 # Check if exception should be excluded 

405 if isinstance(exc, self.config.excluded_exceptions): 

406 return 

407 

408 state_change: tuple[CircuitState, CircuitState] | None = None 

409 

410 with self._state.lock: 

411 self._update_failure_metrics() 

412 

413 match self._state.state: 

414 case CircuitState.HALF_OPEN: 

415 state_change = self._handle_failure_half_open() 

416 case CircuitState.CLOSED: 

417 state_change = self._handle_failure_closed() 

418 case CircuitState.OPEN: # pragma: no branch 

419 pass # Already open, nothing to do 

420 

421 if state_change: 

422 self._notify_state_change(*state_change) 

423 

424 def reset(self) -> None: 

425 """Reset circuit breaker to closed state.""" 

426 with self._state.lock: 

427 self._state.state = CircuitState.CLOSED 

428 self._state.failure_count = 0 

429 self._state.success_count = 0 

430 self._state.half_open_attempts = 0 

431 logger.info("Circuit %s manually reset", self.name) 

432 

433 def _process_result(self, result: R) -> R: 

434 """Process Result outcome and record success/failure. 

435 

436 Args: 

437 result: The result to process. 

438 

439 Returns: 

440 The original result. 

441 

442 """ 

443 if isinstance(result, Err): 

444 err_val = result.unwrap_err() 

445 if isinstance(err_val, self.config.failure_exceptions): 

446 self._record_failure(err_val) 

447 return result 

448 # Ignored exception in Result monad 

449 return result 

450 self._record_success() 

451 return result 

452 

453 def _decrement_half_open(self, is_half_open: bool) -> None: 

454 """Decrement half-open attempt count if applicable. 

455 

456 Args: 

457 is_half_open: Whether the circuit was half-open before attempt. 

458 

459 """ 

460 if is_half_open: 

461 with self._state.lock: 

462 try: 

463 if ( 

464 self._state.state == CircuitState.HALF_OPEN 

465 and math.isfinite(self._state.half_open_attempts) 

466 and self._state.half_open_attempts > 0 

467 ): 

468 self._state.half_open_attempts -= 1 

469 except TypeError: 

470 # Reset if state is corrupted to prevent crash 

471 self._state.half_open_attempts = 0 

472 

473 def __call__( 

474 self, func: Callable[P, R] | Callable[P, Awaitable[R]] 

475 ) -> Callable[P, R] | Callable[P, Awaitable[R]]: 

476 """Decorate a sync or async function with circuit breaker protection.""" 

477 if inspect.iscoroutinefunction(func): 

478 func_coro = cast(Callable[P, Awaitable[R]], func) 

479 

480 @functools.wraps(func_coro) 

481 async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R: 

482 if not self._should_attempt(): 

483 raise CircuitBreakerError( 

484 f"Circuit {self.name} is open", 

485 state=self._state.state, 

486 ) 

487 

488 is_half_open = self._state.state == CircuitState.HALF_OPEN 

489 

490 try: 

491 result = await func_coro(*args, **kwargs) 

492 return self._process_result(result) 

493 except self.config.failure_exceptions as e: 

494 self._record_failure(e) 

495 raise 

496 finally: 

497 self._decrement_half_open(is_half_open) 

498 

499 return async_wrapper 

500 

501 func_sync = cast(Callable[P, R], func) 

502 

503 @functools.wraps(func_sync) 

504 def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: 

505 if not self._should_attempt(): 

506 raise CircuitBreakerError( 

507 f"Circuit {self.name} is open", 

508 state=self._state.state, 

509 ) 

510 

511 is_half_open = self._state.state == CircuitState.HALF_OPEN 

512 

513 try: 

514 result = func_sync(*args, **kwargs) 

515 return self._process_result(result) 

516 except self.config.failure_exceptions as e: 

517 self._record_failure(e) 

518 raise 

519 finally: 

520 self._decrement_half_open(is_half_open) 

521 

522 return wrapper 

523 

524 

525def circuit_breaker( 

526 *, 

527 failure_threshold: int = 5, 

528 success_threshold: int = 2, 

529 timeout: float = 30.0, 

530 excluded_exceptions: tuple[type[Exception], ...] = (), 

531 failure_exceptions: tuple[type[Exception], ...] = (Exception,), 

532 name: str | None = None, 

533 on_state_change: Callable[[CircuitState, CircuitState], None] | None = None, 

534) -> CircuitBreakerDecorator: 

535 """Decorate a sync or async function with circuit breaker pattern. 

536 

537 Args: 

538 failure_threshold: Failures before opening circuit. 

539 success_threshold: Successes to close from half-open. 

540 timeout: Seconds before attempting half-open. 

541 excluded_exceptions: Exceptions that don't trip circuit. 

542 failure_exceptions: Exceptions that count as failures. 

543 name: Optional name for the circuit. 

544 on_state_change: Optional callback invoked on state transitions 

545 with (old_state, new_state). 

546 

547 Returns: 

548 Decorated function with circuit breaker protection. 

549 

550 Example: 

551 >>> @circuit_breaker(failure_threshold=3, timeout=60) 

552 ... def call_api(endpoint: str) -> dict: 

553 ... return requests.get(endpoint, timeout=10).json() 

554 

555 >>> @circuit_breaker( 

556 ... failure_threshold=3, 

557 ... on_state_change=lambda old, new: print(f"{old} -> {new}"), 

558 ... ) 

559 ... def monitored_call() -> str: 

560 ... return service.call() 

561 

562 """ 

563 

564 def decorator( 

565 func: Callable[P, R] | Callable[P, Awaitable[R]], 

566 ) -> Callable[P, R] | Callable[P, Awaitable[R]]: 

567 breaker = CircuitBreaker( 

568 failure_threshold=failure_threshold, 

569 success_threshold=success_threshold, 

570 timeout=timeout, 

571 excluded_exceptions=excluded_exceptions, 

572 failure_exceptions=failure_exceptions, 

573 name=name or func.__name__, 

574 on_state_change=on_state_change, 

575 ) 

576 return breaker(func) 

577 

578 return cast(CircuitBreakerDecorator, decorator)