Coverage for src / taipanstack / security / guards.py: 100%

200 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-12 21:18 +0000

1""" 

2Runtime guards for protection against errors and AI hallucinations. 

3 

4These guards provide runtime protection against common security issues 

5and programming errors that can occur from incorrect AI-generated code. 

6All guards raise SecurityError on violation. 

7""" 

8 

9import functools 

10import ipaddress 

11import os 

12import re 

13import socket 

14import unicodedata 

15from collections.abc import Sequence 

16from pathlib import Path 

17from urllib.parse import urlsplit 

18 

19from result import Err, Ok, Result 

20 

21from taipanstack.security.validators import MAX_URL_LENGTH 

22 

23# Build regex for traversal patterns. 

24# Note: we handle ~ specially to only match at start of path or after a separator 

25# to avoid false positives with Windows short paths (e.g., RUNNER~1). 

26TRAVERSAL_REGEX = re.compile( 

27 r"(?:\.\.|%2e%2e|%252e%252e)|(?:^|[\\/])~", 

28 re.IGNORECASE, 

29) 

30 

31_DANGEROUS_COMMAND_PATTERNS: tuple[tuple[str, str], ...] = ( 

32 (";", "command separator"), 

33 ("|", "pipe"), 

34 ("&", "background/and operator"), 

35 ("$", "variable expansion"), 

36 ("`", "command substitution"), 

37 ("$(", "command substitution"), 

38 ("${", "variable expansion"), 

39 (">", "redirect"), 

40 ("<", "redirect"), 

41 (">>", "redirect append"), 

42 ("||", "or operator"), 

43 ("&&", "and operator"), 

44 ("\n", "newline"), 

45 ("\r", "carriage return"), 

46 ("\x00", "null byte"), 

47) 

48 

49# Pre-compiled regex and lookup map for fast-path command injection detection 

50_DANGEROUS_COMMAND_RE = re.compile( 

51 "|".join(re.escape(p) for p, _ in _DANGEROUS_COMMAND_PATTERNS) 

52) 

53_DANGEROUS_COMMAND_LOOKUP = dict(_DANGEROUS_COMMAND_PATTERNS) 

54 

55_DEFAULT_DENIED_EXTENSIONS = frozenset( 

56 [ 

57 "exe", 

58 "dll", 

59 "so", 

60 "dylib", # Executables 

61 "sh", 

62 "bash", 

63 "zsh", 

64 "ps1", 

65 "bat", 

66 "cmd", # Scripts 

67 "php", 

68 "jsp", 

69 "asp", 

70 "aspx", # Server-side scripts 

71 ] 

72) 

73 

74_DEFAULT_DENIED_ENV_VARS = frozenset( 

75 [ 

76 "AWS_SECRET_ACCESS_KEY", 

77 "AWS_SESSION_TOKEN", 

78 "GITHUB_TOKEN", 

79 "GH_TOKEN", 

80 "GITLAB_TOKEN", 

81 "DATABASE_URL", 

82 "DB_PASSWORD", 

83 "PASSWORD", 

84 "SECRET_KEY", 

85 "PRIVATE_KEY", 

86 "API_KEY", 

87 "API_SECRET", 

88 ] 

89) 

90 

91_SENSITIVE_ENV_VAR_PATTERN = re.compile( 

92 r"SECRET|PASSWORD|TOKEN|PRIVATE.*?KEY|API.*?KEY" 

93) 

94 

95_SAFE_HASH_ALGORITHMS = frozenset( 

96 [ 

97 "sha256", 

98 "sha384", 

99 "sha512", 

100 "sha3_256", 

101 "sha3_384", 

102 "sha3_512", 

103 "blake2b", 

104 "blake2s", 

105 ] 

106) 

107 

108 

109class SecurityError(Exception): 

110 """Raised when a security guard detects a violation. 

111 

112 Attributes: 

113 guard_name: Name of the guard that was triggered. 

114 message: Description of the violation. 

115 value: The offending value (if safe to log). 

116 

117 """ 

118 

119 def __init__( 

120 self, 

121 message: str, 

122 guard_name: str = "unknown", 

123 value: str | None = None, 

124 ) -> None: 

125 """Initialize SecurityError. 

126 

127 Args: 

128 message: Description of the violation. 

129 guard_name: Name of the guard that triggered. 

130 value: The offending value (sanitized). 

131 

132 """ 

133 self.guard_name = guard_name 

134 self.value = value 

