Coverage for src / taipanstack / security / guards.py: 100%
200 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-12 21:18 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-12 21:18 +0000
1"""
2Runtime guards for protection against errors and AI hallucinations.
4These guards provide runtime protection against common security issues
5and programming errors that can occur from incorrect AI-generated code.
6All guards raise SecurityError on violation.
7"""
9import functools
10import ipaddress
11import os
12import re
13import socket
14import unicodedata
15from collections.abc import Sequence
16from pathlib import Path
17from urllib.parse import urlsplit
19from result import Err, Ok, Result
21from taipanstack.security.validators import MAX_URL_LENGTH
23# Build regex for traversal patterns.
24# Note: we handle ~ specially to only match at start of path or after a separator
25# to avoid false positives with Windows short paths (e.g., RUNNER~1).
26TRAVERSAL_REGEX = re.compile(
27 r"(?:\.\.|%2e%2e|%252e%252e)|(?:^|[\\/])~",
28 re.IGNORECASE,
29)
31_DANGEROUS_COMMAND_PATTERNS: tuple[tuple[str, str], ...] = (
32 (";", "command separator"),
33 ("|", "pipe"),
34 ("&", "background/and operator"),
35 ("$", "variable expansion"),
36 ("`", "command substitution"),
37 ("$(", "command substitution"),
38 ("${", "variable expansion"),
39 (">", "redirect"),
40 ("<", "redirect"),
41 (">>", "redirect append"),
42 ("||", "or operator"),
43 ("&&", "and operator"),
44 ("\n", "newline"),
45 ("\r", "carriage return"),
46 ("\x00", "null byte"),
47)
49# Pre-compiled regex and lookup map for fast-path command injection detection
50_DANGEROUS_COMMAND_RE = re.compile(
51 "|".join(re.escape(p) for p, _ in _DANGEROUS_COMMAND_PATTERNS)
52)
53_DANGEROUS_COMMAND_LOOKUP = dict(_DANGEROUS_COMMAND_PATTERNS)
55_DEFAULT_DENIED_EXTENSIONS = frozenset(
56 [
57 "exe",
58 "dll",
59 "so",
60 "dylib", # Executables
61 "sh",
62 "bash",
63 "zsh",
64 "ps1",
65 "bat",
66 "cmd", # Scripts
67 "php",
68 "jsp",
69 "asp",
70 "aspx", # Server-side scripts
71 ]
72)
74_DEFAULT_DENIED_ENV_VARS = frozenset(
75 [
76 "AWS_SECRET_ACCESS_KEY",
77 "AWS_SESSION_TOKEN",
78 "GITHUB_TOKEN",
79 "GH_TOKEN",
80 "GITLAB_TOKEN",
81 "DATABASE_URL",
82 "DB_PASSWORD",
83 "PASSWORD",
84 "SECRET_KEY",
85 "PRIVATE_KEY",
86 "API_KEY",
87 "API_SECRET",
88 ]
89)
91_SENSITIVE_ENV_VAR_PATTERN = re.compile(
92 r"SECRET|PASSWORD|TOKEN|PRIVATE.*?KEY|API.*?KEY"
93)
95_SAFE_HASH_ALGORITHMS = frozenset(
96 [
97 "sha256",
98 "sha384",
99 "sha512",
100 "sha3_256",
101 "sha3_384",
102 "sha3_512",
103 "blake2b",
104 "blake2s",
105 ]
106)
109class SecurityError(Exception):
110 """Raised when a security guard detects a violation.
112 Attributes:
113 guard_name: Name of the guard that was triggered.
114 message: Description of the violation.
115 value: The offending value (if safe to log).
117 """
119 def __init__(
120 self,
121 message: str,
122 guard_name: str = "unknown",
123 value: str | None = None,
124 ) -> None:
125 """Initialize SecurityError.
127 Args:
128 message: Description of the violation.
129 guard_name: Name of the guard that triggered.
130 value: The offending value (sanitized).
132 """
133 self.guard_name = guard_name
134 self.value = value
135 super().__init__(f"[{guard_name}] {message}")
138def _check_traversal_patterns(path_str: str) -> None:
139 """Check for explicit traversal patterns before resolution."""
140 match = TRAVERSAL_REGEX.search(path_str.lower())
141 if match:
142 raise SecurityError(
143 f"Path traversal pattern detected: {match.group(0)}",
144 guard_name="path_traversal",
145 value=path_str[:50], # Truncate for safety
146 )
149def _resolve_and_check_bounds(path: Path, base_dir: Path) -> tuple[Path, Path]:
150 """Resolve the path and check if it is within base_dir."""
151 try:
152 full_path = path if path.is_absolute() else (base_dir / path)
153 resolved = full_path.resolve()
154 except (OSError, ValueError, RuntimeError) as e:
155 raise SecurityError(
156 f"Invalid path: {e}",
157 guard_name="path_traversal",
158 ) from e
160 if not resolved.is_relative_to(base_dir):
161 raise SecurityError(
162 "Path escapes base directory",
163 guard_name="path_traversal",
164 )
165 return full_path, resolved
168def _check_symlink_safety(full_path: Path, base_dir: Path) -> None:
169 """Check for symlinks recursively up to the base directory."""
170 current = full_path
171 # Only check components from the user-provided path, not the base_dir
172 while current not in (base_dir, current.parent):
173 # We don't check .exists() because it returns False for broken symlinks
174 try:
175 is_symlink = current.is_symlink()
176 except OSError as e:
177 raise SecurityError(
178 f"Invalid path encountered during symlink check: {e}",
179 guard_name="path_traversal",
180 value=str(current)[:50],
181 ) from e
182 if is_symlink:
183 raise SecurityError(
184 "Symlinks are not allowed",
185 guard_name="path_traversal",
186 value=str(current),
187 )
188 current = current.parent
191def guard_path_traversal(
192 path: Path | str,
193 base_dir: Path | str | None = None,
194 *,
195 allow_symlinks: bool = False,
196) -> Path:
197 """Prevent path traversal attacks.
199 Ensures that the given path does not escape the base directory
200 using techniques like '..' or symlinks.
202 Args:
203 path: The path to validate.
204 base_dir: The base directory to constrain to. Defaults to cwd.
205 allow_symlinks: Whether to allow symlinks (default: False).
207 Returns:
208 The resolved, validated path.
210 Raises:
211 SecurityError: If path traversal is detected.
213 Example:
214 >>> guard_path_traversal("../etc/passwd", Path("/app"))
215 SecurityError: [path_traversal] Path escapes base directory
217 """
218 if not isinstance(path, (str, Path)):
219 raise TypeError(f"path must be str or Path, got {type(path).__name__}")
221 path_obj = Path(path) if isinstance(path, str) else path
222 base = Path(base_dir).resolve() if base_dir else Path.cwd().resolve()
224 _check_traversal_patterns(str(path_obj))
225 full_path, resolved = _resolve_and_check_bounds(path_obj, base)
227 if not allow_symlinks:
228 _check_symlink_safety(full_path, base)
230 return resolved
233def _check_command_not_empty(command: Sequence[str]) -> None:
234 if not command:
235 raise SecurityError(
236 "Empty command is not allowed",
237 guard_name="command_injection",
238 )
241def _check_command_null_bytes(cmd_list: list[str]) -> None:
242 for arg in cmd_list:
243 if isinstance(arg, str) and "\x00" in arg:
244 raise SecurityError(
245 "Dangerous shell character detected: null byte",
246 guard_name="command_injection",
247 value=arg[:50],
248 )
251def _check_command_patterns(cmd_list: list[str]) -> None:
252 for i, arg in enumerate(cmd_list):
253 if not isinstance(arg, str):
254 raise TypeError(
255 f"All command arguments must be strings, "
256 f"got {type(arg).__name__} at index {i}"
257 )
259 match = _DANGEROUS_COMMAND_RE.search(arg)
260 if match:
261 description = _DANGEROUS_COMMAND_LOOKUP[match.group(0)]
262 raise SecurityError(
263 f"Dangerous shell character detected: {description}",
264 guard_name="command_injection",
265 value=arg[:50],
266 )
269def _check_allowed_commands(
270 cmd_list: list[str], allowed_commands: Sequence[str] | None
271) -> None:
272 if allowed_commands is None:
273 return
275 base_command = cmd_list[0]
276 command_name = Path(base_command).name
277 cmd_not_allowed = (
278 command_name not in allowed_commands and base_command not in allowed_commands
279 )
280 if cmd_not_allowed:
281 raise SecurityError(
282 f"Command not in allowed list: {command_name}",
283 guard_name="command_injection",
284 value=command_name,
285 )
288def guard_command_injection(
289 command: Sequence[str],
290 *,
291 allowed_commands: Sequence[str] | None = None,
292) -> list[str]:
293 """Prevent command injection attacks.
295 Validates that command arguments don't contain shell metacharacters
296 that could lead to command injection.
298 Args:
299 command: The command and arguments as a sequence.
300 allowed_commands: Optional whitelist of allowed base commands.
302 Returns:
303 The validated command as a list.
305 Raises:
306 SecurityError: If command injection is detected.
308 Example:
309 >>> guard_command_injection(["echo", "hello; rm -rf /"])
310 SecurityError: [command_injection] Dangerous characters detected
312 """
313 cmd_list = list(command)
315 _check_command_not_empty(cmd_list)
317 _check_command_null_bytes(cmd_list)
318 _check_command_patterns(cmd_list)
319 _check_allowed_commands(cmd_list, allowed_commands)
321 return cmd_list
324def _check_filename_null_bytes(filename_str: str) -> None:
325 if "\x00" in filename_str:
326 raise SecurityError(
327 "Filename contains null bytes",
328 guard_name="file_extension",
329 value=filename_str,
330 )
333def _clean_filename_end(clean_name: str) -> str:
334 end_idx = len(clean_name)
335 while end_idx > 0:
336 char = clean_name[end_idx - 1]
337 if (
338 char == "."
339 or unicodedata.category(char).startswith(("Z", "C"))
340 or char == "\xad"
341 ):
342 end_idx -= 1
343 else:
344 break
345 return clean_name[:end_idx]
348def _normalize_ext(e: str) -> str:
349 return e.lower().lstrip(".")
352def _check_denied_extension(
353 ext: str, original_name: str, denied_extensions: Sequence[str] | None
354) -> None:
355 if denied_extensions is not None:
356 denied = frozenset(_normalize_ext(e) for e in denied_extensions)
357 else:
358 denied = _DEFAULT_DENIED_EXTENSIONS
360 if ext in denied:
361 raise SecurityError(
362 f"File extension '{ext}' is not allowed",
363 guard_name="file_extension",
364 value=original_name,
365 )
368def _check_allowed_extension(
369 ext: str, original_name: str, allowed_extensions: Sequence[str] | None
370) -> None:
371 if allowed_extensions is not None: # pragma: no branch
372 allowed = {_normalize_ext(e) for e in allowed_extensions}
373 if ext not in allowed:
374 raise SecurityError(
375 f"File extension '{ext}' is not in allowed list",
376 guard_name="file_extension",
377 value=original_name,
378 )
381def guard_file_extension(
382 filename: str | Path,
383 *,
384 allowed_extensions: Sequence[str] | None = None,
385 denied_extensions: Sequence[str] | None = None,
386) -> Path:
387 """Validate file extension against allow/deny lists.
389 Args:
390 filename: The filename to check.
391 allowed_extensions: Extensions to allow (with or without dot).
392 denied_extensions: Extensions to deny (with or without dot).
394 Returns:
395 The filename as a Path.
397 Raises:
398 SecurityError: If extension is not allowed or is denied.
400 """
401 filename_str = str(filename)
402 if len(filename_str) > MAX_URL_LENGTH: # Reuse constant to avoid PLR2004
403 raise SecurityError(
404 f"Filename length exceeds maximum allowed limit of {MAX_URL_LENGTH}",
405 guard_name="file_extension",
406 value=filename_str[:80],
407 )
408 _check_filename_null_bytes(filename_str)
410 path = Path(filename)
411 clean_name = _clean_filename_end(path.name)
413 ext = "" if not clean_name else Path(clean_name).suffix.lower().lstrip(".")
415 _check_denied_extension(ext, str(path.name), denied_extensions)
416 _check_allowed_extension(ext, str(path.name), allowed_extensions)
418 return path
421def _check_env_denied(
422 name_upper: str,
423 name: str,
424 denied_names: Sequence[str] | None,
425) -> None:
426 """Check if the environment variable is in the denied list."""
427 if denied_names is not None:
428 denied = frozenset(n.upper() for n in denied_names)
429 else:
430 denied = _DEFAULT_DENIED_ENV_VARS
432 if name_upper in denied:
433 raise SecurityError(
434 f"Access to sensitive variable '{name}' is denied",
435 guard_name="env_variable",
436 value=name,
437 )
440def _check_env_sensitive(
441 name_upper: str,
442 name: str,
443 allowed_names: Sequence[str] | None,
444) -> None:
445 """Check if the environment variable matches sensitive patterns."""
446 if not _SENSITIVE_ENV_VAR_PATTERN.search(name_upper):
447 return
449 # Only block if not explicitly allowed
450 if allowed_names is not None:
451 allowed = {n.upper() for n in allowed_names}
452 if name_upper in allowed:
453 return
455 raise SecurityError(
456 f"Access to potentially sensitive variable '{name}' is denied",
457 guard_name="env_variable",
458 value=name,
459 )
462def guard_env_variable(
463 name: str,
464 *,
465 allowed_names: Sequence[str] | None = None,
466 denied_names: Sequence[str] | None = None,
467) -> str:
468 """Guard against accessing sensitive environment variables.
470 Args:
471 name: The environment variable name.
472 allowed_names: Variable names to allow.
473 denied_names: Variable names to deny.
475 Returns:
476 The environment variable value if safe.
478 Raises:
479 SecurityError: If variable access is not allowed.
481 """
482 # Validate input type
483 if not isinstance(name, str):
484 raise TypeError(f"Variable name must be str, got {type(name).__name__}")
486 # Reject empty/whitespace-only variable names
487 if not name or not name.strip():
488 raise SecurityError(
489 "Environment variable name cannot be empty or whitespace",
490 guard_name="env_variable",
491 )
493 if "\x00" in name:
494 raise SecurityError(
495 "Environment variable name cannot contain null bytes",
496 guard_name="env_variable",
497 )
499 name_upper = name.upper()
501 _check_env_denied(name_upper, name, denied_names)
502 _check_env_sensitive(name_upper, name, allowed_names)
504 # Get the variable
505 value = os.environ.get(name)
506 if value is None:
507 raise SecurityError(
508 f"Environment variable '{name}' is not set",
509 guard_name="env_variable",
510 value=name,
511 )
513 return value
516# ── SSRF Private-Range Constants ─────────────────────────────────────────────
517_ALLOWED_SSRF_SCHEMES: frozenset[str] = frozenset({"http", "https"})
520def _validate_ssrf_url_type_and_length(url: str) -> Result[str, SecurityError]:
521 if not isinstance(url, str):
522 return Err(
523 SecurityError(
524 f"URL must be str, got {type(url).__name__}",
525 guard_name="ssrf",
526 )
527 )
529 if not url:
530 return Err(SecurityError("URL cannot be empty", guard_name="ssrf"))
532 if len(url) > MAX_URL_LENGTH:
533 return Err(
534 SecurityError(
535 "URL length exceeds maximum allowed limit",
536 guard_name="ssrf",
537 value=url[:80],
538 )
539 )
540 return Ok(url)
543def _validate_ssrf_url_parse(
544 url: str, allowed_schemes: frozenset[str]
545) -> Result[str, SecurityError]:
546 try:
547 parsed = urlsplit(url)
548 except ValueError as exc:
549 return Err(
550 SecurityError(
551 f"Malformed URL: {exc}",
552 guard_name="ssrf",
553 value=url[:80],
554 )
555 )
557 if not parsed.scheme or parsed.scheme.lower() not in allowed_schemes:
558 return Err(
559 SecurityError(
560 f"URL scheme '{parsed.scheme}' is not allowed",
561 guard_name="ssrf",
562 value=url[:80],
563 )
564 )
566 hostname = parsed.hostname
567 if not hostname:
568 return Err(
569 SecurityError(
570 "URL has no resolvable hostname",
571 guard_name="ssrf",
572 value=url[:80],
573 )
574 )
576 return Ok(hostname)
579def _validate_ssrf_url(
580 url: str,
581 allowed_schemes: frozenset[str],
582) -> Result[str, SecurityError]:
583 """Validate the URL format, scheme, and presence of hostname."""
584 type_len_res = _validate_ssrf_url_type_and_length(url)
585 if not isinstance(type_len_res, Ok):
586 return type_len_res
588 return _validate_ssrf_url_parse(url, allowed_schemes)
591@functools.lru_cache(maxsize=1024)
592def _is_ip_safe(raw_ip: str) -> bool:
593 """Check if a single IP address is safe (not private/loopback/reserved)."""
594 try:
595 addr = ipaddress.ip_address(raw_ip)
596 except ValueError:
597 return True
599 return not (
600 addr.is_private or addr.is_loopback or addr.is_link_local or addr.is_reserved
601 )
604def _check_ip_safety(hostname: str) -> Result[None, SecurityError]:
605 """Resolve hostname to IP addresses and check for SSRF risk."""
606 try:
607 addr_infos = socket.getaddrinfo(hostname, None)
608 except (socket.gaierror, UnicodeError):
609 return Err(
610 SecurityError(
611 "Hostname could not be resolved or contains invalid characters",
612 guard_name="ssrf",
613 )
614 )
616 for addr_info in addr_infos:
617 raw_ip = addr_info[4][0]
618 if not _is_ip_safe(raw_ip):
619 return Err(
620 SecurityError(
621 "SSRF detected: hostname resolves to private/reserved address",
622 guard_name="ssrf",
623 )
624 )
626 return Ok(None)
629def guard_ssrf(
630 url: str,
631 *,
632 allowed_schemes: frozenset[str] = _ALLOWED_SSRF_SCHEMES,
633) -> Result[str, SecurityError]:
634 """Validate a URL against Server-Side Request Forgery (SSRF) attacks.
636 Parse the URL, resolve its hostname via DNS, and reject it when the
637 resulting IP address falls inside a private, loopback, link-local, or
638 otherwise reserved network range.
640 Args:
641 url: The URL string to validate.
642 allowed_schemes: Set of URL schemes considered safe.
643 Defaults to ``{"http", "https"}``.
645 Returns:
646 ``Ok(url)`` when the URL is safe to fetch.
647 ``Err(SecurityError)`` when an SSRF risk is detected.
649 Raises:
650 TypeError: If *url* is not a :class:`str`.
652 Example:
653 >>> guard_ssrf("https://example.com")
654 Ok('https://example.com')
655 >>> guard_ssrf("http://169.254.169.254/metadata")
656 Err(SecurityError('[ssrf] ...))
658 """
659 # 1. Validate format and scheme
660 val_res = _validate_ssrf_url(url, allowed_schemes)
661 if not isinstance(val_res, Ok):
662 return val_res
664 # 2. Check IP safety
665 ip_res = _check_ip_safety(val_res.ok_value)
666 if not isinstance(ip_res, Ok):
667 return ip_res
669 return Ok(url)