Coverage for src / taipanstack / security / decorators.py: 100%

129 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-12 21:18 +0000

1""" 

2Security decorators for robust Python applications. 

3 

4Provides decorators for input validation, exception handling, 

5timeout control, and other security patterns. Compatible with 

6any Python framework (Flask, FastAPI, Django, etc.). 

7""" 

8 

9import functools 

10import inspect 

11import math 

12import signal 

13import sys 

14import threading 

15from collections.abc import Callable 

16from types import FrameType 

17from typing import ParamSpec, Protocol, TypeVar 

18 

19from taipanstack.security.guards import SecurityError 

20 

21P = ParamSpec("P") 

22R = TypeVar("R") 

23T = TypeVar("T") 

24V_contra = TypeVar("V_contra", contravariant=True) 

25V_co = TypeVar("V_co", covariant=True) 

26 

27 

28class ValidatorFunc(Protocol[V_contra, V_co]): 

29 """Protocol defining the signature of input validators.""" 

30 

31 def __call__(self, value: V_contra, /) -> V_co: 

32 """Validate an input value.""" 

33 ... 

34 

35 

36class OperationTimeoutError(Exception): 

37 """Raised when a function exceeds its timeout limit.""" 

38 

39 def __init__(self, seconds: float, func_name: str = "function") -> None: 

40 """Initialize OperationTimeoutError. 

41 

42 Args: 

43 seconds: The timeout that was exceeded. 

44 func_name: Name of the function that timed out. 

45 

46 """ 

47 self.seconds = seconds 

48 self.func_name = func_name 

49 super().__init__(f"{func_name} timed out after {seconds} seconds") 

50 

51 

52class ValidationError(Exception): 

53 """Raised when input validation fails.""" 

54 

55 def __init__( 

56 self, 

57 message: str, 

58 param_name: str | None = None, 

59 value: object = None, 

60 ) -> None: 

61 """Initialize ValidationError. 

62 

63 Args: 

64 message: Description of the validation failure. 

65 param_name: Name of the parameter that failed. 

66 value: The invalid value (sanitized). 

67 

68 """ 

69 self.param_name = param_name 

70 self.value = value 

71 super().__init__(message) 

72 

73 

74def validate_inputs( 

75 **validators: ValidatorFunc[object, object], 

76) -> Callable[[Callable[P, R]], Callable[P, R]]: 

77 """Decorator to validate function inputs. 

78 

79 Validates function arguments using provided validator functions. 

80 Validators should raise ValueError or ValidationError on invalid input. 

81 

82 Args: 

83 **validators: Mapping of parameter names to validator functions. 

84 

85 Returns: 

86 Decorated function with input validation. 

87 

88 Example: 

89 >>> from taipanstack.security.validators import validate_email, validate_port 

90 >>> @validate_inputs(email=validate_email, port=validate_port) 

91 ... def connect(email: str, port: int) -> None: 

92 ... pass 

93 >>> connect(email="invalid", port=8080) 

94 ValidationError: Invalid email format: invalid 

95 

96 """ 

97 

98 def decorator(func: Callable[P, R]) -> Callable[P, R]: 

99 sig = inspect.signature(func) 

100 

101 @functools.wraps(func) 

102 def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: 

103 # Bind and apply defaults on each call 

104 bound = sig.bind(*args, **kwargs) 

105 bound.apply_defaults() 

106 

107 # Validate each parameter that has a validator 

108 for param_name, validator in validators.items(): # pragma: no branch 

109 if param_name in bound.arguments: # pragma: no branch 

110 value = bound.arguments[param_name] 

111 try: 

112 # Call validator - it should raise on invalid input 

113 validated = validator(value) 

114 # Update to validated value if returned 

115 if validated is not None: # pragma: no branch 

116 bound.arguments[param_name] = validated 

117 except (ValueError, TypeError) as e: 

118 raise ValidationError( 

119 str(e), 

120 param_name=param_name, 

121 value=repr(value)[:100], 

122 ) from e 

123 

124 # Call original function with validated arguments 

125 return func(*bound.args, **bound.kwargs) 

126 

127 return wrapper 

128 

129 return decorator 

130 