135 super().__init__(f"[{guard_name}] {message}") 

136 

137 

138def _check_traversal_patterns(path_str: str) -> None: 

139 """Check for explicit traversal patterns before resolution.""" 

140 match = TRAVERSAL_REGEX.search(path_str.lower()) 

141 if match: 

142 raise SecurityError( 

143 f"Path traversal pattern detected: {match.group(0)}", 

144 guard_name="path_traversal", 

145 value=path_str[:50], # Truncate for safety 

146 ) 

147 

148 

149def _resolve_and_check_bounds(path: Path, base_dir: Path) -> tuple[Path, Path]: 

150 """Resolve the path and check if it is within base_dir.""" 

151 try: 

152 full_path = path if path.is_absolute() else (base_dir / path) 

153 resolved = full_path.resolve() 

154 except (OSError, ValueError, RuntimeError) as e: 

155 raise SecurityError( 

156 f"Invalid path: {e}", 

157 guard_name="path_traversal", 

158 ) from e 

159 

160 if not resolved.is_relative_to(base_dir): 

161 raise SecurityError( 

162 "Path escapes base directory", 

163 guard_name="path_traversal", 

164 ) 

165 return full_path, resolved 

166 

167 

168def _check_symlink_safety(full_path: Path, base_dir: Path) -> None: 

169 """Check for symlinks recursively up to the base directory.""" 

170 current = full_path 

171 # Only check components from the user-provided path, not the base_dir 

172 while current not in (base_dir, current.parent): 

173 # We don't check .exists() because it returns False for broken symlinks 

174 try: 

175 is_symlink = current.is_symlink() 

176 except OSError as e: 

177 raise SecurityError( 

178 f"Invalid path encountered during symlink check: {e}", 

179 guard_name="path_traversal", 

180 value=str(current)[:50], 

181 ) from e 

182 if is_symlink: 

183 raise SecurityError( 

184 "Symlinks are not allowed", 

185 guard_name="path_traversal", 

186 value=str(current), 

187 ) 

188 current = current.parent 

189 

190 

191def guard_path_traversal( 

192 path: Path | str, 

193 base_dir: Path | str | None = None, 

194 *, 

195 allow_symlinks: bool = False, 

196) -> Path: 

197 """Prevent path traversal attacks. 

198 

199 Ensures that the given path does not escape the base directory 

200 using techniques like '..' or symlinks. 

201 

202 Args: 

203 path: The path to validate. 

204 base_dir: The base directory to constrain to. Defaults to cwd. 

205 allow_symlinks: Whether to allow symlinks (default: False). 

206 

207 Returns: 

208 The resolved, validated path. 

209 

210 Raises: 

211 SecurityError: If path traversal is detected. 

212 

213 Example: 

214 >>> guard_path_traversal("../etc/passwd", Path("/app")) 

215 SecurityError: [path_traversal] Path escapes base directory 

216 

217 """ 

218 if not isinstance(path, (str, Path)): 

219 raise TypeError(f"path must be str or Path, got {type(path).__name__}") 

220 

221 path_obj = Path(path) if isinstance(path, str) else path 

222 base = Path(base_dir).resolve() if base_dir else Path.cwd().resolve() 

223 

224 _check_traversal_patterns(str(path_obj)) 

225 full_path, resolved = _resolve_and_check_bounds(path_obj, base) 

226 

227 if not allow_symlinks: 

228 _check_symlink_safety(full_path, base) 

229 

230 return resolved 

231 

232 

233def _check_command_not_empty(command: Sequence[str]) -> None: 

234 if not command: 

235 raise SecurityError( 

236 "Empty command is not allowed", 

237 guard_name="command_injection", 

238 ) 

239 

240 

241def _check_command_null_bytes(cmd_list: list[str]) -> None: 

242 for arg in cmd_list: 

243 if isinstance(arg, str) and "\x00" in arg: 

244 raise SecurityError( 

245 "Dangerous shell character detected: null byte", 

246 guard_name="command_injection", 

247 value=arg[:50], 

248 ) 

249 

250 

251def _check_command_patterns(cmd_list: list[str]) -> None: 

252 for i, arg in enumerate(cmd_list): 

253 if not isinstance(arg, str): 

254 raise TypeError( 

255 f"All command arguments must be strings, " 

256 f"got {type(arg).__name__} at index {i}" 

257 ) 

258 

259 match = _DANGEROUS_COMMAND_RE.search(arg) 

260 if match: 

