Coverage for src / taipanstack / security / guards.py: 100%

157 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-23 14:54 +0000

1""" 

2Runtime guards for protection against errors and AI hallucinations. 

3 

4These guards provide runtime protection against common security issues 

5and programming errors that can occur from incorrect AI-generated code. 

6All guards raise SecurityError on violation. 

7""" 

8 

9import ipaddress 

10import os 

11import re 

12import socket 

13from collections.abc import Sequence 

14from pathlib import Path 

15from urllib.parse import urlparse 

16 

17from result import Err, Ok, Result 

18 

19# Build regex for traversal patterns. 

20# Note: we handle ~ specially to only match at start of path or after a separator 

21# to avoid false positives with Windows short paths (e.g., RUNNER~1). 

22TRAVERSAL_REGEX = re.compile( 

23 r"(?:\.\.|%2e%2e|%252e%252e)|(?:^|[\\/])~", 

24 re.IGNORECASE, 

25) 

26 

27_DANGEROUS_COMMAND_PATTERNS: tuple[tuple[str, str], ...] = ( 

28 (";", "command separator"), 

29 ("|", "pipe"), 

30 ("&", "background/and operator"), 

31 ("$", "variable expansion"), 

32 ("`", "command substitution"), 

33 ("$(", "command substitution"), 

34 ("${", "variable expansion"), 

35 (">", "redirect"), 

36 ("<", "redirect"), 

37 (">>", "redirect append"), 

38 ("||", "or operator"), 

39 ("&&", "and operator"), 

40 ("\n", "newline"), 

41 ("\r", "carriage return"), 

42 ("\x00", "null byte"), 

43) 

44 

45# Pre-compiled regex and lookup map for fast-path command injection detection 

46_DANGEROUS_COMMAND_RE = re.compile( 

47 "|".join(re.escape(p) for p, _ in _DANGEROUS_COMMAND_PATTERNS) 

48) 

49_DANGEROUS_COMMAND_LOOKUP = dict(_DANGEROUS_COMMAND_PATTERNS) 

50 

51_DEFAULT_DENIED_EXTENSIONS = frozenset( 

52 [ 

53 "exe", 

54 "dll", 

55 "so", 

56 "dylib", # Executables 

57 "sh", 

58 "bash", 

59 "zsh", 

60 "ps1", 

61 "bat", 

62 "cmd", # Scripts 

63 "php", 

64 "jsp", 

65 "asp", 

66 "aspx", # Server-side scripts 

67 ] 

68) 

69 

70_DEFAULT_DENIED_ENV_VARS = frozenset( 

71 [ 

72 "AWS_SECRET_ACCESS_KEY", 

73 "AWS_SESSION_TOKEN", 

74 "GITHUB_TOKEN", 

75 "GH_TOKEN", 

76 "GITLAB_TOKEN", 

77 "DATABASE_URL", 

78 "DB_PASSWORD", 

79 "PASSWORD", 

80 "SECRET_KEY", 

81 "PRIVATE_KEY", 

82 "API_KEY", 

83 "API_SECRET", 

84 ] 

85) 

86 

87_SENSITIVE_ENV_VAR_PATTERN = re.compile(r"SECRET|PASSWORD|TOKEN|PRIVATE.*KEY|API.*KEY") 

88 

89_SAFE_HASH_ALGORITHMS = frozenset( 

90 [ 

91 "sha256", 

92 "sha384", 

93 "sha512", 

94 "sha3_256", 

95 "sha3_384", 

96 "sha3_512", 

97 "blake2b", 

98 "blake2s", 

99 ] 

100) 

101 

102 

103class SecurityError(Exception): 

104 """Raised when a security guard detects a violation. 

105 

106 Attributes: 

107 guard_name: Name of the guard that was triggered. 

108 message: Description of the violation. 

109 value: The offending value (if safe to log). 

110 

111 """ 

112 

113 def __init__( 

114 self, 

115 message: str, 

116 guard_name: str = "unknown", 

117 value: str | None = None, 

118 ) -> None: 