131 

132def guard_exceptions( 

133 *, 

134 catch: tuple[type[Exception], ...] = (Exception,), 

135 reraise_as: type[Exception] | None = None, 

136 default: T | None = None, 

137 log_errors: bool = True, 

138) -> Callable[[Callable[P, R]], Callable[P, R | T | None]]: 

139 """Decorator to safely handle exceptions. 

140 

141 Catches exceptions and optionally re-raises as a different type 

142 or returns a default value. 

143 

144 Args: 

145 catch: Exception types to catch. 

146 reraise_as: Exception type to re-raise as (None = don't reraise). 

147 default: Default value to return if exception caught and not reraised. 

148 log_errors: Whether to log caught exceptions. 

149 

150 Returns: 

151 Decorated function with exception handling. 

152 

153 Example: 

154 >>> @guard_exceptions(catch=(IOError,), reraise_as=SecurityError) 

155 ... def read_file(path: str) -> str: 

156 ... return open(path).read() 

157 >>> read_file("/nonexistent") 

158 SecurityError: [guard_exceptions] ... 

159 

160 """ 

161 

162 def decorator(func: Callable[P, R]) -> Callable[P, R | T | None]: 

163 @functools.wraps(func) 

164 def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | T | None: 

165 try: 

166 return func(*args, **kwargs) 

167 except catch as e: 

168 if log_errors: # pragma: no branch 

169 import logging 

170 

171 logging.getLogger("taipanstack.security").warning( 

172 "Exception caught in %s: %s", 

173 func.__name__, 

174 str(e), 

175 ) 

176 

177 if reraise_as is not None: 

178 if reraise_as == SecurityError: 

179 raise SecurityError( 

180 str(e), 

181 guard_name="guard_exceptions", 

182 ) from e 

183 raise reraise_as(str(e)) from e 

184 

185 return default 

186 

187 return wrapper 

188 

189 return decorator 

190 

191 

192def timeout( 

193 seconds: float, 

194 *, 

195 use_signal: bool = True, 

196) -> Callable[[Callable[P, R]], Callable[P, R]]: 

197 """Decorator to limit function execution time. 

198 

199 Uses signal-based timeout on Unix or thread-based on Windows. 

200 Signal-based is more reliable but only works in main thread. 

201 

202 Args: 

203 seconds: Maximum execution time in seconds. 

204 use_signal: Use signal-based timeout (Unix only, main thread only). 

205 

206 Returns: 

207 Decorated function with timeout. 

208 

209 Example: 

210 >>> @timeout(5.0) 

211 ... def slow_operation() -> str: 

212 ... import time 

213 ... time.sleep(10) 

214 ... return "done" 

215 >>> slow_operation() 

216 TimeoutError: slow_operation timed out after 5.0 seconds 

217 

218 """ 

219 # Security Enhancement: explicitly validate bounds using math.isfinite() 

220 # and check for non-negative limits to prevent silent NaN propagation, 

221 # unhandled ValueError exceptions from threading/asyncio primitives, 

222 # or unexpected infinite blocking behaviors. 

223 if not (math.isfinite(seconds) and seconds >= 0): 

224 raise ValueError("timeout must be a finite non-negative number") 

225 

226 def decorator(func: Callable[P, R]) -> Callable[P, R]: 

227 @functools.wraps(func) 

228 def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: 

229 # Determine if we can use signals 

230 can_use_signal = ( 

231 use_signal 

232 and sys.platform != "win32" 

233 and threading.current_thread() is threading.main_thread() 

234 ) 

235 

236 if can_use_signal: 

237 return _timeout_with_signal( 

238 func, 

239 seconds, 

240 args, 

241 dict(kwargs), 

242 ) 

243 return _timeout_with_thread( 

244 func, 

245 seconds, 

246 args, 

247 dict(kwargs), 

248 ) 

249 

250 return wrapper 

251 

252 return decorator 

253 

254 

255def _timeout_with_signal( 

256 func: Callable[..., R], 

257 seconds: float, 

258 args: tuple[object, ...], 

259 kwargs: dict[str, object], 

260) -> R: 

261 """Implement timeout using Unix signals.""" 

