Coverage for src / taipanstack / security / decorators.py: 100%

115 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-23 14:54 +0000

1""" 

2Security decorators for robust Python applications. 

3 

4Provides decorators for input validation, exception handling, 

5timeout control, and other security patterns. Compatible with 

6any Python framework (Flask, FastAPI, Django, etc.). 

7""" 

8 

9import functools 

10import inspect 

11import signal 

12import sys 

13import threading 

14from collections.abc import Callable 

15from types import FrameType 

16from typing import Any, ParamSpec, Protocol, TypeVar 

17 

18from taipanstack.security.guards import SecurityError 

19 

20P = ParamSpec("P") 

21R = TypeVar("R") 

22T = TypeVar("T") 

23V_contra = TypeVar("V_contra", contravariant=True) 

24V_co = TypeVar("V_co", covariant=True) 

25 

26 

27class ValidatorFunc(Protocol[V_contra, V_co]): 

28 """Protocol defining the signature of input validators.""" 

29 

30 def __call__(self, value: V_contra, /) -> V_co: 

31 """Validate an input value.""" 

32 ... 

33 

34 

35class OperationTimeoutError(Exception): 

36 """Raised when a function exceeds its timeout limit.""" 

37 

38 def __init__(self, seconds: float, func_name: str = "function") -> None: 

39 """Initialize OperationTimeoutError. 

40 

41 Args: 

42 seconds: The timeout that was exceeded. 

43 func_name: Name of the function that timed out. 

44 

45 """ 

46 self.seconds = seconds 

47 self.func_name = func_name 

48 super().__init__(f"{func_name} timed out after {seconds} seconds") 

49 

50 

51class ValidationError(Exception): 

52 """Raised when input validation fails.""" 

53 

54 def __init__( 

55 self, 

56 message: str, 

57 param_name: str | None = None, 

58 value: object = None, 

59 ) -> None: 

60 """Initialize ValidationError. 

61 

62 Args: 

63 message: Description of the validation failure. 

64 param_name: Name of the parameter that failed. 

65 value: The invalid value (sanitized). 

66 

67 """ 

68 self.param_name = param_name 

69 self.value = value 

70 super().__init__(message) 

71 

72 

73def validate_inputs( 

74 **validators: ValidatorFunc[Any, Any], 

75) -> Callable[[Callable[P, R]], Callable[P, R]]: 

76 """Decorator to validate function inputs. 

77 

78 Validates function arguments using provided validator functions. 

79 Validators should raise ValueError or ValidationError on invalid input. 

80 

81 Args: 

82 **validators: Mapping of parameter names to validator functions. 

83 

84 Returns: 

85 Decorated function with input validation. 

86 

87 Example: 

88 >>> from taipanstack.security.validators import validate_email, validate_port 

89 >>> @validate_inputs(email=validate_email, port=validate_port) 

90 ... def connect(email: str, port: int) -> None: 

91 ... pass 

92 >>> connect(email="invalid", port=8080) 

93 ValidationError: Invalid email format: invalid 

94 

95 """ 

96 

97 def decorator(func: Callable[P, R]) -> Callable[P, R]: 

98 sig = inspect.signature(func) 

99 

100 @functools.wraps(func) 

101 def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: 

102 # Bind and apply defaults on each call 

103 bound = sig.bind(*args, **kwargs) 

104 bound.apply_defaults() 

105 

106 # Validate each parameter that has a validator 

107 for param_name, validator in validators.items(): # pragma: no branch 

108 if param_name in bound.arguments: # pragma: no branch 

109 value = bound.arguments[param_name] 

110 try: 

111 # Call validator - it should raise on invalid input 

112 validated = validator(value) 

113 # Update to validated value if returned 

114 if validated is not None: # pragma: no branch 

115 bound.arguments[param_name] = validated 

116 except (ValueError, TypeError) as e: 

117 raise ValidationError( 

118 str(e), 

119 param_name=param_name, 

120 value=repr(value)[:100], 

121 ) from e 

122 

123 # Call original function with validated arguments 