261 description = _DANGEROUS_COMMAND_LOOKUP[match.group(0)] 

262 raise SecurityError( 

263 f"Dangerous shell character detected: {description}", 

264 guard_name="command_injection", 

265 value=arg[:50], 

266 ) 

267 

268 

269def _check_allowed_commands( 

270 cmd_list: list[str], allowed_commands: Sequence[str] | None 

271) -> None: 

272 if allowed_commands is None: 

273 return 

274 

275 base_command = cmd_list[0] 

276 command_name = Path(base_command).name 

277 cmd_not_allowed = ( 

278 command_name not in allowed_commands and base_command not in allowed_commands 

279 ) 

280 if cmd_not_allowed: 

281 raise SecurityError( 

282 f"Command not in allowed list: {command_name}", 

283 guard_name="command_injection", 

284 value=command_name, 

285 ) 

286 

287 

288def guard_command_injection( 

289 command: Sequence[str], 

290 *, 

291 allowed_commands: Sequence[str] | None = None, 

292) -> list[str]: 

293 """Prevent command injection attacks. 

294 

295 Validates that command arguments don't contain shell metacharacters 

296 that could lead to command injection. 

297 

298 Args: 

299 command: The command and arguments as a sequence. 

300 allowed_commands: Optional whitelist of allowed base commands. 

301 

302 Returns: 

303 The validated command as a list. 

304 

305 Raises: 

306 SecurityError: If command injection is detected. 

307 

308 Example: 

309 >>> guard_command_injection(["echo", "hello; rm -rf /"]) 

310 SecurityError: [command_injection] Dangerous characters detected 

311 

312 """ 

313 cmd_list = list(command) 

314 

315 _check_command_not_empty(cmd_list) 

316 

317 _check_command_null_bytes(cmd_list) 

318 _check_command_patterns(cmd_list) 

319 _check_allowed_commands(cmd_list, allowed_commands) 

320 

321 return cmd_list 

322 

323 

324def _check_filename_null_bytes(filename_str: str) -> None: 

325 if "\x00" in filename_str: 

326 raise SecurityError( 

327 "Filename contains null bytes", 

328 guard_name="file_extension", 

329 value=filename_str, 

330 ) 

331 

332 

333def _clean_filename_end(clean_name: str) -> str: 

334 end_idx = len(clean_name) 

335 while end_idx > 0: 

336 char = clean_name[end_idx - 1] 

337 if ( 

338 char == "." 

339 or unicodedata.category(char).startswith(("Z", "C")) 

340 or char == "\xad" 

341 ): 

342 end_idx -= 1 

343 else: 

344 break 

345 return clean_name[:end_idx] 

346 

347 

348def _normalize_ext(e: str) -> str: 

349 return e.lower().lstrip(".") 

350 

351 

352def _check_denied_extension( 

353 ext: str, original_name: str, denied_extensions: Sequence[str] | None 

354) -> None: 

355 if denied_extensions is not None: 

356 denied = frozenset(_normalize_ext(e) for e in denied_extensions) 

357 else: 

358 denied = _DEFAULT_DENIED_EXTENSIONS 

359 

360 if ext in denied: 

361 raise SecurityError( 

362 f"File extension '{ext}' is not allowed", 

363 guard_name="file_extension", 

364 value=original_name, 

365 ) 

366 

367 

368def _check_allowed_extension( 

369 ext: str, original_name: str, allowed_extensions: Sequence[str] | None 

370) -> None: 

371 if allowed_extensions is not None: # pragma: no branch 

372 allowed = {_normalize_ext(e) for e in allowed_extensions} 

373 if ext not in allowed: 

374 raise SecurityError( 

375 f"File extension '{ext}' is not in allowed list", 

376 guard_name="file_extension", 

377 value=original_name, 

378 ) 

379 

380 

381def guard_file_extension( 

382 filename: str | Path, 

383 *, 

384 allowed_extensions: Sequence[str] | None = None, 

385 denied_extensions: Sequence[str] | None = None, 

386) -> Path: 

387 """Validate file extension against allow/deny lists. 

388 

389 Args: 

390 filename: The filename to check. 

391 allowed_extensions: Extensions to allow (with or without dot). 

392 denied_extensions: Extensions to deny (with or without dot). 

393 

394 Returns: 

395 The filename as a Path. 

396 

397 Raises: 

398 SecurityError: If extension is not allowed or is denied. 

399 

400 """ 

401 filename_str = str(filename) 