262 

263 def handler(_signum: int, _frame: FrameType | None) -> None: 

264 raise OperationTimeoutError(seconds, func.__name__) 

265 

266 # Set up signal handler 

267 old_handler = signal.signal(signal.SIGALRM, handler) 

268 signal.setitimer(signal.ITIMER_REAL, seconds) 

269 

270 try: 

271 return func(*args, **kwargs) 

272 finally: 

273 # Restore old handler and cancel alarm 

274 signal.setitimer(signal.ITIMER_REAL, 0) 

275 signal.signal(signal.SIGALRM, old_handler) 

276 

277 

278def _timeout_with_thread( 

279 func: Callable[..., R], 

280 seconds: float, 

281 args: tuple[object, ...], 

282 kwargs: dict[str, object], 

283) -> R: 

284 """Implement timeout using a separate thread.""" 

285 result: list[R] = [] 

286 exception: list[BaseException] = [] 

287 

288 def target() -> None: 

289 try: 

290 result.append(func(*args, **kwargs)) 

291 except BaseException as e: 

292 exception.append(e) 

293 

294 thread = threading.Thread(target=target) 

295 thread.daemon = True 

296 thread.start() 

297 thread.join(timeout=seconds) 

298 

299 if thread.is_alive(): 

300 # Thread still running - timeout occurred 

301 raise OperationTimeoutError(seconds, func.__name__) 

302 

303 if exception: 

304 raise exception[0] 

305 

306 return result[0] 

307 

308 

309def deprecated( 

310 message: str = "", 

311 *, 

312 removal_version: str | None = None, 

313) -> Callable[[Callable[P, R]], Callable[P, R]]: 

314 """Mark a function as deprecated. 

315 

316 Emits a warning when the decorated function is called. 

317 

318 Args: 

319 message: Additional deprecation message. 

320 removal_version: Version when function will be removed. 

321 

322 Returns: 

323 Decorated function that warns on use. 

324 

325 Example: 

326 >>> @deprecated("Use new_function instead", removal_version="2.0") 

327 ... def old_function() -> None: 

328 ... pass 

329 

330 """ 

331 

332 def decorator(func: Callable[P, R]) -> Callable[P, R]: 

333 @functools.wraps(func) 

334 def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: 

335 import warnings 

336 

337 msg = f"{func.__name__} is deprecated." 

338 if removal_version: 

339 msg += f" Will be removed in version {removal_version}." 

340 if message: 

341 msg += f" {message}" 

342 

343 warnings.warn(msg, DeprecationWarning, stacklevel=2) 

344 return func(*args, **kwargs) 

345 

346 return wrapper 

347 

348 return decorator 

349 

350 

351def require_type( 

352 **type_hints: type, 

353) -> Callable[[Callable[P, R]], Callable[P, R]]: 

354 """Decorator to enforce runtime type checking. 

355 

356 Validates that arguments match specified types at runtime. 

357 

358 Args: 

359 **type_hints: Mapping of parameter names to expected types. 

360 

361 Returns: 

362 Decorated function with type checking. 

363 

364 Example: 

365 >>> @require_type(name=str, count=int) 

366 ... def greet(name: str, count: int) -> None: 

367 ... print(f"Hello {name}" * count) 

368 >>> greet(name=123, count=2) 

369 TypeError: Parameter 'name' expected str, got int 

370 

371 """ 

372 

373 def decorator(func: Callable[P, R]) -> Callable[P, R]: 

374 sig = inspect.signature(func) 

375 

376 @functools.wraps(func) 

377 def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: 

378 bound = sig.bind(*args, **kwargs) 

379 bound.apply_defaults() 

380 

381 for param_name, expected_type in type_hints.items(): # pragma: no branch 

382 if param_name in bound.arguments: # pragma: no branch 

383 value = bound.arguments[param_name] 

384 if not isinstance(value, expected_type): 

385 raise TypeError( 

386 f"Parameter '{param_name}' expected " 

387 f"{expected_type.__name__}, got {type(value).__name__}" 

388 ) 

389 

390 return func(*bound.args, **bound.kwargs) 

391 

392 return wrapper 

393 

394 return decorator