119 """Initialize SecurityError. 

120 

121 Args: 

122 message: Description of the violation. 

123 guard_name: Name of the guard that triggered. 

124 value: The offending value (sanitized). 

125 

126 """ 

127 self.guard_name = guard_name 

128 self.value = value 

129 super().__init__(f"[{guard_name}] {message}") 

130 

131 

132def _check_traversal_patterns(path_str: str) -> None: 

133 """Check for explicit traversal patterns before resolution.""" 

134 match = TRAVERSAL_REGEX.search(path_str.lower()) 

135 if match: 

136 raise SecurityError( 

137 f"Path traversal pattern detected: {match.group(0)}", 

138 guard_name="path_traversal", 

139 value=path_str[:50], # Truncate for safety 

140 ) 

141 

142 

143def _resolve_and_check_bounds(path: Path, base_dir: Path) -> tuple[Path, Path]: 

144 """Resolve the path and check if it is within base_dir.""" 

145 try: 

146 full_path = path if path.is_absolute() else (base_dir / path) 

147 resolved = full_path.resolve() 

148 except (OSError, ValueError) as e: 

149 raise SecurityError( 

150 f"Invalid path: {e}", 

151 guard_name="path_traversal", 

152 ) from e 

153 

154 if not resolved.is_relative_to(base_dir): 

155 raise SecurityError( 

156 "Path escapes base directory", 

157 guard_name="path_traversal", 

158 ) 

159 return full_path, resolved 

160 

161 

162def _check_symlink_safety(full_path: Path, base_dir: Path) -> None: 

163 """Check for symlinks recursively up to the base directory.""" 

164 current = full_path 

165 # Only check components from the user-provided path, not the base_dir 

166 while current not in (base_dir, current.parent): 

167 # We don't check .exists() because it returns False for broken symlinks 

168 if current.is_symlink(): 

169 raise SecurityError( 

170 "Symlinks are not allowed", 

171 guard_name="path_traversal", 

172 value=str(current), 

173 ) 

174 current = current.parent 

175 

176 

177def guard_path_traversal( 

178 path: Path | str, 

179 base_dir: Path | str | None = None, 

180 *, 

181 allow_symlinks: bool = False, 

182) -> Path: 

183 """Prevent path traversal attacks. 

184 

185 Ensures that the given path does not escape the base directory 

186 using techniques like '..' or symlinks. 

187 

188 Args: 

189 path: The path to validate. 

190 base_dir: The base directory to constrain to. Defaults to cwd. 

191 allow_symlinks: Whether to allow symlinks (default: False). 

192 

193 Returns: 

194 The resolved, validated path. 

195 

196 Raises: 

197 SecurityError: If path traversal is detected. 

198 

199 Example: 

200 >>> guard_path_traversal("../etc/passwd", Path("/app")) 

201 SecurityError: [path_traversal] Path escapes base directory 

202 

203 """ 

204 if not isinstance(path, (str, Path)): 

205 raise TypeError(f"path must be str or Path, got {type(path).__name__}") 

206 

207 path_obj = Path(path) if isinstance(path, str) else path 

208 base = Path(base_dir).resolve() if base_dir else Path.cwd().resolve() 

209 

210 _check_traversal_patterns(str(path_obj)) 

211 full_path, resolved = _resolve_and_check_bounds(path_obj, base) 

212 

213 if not allow_symlinks: 

214 _check_symlink_safety(full_path, base) 

215 

216 return resolved 

217 

218 

219def guard_command_injection( 

220 command: Sequence[str], 

221 *, 

222 allowed_commands: Sequence[str] | None = None, 

223) -> list[str]: 

