Coverage for src / taipanstack / utils / circuit_breaker.py: 100%

151 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-23 14:54 +0000

1""" 

2Circuit Breaker pattern implementation. 

3 

4Provides protection against cascading failures by temporarily 

5blocking calls to a failing service. Compatible with any 

6Python framework (sync and async). 

7""" 

8 

9import functools 

10import inspect 

11import logging 

12import threading 

13import time 

14from collections.abc import Callable, Coroutine 

15from dataclasses import dataclass, field 

16from enum import Enum 

17from typing import Any, ParamSpec, Protocol, TypeVar, cast, overload 

18 

19P = ParamSpec("P") 

20R = TypeVar("R") 

21 

22 

23class CircuitBreakerDecorator(Protocol): 

24 """Protocol for the circuit breaker decorator.""" 

25 

26 @overload 

27 def __call__(self, func: Callable[P, R]) -> Callable[P, R]: ... # pragma: no cover 

28 

29 @overload 

30 def __call__( 

31 self, func: Callable[P, Coroutine[Any, Any, R]] 

32 ) -> Callable[P, Coroutine[Any, Any, R]]: ... # pragma: no cover 

33 

34 

35logger = logging.getLogger("taipanstack.utils.circuit_breaker") 

36 

37try: 

38 import structlog as _structlog 

39 

40 _structlog_logger = _structlog.get_logger("taipanstack.utils.circuit_breaker") 

41 _HAS_STRUCTLOG = True 

42except ImportError: # pragma: no cover — structlog is optional 

43 _structlog_logger = None 

44 _HAS_STRUCTLOG = False 

45 

46 

47class CircuitState(Enum): 

48 """States of the circuit breaker.""" 

49 

50 CLOSED = "closed" # Normal operation, requests flow through 

51 OPEN = "open" # Circuit is tripped, requests are blocked 

52 HALF_OPEN = "half_open" # Testing if service has recovered 

53 

54 

55class CircuitBreakerError(Exception): 

56 """Raised when circuit breaker is open.""" 

57 

58 def __init__(self, message: str, state: CircuitState) -> None: 

59 """Initialize CircuitBreakerError. 

60 

61 Args: 

62 message: Error description. 

63 state: Current circuit state. 

64 

65 """ 

66 self.state = state 

67 super().__init__(message) 

68 

69 

70@dataclass 

71class CircuitBreakerConfig: 

72 """Configuration for circuit breaker behavior. 

73 

74 Attributes: 

75 failure_threshold: Number of failures before opening circuit. 

76 success_threshold: Successes needed in half-open to close. 

77 timeout: Seconds before trying half-open after open. 

78 excluded_exceptions: Exceptions that don't count as failures. 

79 failure_exceptions: Exceptions that count as failures. 

80 

81 """ 

82 

83 failure_threshold: int = 5 

84 success_threshold: int = 2 

85 timeout: float = 30.0 

86 excluded_exceptions: tuple[type[Exception], ...] = () 

87 failure_exceptions: tuple[type[Exception], ...] = (Exception,) 

88 

89 

90@dataclass 

91class CircuitBreakerState: 

92 """Internal state tracking for circuit breaker.""" 

93 

94 state: CircuitState = CircuitState.CLOSED 

95 failure_count: int = 0 

96 success_count: int = 0 

97 half_open_attempts: int = 0 

98 last_failure_time: float = 0.0 

99 lock: threading.Lock = field(default_factory=threading.Lock) 

100 

101 

102class CircuitBreaker: 

103 """Circuit breaker implementation. 

104 

105 Monitors function calls and opens the circuit when too many 

106 failures occur, preventing further calls until the service 

107 recovers. Supports both sync and async functions. 

108 

109 Example: 

110 >>> breaker = CircuitBreaker(failure_threshold=3) 

111 >>> @breaker 

112 ... def call_external_api(): 

113 ... return requests.get("https://api.example.com", timeout=10) 

114 

115 """ 

116 

117 def __init__( 

118 self, 

119 *, 

120 failure_threshold: int = 5, 

121 success_threshold: int = 2, 

122 timeout: float = 30.0, 

123 excluded_exceptions: tuple[type[Exception], ...] = (), 

124 failure_exceptions: tuple[type[Exception], ...] = (Exception,), 

125 name: str = "default", 

126 on_state_change: Callable[[CircuitState, CircuitState], None] | None = None, 

127 ) -> None: 

128 """Initialize CircuitBreaker. 

129 

130 Args: 

131 failure_threshold: Failures before opening circuit. 

132 success_threshold: Successes to close from half-open. 

133 timeout: Seconds before attempting half-open. 

134 excluded_exceptions: Exceptions that don't trip circuit. 

135 failure_exceptions: Exceptions that count as failures. 

136 name: Name for logging/identification. 

137 on_state_change: Optional callback invoked on state transitions 

138 with (old_state, new_state). Useful for custom monitoring. 

139 

140 """ 

