Coverage for src / taipanstack / utils / retry.py: 100%

118 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-23 14:54 +0000

1""" 

2Retry logic with exponential backoff. 

3 

4Provides decorators for automatic retry of failing operations 

5with configurable backoff strategies. Compatible with any 

6Python framework (sync and async). 

7""" 

8 

9import asyncio 

10import functools 

11import inspect 

12import logging 

13import secrets 

14import time 

15from collections.abc import Callable, Coroutine 

16from dataclasses import dataclass 

17from types import TracebackType 

18from typing import Any, NoReturn, ParamSpec, Protocol, TypeVar, cast, overload 

19 

20P = ParamSpec("P") 

21R = TypeVar("R") 

22 

23 

24class RetryDecorator(Protocol): 

25 """Protocol for the retry decorator.""" 

26 

27 @overload 

28 def __call__(self, func: Callable[P, R]) -> Callable[P, R]: ... # pragma: no cover 

29 

30 @overload 

31 def __call__( 

32 self, func: Callable[P, Coroutine[Any, Any, R]] 

33 ) -> Callable[P, Coroutine[Any, Any, R]]: ... # pragma: no cover 

34 

35 

36logger = logging.getLogger("taipanstack.utils.retry") 

37 

38try: 

39 import structlog as _structlog 

40 

41 _structlog_logger = _structlog.get_logger("taipanstack.utils.retry") 

42 _HAS_STRUCTLOG = True 

43except ImportError: # pragma: no cover — structlog is optional 

44 _structlog_logger = None 

45 _HAS_STRUCTLOG = False 

46 

47 

48@dataclass(frozen=True) 

49class RetryConfig: 

50 """Configuration for retry behavior. 

51 

52 Attributes: 

53 max_attempts: Maximum number of retry attempts. 

54 initial_delay: Initial delay between retries in seconds. 

55 max_delay: Maximum delay between retries. 

56 exponential_base: Base for exponential backoff (2 = double each time). 

57 jitter: Whether to add random jitter to delays. 

58 jitter_factor: Maximum jitter as fraction of delay (0.1 = 10%). 

59 log_retries: Whether to emit standard log messages. 

60 on_retry: Optional callback invoked on each retry. 

61 

62 """ 

63 

64 max_attempts: int = 3 

65 initial_delay: float = 1.0 

66 max_delay: float = 60.0 

67 exponential_base: float = 2.0 

68 jitter: bool = True 

69 jitter_factor: float = 0.1 

70 log_retries: bool = True 

71 on_retry: Callable[[int, int, Exception, float], None] | None = None 

72 

73 

74class RetryError(Exception): 

75 """Raised when all retry attempts have failed.""" 

76 

77 def __init__( 

78 self, 

79 message: str, 

80 attempts: int, 

81 last_exception: Exception | None = None, 

82 ) -> None: 

83 """Initialize RetryError. 

84 

85 Args: 

86 message: Description of the retry failure. 

87 attempts: Number of attempts made. 

88 last_exception: The last exception that was raised. 

89 

90 """ 

91 self.attempts = attempts 

92 self.last_exception = last_exception 

93 super().__init__(message) 

94 

95 

96def calculate_delay( 

97 attempt: int, 

98 config: RetryConfig, 

99) -> float: 

100 """Calculate delay before next retry. 

101 

102 Args: 

103 attempt: Current attempt number (1-indexed). 

104 config: Retry configuration. 

105 

106 Returns: 

107 Delay in seconds before next retry. 

108 

109 """ 

110 safe_attempt = max(1, attempt) 

111 # Exponential backoff 

112 delay = config.initial_delay * (config.exponential_base ** (safe_attempt - 1)) 

113 

114 # Cap at max delay 

115 delay = min(delay, config.max_delay) 

116 

117 # Add jitter if enabled 

118 # Note: Using random for jitter is intentionally non-cryptographic. 

119 # However, to maintain a clean security baseline and satisfy Bandit, 

120 # we use secrets.SystemRandom() which provides cryptographically 

121 # secure random numbers. 

122 if config.jitter: 

123 jitter_amount = delay * config.jitter_factor 

124 delay += secrets.SystemRandom().uniform(-jitter_amount, jitter_amount) 

125 

126 return max(0, delay) 

127 

128 

129def _log_retry_attempt( 

130 func_name: str, 

131 attempt: int, 

132 exc: Exception, 

133 delay: float, 

134 config: RetryConfig, 

135) -> None: 

136 """Log a retry attempt via callback, structlog, or stdlib logger. 

137 

138 Args: 

139 func_name: Name of the retried function. 

140 attempt: Current attempt number. 

141 exc: The exception that triggered the retry. 

142 delay: Delay in seconds before the next attempt. 

143 config: Retry configuration. 

144 

145 """ 

146 if config.log_retries: 

147 logger.info( 

148 "Attempt %d/%d failed for %s: %s. Retrying in %.2f seconds...", 

149 attempt, 

150 config.max_attempts, 

151 func_name, 

152 str(exc), 

153 delay, 

154 ) 