224 """Prevent command injection attacks. 

225 

226 Validates that command arguments don't contain shell metacharacters 

227 that could lead to command injection. 

228 

229 Args: 

230 command: The command and arguments as a sequence. 

231 allowed_commands: Optional whitelist of allowed base commands. 

232 

233 Returns: 

234 The validated command as a list. 

235 

236 Raises: 

237 SecurityError: If command injection is detected. 

238 

239 Example: 

240 >>> guard_command_injection(["echo", "hello; rm -rf /"]) 

241 SecurityError: [command_injection] Dangerous characters detected 

242 

243 """ 

244 if not command: 

245 raise SecurityError( 

246 "Empty command is not allowed", 

247 guard_name="command_injection", 

248 ) 

249 

250 cmd_list = list(command) 

251 

252 # Validate all items are strings and check for dangerous patterns 

253 for i, arg in enumerate(cmd_list): 

254 if not isinstance(arg, str): 

255 raise TypeError( 

256 f"All command arguments must be strings, " 

257 f"got {type(arg).__name__} at index {i}" 

258 ) 

259 

260 match = _DANGEROUS_COMMAND_RE.search(arg) 

261 if match: 

262 # We use the matched substring to look up the description. 

263 # Because the regex is an alternation of the patterns in order, 

264 # this preserves the existing behavior where earlier patterns 

265 # in the list take precedence (e.g. '>' matches before '>>'). 

266 description = _DANGEROUS_COMMAND_LOOKUP[match.group(0)] 

267 raise SecurityError( 

268 f"Dangerous shell character detected: {description}", 

269 guard_name="command_injection", 

270 value=arg[:50], 

271 ) 

272 

273 # Check against allowed commands whitelist 

274 if allowed_commands is not None: 

275 base_command = cmd_list[0] 

276 # Get just the command name without path 

277 command_name = Path(base_command).name 

278 cmd_not_allowed = ( 

279 command_name not in allowed_commands 

280 and base_command not in allowed_commands 

281 ) 

282 if cmd_not_allowed: 

283 raise SecurityError( 

284 f"Command not in allowed list: {command_name}", 

285 guard_name="command_injection", 

286 value=command_name, 

287 ) 

288 

289 return cmd_list 

290 

291 

292def guard_file_extension( 

293 filename: str | Path, 

294 *, 

295 allowed_extensions: Sequence[str] | None = None, 

296 denied_extensions: Sequence[str] | None = None, 

297) -> Path: 

298 """Validate file extension against allow/deny lists. 

299 

300 Args: 

301 filename: The filename to check. 

302 allowed_extensions: Extensions to allow (with or without dot). 

303 denied_extensions: Extensions to deny (with or without dot). 

304 

305 Returns: 

306 The filename as a Path. 

307 

308 Raises: 

309 SecurityError: If extension is not allowed or is denied. 

310 

311 """ 

312 path = Path(filename) 

313 ext = path.suffix.lower().lstrip(".") 

314 

315 # Normalize extension lists 

316 def normalize_ext(e: str) -> str: 

317 return e.lower().lstrip(".") 

318 

319 if denied_extensions is not None: 

320 denied = frozenset(normalize_ext(e) for e in denied_extensions) 

321 else: 

322 denied = _DEFAULT_DENIED_EXTENSIONS 

323 

324 if ext in denied: 

325 raise SecurityError( 

326 f"File extension '{ext}' is not allowed", 

327 guard_name="file_extension", 

328 value=str(path.name), 

329 ) 

330 

331 if allowed_extensions is not None: # pragma: no branch 

332 allowed = {normalize_ext(e) for e in allowed_extensions} 

333 if ext not in allowed: 

334 raise SecurityError( 

335 f"File extension '{ext}' is not in allowed list", 

336 guard_name="file_extension", 

337 value=str(path.name), 

338 ) 

339 

340 return path 

341 

342 

343def _check_env_denied( 

344 name_upper: str, 

345 name: str, 

346 denied_names: Sequence[str] | None, 

347) -> None: 

348 """Check if the environment variable is in the denied list.""" 

349 if denied_names is not None: 

350 denied = frozenset(n.upper() for n in denied_names) 

351 else: 

