Coverage for src / taipanstack / security / sanitizers.py: 100%

201 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-12 21:18 +0000

1""" 

2Input sanitizers for cleaning untrusted data. 

3 

4Provides functions to sanitize strings, filenames, and paths 

5to remove potentially dangerous characters. 

6""" 

7 

8import re 

9from pathlib import Path 

10 

11# Constants to avoid magic values (PLR2004) 

12MAX_SQL_IDENTIFIER_LENGTH = 128 # pragma: no mutate 

13MAX_PATH_LENGTH = 4096 # pragma: no mutate 

14 

15# Pre-compiled regex and sets for Performance Benchmarks 

16_INVALID_FILENAME_CHARS_RE = re.compile(r'[<>:"/\\|?*\x00-\x1f]') # pragma: no mutate 

17_SQL_IDENTIFIER_DENY_RE = re.compile(r"[^a-zA-Z0-9_]") # pragma: no mutate 

18_HTML_TAGS_RE = re.compile(r"<[^>]+>") # pragma: no mutate 

19# Remove control characters (C0 and C1 sets) 

20_CONTROL_CHARS_RE = re.compile( 

21 r"[\x00-\x08\x0b\x0c\x0e-\x1f\x7f-\x9f]" 

22) # pragma: no mutate 

23_VALID_SQL_PREFIX = frozenset( 

24 "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ_" 

25) # pragma: no mutate 

26_WINDOWS_RESERVED_NAMES = frozenset( # pragma: no mutate 

27 { 

28 "CON", 

29 "PRN", 

30 "AUX", 

31 "NUL", 

32 "COM1", 

33 "COM2", 

34 "COM3", 

35 "COM4", 

36 "COM5", 

37 "COM6", 

38 "COM7", 

39 "COM8", 

40 "COM9", 

41 "LPT1", 

42 "LPT2", 

43 "LPT3", 

44 "LPT4", 

45 "LPT5", 

46 "LPT6", 

47 "LPT7", 

48 "LPT8", 

49 "LPT9", 

50 } 

51) 

52 

53 

54def _handle_html(result: str, allow_html: bool) -> str: 

55 """Remove HTML tags and escape entities if not allowed.""" 

56 if allow_html: 

57 return result 

58 result = _HTML_TAGS_RE.sub("", result) 

59 result = result.replace("&", "&amp;") 

60 result = result.replace("<", "&lt;") 

61 return result.replace(">", "&gt;") 

62 

63 

64def _handle_unicode(result: str, allow_unicode: bool) -> str: 

65 """Filter out non-ASCII characters if unicode is not allowed.""" 

66 if allow_unicode: 

67 return result 

68 return result.encode("ascii", errors="ignore").decode("ascii") 

69 

70 

71def sanitize_string( 

72 value: str, 

73 *, 

74 max_length: int | None = None, 

75 allow_html: bool = False, 

76 allow_unicode: bool = True, 

77 strip_whitespace: bool = True, 

78) -> str: 

79 """Sanitize a string by removing dangerous characters. 

80 

81 Args: 

82 value: The string to sanitize. 

83 max_length: Maximum length to truncate to. 

84 allow_html: Whether to keep HTML tags (default: False). 

85 allow_unicode: Whether to keep non-ASCII characters. 

86 strip_whitespace: Whether to strip leading/trailing whitespace. 

87 

88 Returns: 

89 The sanitized string. 

90 

91 Example: 

92 ```python 

93 sanitize_string("<script>alert('xss')</script>Hello") 

94 # Returns: "scriptalert('xss')/scriptHello" 

95 ``` 

96 

97 """ 

98 if not isinstance(value, str): 

99 raise TypeError(f"value must be str, got {type(value).__name__}") 

100 

101 if not value: 

102 return "" 

103 

104 result = value.strip() if strip_whitespace else value 

105 result = _CONTROL_CHARS_RE.sub("", result) 

106 result = _handle_html(result, allow_html) 

107 result = _handle_unicode(result, allow_unicode) 

108 

109 if max_length is not None and len(result) > max_length: 

110 return result[:max_length] 

111 return result 

112 

113 

114def _get_filename_from_path(filename: str) -> str: 

115 """Extract the base filename from a full path.""" 

116 slash_idx = max(filename.rfind("/"), filename.rfind("\\")) 

117 if slash_idx >= 0: 

118 return filename[slash_idx + 1 :] 

119 return filename 

120 

121 

122def _has_valid_extension(name: str, idx: int) -> bool: 

123 """Determine if a dot represents a valid extension.""" 

