Coverage for src / taipanstack / security / decorators.py: 100%
115 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-23 14:54 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-23 14:54 +0000
1"""
2Security decorators for robust Python applications.
4Provides decorators for input validation, exception handling,
5timeout control, and other security patterns. Compatible with
6any Python framework (Flask, FastAPI, Django, etc.).
7"""
9import functools
10import inspect
11import signal
12import sys
13import threading
14from collections.abc import Callable
15from types import FrameType
16from typing import Any, ParamSpec, Protocol, TypeVar
18from taipanstack.security.guards import SecurityError
20P = ParamSpec("P")
21R = TypeVar("R")
22T = TypeVar("T")
23V_contra = TypeVar("V_contra", contravariant=True)
24V_co = TypeVar("V_co", covariant=True)
27class ValidatorFunc(Protocol[V_contra, V_co]):
28 """Protocol defining the signature of input validators."""
30 def __call__(self, value: V_contra, /) -> V_co:
31 """Validate an input value."""
32 ...
35class OperationTimeoutError(Exception):
36 """Raised when a function exceeds its timeout limit."""
38 def __init__(self, seconds: float, func_name: str = "function") -> None:
39 """Initialize OperationTimeoutError.
41 Args:
42 seconds: The timeout that was exceeded.
43 func_name: Name of the function that timed out.
45 """
46 self.seconds = seconds
47 self.func_name = func_name
48 super().__init__(f"{func_name} timed out after {seconds} seconds")
51class ValidationError(Exception):
52 """Raised when input validation fails."""
54 def __init__(
55 self,
56 message: str,
57 param_name: str | None = None,
58 value: object = None,
59 ) -> None:
60 """Initialize ValidationError.
62 Args:
63 message: Description of the validation failure.
64 param_name: Name of the parameter that failed.
65 value: The invalid value (sanitized).
67 """
68 self.param_name = param_name
69 self.value = value
70 super().__init__(message)
73def validate_inputs(
74 **validators: ValidatorFunc[Any, Any],
75) -> Callable[[Callable[P, R]], Callable[P, R]]:
76 """Decorator to validate function inputs.
78 Validates function arguments using provided validator functions.
79 Validators should raise ValueError or ValidationError on invalid input.
81 Args:
82 **validators: Mapping of parameter names to validator functions.
84 Returns:
85 Decorated function with input validation.
87 Example:
88 >>> from taipanstack.security.validators import validate_email, validate_port
89 >>> @validate_inputs(email=validate_email, port=validate_port)
90 ... def connect(email: str, port: int) -> None:
91 ... pass
92 >>> connect(email="invalid", port=8080)
93 ValidationError: Invalid email format: invalid
95 """
97 def decorator(func: Callable[P, R]) -> Callable[P, R]:
98 sig = inspect.signature(func)
100 @functools.wraps(func)
101 def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
102 # Bind and apply defaults on each call
103 bound = sig.bind(*args, **kwargs)
104 bound.apply_defaults()
106 # Validate each parameter that has a validator
107 for param_name, validator in validators.items(): # pragma: no branch
108 if param_name in bound.arguments: # pragma: no branch
109 value = bound.arguments[param_name]
110 try:
111 # Call validator - it should raise on invalid input
112 validated = validator(value)
113 # Update to validated value if returned
114 if validated is not None: # pragma: no branch
115 bound.arguments[param_name] = validated
116 except (ValueError, TypeError) as e:
117 raise ValidationError(
118 str(e),
119 param_name=param_name,
120 value=repr(value)[:100],
121 ) from e
123 # Call original function with validated arguments
124 return func(*bound.args, **bound.kwargs)
126 return wrapper
128 return decorator
131def guard_exceptions(
132 *,
133 catch: tuple[type[Exception], ...] = (Exception,),
134 reraise_as: type[Exception] | None = None,
135 default: T | None = None,
136 log_errors: bool = True,
137) -> Callable[[Callable[P, R]], Callable[P, R | T | None]]:
138 """Decorator to safely handle exceptions.
140 Catches exceptions and optionally re-raises as a different type
141 or returns a default value.
143 Args:
144 catch: Exception types to catch.
145 reraise_as: Exception type to re-raise as (None = don't reraise).
146 default: Default value to return if exception caught and not reraised.
147 log_errors: Whether to log caught exceptions.
149 Returns:
150 Decorated function with exception handling.
152 Example:
153 >>> @guard_exceptions(catch=(IOError,), reraise_as=SecurityError)
154 ... def read_file(path: str) -> str:
155 ... return open(path).read()
156 >>> read_file("/nonexistent")
157 SecurityError: [guard_exceptions] ...
159 """
161 def decorator(func: Callable[P, R]) -> Callable[P, R | T | None]:
162 @functools.wraps(func)
163 def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | T | None:
164 try:
165 return func(*args, **kwargs)
166 except catch as e:
167 if log_errors: # pragma: no branch
168 import logging
170 logging.getLogger("taipanstack.security").warning(
171 "Exception caught in %s: %s",
172 func.__name__,
173 str(e),
174 )
176 if reraise_as is not None:
177 if reraise_as == SecurityError:
178 raise SecurityError(
179 str(e),
180 guard_name="guard_exceptions",
181 ) from e
182 raise reraise_as(str(e)) from e
184 return default
186 return wrapper
188 return decorator
191def timeout(
192 seconds: float,
193 *,
194 use_signal: bool = True,
195) -> Callable[[Callable[P, R]], Callable[P, R]]:
196 """Decorator to limit function execution time.
198 Uses signal-based timeout on Unix or thread-based on Windows.
199 Signal-based is more reliable but only works in main thread.
201 Args:
202 seconds: Maximum execution time in seconds.
203 use_signal: Use signal-based timeout (Unix only, main thread only).
205 Returns:
206 Decorated function with timeout.
208 Example:
209 >>> @timeout(5.0)
210 ... def slow_operation() -> str:
211 ... import time
212 ... time.sleep(10)
213 ... return "done"
214 >>> slow_operation()
215 TimeoutError: slow_operation timed out after 5.0 seconds
217 """
219 def decorator(func: Callable[P, R]) -> Callable[P, R]:
220 @functools.wraps(func)
221 def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
222 # Determine if we can use signals
223 can_use_signal = (
224 use_signal
225 and sys.platform != "win32"
226 and threading.current_thread() is threading.main_thread()
227 )
229 if can_use_signal: # pragma: no cover
230 return _timeout_with_signal(
231 func,
232 seconds,
233 args,
234 dict(kwargs),
235 )
236 return _timeout_with_thread(
237 func,
238 seconds,
239 args,
240 dict(kwargs),
241 )
243 return wrapper
245 return decorator
248def _timeout_with_signal( # pragma: no cover
249 func: Callable[P, R],
250 seconds: float,
251 args: tuple[Any, ...],
252 kwargs: dict[str, Any],
253) -> R:
254 """Implement timeout using Unix signals."""
256 def handler(_signum: int, _frame: FrameType | None) -> None:
257 raise OperationTimeoutError(seconds, func.__name__)
259 # Set up signal handler
260 old_handler = signal.signal(signal.SIGALRM, handler)
261 signal.setitimer(signal.ITIMER_REAL, seconds)
263 try:
264 return func(*args, **kwargs)
265 finally:
266 # Restore old handler and cancel alarm
267 signal.setitimer(signal.ITIMER_REAL, 0)
268 signal.signal(signal.SIGALRM, old_handler)
271def _timeout_with_thread(
272 func: Callable[P, R],
273 seconds: float,
274 args: tuple[Any, ...],
275 kwargs: dict[str, Any],
276) -> R:
277 """Implement timeout using a separate thread."""
278 result: list[R] = []
279 exception: list[Exception] = []
281 def target() -> None:
282 try:
283 result.append(func(*args, **kwargs))
284 except Exception as e:
285 exception.append(e)
287 thread = threading.Thread(target=target)
288 thread.daemon = True
289 thread.start()
290 thread.join(timeout=seconds)
292 if thread.is_alive():
293 # Thread still running - timeout occurred
294 raise OperationTimeoutError(seconds, func.__name__)
296 if exception:
297 raise exception[0]
299 return result[0]
302def deprecated(
303 message: str = "",
304 *,
305 removal_version: str | None = None,
306) -> Callable[[Callable[P, R]], Callable[P, R]]:
307 """Mark a function as deprecated.
309 Emits a warning when the decorated function is called.
311 Args:
312 message: Additional deprecation message.
313 removal_version: Version when function will be removed.
315 Returns:
316 Decorated function that warns on use.
318 Example:
319 >>> @deprecated("Use new_function instead", removal_version="2.0")
320 ... def old_function() -> None:
321 ... pass
323 """
325 def decorator(func: Callable[P, R]) -> Callable[P, R]:
326 @functools.wraps(func)
327 def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
328 import warnings
330 msg = f"{func.__name__} is deprecated."
331 if removal_version:
332 msg += f" Will be removed in version {removal_version}."
333 if message:
334 msg += f" {message}"
336 warnings.warn(msg, DeprecationWarning, stacklevel=2)
337 return func(*args, **kwargs)
339 return wrapper
341 return decorator
344def require_type(
345 **type_hints: type,
346) -> Callable[[Callable[P, R]], Callable[P, R]]:
347 """Decorator to enforce runtime type checking.
349 Validates that arguments match specified types at runtime.
351 Args:
352 **type_hints: Mapping of parameter names to expected types.
354 Returns:
355 Decorated function with type checking.
357 Example:
358 >>> @require_type(name=str, count=int)
359 ... def greet(name: str, count: int) -> None:
360 ... print(f"Hello {name}" * count)
361 >>> greet(name=123, count=2)
362 TypeError: Parameter 'name' expected str, got int
364 """
366 def decorator(func: Callable[P, R]) -> Callable[P, R]:
367 sig = inspect.signature(func)
369 @functools.wraps(func)
370 def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
371 bound = sig.bind(*args, **kwargs)
372 bound.apply_defaults()
374 for param_name, expected_type in type_hints.items(): # pragma: no branch
375 if param_name in bound.arguments: # pragma: no branch
376 value = bound.arguments[param_name]
377 if not isinstance(value, expected_type):
378 raise TypeError(
379 f"Parameter '{param_name}' expected "
380 f"{expected_type.__name__}, got {type(value).__name__}"
381 )
383 return func(*bound.args, **bound.kwargs)
385 return wrapper
387 return decorator