352 denied = _DEFAULT_DENIED_ENV_VARS 

353 

354 if name_upper in denied: 

355 raise SecurityError( 

356 f"Access to sensitive variable '{name}' is denied", 

357 guard_name="env_variable", 

358 value=name, 

359 ) 

360 

361 

362def _check_env_sensitive( 

363 name_upper: str, 

364 name: str, 

365 allowed_names: Sequence[str] | None, 

366) -> None: 

367 """Check if the environment variable matches sensitive patterns.""" 

368 if not _SENSITIVE_ENV_VAR_PATTERN.search(name_upper): 

369 return 

370 

371 # Only block if not explicitly allowed 

372 if allowed_names is not None: 

373 allowed = {n.upper() for n in allowed_names} 

374 if name_upper in allowed: 

375 return 

376 

377 raise SecurityError( 

378 f"Access to potentially sensitive variable '{name}' is denied", 

379 guard_name="env_variable", 

380 value=name, 

381 ) 

382 

383 

384def guard_env_variable( 

385 name: str, 

386 *, 

387 allowed_names: Sequence[str] | None = None, 

388 denied_names: Sequence[str] | None = None, 

389) -> str: 

390 """Guard against accessing sensitive environment variables. 

391 

392 Args: 

393 name: The environment variable name. 

394 allowed_names: Variable names to allow. 

395 denied_names: Variable names to deny. 

396 

397 Returns: 

398 The environment variable value if safe. 

399 

400 Raises: 

401 SecurityError: If variable access is not allowed. 

402 

403 """ 

404 # Validate input type 

405 if not isinstance(name, str): 

406 raise TypeError(f"Variable name must be str, got {type(name).__name__}") 

407 

408 # Reject empty/whitespace-only variable names 

409 if not name or not name.strip(): 

410 raise SecurityError( 

411 "Environment variable name cannot be empty or whitespace", 

412 guard_name="env_variable", 

413 ) 

414 

415 name_upper = name.upper() 

416 

417 _check_env_denied(name_upper, name, denied_names) 

418 _check_env_sensitive(name_upper, name, allowed_names) 

419 

420 # Get the variable 

421 value = os.environ.get(name) 

422 if value is None: 

423 raise SecurityError( 

424 f"Environment variable '{name}' is not set", 

425 guard_name="env_variable", 

426 value=name, 

427 ) 

428 

429 return value 

430 

431 

432def guard_hash_algorithm( 

433 algorithm: str, 

434 *, 

435 allowed_algorithms: Sequence[str] | None = None, 

436) -> str: 

437 """Validate hash algorithm against a whitelist of secure ones. 

438 

439 Args: 

440 algorithm: The name of the hash algorithm to validate. 

441 allowed_algorithms: Optional whitelist of allowed algorithms. 

442 Defaults to a secure set (SHA-256, SHA-512, etc.). 

443 

444 Returns: 

445 The normalized (lowercase) algorithm name. 

446 

447 Raises: 

448 SecurityError: If the algorithm is potentially weak or not allowed. 

449 

450 """ 

451 if not isinstance(algorithm, str): 

452 raise TypeError(f"Algorithm name must be str, got {type(algorithm).__name__}") 

453 

454 algo_lower = algorithm.lower().replace("-", "") 

455 

456 allowed: frozenset[str] 

457 if allowed_algorithms is not None: 

458 allowed = frozenset(a.lower().replace("-", "") for a in allowed_algorithms) 

459 else: 

460 allowed = _SAFE_HASH_ALGORITHMS 

461 

462 if algo_lower not in allowed: 

463 raise SecurityError( 

464 f"Hash algorithm '{algorithm}' is considered weak or is not allowed", 

465 guard_name="hash_algorithm", 

466 value=algorithm, 

467 ) 

468 

469 return algo_lower 

470 

471 

472# ── SSRF Private-Range Constants ───────────────────────────────────────────── 

473_ALLOWED_SSRF_SCHEMES: frozenset[str] = frozenset({"http", "https"}) 

