Coverage for src / taipanstack / utils / subprocess.py: 100%
84 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-12 21:18 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-12 21:18 +0000
1"""
2Safe subprocess execution with security guards.
4Provides secure wrappers around subprocess execution with
5command validation, timeout handling, and retry logic.
6"""
8import math
9import os
10import shutil
11import subprocess # nosec B404
12import time
13from collections.abc import Sequence
14from dataclasses import dataclass
15from pathlib import Path
17from taipanstack.security.guards import (
18 SecurityError,
19 guard_command_injection,
20)
23@dataclass(frozen=True)
24class SafeCommandResult:
25 """Result of a safe command execution.
27 Attributes:
28 command: The executed command.
29 returncode: Exit code of the command.
30 stdout: Standard output.
31 stderr: Standard error.
32 success: Whether the command succeeded (returncode == 0).
33 duration_seconds: How long the command took.
35 """
37 command: list[str]
38 returncode: int
39 stdout: str = ""
40 stderr: str = ""
41 duration_seconds: float = 0.0
43 @property
44 def success(self) -> bool:
45 """Check if command succeeded."""
46 return self.returncode == 0
48 def raise_on_error(self) -> "SafeCommandResult":
49 """Raise an exception if command failed.
51 Returns:
52 Self if successful.
54 Raises:
55 subprocess.CalledProcessError: If command failed.
57 """
58 if not self.success:
59 raise subprocess.CalledProcessError(
60 self.returncode,
61 self.command,
62 self.stdout,
63 self.stderr,
64 )
65 return self
68# Default allowed commands whitelist
69DEFAULT_ALLOWED_COMMANDS: frozenset[str] = frozenset(
70 {
71 # Python/Poetry
72 "python",
73 "python3",
74 "pip",
75 "pip3",
76 "poetry",
77 "pipx",
78 # Version control
79 "git",
80 # Build tools
81 "make",
82 # Testing
83 "pytest",
84 "mypy",
85 "ruff",
86 "bandit",
87 "safety",
88 "semgrep",
89 "pre-commit",
90 # System
91 "echo",
92 "cat",
93 "ls",
94 "pwd",
95 "mkdir",
96 "rm",
97 "cp",
98 "mv",
99 "touch",
100 "chmod",
101 "which",
102 }
103)
106def _validate_and_resolve_command(
107 command: Sequence[str],
108 allowed_commands: Sequence[str] | None,
109) -> list[str]:
110 """Validate and resolve a command.
112 Args:
113 command: Command and arguments as a sequence.
114 allowed_commands: Whitelist of allowed commands.
116 Returns:
117 The validated command as a list.
119 Raises:
120 SecurityError: If validation fails.
122 """
123 cmd_list = list(command)
125 if not cmd_list:
126 raise SecurityError(
127 "Empty command is not allowed",
128 guard_name="safe_command",
129 )
131 if allowed_commands is not None:
132 whitelist = list(allowed_commands)
133 else:
134 whitelist = list(DEFAULT_ALLOWED_COMMANDS)
136 validated_cmd = guard_command_injection(cmd_list, allowed_commands=whitelist)
138 base_command = validated_cmd[0]
139 if not shutil.which(base_command):
140 raise SecurityError(
141 f"Command not found: {base_command}",
142 guard_name="safe_command",
143 value=base_command,
144 )
146 return validated_cmd
149def _get_allowed_keys(allowed_env_vars: Sequence[str] | None) -> set[str]:
150 """Get the allowed environment variable keys.
152 Args:
153 allowed_env_vars: Whitelist of allowed environment variable names.
155 Returns:
156 Set of allowed keys in uppercase.
158 """
159 if allowed_env_vars is None:
160 return {"PATH"}
161 return {k.upper() for k in allowed_env_vars}
164def _filter_environment(
165 env: dict[str, str] | None,
166 allowed_env_vars: Sequence[str] | None = None,
167) -> dict[str, str]:
168 """Filter environment variables for safe execution using a whitelist approach.
170 Args:
171 env: Environment variables to filter. If None, uses os.environ.
172 allowed_env_vars: Whitelist of allowed environment variable names.
173 If None, only minimal necessary variables (e.g. PATH) are inherited.
174 Pass an empty list for a completely empty environment.
176 Returns:
177 Filtered environment variables containing only allowed keys.
179 """
180 env_to_filter = env if env is not None else dict(os.environ)
181 allowed_keys = _get_allowed_keys(allowed_env_vars)
183 if not allowed_keys:
184 return {}
186 return {k: str(v) for k, v in env_to_filter.items() if k.upper() in allowed_keys}
189def _extract_timeout_stdout(e: subprocess.TimeoutExpired) -> str:
190 """Extract stdout from a TimeoutExpired exception safely.
192 Args:
193 e: The TimeoutExpired exception.
195 Returns:
196 The extracted stdout string.
198 """
199 if not hasattr(e, "stdout") or e.stdout is None:
200 return ""
201 if isinstance(e.stdout, str):
202 return e.stdout
203 return e.stdout.decode("utf-8", errors="replace")
206def _execute_command(
207 validated_cmd: list[str],
208 cwd: Path | None,
209 timeout: float,
210 capture_output: bool,
211 safe_env: dict[str, str],
212) -> SafeCommandResult:
213 """Execute the command and handle TimeoutExpired.
215 Args:
216 validated_cmd: Validated command list.
217 cwd: Resolved working directory.
218 timeout: Execution timeout.
219 capture_output: Capture stdout/stderr.
220 safe_env: Filtered environment variables.
222 Returns:
223 The execution result.
225 """
226 start_time = time.time()
228 try:
229 result = subprocess.run( # nosec B603
230 validated_cmd,
231 cwd=cwd,
232 timeout=timeout,
233 capture_output=capture_output,
234 text=True,
235 encoding="utf-8",
236 env=safe_env,
237 check=False,
238 )
239 except subprocess.TimeoutExpired as e:
240 duration = time.time() - start_time
241 return SafeCommandResult(
242 command=validated_cmd,
243 returncode=-1,
244 stdout=_extract_timeout_stdout(e),
245 stderr=f"Command timed out after {timeout}s",
246 duration_seconds=duration,
247 )
249 duration = time.time() - start_time
251 return SafeCommandResult(
252 command=validated_cmd,
253 returncode=result.returncode,
254 stdout=result.stdout or "",
255 stderr=result.stderr or "",
256 duration_seconds=duration,
257 )
260def _validate_cwd(cwd: Path | str | None) -> Path | None:
261 """Resolve and validate the working directory."""
262 if cwd is None:
263 return None
264 resolved_cwd = Path(cwd).resolve()
265 if not resolved_cwd.exists():
266 raise SecurityError(
267 f"Working directory does not exist: {cwd}",
268 guard_name="safe_command",
269 )
270 return resolved_cwd
273def _handle_dry_run(validated_cmd: list[str]) -> SafeCommandResult:
274 """Return a SafeCommandResult for a dry run execution."""
275 return SafeCommandResult(
276 command=validated_cmd,
277 returncode=0,
278 stdout=f"[DRY-RUN] Would execute: {' '.join(validated_cmd)}",
279 stderr="",
280 duration_seconds=0.0,
281 )
284def _validate_timeout(timeout: float) -> None:
285 """Validate that the timeout is a finite, non-negative number.
287 Args:
288 timeout: The timeout value to validate.
290 Raises:
291 ValueError: If timeout is invalid.
293 """
294 if timeout is not None and not (math.isfinite(timeout) and timeout >= 0):
295 raise ValueError("timeout must be a finite non-negative number")
298def run_safe_command(
299 command: Sequence[str],
300 *,
301 cwd: Path | str | None = None,
302 timeout: float = 300.0,
303 capture_output: bool = True,
304 check: bool = False,
305 allowed_commands: Sequence[str] | None = None,
306 env: dict[str, str] | None = None,
307 allowed_env_vars: Sequence[str] | None = None,
308 dry_run: bool = False,
309) -> SafeCommandResult:
310 """Execute a command safely with security guards.
312 This function provides a secure wrapper around subprocess.run
313 with command injection protection, timeout handling, and
314 optional command whitelisting.
316 Args:
317 command: Command and arguments as a sequence.
318 cwd: Working directory for the command.
319 timeout: Maximum execution time in seconds.
320 capture_output: Whether to capture stdout/stderr.
321 check: Whether to raise on non-zero exit.
322 allowed_commands: Whitelist of allowed commands.
323 env: Environment variables to set.
324 allowed_env_vars: Whitelist of allowed environment variables.
325 dry_run: If True, don't actually execute the command.
327 Returns:
328 SafeCommandResult with execution details.
330 Raises:
331 SecurityError: If command validation fails.
332 subprocess.TimeoutExpired: If command times out.
333 subprocess.CalledProcessError: If check=True and command fails.
335 Example:
336 >>> result = run_safe_command(["poetry", "install"])
337 >>> if result.success:
338 ... print("Installation complete!")
340 """
341 # Security Enhancement: Prevent DoS via infinite blocking behaviors or
342 # unhandled exceptions from threading/asyncio primitives by explicitly
343 # validating that timeout is a finite, non-negative number.
344 _validate_timeout(timeout)
346 validated_cmd = _validate_and_resolve_command(command, allowed_commands)
347 safe_env = _filter_environment(env, allowed_env_vars)
349 if dry_run:
350 return _handle_dry_run(validated_cmd)
352 resolved_cwd = _validate_cwd(cwd)
354 safe_result = _execute_command(
355 validated_cmd,
356 resolved_cwd,
357 timeout,
358 capture_output,
359 safe_env,
360 )
362 if check:
363 safe_result.raise_on_error()
365 return safe_result