Coverage for src / taipanstack / bridges / web_bridge.py: 100%
74 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-12 21:18 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-12 21:18 +0000
1"""
2Web Bridge — ASGI middleware for rate limiting and security headers.
4Provides a framework-agnostic ASGI middleware that integrates
5TaipanStack's rate limiter and security headers into any ASGI
6application (FastAPI, Litestar, Starlette, etc.).
7"""
9from __future__ import annotations
11import json
12import logging
13from collections.abc import Awaitable, Callable, MutableMapping
14from dataclasses import dataclass
15from typing import TypeAlias, TypeVar
17from taipanstack.core.result import Err, Ok, Result
18from taipanstack.utils.rate_limit import RateLimiter
20logger = logging.getLogger("taipanstack.bridges.web")
22T = TypeVar("T")
24# ASGI type aliases
25Scope: TypeAlias = MutableMapping[str, object]
26Receive: TypeAlias = Callable[[], Awaitable[MutableMapping[str, object]]]
27Send: TypeAlias = Callable[[MutableMapping[str, object]], Awaitable[None]]
28ASGIApp: TypeAlias = Callable[[Scope, Receive, Send], Awaitable[None]]
31@dataclass(frozen=True)
32class SecurityHeadersConfig:
33 """Configuration for security response headers.
35 Attributes:
36 x_content_type_options: Value for X-Content-Type-Options.
37 x_frame_options: Value for X-Frame-Options.
38 x_xss_protection: Value for X-XSS-Protection.
39 strict_transport_security: Value for Strict-Transport-Security.
40 referrer_policy: Value for Referrer-Policy.
41 content_security_policy: Value for Content-Security-Policy.
43 """
45 x_content_type_options: str = "nosniff"
46 x_frame_options: str = "DENY"
47 x_xss_protection: str = "1; mode=block"
48 strict_transport_security: str = "max-age=31536000; includeSubDomains"
49 referrer_policy: str = "strict-origin-when-cross-origin"
50 content_security_policy: str = "default-src 'self'"
52 def to_headers(self) -> list[tuple[bytes, bytes]]:
53 """Convert config to ASGI header pairs.
55 Returns:
56 List of (name, value) byte tuples.
58 """
59 return [
60 (b"x-content-type-options", self.x_content_type_options.encode()),
61 (b"x-frame-options", self.x_frame_options.encode()),
62 (b"x-xss-protection", self.x_xss_protection.encode()),
63 (
64 b"strict-transport-security",
65 self.strict_transport_security.encode(),
66 ),
67 (b"referrer-policy", self.referrer_policy.encode()),
68 (b"content-security-policy", self.content_security_policy.encode()),
69 ]
72def result_to_response(
73 result: Result[T, Exception],
74 *,
75 status_ok: int = 200,
76 status_err: int = 500,
77) -> dict[str, object]:
78 """Convert a ``Result`` to a JSON-friendly response dict.
80 Args:
81 result: The Result to convert.
82 status_ok: HTTP status for ``Ok`` values.
83 status_err: HTTP status for ``Err`` values.
85 Returns:
86 Dict with ``status``, ``data``/``error`` keys.
88 Example:
89 >>> result_to_response(Ok({"id": 1}))
90 {"status": 200, "data": {"id": 1}}
92 """
93 match result:
94 case Ok(value):
95 return {"status": status_ok, "data": value}
96 case Err(error):
97 return {"status": status_err, "error": str(error)}
100async def _send_json_response(
101 send: Send,
102 *,
103 status: int,
104 body: dict[str, object],
105 extra_headers: list[tuple[bytes, bytes]] | None = None,
106) -> None:
107 """Send a JSON response via ASGI send.
109 Args:
110 send: ASGI send callable.
111 status: HTTP status code.
112 body: JSON-serializable body.
113 extra_headers: Additional headers to include.
115 """
116 payload = json.dumps(body).encode("utf-8")
117 headers: list[tuple[bytes, bytes]] = [
118 (b"content-type", b"application/json"),
119 (b"content-length", str(len(payload)).encode()),
120 ]
121 if extra_headers:
122 headers.extend(extra_headers)
124 await send(
125 {
126 "type": "http.response.start",
127 "status": status,
128 "headers": headers,
129 }
130 )
131 await send(
132 {
133 "type": "http.response.body",
134 "body": payload,
135 }
136 )
139class TaipanMiddleware:
140 """ASGI middleware providing rate limiting and security headers.
142 Args:
143 app: The wrapped ASGI application.
144 rate_limiter: Optional rate limiter instance.
145 security_headers: Whether to inject security headers.
146 headers_config: Custom security headers configuration.
148 Example:
149 >>> from taipanstack.utils.rate_limit import RateLimiter
150 >>> app = TaipanMiddleware(
151 ... my_asgi_app,
152 ... rate_limiter=RateLimiter(max_calls=100, time_window=60),
153 ... security_headers=True,
154 ... )
156 """
158 def __init__(
159 self,
160 app: ASGIApp,
161 *,
162 rate_limiter: RateLimiter | None = None,
163 security_headers: bool = True,
164 headers_config: SecurityHeadersConfig | None = None,
165 ) -> None:
166 """Initialize the middleware.
168 Args:
169 app: ASGI application to wrap.
170 rate_limiter: Optional rate limiter.
171 security_headers: Inject security headers.
172 headers_config: Custom headers config.
174 """
175 self._app = app
176 self._rate_limiter = rate_limiter
177 self._security_headers = security_headers
178 self._headers_config = headers_config or SecurityHeadersConfig()
180 def _wrap_send_with_security_headers(self, send: Send) -> Send:
181 """Wrap the send callable to inject security headers if enabled.
183 Args:
184 send: The original ASGI send callable.
186 Returns:
187 The wrapped ASGI send callable.
189 """
190 if not self._security_headers:
191 return send
193 extra_headers = self._headers_config.to_headers()
195 async def send_with_headers(message: MutableMapping[str, object]) -> None:
196 if message.get("type") == "http.response.start":
197 headers = message.get("headers")
198 existing = list(headers) if isinstance(headers, list) else []
199 existing.extend(extra_headers)
200 message["headers"] = existing
201 await send(message)
203 return send_with_headers
205 async def _handle_rate_limit(self, send: Send) -> bool:
206 """Apply rate limiting and send response if exceeded.
208 Args:
209 send: ASGI send callable.
211 Returns:
212 True if rate limit was exceeded, False otherwise.
214 """
215 if self._rate_limiter is None or self._rate_limiter.consume():
216 return False
218 logger.warning("Rate limit exceeded for request")
219 security_hdrs = (
220 self._headers_config.to_headers() if self._security_headers else None
221 )
222 await _send_json_response(
223 send,
224 status=429,
225 body={"error": "Rate limit exceeded", "retry_after": 1},
226 extra_headers=security_hdrs,
227 )
228 return True
230 async def __call__(
231 self,
232 scope: Scope,
233 receive: Receive,
234 send: Send,
235 ) -> None:
236 """Process an ASGI request.
238 Args:
239 scope: ASGI scope dict.
240 receive: ASGI receive callable.
241 send: ASGI send callable.
243 """
244 # Only handle HTTP requests
245 if scope.get("type") != "http":
246 await self._app(scope, receive, send)
247 return
249 # Rate limiting
250 if await self._handle_rate_limit(send):
251 return
253 # Wrap send to inject security headers
254 send = self._wrap_send_with_security_headers(send)
256 # Call the actual application
257 try:
258 await self._app(scope, receive, send)
259 except Exception:
260 logger.exception("Unhandled exception in ASGI app")
261 await _send_json_response(
262 send,
263 status=500,
264 body={"error": "Internal server error"},
265 )