124 return idx > 0 and not all(c == "." for c in name) and name != ".." 

125 

126 

127def _extract_stem_and_suffix( 

128 filename: str, preserve_extension: bool 

129) -> tuple[str, str]: 

130 """Extract stem and suffix from a filename.""" 

131 name = _get_filename_from_path(filename) 

132 idx = name.rfind(".") 

133 

134 if _has_valid_extension(name, idx): 

135 stem = name[:idx] 

136 suffix = name[idx:] if preserve_extension else "" 

137 return stem, suffix 

138 

139 return name, "" 

140 

141 

142def _remove_invalid_chars(stem: str, replacement: str) -> str: 

143 """Remove or replace invalid characters in a filename stem.""" 

144 try: 

145 if "\\" in replacement: 

146 # Use lambda to avoid processing regex escape sequences in replacement 

147 safe_stem = _INVALID_FILENAME_CHARS_RE.sub(lambda _: replacement, stem) 

148 else: 

149 safe_stem = _INVALID_FILENAME_CHARS_RE.sub(replacement, stem) 

150 except re.error: 

151 safe_stem = _INVALID_FILENAME_CHARS_RE.sub("_", stem) 

152 

153 # Remove leading/trailing dots and spaces (Windows issues) 

154 safe_stem = safe_stem.strip(". ") 

155 

156 # Remove path separators that might have snuck through 

157 safe_stem = safe_stem.replace("/", replacement) 

158 safe_stem = safe_stem.replace("\\", replacement) 

159 

160 return safe_stem 

161 

162 

163def _collapse_replacements(safe_stem: str, replacement: str) -> str: 

164 """Collapse multiple consecutive replacement characters.""" 

165 if replacement: 

166 double_replacement = replacement + replacement 

167 while double_replacement in safe_stem: 

168 safe_stem = safe_stem.replace(double_replacement, replacement) 

169 safe_stem = safe_stem.strip(replacement) 

170 return safe_stem 

171 

172 

173def _truncate_filename(safe_stem: str, suffix: str, max_length: int) -> str: 

174 """Truncate the filename while keeping the extension if possible.""" 

175 result = f"{safe_stem}{suffix}" 

176 if len(result) > max_length: 

177 available = max_length - len(suffix) 

178 if available > 0: 

179 safe_stem = safe_stem[:available] 

180 result = f"{safe_stem}{suffix}" 

181 else: 

182 result = result[:max_length] 

183 return result 

184 

185 

186def _is_filename_safe(filename: str, max_length: int, stem: str) -> bool: 

187 """Check if a filename is already safe without any modifications.""" 

188 return ( 

189 len(filename) <= max_length 

190 and filename not in {"..", "."} 

191 and stem.upper() not in _WINDOWS_RESERVED_NAMES 

192 and filename.isascii() 

193 and filename.replace(".", "").replace("-", "").replace("_", "").isalnum() 

194 ) 

195 

196 

197def _finalize_filename( 

198 safe_stem: str, replacement: str, suffix: str, max_length: int 

199) -> str: 

200 """Finalize the sanitized filename by handling reserved names and empty results.""" 

201 # Handle reserved names (Windows) 

202 if safe_stem.upper() in _WINDOWS_RESERVED_NAMES: 

203 safe_stem = f"{replacement}{safe_stem}" 

204 

205 # Handle empty result 

206 if not safe_stem: 

207 safe_stem = "unnamed" 

208 

209 return _truncate_filename(safe_stem, suffix, max_length) 

210 

211 

212def sanitize_filename( 

213 filename: str, 

214 *, 

215 max_length: int = 255, 

216 replacement: str = "_", 

217 preserve_extension: bool = True, 

218) -> str: 

219 """Sanitize a filename to be safe for filesystem use. 

220 

221 Removes or replaces characters that are: 

222 - Not allowed in filenames on various OSes 

223 - Potentially dangerous (path separators, etc.) 

224 

225 Args: 

226 filename: The filename to sanitize. 

227 max_length: Maximum length for the filename. 

228 replacement: Character to replace invalid chars with. 

229 preserve_extension: Keep original extension. 

230 

231 Returns: 

232 The sanitized filename. 

233 

234 Example: 

235 ```python 

236 sanitize_filename("my/../file<>:name.txt") 

237 # Returns: 'my_file_name.txt' 

238 ``` 

239 

240 """ 

241 if not isinstance(filename, str): 

242 raise TypeError(f"filename must be str, got {type(filename).__name__}") 

243 

244 if not filename: 

245 filename = "unnamed" 