402 if len(filename_str) > MAX_URL_LENGTH: # Reuse constant to avoid PLR2004 

403 raise SecurityError( 

404 f"Filename length exceeds maximum allowed limit of {MAX_URL_LENGTH}", 

405 guard_name="file_extension", 

406 value=filename_str[:80], 

407 ) 

408 _check_filename_null_bytes(filename_str) 

409 

410 path = Path(filename) 

411 clean_name = _clean_filename_end(path.name) 

412 

413 ext = "" if not clean_name else Path(clean_name).suffix.lower().lstrip(".") 

414 

415 _check_denied_extension(ext, str(path.name), denied_extensions) 

416 _check_allowed_extension(ext, str(path.name), allowed_extensions) 

417 

418 return path 

419 

420 

421def _check_env_denied( 

422 name_upper: str, 

423 name: str, 

424 denied_names: Sequence[str] | None, 

425) -> None: 

426 """Check if the environment variable is in the denied list.""" 

427 if denied_names is not None: 

428 denied = frozenset(n.upper() for n in denied_names) 

429 else: 

430 denied = _DEFAULT_DENIED_ENV_VARS 

431 

432 if name_upper in denied: 

433 raise SecurityError( 

434 f"Access to sensitive variable '{name}' is denied", 

435 guard_name="env_variable", 

436 value=name, 

437 ) 

438 

439 

440def _check_env_sensitive( 

441 name_upper: str, 

442 name: str, 

443 allowed_names: Sequence[str] | None, 

444) -> None: 

445 """Check if the environment variable matches sensitive patterns.""" 

446 if not _SENSITIVE_ENV_VAR_PATTERN.search(name_upper): 

447 return 

448 

449 # Only block if not explicitly allowed 

450 if allowed_names is not None: 

451 allowed = {n.upper() for n in allowed_names} 

452 if name_upper in allowed: 

453 return 

454 

455 raise SecurityError( 

456 f"Access to potentially sensitive variable '{name}' is denied", 

457 guard_name="env_variable", 

458 value=name, 

459 ) 

460 

461 

462def guard_env_variable( 

463 name: str, 

464 *, 

465 allowed_names: Sequence[str] | None = None, 

466 denied_names: Sequence[str] | None = None, 

467) -> str: 

468 """Guard against accessing sensitive environment variables. 

469 

470 Args: 

471 name: The environment variable name. 

472 allowed_names: Variable names to allow. 

473 denied_names: Variable names to deny. 

474 

475 Returns: 

476 The environment variable value if safe. 

477 

478 Raises: 

479 SecurityError: If variable access is not allowed. 

480 

481 """ 

482 # Validate input type 

483 if not isinstance(name, str): 

484 raise TypeError(f"Variable name must be str, got {type(name).__name__}") 

485 

486 # Reject empty/whitespace-only variable names 

487 if not name or not name.strip(): 

488 raise SecurityError( 

489 "Environment variable name cannot be empty or whitespace", 

490 guard_name="env_variable", 

491 ) 

492 

493 if "\x00" in name: 

494 raise SecurityError( 

495 "Environment variable name cannot contain null bytes", 

496 guard_name="env_variable", 

497 ) 

498 

499 name_upper = name.upper() 

500 

501 _check_env_denied(name_upper, name, denied_names) 

502 _check_env_sensitive(name_upper, name, allowed_names) 

503 

504 # Get the variable 

505 value = os.environ.get(name) 

506 if value is None: 

507 raise SecurityError( 

508 f"Environment variable '{name}' is not set", 

509 guard_name="env_variable", 

510 value=name, 

511 ) 

512 

513 return value 

514 

515 

516# ── SSRF Private-Range Constants ───────────────────────────────────────────── 

517_ALLOWED_SSRF_SCHEMES: frozenset[str] = frozenset({"http", "https"}) 

518 

519 

520def _validate_ssrf_url_type_and_length(url: str) -> Result[str, SecurityError]: 

521 if not isinstance(url, str): 

522 return Err( 

523 SecurityError( 

524 f"URL must be str, got {type(url).__name__}", 

525 guard_name="ssrf", 

526 ) 

527 ) 

528 

529 if not url: 

530 return Err(SecurityError("URL cannot be empty", guard_name="ssrf")) 

531 

532 if len(url) > MAX_URL_LENGTH: 

533 return Err( 

534 SecurityError( 

535 "URL length exceeds maximum allowed limit", 

536 guard_name="ssrf", 

537 value=url[:80], 

538 ) 

539 ) 

