Coverage for src / taipanstack / security / decorators.py: 100%
129 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-12 21:18 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-12 21:18 +0000
1"""
2Security decorators for robust Python applications.
4Provides decorators for input validation, exception handling,
5timeout control, and other security patterns. Compatible with
6any Python framework (Flask, FastAPI, Django, etc.).
7"""
9import functools
10import inspect
11import math
12import signal
13import sys
14import threading
15from collections.abc import Callable
16from types import FrameType
17from typing import ParamSpec, Protocol, TypeVar
19from taipanstack.security.guards import SecurityError
21P = ParamSpec("P")
22R = TypeVar("R")
23T = TypeVar("T")
24V_contra = TypeVar("V_contra", contravariant=True)
25V_co = TypeVar("V_co", covariant=True)
28class ValidatorFunc(Protocol[V_contra, V_co]):
29 """Protocol defining the signature of input validators."""
31 def __call__(self, value: V_contra, /) -> V_co:
32 """Validate an input value."""
33 ...
36class OperationTimeoutError(Exception):
37 """Raised when a function exceeds its timeout limit."""
39 def __init__(self, seconds: float, func_name: str = "function") -> None:
40 """Initialize OperationTimeoutError.
42 Args:
43 seconds: The timeout that was exceeded.
44 func_name: Name of the function that timed out.
46 """
47 self.seconds = seconds
48 self.func_name = func_name
49 super().__init__(f"{func_name} timed out after {seconds} seconds")
52class ValidationError(Exception):
53 """Raised when input validation fails."""
55 def __init__(
56 self,
57 message: str,
58 param_name: str | None = None,
59 value: object = None,
60 ) -> None:
61 """Initialize ValidationError.
63 Args:
64 message: Description of the validation failure.
65 param_name: Name of the parameter that failed.
66 value: The invalid value (sanitized).
68 """
69 self.param_name = param_name
70 self.value = value
71 super().__init__(message)
74def validate_inputs(
75 **validators: ValidatorFunc[object, object],
76) -> Callable[[Callable[P, R]], Callable[P, R]]:
77 """Decorator to validate function inputs.
79 Validates function arguments using provided validator functions.
80 Validators should raise ValueError or ValidationError on invalid input.
82 Args:
83 **validators: Mapping of parameter names to validator functions.
85 Returns:
86 Decorated function with input validation.
88 Example:
89 >>> from taipanstack.security.validators import validate_email, validate_port
90 >>> @validate_inputs(email=validate_email, port=validate_port)
91 ... def connect(email: str, port: int) -> None:
92 ... pass
93 >>> connect(email="invalid", port=8080)
94 ValidationError: Invalid email format: invalid
96 """
98 def decorator(func: Callable[P, R]) -> Callable[P, R]:
99 sig = inspect.signature(func)
101 @functools.wraps(func)
102 def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
103 # Bind and apply defaults on each call
104 bound = sig.bind(*args, **kwargs)
105 bound.apply_defaults()
107 # Validate each parameter that has a validator
108 for param_name, validator in validators.items(): # pragma: no branch
109 if param_name in bound.arguments: # pragma: no branch
110 value = bound.arguments[param_name]
111 try:
112 # Call validator - it should raise on invalid input
113 validated = validator(value)
114 # Update to validated value if returned
115 if validated is not None: # pragma: no branch
116 bound.arguments[param_name] = validated
117 except (ValueError, TypeError) as e:
118 raise ValidationError(
119 str(e),
120 param_name=param_name,
121 value=repr(value)[:100],
122 ) from e
124 # Call original function with validated arguments
125 return func(*bound.args, **bound.kwargs)
127 return wrapper
129 return decorator
132def guard_exceptions(
133 *,
134 catch: tuple[type[Exception], ...] = (Exception,),
135 reraise_as: type[Exception] | None = None,
136 default: T | None = None,
137 log_errors: bool = True,
138) -> Callable[[Callable[P, R]], Callable[P, R | T | None]]:
139 """Decorator to safely handle exceptions.
141 Catches exceptions and optionally re-raises as a different type
142 or returns a default value.
144 Args:
145 catch: Exception types to catch.
146 reraise_as: Exception type to re-raise as (None = don't reraise).
147 default: Default value to return if exception caught and not reraised.
148 log_errors: Whether to log caught exceptions.
150 Returns:
151 Decorated function with exception handling.
153 Example:
154 >>> @guard_exceptions(catch=(IOError,), reraise_as=SecurityError)
155 ... def read_file(path: str) -> str:
156 ... return open(path).read()
157 >>> read_file("/nonexistent")
158 SecurityError: [guard_exceptions] ...
160 """
162 def decorator(func: Callable[P, R]) -> Callable[P, R | T | None]:
163 @functools.wraps(func)
164 def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | T | None:
165 try:
166 return func(*args, **kwargs)
167 except catch as e:
168 if log_errors: # pragma: no branch
169 import logging
171 logging.getLogger("taipanstack.security").warning(
172 "Exception caught in %s: %s",
173 func.__name__,
174 str(e),
175 )
177 if reraise_as is not None:
178 if reraise_as == SecurityError:
179 raise SecurityError(
180 str(e),
181 guard_name="guard_exceptions",
182 ) from e
183 raise reraise_as(str(e)) from e
185 return default
187 return wrapper
189 return decorator
192def timeout(
193 seconds: float,
194 *,
195 use_signal: bool = True,
196) -> Callable[[Callable[P, R]], Callable[P, R]]:
197 """Decorator to limit function execution time.
199 Uses signal-based timeout on Unix or thread-based on Windows.
200 Signal-based is more reliable but only works in main thread.
202 Args:
203 seconds: Maximum execution time in seconds.
204 use_signal: Use signal-based timeout (Unix only, main thread only).
206 Returns:
207 Decorated function with timeout.
209 Example:
210 >>> @timeout(5.0)
211 ... def slow_operation() -> str:
212 ... import time
213 ... time.sleep(10)
214 ... return "done"
215 >>> slow_operation()
216 TimeoutError: slow_operation timed out after 5.0 seconds
218 """
219 # Security Enhancement: explicitly validate bounds using math.isfinite()
220 # and check for non-negative limits to prevent silent NaN propagation,
221 # unhandled ValueError exceptions from threading/asyncio primitives,
222 # or unexpected infinite blocking behaviors.
223 if not (math.isfinite(seconds) and seconds >= 0):
224 raise ValueError("timeout must be a finite non-negative number")
226 def decorator(func: Callable[P, R]) -> Callable[P, R]:
227 @functools.wraps(func)
228 def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
229 # Determine if we can use signals
230 can_use_signal = (
231 use_signal
232 and sys.platform != "win32"
233 and threading.current_thread() is threading.main_thread()
234 )
236 if can_use_signal:
237 return _timeout_with_signal(
238 func,
239 seconds,
240 args,
241 dict(kwargs),
242 )
243 return _timeout_with_thread(
244 func,
245 seconds,
246 args,
247 dict(kwargs),
248 )
250 return wrapper
252 return decorator
255def _timeout_with_signal(
256 func: Callable[..., R],
257 seconds: float,
258 args: tuple[object, ...],
259 kwargs: dict[str, object],
260) -> R:
261 """Implement timeout using Unix signals."""
263 def handler(_signum: int, _frame: FrameType | None) -> None:
264 raise OperationTimeoutError(seconds, func.__name__)
266 # Set up signal handler
267 old_handler = signal.signal(signal.SIGALRM, handler)
268 signal.setitimer(signal.ITIMER_REAL, seconds)
270 try:
271 return func(*args, **kwargs)
272 finally:
273 # Restore old handler and cancel alarm
274 signal.setitimer(signal.ITIMER_REAL, 0)
275 signal.signal(signal.SIGALRM, old_handler)
278def _timeout_with_thread(
279 func: Callable[..., R],
280 seconds: float,
281 args: tuple[object, ...],
282 kwargs: dict[str, object],
283) -> R:
284 """Implement timeout using a separate thread."""
285 result: list[R] = []
286 exception: list[BaseException] = []
288 def target() -> None:
289 try:
290 result.append(func(*args, **kwargs))
291 except BaseException as e:
292 exception.append(e)
294 thread = threading.Thread(target=target)
295 thread.daemon = True
296 thread.start()
297 thread.join(timeout=seconds)
299 if thread.is_alive():
300 # Thread still running - timeout occurred
301 raise OperationTimeoutError(seconds, func.__name__)
303 if exception:
304 raise exception[0]
306 return result[0]
309def deprecated(
310 message: str = "",
311 *,
312 removal_version: str | None = None,
313) -> Callable[[Callable[P, R]], Callable[P, R]]:
314 """Mark a function as deprecated.
316 Emits a warning when the decorated function is called.
318 Args:
319 message: Additional deprecation message.
320 removal_version: Version when function will be removed.
322 Returns:
323 Decorated function that warns on use.
325 Example:
326 >>> @deprecated("Use new_function instead", removal_version="2.0")
327 ... def old_function() -> None:
328 ... pass
330 """
332 def decorator(func: Callable[P, R]) -> Callable[P, R]:
333 @functools.wraps(func)
334 def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
335 import warnings
337 msg = f"{func.__name__} is deprecated."
338 if removal_version:
339 msg += f" Will be removed in version {removal_version}."
340 if message:
341 msg += f" {message}"
343 warnings.warn(msg, DeprecationWarning, stacklevel=2)
344 return func(*args, **kwargs)
346 return wrapper
348 return decorator
351def require_type(
352 **type_hints: type,
353) -> Callable[[Callable[P, R]], Callable[P, R]]:
354 """Decorator to enforce runtime type checking.
356 Validates that arguments match specified types at runtime.
358 Args:
359 **type_hints: Mapping of parameter names to expected types.
361 Returns:
362 Decorated function with type checking.
364 Example:
365 >>> @require_type(name=str, count=int)
366 ... def greet(name: str, count: int) -> None:
367 ... print(f"Hello {name}" * count)
368 >>> greet(name=123, count=2)
369 TypeError: Parameter 'name' expected str, got int
371 """
373 def decorator(func: Callable[P, R]) -> Callable[P, R]:
374 sig = inspect.signature(func)
376 @functools.wraps(func)
377 def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
378 bound = sig.bind(*args, **kwargs)
379 bound.apply_defaults()
381 for param_name, expected_type in type_hints.items(): # pragma: no branch
382 if param_name in bound.arguments: # pragma: no branch
383 value = bound.arguments[param_name]
384 if not isinstance(value, expected_type):
385 raise TypeError(
386 f"Parameter '{param_name}' expected "
387 f"{expected_type.__name__}, got {type(value).__name__}"
388 )
390 return func(*bound.args, **bound.kwargs)
392 return wrapper
394 return decorator