155 

156 # Invoke callback or emit structured log if no callback set 

157 if config.on_retry is not None: 

158 config.on_retry(attempt, config.max_attempts, exc, delay) 

159 elif _HAS_STRUCTLOG and _structlog_logger is not None: # pragma: no branch 

160 _structlog_logger.warning( 

161 "retry_attempted", 

162 function=func_name, 

163 attempt=attempt, 

164 max_attempts=config.max_attempts, 

165 error=str(exc), 

166 delay_seconds=round(delay, 3), 

167 ) 

168 

169 

170def _log_all_failed( 

171 func_name: str, 

172 exc: Exception, 

173 config: RetryConfig, 

174) -> None: 

175 """Log when all retry attempts have been exhausted. 

176 

177 Args: 

178 func_name: Name of the retried function. 

179 exc: The last exception raised. 

180 config: Retry configuration. 

181 

182 """ 

183 if config.log_retries: 

184 logger.warning( 

185 "All %d attempts failed for %s: %s", 

186 config.max_attempts, 

187 func_name, 

188 str(exc), 

189 ) 

190 

191 

192def _raise_retry_error( 

193 func_name: str, 

194 max_attempts: int, 

195 reraise: bool, 

196 last_exception: Exception | None, 

197) -> NoReturn: 

198 """Raise a RetryError after all attempts fail. 

199 

200 Args: 

201 func_name: Name of the retried function. 

202 max_attempts: Number of attempts made. 

203 reraise: Whether to reraise the original exception. 

204 last_exception: The last exception that was raised. 

205 

206 Raises: 

207 RetryError: The wrapped or unwrapped exception. 

208 

209 """ 

210 if reraise and last_exception is not None: 

211 raise RetryError( 

212 f"All {max_attempts} attempts failed for {func_name}", 

213 attempts=max_attempts, 

214 last_exception=last_exception, 

215 ) from last_exception 

216 

217 raise RetryError( 

218 f"All {max_attempts} attempts failed for {func_name}", 

219 attempts=max_attempts, 

220 last_exception=last_exception, 

221 ) 

222 

223 

224def retry( 

225 *, 

226 max_attempts: int = 3, 

227 initial_delay: float = 1.0, 

228 max_delay: float = 60.0, 

229 exponential_base: float = 2.0, 

230 jitter: bool = True, 

231 on: tuple[type[Exception], ...] = (Exception,), 

232 reraise: bool = True, 

233 log_retries: bool = True, 

234 on_retry: Callable[[int, int, Exception, float], None] | None = None, 

235) -> RetryDecorator: 

236 """Retry a sync or async function with exponential backoff. 

237 

238 Automatically retries the decorated function when specified 

239 exceptions are raised, with configurable backoff strategy. 

240 Detects coroutine functions and preserves their async nature. 

241 

242 Args: 

243 max_attempts: Maximum number of retry attempts. 

244 initial_delay: Initial delay between retries in seconds. 

245 max_delay: Maximum delay between retries. 

246 exponential_base: Base for exponential backoff. 

247 jitter: Whether to add random jitter to delays. 

248 on: Exception types to retry on. 

249 reraise: Whether to reraise the last exception on failure. 

250 log_retries: Whether to log retry attempts. 

251 on_retry: Optional callback invoked on each retry with 

252 (attempt, max_attempts, exception, delay). Useful for 

253 custom monitoring or metrics collection. 

254 

255 Returns: 

256 Decorated function with retry logic. 

257 

258 Example: 

259 >>> @retry(max_attempts=3, on=(ConnectionError, TimeoutError)) 

260 ... def fetch_data(url: str) -> dict: 

261 ... return requests.get(url, timeout=10).json() 

262 

263 >>> @retry(max_attempts=3, on_retry=lambda a, m, e, d: print(f"Retry {a}/{m}")) 

264 ... def fragile_operation() -> str: 

265 ... return do_something() 

266 

267 """ 

268 config = RetryConfig( 

269 max_attempts=max_attempts, 

270 initial_delay=initial_delay, 

271 max_delay=max_delay, 

272 exponential_base=exponential_base, 

273 jitter=jitter, 

274 log_retries=log_retries, 

275 on_retry=on_retry, 

276 ) 

277 

278 def decorator( 

279 func: Callable[P, R] | Callable[P, Coroutine[Any, Any, R]], 

280 ) -> Callable[P, R] | Callable[P, Coroutine[Any, Any, R]]: 

281 if inspect.iscoroutinefunction(func): 

282 func_coro = cast(Callable[P, Coroutine[Any, Any, R]], func) 

283 

284 @functools.wraps(func_coro) 

285 async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R: 

286 last_exception: Exception | None = None 

287 

288 for attempt in range(1, max_attempts + 1): # pragma: no branch 

289 try: 

290 return await func_coro(*args, **kwargs) 