124 return func(*bound.args, **bound.kwargs) 

125 

126 return wrapper 

127 

128 return decorator 

129 

130 

131def guard_exceptions( 

132 *, 

133 catch: tuple[type[Exception], ...] = (Exception,), 

134 reraise_as: type[Exception] | None = None, 

135 default: T | None = None, 

136 log_errors: bool = True, 

137) -> Callable[[Callable[P, R]], Callable[P, R | T | None]]: 

138 """Decorator to safely handle exceptions. 

139 

140 Catches exceptions and optionally re-raises as a different type 

141 or returns a default value. 

142 

143 Args: 

144 catch: Exception types to catch. 

145 reraise_as: Exception type to re-raise as (None = don't reraise). 

146 default: Default value to return if exception caught and not reraised. 

147 log_errors: Whether to log caught exceptions. 

148 

149 Returns: 

150 Decorated function with exception handling. 

151 

152 Example: 

153 >>> @guard_exceptions(catch=(IOError,), reraise_as=SecurityError) 

154 ... def read_file(path: str) -> str: 

155 ... return open(path).read() 

156 >>> read_file("/nonexistent") 

157 SecurityError: [guard_exceptions] ... 

158 

159 """ 

160 

161 def decorator(func: Callable[P, R]) -> Callable[P, R | T | None]: 

162 @functools.wraps(func) 

163 def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | T | None: 

164 try: 

165 return func(*args, **kwargs) 

166 except catch as e: 

167 if log_errors: # pragma: no branch 

168 import logging 

169 

170 logging.getLogger("taipanstack.security").warning( 

171 "Exception caught in %s: %s", 

172 func.__name__, 

173 str(e), 

174 ) 

175 

176 if reraise_as is not None: 

177 if reraise_as == SecurityError: 

178 raise SecurityError( 

179 str(e), 

180 guard_name="guard_exceptions", 

181 ) from e 

182 raise reraise_as(str(e)) from e 

183 

184 return default 

185 

186 return wrapper 

187 

188 return decorator 

189 

190 

191def timeout( 

192 seconds: float, 

193 *, 

194 use_signal: bool = True, 

195) -> Callable[[Callable[P, R]], Callable[P, R]]: 

196 """Decorator to limit function execution time. 

197 

198 Uses signal-based timeout on Unix or thread-based on Windows. 

199 Signal-based is more reliable but only works in main thread. 

200 

201 Args: 

202 seconds: Maximum execution time in seconds. 

203 use_signal: Use signal-based timeout (Unix only, main thread only). 

204 

205 Returns: 

206 Decorated function with timeout. 

207 

208 Example: 

209 >>> @timeout(5.0) 

210 ... def slow_operation() -> str: 

211 ... import time 

212 ... time.sleep(10) 

213 ... return "done" 

214 >>> slow_operation() 

215 TimeoutError: slow_operation timed out after 5.0 seconds 

216 

217 """ 

218 

219 def decorator(func: Callable[P, R]) -> Callable[P, R]: 

220 @functools.wraps(func) 

221 def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: 

222 # Determine if we can use signals 

223 can_use_signal = ( 

224 use_signal 

225 and sys.platform != "win32" 

226 and threading.current_thread() is threading.main_thread() 

227 ) 

228 

229 if can_use_signal: # pragma: no cover 

230 return _timeout_with_signal( 

231 func, 

232 seconds, 

233 args, 

234 dict(kwargs), 

235 ) 

236 return _timeout_with_thread( 

237 func, 

238 seconds, 

239 args, 

240 dict(kwargs), 

241 ) 

242 

243 return wrapper 

244 

245 return decorator 

246 

247 

248def _timeout_with_signal( # pragma: no cover 

249 func: Callable[P, R], 

250 seconds: float, 

251 args: tuple[Any, ...], 

252 kwargs: dict[str, Any], 

253) -> R: 

254 """Implement timeout using Unix signals.""" 

255 

256 def handler(_signum: int, _frame: FrameType | None) -> None: 

257 raise OperationTimeoutError(seconds, func.__name__) 