540 return Ok(url) 

541 

542 

543def _validate_ssrf_url_parse( 

544 url: str, allowed_schemes: frozenset[str] 

545) -> Result[str, SecurityError]: 

546 try: 

547 parsed = urlsplit(url) 

548 except ValueError as exc: 

549 return Err( 

550 SecurityError( 

551 f"Malformed URL: {exc}", 

552 guard_name="ssrf", 

553 value=url[:80], 

554 ) 

555 ) 

556 

557 if not parsed.scheme or parsed.scheme.lower() not in allowed_schemes: 

558 return Err( 

559 SecurityError( 

560 f"URL scheme '{parsed.scheme}' is not allowed", 

561 guard_name="ssrf", 

562 value=url[:80], 

563 ) 

564 ) 

565 

566 hostname = parsed.hostname 

567 if not hostname: 

568 return Err( 

569 SecurityError( 

570 "URL has no resolvable hostname", 

571 guard_name="ssrf", 

572 value=url[:80], 

573 ) 

574 ) 

575 

576 return Ok(hostname) 

577 

578 

579def _validate_ssrf_url( 

580 url: str, 

581 allowed_schemes: frozenset[str], 

582) -> Result[str, SecurityError]: 

583 """Validate the URL format, scheme, and presence of hostname.""" 

584 type_len_res = _validate_ssrf_url_type_and_length(url) 

585 if not isinstance(type_len_res, Ok): 

586 return type_len_res 

587 

588 return _validate_ssrf_url_parse(url, allowed_schemes) 

589 

590 

591@functools.lru_cache(maxsize=1024) 

592def _is_ip_safe(raw_ip: str) -> bool: 

593 """Check if a single IP address is safe (not private/loopback/reserved).""" 

594 try: 

595 addr = ipaddress.ip_address(raw_ip) 

596 except ValueError: 

597 return True 

598 

599 return not ( 

600 addr.is_private or addr.is_loopback or addr.is_link_local or addr.is_reserved 

601 ) 

602 

603 

604def _check_ip_safety(hostname: str) -> Result[None, SecurityError]: 

605 """Resolve hostname to IP addresses and check for SSRF risk.""" 

606 try: 

607 addr_infos = socket.getaddrinfo(hostname, None) 

608 except (socket.gaierror, UnicodeError): 

609 return Err( 

610 SecurityError( 

611 "Hostname could not be resolved or contains invalid characters", 

612 guard_name="ssrf", 

613 ) 

614 ) 

615 

616 for addr_info in addr_infos: 

617 raw_ip = addr_info[4][0] 

618 if not _is_ip_safe(raw_ip): 

619 return Err( 

620 SecurityError( 

621 "SSRF detected: hostname resolves to private/reserved address", 

622 guard_name="ssrf", 

623 ) 

624 ) 

625 

626 return Ok(None) 

627 

628 

629def guard_ssrf( 

630 url: str, 

631 *, 

632 allowed_schemes: frozenset[str] = _ALLOWED_SSRF_SCHEMES, 

633) -> Result[str, SecurityError]: 

634 """Validate a URL against Server-Side Request Forgery (SSRF) attacks. 

635 

636 Parse the URL, resolve its hostname via DNS, and reject it when the 

637 resulting IP address falls inside a private, loopback, link-local, or 

638 otherwise reserved network range. 

639 

640 Args: 

641 url: The URL string to validate. 

642 allowed_schemes: Set of URL schemes considered safe. 

643 Defaults to ``{"http", "https"}``. 

644 

645 Returns: 

646 ``Ok(url)`` when the URL is safe to fetch. 

647 ``Err(SecurityError)`` when an SSRF risk is detected. 

648 

649 Raises: 

650 TypeError: If *url* is not a :class:`str`. 

651 

652 Example: 

653 >>> guard_ssrf("https://example.com") 

654 Ok('https://example.com') 

655 >>> guard_ssrf("http://169.254.169.254/metadata") 

656 Err(SecurityError('[ssrf] ...)) 

657 

658 """ 

659 # 1. Validate format and scheme 

660 val_res = _validate_ssrf_url(url, allowed_schemes) 

661 if not isinstance(val_res, Ok): 

662 return val_res 

663 

664 # 2. Check IP safety 

665 ip_res = _check_ip_safety(val_res.ok_value) 

666 if not isinstance(ip_res, Ok): 

667 return ip_res 

668 

669 return Ok(url)