246 

247 stem, suffix = _extract_stem_and_suffix(filename, preserve_extension) 

248 

249 if _is_filename_safe(filename, max_length, stem): 

250 return f"{stem}{suffix}" 

251 

252 safe_stem = _remove_invalid_chars(stem, replacement) 

253 safe_stem = _collapse_replacements(safe_stem, replacement) 

254 

255 return _finalize_filename(safe_stem, replacement, suffix, max_length) 

256 

257 

258def _get_stem(part: str) -> str: 

259 """Get the stem of a path part.""" 

260 idx = part.rfind(".") 

261 return part[:idx] if idx > 0 and not all(c == "." for c in part) else part 

262 

263 

264def _is_safe_path_part(part: str, stem: str) -> bool: 

265 """Check if a path part is safe.""" 

266 return ( 

267 len(part) <= 255 # noqa: PLR2004 

268 and part.isascii() 

269 and part.replace(".", "").replace("-", "").replace("_", "").isalnum() 

270 and stem.upper() not in _WINDOWS_RESERVED_NAMES 

271 ) 

272 

273 

274def _handle_dot_dot(parts: list[str], anchor: str) -> None: 

275 """Handle '..' by popping the last part if safe.""" 

276 if parts and parts[-1] != ".." and parts[-1] != anchor: 

277 parts.pop() 

278 

279 

280def _handle_normal_part(part: str, parts: list[str]) -> None: 

281 """Handle a normal part by checking if it's safe or sanitizing it.""" 

282 stem = _get_stem(part) 

283 if _is_safe_path_part(part, stem): 

284 parts.append(part) 

285 else: 

286 safe_part = sanitize_filename(part, preserve_extension=True) 

287 if safe_part and safe_part != "..": # pragma: no branch 

288 parts.append(safe_part) 

289 

290 

291def _process_path_part(part: str, parts: list[str], anchor: str) -> None: 

292 """Process a single path component, updating the parts list inline.""" 

293 if part == "..": 

294 _handle_dot_dot(parts, anchor) 

295 elif part != ".": # pragma: no branch 

296 _handle_normal_part(part, parts) 

297 

298 

299def _clean_path_parts(path: Path) -> list[str]: 

300 """Clean and sanitize individual path components.""" 

301 parts: list[str] = [] 

302 anchor = path.anchor 

303 for part in path.parts: 

304 _process_path_part(part, parts, anchor) 

305 return parts 

306 

307 

308def _apply_base_dir_constraint( 

309 sanitized: Path, 

310 base_dir: Path | str | None, 

311 resolve: bool, 

312) -> Path: 

313 """Apply base directory constraints to a sanitized path.""" 

314 if base_dir is None: 

315 return sanitized 

316 

317 base = Path(base_dir).resolve() 

318 if resolve: 

319 try: 

320 return sanitized.resolve() 

321 except (OSError, RuntimeError) as e: 

322 msg = f"Cannot resolve path: {e}" 

323 raise ValueError(msg) from e 

324 

325 # Make absolute relative to base 

326 if not sanitized.is_absolute(): # pragma: no branch 

327 return base / sanitized 

328 

329 return sanitized 

330 

331 

332def _normalize_path_input(path: str | Path) -> Path: 

333 """Normalize input path string or Path object.""" 

334 if isinstance(path, str): 

335 if len(path) > MAX_PATH_LENGTH: 

336 msg = "Path length exceeds maximum allowed" 

337 raise ValueError(msg) 

338 if "\x00" in path: # pragma: no branch 

339 path = path.replace("\x00", "") 

340 return Path(path) 

341 

342 if len(str(path)) > MAX_PATH_LENGTH: 

343 msg = "Path length exceeds maximum allowed" 

344 raise ValueError(msg) 

345 return Path(path) 

346 

347 

348def _reconstruct_path(original_path: Path, parts: list[str]) -> Path: 

349 """Reconstruct a path from its sanitized parts.""" 

350 if original_path.is_absolute(): # pragma: no branch 

351 # Use path.anchor to correctly preserve absolute roots on Windows (e.g. C:\) 

352 anchor = Path(original_path.anchor) 

353 return anchor.joinpath(*parts) if parts else anchor 

354 

355 if parts: # pragma: no branch 

356 return Path().joinpath(*parts) 

357 

358 return Path() 

359 

360 

361def _validate_path_depth(path: Path, max_depth: int | None) -> None: 

362 """Validate that the path depth does not exceed max_depth.""" 

363 depth = len(path.parts) 

364 if max_depth is not None and depth > max_depth: 

