Coverage for src / taipanstack / security / guards.py: 100%
157 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-23 14:54 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-23 14:54 +0000
1"""
2Runtime guards for protection against errors and AI hallucinations.
4These guards provide runtime protection against common security issues
5and programming errors that can occur from incorrect AI-generated code.
6All guards raise SecurityError on violation.
7"""
9import ipaddress
10import os
11import re
12import socket
13from collections.abc import Sequence
14from pathlib import Path
15from urllib.parse import urlparse
17from result import Err, Ok, Result
19# Build regex for traversal patterns.
20# Note: we handle ~ specially to only match at start of path or after a separator
21# to avoid false positives with Windows short paths (e.g., RUNNER~1).
22TRAVERSAL_REGEX = re.compile(
23 r"(?:\.\.|%2e%2e|%252e%252e)|(?:^|[\\/])~",
24 re.IGNORECASE,
25)
27_DANGEROUS_COMMAND_PATTERNS: tuple[tuple[str, str], ...] = (
28 (";", "command separator"),
29 ("|", "pipe"),
30 ("&", "background/and operator"),
31 ("$", "variable expansion"),
32 ("`", "command substitution"),
33 ("$(", "command substitution"),
34 ("${", "variable expansion"),
35 (">", "redirect"),
36 ("<", "redirect"),
37 (">>", "redirect append"),
38 ("||", "or operator"),
39 ("&&", "and operator"),
40 ("\n", "newline"),
41 ("\r", "carriage return"),
42 ("\x00", "null byte"),
43)
45# Pre-compiled regex and lookup map for fast-path command injection detection
46_DANGEROUS_COMMAND_RE = re.compile(
47 "|".join(re.escape(p) for p, _ in _DANGEROUS_COMMAND_PATTERNS)
48)
49_DANGEROUS_COMMAND_LOOKUP = dict(_DANGEROUS_COMMAND_PATTERNS)
51_DEFAULT_DENIED_EXTENSIONS = frozenset(
52 [
53 "exe",
54 "dll",
55 "so",
56 "dylib", # Executables
57 "sh",
58 "bash",
59 "zsh",
60 "ps1",
61 "bat",
62 "cmd", # Scripts
63 "php",
64 "jsp",
65 "asp",
66 "aspx", # Server-side scripts
67 ]
68)
70_DEFAULT_DENIED_ENV_VARS = frozenset(
71 [
72 "AWS_SECRET_ACCESS_KEY",
73 "AWS_SESSION_TOKEN",
74 "GITHUB_TOKEN",
75 "GH_TOKEN",
76 "GITLAB_TOKEN",
77 "DATABASE_URL",
78 "DB_PASSWORD",
79 "PASSWORD",
80 "SECRET_KEY",
81 "PRIVATE_KEY",
82 "API_KEY",
83 "API_SECRET",
84 ]
85)
87_SENSITIVE_ENV_VAR_PATTERN = re.compile(r"SECRET|PASSWORD|TOKEN|PRIVATE.*KEY|API.*KEY")
89_SAFE_HASH_ALGORITHMS = frozenset(
90 [
91 "sha256",
92 "sha384",
93 "sha512",
94 "sha3_256",
95 "sha3_384",
96 "sha3_512",
97 "blake2b",
98 "blake2s",
99 ]
100)
103class SecurityError(Exception):
104 """Raised when a security guard detects a violation.
106 Attributes:
107 guard_name: Name of the guard that was triggered.
108 message: Description of the violation.
109 value: The offending value (if safe to log).
111 """
113 def __init__(
114 self,
115 message: str,
116 guard_name: str = "unknown",
117 value: str | None = None,
118 ) -> None:
119 """Initialize SecurityError.
121 Args:
122 message: Description of the violation.
123 guard_name: Name of the guard that triggered.
124 value: The offending value (sanitized).
126 """
127 self.guard_name = guard_name
128 self.value = value
129 super().__init__(f"[{guard_name}] {message}")
132def _check_traversal_patterns(path_str: str) -> None:
133 """Check for explicit traversal patterns before resolution."""
134 match = TRAVERSAL_REGEX.search(path_str.lower())
135 if match:
136 raise SecurityError(
137 f"Path traversal pattern detected: {match.group(0)}",
138 guard_name="path_traversal",
139 value=path_str[:50], # Truncate for safety
140 )
143def _resolve_and_check_bounds(path: Path, base_dir: Path) -> tuple[Path, Path]:
144 """Resolve the path and check if it is within base_dir."""
145 try:
146 full_path = path if path.is_absolute() else (base_dir / path)
147 resolved = full_path.resolve()
148 except (OSError, ValueError) as e:
149 raise SecurityError(
150 f"Invalid path: {e}",
151 guard_name="path_traversal",
152 ) from e
154 if not resolved.is_relative_to(base_dir):
155 raise SecurityError(
156 "Path escapes base directory",
157 guard_name="path_traversal",
158 )
159 return full_path, resolved
162def _check_symlink_safety(full_path: Path, base_dir: Path) -> None:
163 """Check for symlinks recursively up to the base directory."""
164 current = full_path
165 # Only check components from the user-provided path, not the base_dir
166 while current not in (base_dir, current.parent):
167 # We don't check .exists() because it returns False for broken symlinks
168 if current.is_symlink():
169 raise SecurityError(
170 "Symlinks are not allowed",
171 guard_name="path_traversal",
172 value=str(current),
173 )
174 current = current.parent
177def guard_path_traversal(
178 path: Path | str,
179 base_dir: Path | str | None = None,
180 *,
181 allow_symlinks: bool = False,
182) -> Path:
183 """Prevent path traversal attacks.
185 Ensures that the given path does not escape the base directory
186 using techniques like '..' or symlinks.
188 Args:
189 path: The path to validate.
190 base_dir: The base directory to constrain to. Defaults to cwd.
191 allow_symlinks: Whether to allow symlinks (default: False).
193 Returns:
194 The resolved, validated path.
196 Raises:
197 SecurityError: If path traversal is detected.
199 Example:
200 >>> guard_path_traversal("../etc/passwd", Path("/app"))
201 SecurityError: [path_traversal] Path escapes base directory
203 """
204 if not isinstance(path, (str, Path)):
205 raise TypeError(f"path must be str or Path, got {type(path).__name__}")
207 path_obj = Path(path) if isinstance(path, str) else path
208 base = Path(base_dir).resolve() if base_dir else Path.cwd().resolve()
210 _check_traversal_patterns(str(path_obj))
211 full_path, resolved = _resolve_and_check_bounds(path_obj, base)
213 if not allow_symlinks:
214 _check_symlink_safety(full_path, base)
216 return resolved
219def guard_command_injection(
220 command: Sequence[str],
221 *,
222 allowed_commands: Sequence[str] | None = None,
223) -> list[str]:
224 """Prevent command injection attacks.
226 Validates that command arguments don't contain shell metacharacters
227 that could lead to command injection.
229 Args:
230 command: The command and arguments as a sequence.
231 allowed_commands: Optional whitelist of allowed base commands.
233 Returns:
234 The validated command as a list.
236 Raises:
237 SecurityError: If command injection is detected.
239 Example:
240 >>> guard_command_injection(["echo", "hello; rm -rf /"])
241 SecurityError: [command_injection] Dangerous characters detected
243 """
244 if not command:
245 raise SecurityError(
246 "Empty command is not allowed",
247 guard_name="command_injection",
248 )
250 cmd_list = list(command)
252 # Validate all items are strings and check for dangerous patterns
253 for i, arg in enumerate(cmd_list):
254 if not isinstance(arg, str):
255 raise TypeError(
256 f"All command arguments must be strings, "
257 f"got {type(arg).__name__} at index {i}"
258 )
260 match = _DANGEROUS_COMMAND_RE.search(arg)
261 if match:
262 # We use the matched substring to look up the description.
263 # Because the regex is an alternation of the patterns in order,
264 # this preserves the existing behavior where earlier patterns
265 # in the list take precedence (e.g. '>' matches before '>>').
266 description = _DANGEROUS_COMMAND_LOOKUP[match.group(0)]
267 raise SecurityError(
268 f"Dangerous shell character detected: {description}",
269 guard_name="command_injection",
270 value=arg[:50],
271 )
273 # Check against allowed commands whitelist
274 if allowed_commands is not None:
275 base_command = cmd_list[0]
276 # Get just the command name without path
277 command_name = Path(base_command).name
278 cmd_not_allowed = (
279 command_name not in allowed_commands
280 and base_command not in allowed_commands
281 )
282 if cmd_not_allowed:
283 raise SecurityError(
284 f"Command not in allowed list: {command_name}",
285 guard_name="command_injection",
286 value=command_name,
287 )
289 return cmd_list
292def guard_file_extension(
293 filename: str | Path,
294 *,
295 allowed_extensions: Sequence[str] | None = None,
296 denied_extensions: Sequence[str] | None = None,
297) -> Path:
298 """Validate file extension against allow/deny lists.
300 Args:
301 filename: The filename to check.
302 allowed_extensions: Extensions to allow (with or without dot).
303 denied_extensions: Extensions to deny (with or without dot).
305 Returns:
306 The filename as a Path.
308 Raises:
309 SecurityError: If extension is not allowed or is denied.
311 """
312 path = Path(filename)
313 ext = path.suffix.lower().lstrip(".")
315 # Normalize extension lists
316 def normalize_ext(e: str) -> str:
317 return e.lower().lstrip(".")
319 if denied_extensions is not None:
320 denied = frozenset(normalize_ext(e) for e in denied_extensions)
321 else:
322 denied = _DEFAULT_DENIED_EXTENSIONS
324 if ext in denied:
325 raise SecurityError(
326 f"File extension '{ext}' is not allowed",
327 guard_name="file_extension",
328 value=str(path.name),
329 )
331 if allowed_extensions is not None: # pragma: no branch
332 allowed = {normalize_ext(e) for e in allowed_extensions}
333 if ext not in allowed:
334 raise SecurityError(
335 f"File extension '{ext}' is not in allowed list",
336 guard_name="file_extension",
337 value=str(path.name),
338 )
340 return path
343def _check_env_denied(
344 name_upper: str,
345 name: str,
346 denied_names: Sequence[str] | None,
347) -> None:
348 """Check if the environment variable is in the denied list."""
349 if denied_names is not None:
350 denied = frozenset(n.upper() for n in denied_names)
351 else:
352 denied = _DEFAULT_DENIED_ENV_VARS
354 if name_upper in denied:
355 raise SecurityError(
356 f"Access to sensitive variable '{name}' is denied",
357 guard_name="env_variable",
358 value=name,
359 )
362def _check_env_sensitive(
363 name_upper: str,
364 name: str,
365 allowed_names: Sequence[str] | None,
366) -> None:
367 """Check if the environment variable matches sensitive patterns."""
368 if not _SENSITIVE_ENV_VAR_PATTERN.search(name_upper):
369 return
371 # Only block if not explicitly allowed
372 if allowed_names is not None:
373 allowed = {n.upper() for n in allowed_names}
374 if name_upper in allowed:
375 return
377 raise SecurityError(
378 f"Access to potentially sensitive variable '{name}' is denied",
379 guard_name="env_variable",
380 value=name,
381 )
384def guard_env_variable(
385 name: str,
386 *,
387 allowed_names: Sequence[str] | None = None,
388 denied_names: Sequence[str] | None = None,
389) -> str:
390 """Guard against accessing sensitive environment variables.
392 Args:
393 name: The environment variable name.
394 allowed_names: Variable names to allow.
395 denied_names: Variable names to deny.
397 Returns:
398 The environment variable value if safe.
400 Raises:
401 SecurityError: If variable access is not allowed.
403 """
404 # Validate input type
405 if not isinstance(name, str):
406 raise TypeError(f"Variable name must be str, got {type(name).__name__}")
408 # Reject empty/whitespace-only variable names
409 if not name or not name.strip():
410 raise SecurityError(
411 "Environment variable name cannot be empty or whitespace",
412 guard_name="env_variable",
413 )
415 name_upper = name.upper()
417 _check_env_denied(name_upper, name, denied_names)
418 _check_env_sensitive(name_upper, name, allowed_names)
420 # Get the variable
421 value = os.environ.get(name)
422 if value is None:
423 raise SecurityError(
424 f"Environment variable '{name}' is not set",
425 guard_name="env_variable",
426 value=name,
427 )
429 return value
432def guard_hash_algorithm(
433 algorithm: str,
434 *,
435 allowed_algorithms: Sequence[str] | None = None,
436) -> str:
437 """Validate hash algorithm against a whitelist of secure ones.
439 Args:
440 algorithm: The name of the hash algorithm to validate.
441 allowed_algorithms: Optional whitelist of allowed algorithms.
442 Defaults to a secure set (SHA-256, SHA-512, etc.).
444 Returns:
445 The normalized (lowercase) algorithm name.
447 Raises:
448 SecurityError: If the algorithm is potentially weak or not allowed.
450 """
451 if not isinstance(algorithm, str):
452 raise TypeError(f"Algorithm name must be str, got {type(algorithm).__name__}")
454 algo_lower = algorithm.lower().replace("-", "")
456 allowed: frozenset[str]
457 if allowed_algorithms is not None:
458 allowed = frozenset(a.lower().replace("-", "") for a in allowed_algorithms)
459 else:
460 allowed = _SAFE_HASH_ALGORITHMS
462 if algo_lower not in allowed:
463 raise SecurityError(
464 f"Hash algorithm '{algorithm}' is considered weak or is not allowed",
465 guard_name="hash_algorithm",
466 value=algorithm,
467 )
469 return algo_lower
472# ── SSRF Private-Range Constants ─────────────────────────────────────────────
473_ALLOWED_SSRF_SCHEMES: frozenset[str] = frozenset({"http", "https"})
476def _validate_ssrf_url(
477 url: str,
478 allowed_schemes: frozenset[str],
479) -> Result[str, SecurityError]:
480 """Validate the URL format, scheme, and presence of hostname."""
481 if not isinstance(url, str):
482 raise TypeError(f"URL must be str, got {type(url).__name__}")
484 if not url:
485 return Err(SecurityError("URL cannot be empty", guard_name="ssrf"))
487 try:
488 parsed = urlparse(url)
489 except ValueError as exc: # pragma: no cover
490 return Err(
491 SecurityError(
492 f"Malformed URL: {exc}",
493 guard_name="ssrf",
494 value=url[:80],
495 )
496 )
498 if not parsed.scheme or parsed.scheme.lower() not in allowed_schemes:
499 return Err(
500 SecurityError(
501 f"URL scheme '{parsed.scheme}' is not allowed",
502 guard_name="ssrf",
503 value=url[:80],
504 )
505 )
507 hostname = parsed.hostname
508 if not hostname:
509 return Err(
510 SecurityError(
511 "URL has no resolvable hostname",
512 guard_name="ssrf",
513 value=url[:80],
514 )
515 )
517 return Ok(hostname)
520def _check_ip_safety(hostname: str) -> Result[None, SecurityError]:
521 """Resolve hostname to IP addresses and check for SSRF risk."""
522 try:
523 addr_infos = socket.getaddrinfo(hostname, None)
524 except socket.gaierror:
525 return Err(
526 SecurityError(
527 "Hostname could not be resolved",
528 guard_name="ssrf",
529 )
530 )
532 for addr_info in addr_infos:
533 raw_ip = addr_info[4][0]
534 try:
535 addr = ipaddress.ip_address(raw_ip)
536 except ValueError:
537 continue
539 if (
540 addr.is_private
541 or addr.is_loopback
542 or addr.is_link_local
543 or addr.is_reserved
544 ):
545 return Err(
546 SecurityError(
547 "SSRF detected: hostname resolves to private/reserved address",
548 guard_name="ssrf",
549 )
550 )
552 return Ok(None)
555def guard_ssrf(
556 url: str,
557 *,
558 allowed_schemes: frozenset[str] = _ALLOWED_SSRF_SCHEMES,
559) -> Result[str, SecurityError]:
560 """Validate a URL against Server-Side Request Forgery (SSRF) attacks.
562 Parse the URL, resolve its hostname via DNS, and reject it when the
563 resulting IP address falls inside a private, loopback, link-local, or
564 otherwise reserved network range.
566 Args:
567 url: The URL string to validate.
568 allowed_schemes: Set of URL schemes considered safe.
569 Defaults to ``{"http", "https"}``.
571 Returns:
572 ``Ok(url)`` when the URL is safe to fetch.
573 ``Err(SecurityError)`` when an SSRF risk is detected.
575 Raises:
576 TypeError: If *url* is not a :class:`str`.
578 Example:
579 >>> guard_ssrf("https://example.com")
580 Ok('https://example.com')
581 >>> guard_ssrf("http://169.254.169.254/metadata")
582 Err(SecurityError('[ssrf] ...))
584 """
585 # 1. Validate format and scheme
586 match _validate_ssrf_url(url, allowed_schemes):
587 case Err(e):
588 return Err(e)
589 case Ok(hostname):
590 # 2. Check IP safety
591 match _check_ip_safety(hostname):
592 case Err(e):
593 return Err(e)
594 case Ok():
595 return Ok(url)