258 

259 # Set up signal handler 

260 old_handler = signal.signal(signal.SIGALRM, handler) 

261 signal.setitimer(signal.ITIMER_REAL, seconds) 

262 

263 try: 

264 return func(*args, **kwargs) 

265 finally: 

266 # Restore old handler and cancel alarm 

267 signal.setitimer(signal.ITIMER_REAL, 0) 

268 signal.signal(signal.SIGALRM, old_handler) 

269 

270 

271def _timeout_with_thread( 

272 func: Callable[P, R], 

273 seconds: float, 

274 args: tuple[Any, ...], 

275 kwargs: dict[str, Any], 

276) -> R: 

277 """Implement timeout using a separate thread.""" 

278 result: list[R] = [] 

279 exception: list[Exception] = [] 

280 

281 def target() -> None: 

282 try: 

283 result.append(func(*args, **kwargs)) 

284 except Exception as e: 

285 exception.append(e) 

286 

287 thread = threading.Thread(target=target) 

288 thread.daemon = True 

289 thread.start() 

290 thread.join(timeout=seconds) 

291 

292 if thread.is_alive(): 

293 # Thread still running - timeout occurred 

294 raise OperationTimeoutError(seconds, func.__name__) 

295 

296 if exception: 

297 raise exception[0] 

298 

299 return result[0] 

300 

301 

302def deprecated( 

303 message: str = "", 

304 *, 

305 removal_version: str | None = None, 

306) -> Callable[[Callable[P, R]], Callable[P, R]]: 

307 """Mark a function as deprecated. 

308 

309 Emits a warning when the decorated function is called. 

310 

311 Args: 

312 message: Additional deprecation message. 

313 removal_version: Version when function will be removed. 

314 

315 Returns: 

316 Decorated function that warns on use. 

317 

318 Example: 

319 >>> @deprecated("Use new_function instead", removal_version="2.0") 

320 ... def old_function() -> None: 

321 ... pass 

322 

323 """ 

324 

325 def decorator(func: Callable[P, R]) -> Callable[P, R]: 

326 @functools.wraps(func) 

327 def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: 

328 import warnings 

329 

330 msg = f"{func.__name__} is deprecated." 

331 if removal_version: 

332 msg += f" Will be removed in version {removal_version}." 

333 if message: 

334 msg += f" {message}" 

335 

336 warnings.warn(msg, DeprecationWarning, stacklevel=2) 

337 return func(*args, **kwargs) 

338 

339 return wrapper 

340 

341 return decorator 

342 

343 

344def require_type( 

345 **type_hints: type, 

346) -> Callable[[Callable[P, R]], Callable[P, R]]: 

347 """Decorator to enforce runtime type checking. 

348 

349 Validates that arguments match specified types at runtime. 

350 

351 Args: 

352 **type_hints: Mapping of parameter names to expected types. 

353 

354 Returns: 

355 Decorated function with type checking. 

356 

357 Example: 

358 >>> @require_type(name=str, count=int) 

359 ... def greet(name: str, count: int) -> None: 

360 ... print(f"Hello {name}" * count) 

361 >>> greet(name=123, count=2) 

362 TypeError: Parameter 'name' expected str, got int 

363 

364 """ 

365 

366 def decorator(func: Callable[P, R]) -> Callable[P, R]: 

367 sig = inspect.signature(func) 

368 

369 @functools.wraps(func) 

370 def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: 

371 bound = sig.bind(*args, **kwargs) 

372 bound.apply_defaults() 

373 

374 for param_name, expected_type in type_hints.items(): # pragma: no branch 

375 if param_name in bound.arguments: # pragma: no branch 

376 value = bound.arguments[param_name] 

377 if not isinstance(value, expected_type): 

378 raise TypeError( 

379 f"Parameter '{param_name}' expected " 

380 f"{expected_type.__name__}, got {type(value).__name__}" 

381 ) 

382 

383 return func(*bound.args, **bound.kwargs) 

384 

385 return wrapper 

386 

387 return decorator