141 self.config = CircuitBreakerConfig( 

142 failure_threshold=failure_threshold, 

143 success_threshold=success_threshold, 

144 timeout=timeout, 

145 excluded_exceptions=excluded_exceptions, 

146 failure_exceptions=failure_exceptions, 

147 ) 

148 self.name = name 

149 self._state = CircuitBreakerState() 

150 self._on_state_change = on_state_change 

151 

152 @property 

153 def state(self) -> CircuitState: 

154 """Get current circuit state.""" 

155 return self._state.state 

156 

157 @property 

158 def failure_count(self) -> int: 

159 """Get current failure count.""" 

160 return self._state.failure_count 

161 

162 def _notify_state_change( 

163 self, 

164 old_state: CircuitState, 

165 new_state: CircuitState, 

166 ) -> None: 

167 """Notify callback of state transition if registered. 

168 

169 Emit a structured log via structlog when no callback is provided 

170 and structlog is available. 

171 """ 

172 if self._on_state_change is not None: 

173 self._on_state_change(old_state, new_state) 

174 elif _HAS_STRUCTLOG and _structlog_logger is not None: # pragma: no branch 

175 _structlog_logger.warning( 

176 "circuit_state_changed", 

177 circuit=self.name, 

178 old_state=old_state.value, 

179 new_state=new_state.value, 

180 failure_count=self._state.failure_count, 

181 ) 

182 

183 def _should_attempt(self) -> bool: 

184 """Check if a call should be attempted.""" 

185 with self._state.lock: 

186 match self._state.state: 

187 case CircuitState.CLOSED: 

188 return True 

189 

190 case CircuitState.OPEN: 

191 # Check if timeout has passed 

192 elapsed = time.monotonic() - self._state.last_failure_time 

193 if elapsed >= self.config.timeout: 

194 # Before transitioning, verify if we can make an attempt 

195 # This happens in a lock, so it's thread-safe. However, once 

196 # the state changes to HALF_OPEN, subsequent threads in the 

197 # same lock block will hit the HALF_OPEN case. 

198 self._state.state = CircuitState.HALF_OPEN 

199 self._state.success_count = 0 

200 # Initialize half_open_attempts to 1 because this first call 

201 # that transitions the state is also an attempt. 

202 self._state.half_open_attempts = 1 

203 logger.info( 

204 "Circuit %s entering half-open state " 

205 "(was open for %.1fs, failures=%d)", 

206 self.name, 

207 elapsed, 

208 self._state.failure_count, 

209 ) 

210 self._notify_state_change( 

211 CircuitState.OPEN, 

212 CircuitState.HALF_OPEN, 

213 ) 

214 return True 

215 return False 

216 

217 case CircuitState.HALF_OPEN: 

218 # Allow limited attempts to prevent thundering herd 

219 if self._state.half_open_attempts < self.config.success_threshold: 

220 self._state.half_open_attempts += 1 

221 return True 

222 return False 

223 

224 return False # pragma: no cover — unreachable, satisfies type checker 

225 

226 def _record_success(self) -> None: 

227 """Record a successful call.""" 

228 with self._state.lock: 

229 match self._state.state: 

230 case CircuitState.HALF_OPEN: 

231 self._state.success_count += 1 

232 if self._state.success_count >= self.config.success_threshold: 

233 self._state.state = CircuitState.CLOSED 

234 self._state.failure_count = 0 

235 self._state.half_open_attempts = 0 

236 logger.info( 

237 "Circuit %s closed after recovery " 

238 "(%d consecutive successes)", 

239 self.name, 

240 self._state.success_count, 

241 ) 

242 self._notify_state_change( 

243 CircuitState.HALF_OPEN, 

244 CircuitState.CLOSED, 

245 ) 

246 

247 case CircuitState.CLOSED: 

248 # Reset failure count on success 

249 self._state.failure_count = 0 

250 

251 case CircuitState.OPEN: # pragma: no branch 

252 pass # Should not happen, but handle gracefully 

253 

254 def _record_failure(self, exc: Exception) -> None: 

255 """Record a failed call.""" 

256 # Check if exception should be excluded 

257 if isinstance(exc, self.config.excluded_exceptions): 

258 return 

259 

260 with self._state.lock: 

261 self._state.failure_count += 1 

262 self._state.last_failure_time = time.monotonic() 

263 

264 match self._state.state: 

265 case CircuitState.HALF_OPEN: 

266 # Any failure in half-open reopens circuit 

267 self._state.state = CircuitState.OPEN 

268 self._state.half_open_attempts = 0 

269 logger.warning( 

270 "Circuit %s reopened after failure in half-open " 

271 "(total failures=%d)", 

272 self.name, 

273 self._state.failure_count, 

274 ) 