474 

475 

476def _validate_ssrf_url( 

477 url: str, 

478 allowed_schemes: frozenset[str], 

479) -> Result[str, SecurityError]: 

480 """Validate the URL format, scheme, and presence of hostname.""" 

481 if not isinstance(url, str): 

482 raise TypeError(f"URL must be str, got {type(url).__name__}") 

483 

484 if not url: 

485 return Err(SecurityError("URL cannot be empty", guard_name="ssrf")) 

486 

487 try: 

488 parsed = urlparse(url) 

489 except ValueError as exc: # pragma: no cover 

490 return Err( 

491 SecurityError( 

492 f"Malformed URL: {exc}", 

493 guard_name="ssrf", 

494 value=url[:80], 

495 ) 

496 ) 

497 

498 if not parsed.scheme or parsed.scheme.lower() not in allowed_schemes: 

499 return Err( 

500 SecurityError( 

501 f"URL scheme '{parsed.scheme}' is not allowed", 

502 guard_name="ssrf", 

503 value=url[:80], 

504 ) 

505 ) 

506 

507 hostname = parsed.hostname 

508 if not hostname: 

509 return Err( 

510 SecurityError( 

511 "URL has no resolvable hostname", 

512 guard_name="ssrf", 

513 value=url[:80], 

514 ) 

515 ) 

516 

517 return Ok(hostname) 

518 

519 

520def _check_ip_safety(hostname: str) -> Result[None, SecurityError]: 

521 """Resolve hostname to IP addresses and check for SSRF risk.""" 

522 try: 

523 addr_infos = socket.getaddrinfo(hostname, None) 

524 except socket.gaierror: 

525 return Err( 

526 SecurityError( 

527 "Hostname could not be resolved", 

528 guard_name="ssrf", 

529 ) 

530 ) 

531 

532 for addr_info in addr_infos: 

533 raw_ip = addr_info[4][0] 

534 try: 

535 addr = ipaddress.ip_address(raw_ip) 

536 except ValueError: 

537 continue 

538 

539 if ( 

540 addr.is_private 

541 or addr.is_loopback 

542 or addr.is_link_local 

543 or addr.is_reserved 

544 ): 

545 return Err( 

546 SecurityError( 

547 "SSRF detected: hostname resolves to private/reserved address", 

548 guard_name="ssrf", 

549 ) 

550 ) 

551 

552 return Ok(None) 

553 

554 

555def guard_ssrf( 

556 url: str, 

557 *, 

558 allowed_schemes: frozenset[str] = _ALLOWED_SSRF_SCHEMES, 

559) -> Result[str, SecurityError]: 

560 """Validate a URL against Server-Side Request Forgery (SSRF) attacks. 

561 

562 Parse the URL, resolve its hostname via DNS, and reject it when the 

563 resulting IP address falls inside a private, loopback, link-local, or 

564 otherwise reserved network range. 

565 

566 Args: 

567 url: The URL string to validate. 

568 allowed_schemes: Set of URL schemes considered safe. 

569 Defaults to ``{"http", "https"}``. 

570 

571 Returns: 

572 ``Ok(url)`` when the URL is safe to fetch. 

573 ``Err(SecurityError)`` when an SSRF risk is detected. 

574 

575 Raises: 

576 TypeError: If *url* is not a :class:`str`. 

577 

578 Example: 

579 >>> guard_ssrf("https://example.com") 

580 Ok('https://example.com') 

581 >>> guard_ssrf("http://169.254.169.254/metadata") 

582 Err(SecurityError('[ssrf] ...)) 

583 

584 """ 

585 # 1. Validate format and scheme 

586 match _validate_ssrf_url(url, allowed_schemes): 

587 case Err(e): 

588 return Err(e) 

589 case Ok(hostname): 

590 # 2. Check IP safety 

591 match _check_ip_safety(hostname): 

592 case Err(e): 

593 return Err(e) 

594 case Ok(): 

595 return Ok(url)