291 except on as e: 

292 last_exception = e 

293 

294 if attempt == max_attempts: 

295 _log_all_failed( 

296 func_coro.__name__, 

297 e, 

298 config, 

299 ) 

300 break 

301 

302 delay = calculate_delay(attempt, config) 

303 _log_retry_attempt( 

304 func_coro.__name__, 

305 attempt, 

306 e, 

307 delay, 

308 config, 

309 ) 

310 await asyncio.sleep(delay) 

311 

312 _raise_retry_error( 

313 func_coro.__name__, 

314 max_attempts, 

315 reraise, 

316 last_exception, 

317 ) 

318 

319 return async_wrapper 

320 

321 func_sync = cast(Callable[P, R], func) 

322 

323 @functools.wraps(func_sync) 

324 def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: 

325 last_exception: Exception | None = None 

326 

327 for attempt in range(1, max_attempts + 1): # pragma: no branch 

328 try: 

329 return func_sync(*args, **kwargs) 

330 except on as e: 

331 last_exception = e 

332 

333 if attempt == max_attempts: 

334 _log_all_failed( 

335 func_sync.__name__, 

336 e, 

337 config, 

338 ) 

339 break 

340 

341 # Calculate delay and wait 

342 delay = calculate_delay(attempt, config) 

343 _log_retry_attempt( 

344 func_sync.__name__, 

345 attempt, 

346 e, 

347 delay, 

348 config, 

349 ) 

350 time.sleep(delay) 

351 

352 _raise_retry_error( 

353 func_sync.__name__, 

354 max_attempts, 

355 reraise, 

356 last_exception, 

357 ) 

358 

359 return wrapper 

360 

361 return decorator # type: ignore[return-value] 

362 

363 

364def retry_on_exception( 

365 exception_types: tuple[type[Exception], ...], 

366 max_attempts: int = 3, 

367) -> RetryDecorator: 

368 """Retry on specific exceptions. 

369 

370 A simpler alternative to the full retry decorator when you 

371 just need basic retry functionality. 

372 

373 Args: 

374 exception_types: Exception types to retry on. 

375 max_attempts: Maximum number of attempts. 

376 

377 Returns: 

378 Decorated function with retry logic. 

379 

380 Example: 

381 >>> @retry_on_exception((ValueError,), max_attempts=2) 

382 ... def parse_data(data: str) -> dict: 

383 ... return json.loads(data) 

384 

385 """ 

386 return retry( 

387 max_attempts=max_attempts, 

388 on=exception_types, 

389 jitter=False, 

390 log_retries=False, 

391 ) 

392 

393 

394class Retrier: 

395 """Context manager for retry logic. 

396 

397 Provides a context manager interface for retry logic when 

398 decorators are not suitable. 

399 

400 Example: 

401 >>> retrier = Retrier(max_attempts=3, on=(ConnectionError,)) 

402 >>> with retrier: 

403 ... result = some_operation() 

404 

405 """ 

406 

407 def __init__( 

408 self, 

409 *, 

410 max_attempts: int = 3, 

411 initial_delay: float = 1.0, 

412 max_delay: float = 60.0, 

413 on: tuple[type[Exception], ...] = (Exception,), 

414 ) -> None: 

415 """Initialize Retrier. 

416 

417 Args: 

418 max_attempts: Maximum retry attempts. 

419 initial_delay: Initial delay between retries. 

420 max_delay: Maximum delay between retries. 

421 on: Exception types to retry on. 

422 

423 """ 

424 self.config = RetryConfig( 

425 max_attempts=max_attempts, 

426 initial_delay=initial_delay, 

427 max_delay=max_delay, 

428 ) 

429 self.exception_types = on 

430 self.attempt = 0 

431 self.last_exception: Exception | None = None 

432 

433 def __enter__(self) -> "Retrier": 

434 """Enter the retry context.""" 

435 self.attempt = 0 

436 self.last_exception = None 

437 return self 

438 

439 def __exit__( 

440 self, 

441 exc_type: type[BaseException] | None, 

442 exc_val: BaseException | None, 

443 _exc_tb: TracebackType | None, 

444 ) -> bool: 

445 """Exit the retry context. 

446 

447 Returns True to suppress the exception if we should retry, 

448 False to let it propagate. 

449 """ 

450 if exc_type is None: 

451 return False # No exception, exit normally 

452 

453 if not issubclass(exc_type, self.exception_types): 

454 return False # Exception type not in retry list 

455 

456 # Safe cast: issubclass guard above ensures exc_val is Exception 

457 self.last_exception = exc_val if isinstance(exc_val, Exception) else None 

458 self.attempt += 1 

459 

460 if self.attempt >= self.config.max_attempts: 

461 return False # Max attempts reached, propagate exception 

462 

463 # Calculate delay and wait 

464 delay = calculate_delay(self.attempt, self.config) 

465 time.sleep(delay) 

466 

467 return True # Suppress exception and retry