275 self._notify_state_change( 

276 CircuitState.HALF_OPEN, 

277 CircuitState.OPEN, 

278 ) 

279 

280 case CircuitState.CLOSED: 

281 if self._state.failure_count >= self.config.failure_threshold: 

282 self._state.state = CircuitState.OPEN 

283 logger.warning( 

284 "Circuit %s opened after %d failures (threshold=%d)", 

285 self.name, 

286 self._state.failure_count, 

287 self.config.failure_threshold, 

288 ) 

289 self._notify_state_change( 

290 CircuitState.CLOSED, 

291 CircuitState.OPEN, 

292 ) 

293 

294 case CircuitState.OPEN: # pragma: no branch 

295 pass # Already open, nothing to do 

296 

297 def reset(self) -> None: 

298 """Reset circuit breaker to closed state.""" 

299 with self._state.lock: 

300 self._state.state = CircuitState.CLOSED 

301 self._state.failure_count = 0 

302 self._state.success_count = 0 

303 self._state.half_open_attempts = 0 

304 logger.info("Circuit %s manually reset", self.name) 

305 

306 def __call__( 

307 self, func: Callable[P, R] | Callable[P, Coroutine[Any, Any, R]] 

308 ) -> Callable[P, R] | Callable[P, Coroutine[Any, Any, R]]: 

309 """Decorate a sync or async function with circuit breaker protection.""" 

310 if inspect.iscoroutinefunction(func): 

311 func_coro = cast(Callable[P, Coroutine[Any, Any, R]], func) 

312 

313 @functools.wraps(func_coro) 

314 async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R: 

315 if not self._should_attempt(): 

316 raise CircuitBreakerError( 

317 f"Circuit {self.name} is open", 

318 state=self._state.state, 

319 ) 

320 

321 try: 

322 result = await func_coro(*args, **kwargs) 

323 self._record_success() 

324 return result 

325 except self.config.failure_exceptions as e: 

326 self._record_failure(e) 

327 raise 

328 

329 return async_wrapper 

330 

331 func_sync = cast(Callable[P, R], func) 

332 

333 @functools.wraps(func_sync) 

334 def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: 

335 if not self._should_attempt(): 

336 raise CircuitBreakerError( 

337 f"Circuit {self.name} is open", 

338 state=self._state.state, 

339 ) 

340 

341 try: 

342 result = func_sync(*args, **kwargs) 

343 self._record_success() 

344 return result 

345 except self.config.failure_exceptions as e: 

346 self._record_failure(e) 

347 raise 

348 

349 return wrapper 

350 

351 

352def circuit_breaker( 

353 *, 

354 failure_threshold: int = 5, 

355 success_threshold: int = 2, 

356 timeout: float = 30.0, 

357 excluded_exceptions: tuple[type[Exception], ...] = (), 

358 failure_exceptions: tuple[type[Exception], ...] = (Exception,), 

359 name: str | None = None, 

360 on_state_change: Callable[[CircuitState, CircuitState], None] | None = None, 

361) -> CircuitBreakerDecorator: 

362 """Decorate a sync or async function with circuit breaker pattern. 

363 

364 Args: 

365 failure_threshold: Failures before opening circuit. 

366 success_threshold: Successes to close from half-open. 

367 timeout: Seconds before attempting half-open. 

368 excluded_exceptions: Exceptions that don't trip circuit. 

369 failure_exceptions: Exceptions that count as failures. 

370 name: Optional name for the circuit. 

371 on_state_change: Optional callback invoked on state transitions 

372 with (old_state, new_state). 

373 

374 Returns: 

375 Decorated function with circuit breaker protection. 

376 

377 Example: 

378 >>> @circuit_breaker(failure_threshold=3, timeout=60) 

379 ... def call_api(endpoint: str) -> dict: 

380 ... return requests.get(endpoint, timeout=10).json() 

381 

382 >>> @circuit_breaker( 

383 ... failure_threshold=3, 

384 ... on_state_change=lambda old, new: print(f"{old} -> {new}"), 

385 ... ) 

386 ... def monitored_call() -> str: 

387 ... return service.call() 

388 

389 """ 

390 

391 def decorator( 

392 func: Callable[P, R] | Callable[P, Coroutine[Any, Any, R]], 

393 ) -> Callable[P, R] | Callable[P, Coroutine[Any, Any, R]]: 

394 breaker = CircuitBreaker( 

395 failure_threshold=failure_threshold, 

396 success_threshold=success_threshold, 

397 timeout=timeout, 

398 excluded_exceptions=excluded_exceptions, 

399 failure_exceptions=failure_exceptions, 

400 name=name or func.__name__, 

401 on_state_change=on_state_change, 

402 ) 

403 return breaker(func) 

404 

405 return decorator # type: ignore[return-value]