365 msg = f"Path depth {depth} exceeds maximum of {max_depth}" 

366 raise ValueError(msg) 

367 

368 

369def sanitize_path( 

370 path: str | Path, 

371 *, 

372 base_dir: Path | None = None, 

373 max_depth: int | None = 10, 

374 resolve: bool = False, 

375) -> Path: 

376 """Sanitize a path to prevent traversal and normalize it. 

377 

378 Args: 

379 path: The path to sanitize. 

380 base_dir: Optional base directory to constrain to. 

381 max_depth: Maximum directory depth allowed. 

382 resolve: Whether to resolve the path (requires it to exist). 

383 

384 Returns: 

385 The sanitized Path object. 

386 

387 Raises: 

388 ValueError: If path is invalid or too deep. 

389 

390 """ 

391 normalized_path = _normalize_path_input(path) 

392 parts = _clean_path_parts(normalized_path) 

393 sanitized = _reconstruct_path(normalized_path, parts) 

394 

395 _validate_path_depth(sanitized, max_depth) 

396 return _apply_base_dir_constraint(sanitized, base_dir, resolve) 

397 

398 

399def _sanitize_env_multiline(value: str, max_length: int) -> str: 

400 """Sanitize an environment value allowing multiline characters.""" 

401 if "\x00" not in value and len(value) <= max_length: 

402 return value 

403 return value.replace("\x00", "") 

404 

405 

406def _sanitize_env_singleline(value: str, max_length: int) -> str: 

407 """Sanitize an environment value, converting multiline to spaces.""" 

408 if ( 

409 "\x00" not in value 

410 and "\n" not in value 

411 and "\r" not in value 

412 and len(value) <= max_length 

413 ): 

414 return value 

415 return value.replace("\x00", "").replace("\n", " ").replace("\r", " ") 

416 

417 

418def sanitize_env_value( 

419 value: str, 

420 *, 

421 max_length: int = 4096, 

422 allow_multiline: bool = False, 

423) -> str: 

424 """Sanitize a value for use as an environment variable. 

425 

426 Args: 

427 value: The value to sanitize. 

428 max_length: Maximum length allowed. 

429 allow_multiline: Whether to allow newlines. 

430 

431 Returns: 

432 The sanitized value. 

433 

434 Raises: 

435 TypeError: If value is not a string. 

436 

437 """ 

438 if not isinstance(value, str): 

439 raise TypeError(f"value must be str, got {type(value).__name__}") 

440 

441 if not value: 

442 return "" 

443 

444 if allow_multiline: 

445 result = _sanitize_env_multiline(value, max_length) 

446 else: 

447 result = _sanitize_env_singleline(value, max_length) 

448 

449 if len(result) > max_length: 

450 return result[:max_length] 

451 return result 

452 

453 

454def _sanitize_sql_identifier_slow_path(identifier: str) -> str: 

455 """Apply slow path sanitization for SQL identifiers.""" 

456 result = _SQL_IDENTIFIER_DENY_RE.sub("", identifier) 

457 

458 # Must start with letter or underscore 

459 if result and not result[0].isalpha() and result[0] != "_": 

460 result = f"_{result}" 

461 

462 # Check length (most DBs limit to 128 chars) 

463 if len(result) > MAX_SQL_IDENTIFIER_LENGTH: 

464 result = result[:MAX_SQL_IDENTIFIER_LENGTH] 

465 

466 if not result: 

467 msg = "SQL identifier contains no valid characters" 

468 raise ValueError(msg) 

469 

470 return result 

471 

472 

473def sanitize_sql_identifier(identifier: str) -> str: 

474 """Sanitize a SQL identifier (table/column name). 

475 

476 Note: This is NOT for SQL values - use parameterized queries for those! 

477 

478 Args: 

479 identifier: The identifier to sanitize. 

480 

481 Returns: 

482 The sanitized identifier. 

483 

484 Raises: 

485 TypeError: If identifier is not a string. 

486 ValueError: If identifier is empty or too long. 

487 

488 """ 

489 if type(identifier) is str: 

490 if ( 

491 len(identifier) <= 128 # noqa: PLR2004 

492 and identifier.isascii() 

493 and identifier.isidentifier() 

494 ): 

495 return identifier 

496 

497 if not identifier: 

498 msg = "SQL identifier cannot be empty" 

499 raise ValueError(msg) 

500 

501 return _sanitize_sql_identifier_slow_path(identifier) 

502 

503 raise TypeError(f"identifier must be str, got {type(identifier).__name__}")