Coverage for src / taipanstack / bridges / web_bridge.py: 100%

74 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-12 21:18 +0000

1""" 

2Web Bridge — ASGI middleware for rate limiting and security headers. 

3 

4Provides a framework-agnostic ASGI middleware that integrates 

5TaipanStack's rate limiter and security headers into any ASGI 

6application (FastAPI, Litestar, Starlette, etc.). 

7""" 

8 

9from __future__ import annotations 

10 

11import json 

12import logging 

13from collections.abc import Awaitable, Callable, MutableMapping 

14from dataclasses import dataclass 

15from typing import TypeAlias, TypeVar 

16 

17from taipanstack.core.result import Err, Ok, Result 

18from taipanstack.utils.rate_limit import RateLimiter 

19 

20logger = logging.getLogger("taipanstack.bridges.web") 

21 

22T = TypeVar("T") 

23 

24# ASGI type aliases 

25Scope: TypeAlias = MutableMapping[str, object] 

26Receive: TypeAlias = Callable[[], Awaitable[MutableMapping[str, object]]] 

27Send: TypeAlias = Callable[[MutableMapping[str, object]], Awaitable[None]] 

28ASGIApp: TypeAlias = Callable[[Scope, Receive, Send], Awaitable[None]] 

29 

30 

31@dataclass(frozen=True) 

32class SecurityHeadersConfig: 

33 """Configuration for security response headers. 

34 

35 Attributes: 

36 x_content_type_options: Value for X-Content-Type-Options. 

37 x_frame_options: Value for X-Frame-Options. 

38 x_xss_protection: Value for X-XSS-Protection. 

39 strict_transport_security: Value for Strict-Transport-Security. 

40 referrer_policy: Value for Referrer-Policy. 

41 content_security_policy: Value for Content-Security-Policy. 

42 

43 """ 

44 

45 x_content_type_options: str = "nosniff" 

46 x_frame_options: str = "DENY" 

47 x_xss_protection: str = "1; mode=block" 

48 strict_transport_security: str = "max-age=31536000; includeSubDomains" 

49 referrer_policy: str = "strict-origin-when-cross-origin" 

50 content_security_policy: str = "default-src 'self'" 

51 

52 def to_headers(self) -> list[tuple[bytes, bytes]]: 

53 """Convert config to ASGI header pairs. 

54 

55 Returns: 

56 List of (name, value) byte tuples. 

57 

58 """ 

59 return [ 

60 (b"x-content-type-options", self.x_content_type_options.encode()), 

61 (b"x-frame-options", self.x_frame_options.encode()), 

62 (b"x-xss-protection", self.x_xss_protection.encode()), 

63 ( 

64 b"strict-transport-security", 

65 self.strict_transport_security.encode(), 

66 ), 

67 (b"referrer-policy", self.referrer_policy.encode()), 

68 (b"content-security-policy", self.content_security_policy.encode()), 

69 ] 

70 

71 

72def result_to_response( 

73 result: Result[T, Exception], 

74 *, 

75 status_ok: int = 200, 

76 status_err: int = 500, 

77) -> dict[str, object]: 

78 """Convert a ``Result`` to a JSON-friendly response dict. 

79 

80 Args: 

81 result: The Result to convert. 

82 status_ok: HTTP status for ``Ok`` values. 

83 status_err: HTTP status for ``Err`` values. 

84 

85 Returns: 

86 Dict with ``status``, ``data``/``error`` keys. 

87 

88 Example: 

89 >>> result_to_response(Ok({"id": 1})) 

90 {"status": 200, "data": {"id": 1}} 

91 

92 """ 

93 match result: 

94 case Ok(value): 

95 return {"status": status_ok, "data": value} 

96 case Err(error): 

97 return {"status": status_err, "error": str(error)} 

98 

99 

100async def _send_json_response( 

101 send: Send, 

102 *, 

103 status: int, 

104 body: dict[str, object], 

105 extra_headers: list[tuple[bytes, bytes]] | None = None, 

106) -> None: 

107 """Send a JSON response via ASGI send. 

108 

109 Args: 

110 send: ASGI send callable. 

111 status: HTTP status code. 

112 body: JSON-serializable body. 

113 extra_headers: Additional headers to include. 

114 

115 """ 

116 payload = json.dumps(body).encode("utf-8") 

117 headers: list[tuple[bytes, bytes]] = [ 

118 (b"content-type", b"application/json"), 

119 (b"content-length", str(len(payload)).encode()), 

120 ] 

121 if extra_headers: 

122 headers.extend(extra_headers) 

123 

124 await send( 

125 { 

126 "type": "http.response.start", 

127 "status": status, 

128 "headers": headers, 

129 } 

130 ) 

131 await send( 

132 { 

133 "type": "http.response.body", 

134 "body": payload, 

135 } 

136 ) 

137 

138 

139class TaipanMiddleware: 

140 """ASGI middleware providing rate limiting and security headers. 

141 

142 Args: 

143 app: The wrapped ASGI application. 

144 rate_limiter: Optional rate limiter instance. 

145 security_headers: Whether to inject security headers. 

146 headers_config: Custom security headers configuration. 

147 

148 Example: 

149 >>> from taipanstack.utils.rate_limit import RateLimiter 

150 >>> app = TaipanMiddleware( 

151 ... my_asgi_app, 

152 ... rate_limiter=RateLimiter(max_calls=100, time_window=60), 

153 ... security_headers=True, 

154 ... ) 

155 

156 """ 

157 

158 def __init__( 

159 self, 

160 app: ASGIApp, 

161 *, 

162 rate_limiter: RateLimiter | None = None, 

163 security_headers: bool = True, 

164 headers_config: SecurityHeadersConfig | None = None, 

165 ) -> None: 

166 """Initialize the middleware. 

167 

168 Args: 

169 app: ASGI application to wrap. 

170 rate_limiter: Optional rate limiter. 

171 security_headers: Inject security headers. 

172 headers_config: Custom headers config. 

173 

174 """ 

175 self._app = app 

176 self._rate_limiter = rate_limiter 

177 self._security_headers = security_headers 

178 self._headers_config = headers_config or SecurityHeadersConfig() 

179 

180 def _wrap_send_with_security_headers(self, send: Send) -> Send: 

181 """Wrap the send callable to inject security headers if enabled. 

182 

183 Args: 

184 send: The original ASGI send callable. 

185 

186 Returns: 

187 The wrapped ASGI send callable. 

188 

189 """ 

190 if not self._security_headers: 

191 return send 

192 

193 extra_headers = self._headers_config.to_headers() 

194 

195 async def send_with_headers(message: MutableMapping[str, object]) -> None: 

196 if message.get("type") == "http.response.start": 

197 headers = message.get("headers") 

198 existing = list(headers) if isinstance(headers, list) else [] 

199 existing.extend(extra_headers) 

200 message["headers"] = existing 

201 await send(message) 

202 

203 return send_with_headers 

204 

205 async def _handle_rate_limit(self, send: Send) -> bool: 

206 """Apply rate limiting and send response if exceeded. 

207 

208 Args: 

209 send: ASGI send callable. 

210 

211 Returns: 

212 True if rate limit was exceeded, False otherwise. 

213 

214 """ 

215 if self._rate_limiter is None or self._rate_limiter.consume(): 

216 return False 

217 

218 logger.warning("Rate limit exceeded for request") 

219 security_hdrs = ( 

220 self._headers_config.to_headers() if self._security_headers else None 

221 ) 

222 await _send_json_response( 

223 send, 

224 status=429, 

225 body={"error": "Rate limit exceeded", "retry_after": 1}, 

226 extra_headers=security_hdrs, 

227 ) 

228 return True 

229 

230 async def __call__( 

231 self, 

232 scope: Scope, 

233 receive: Receive, 

234 send: Send, 

235 ) -> None: 

236 """Process an ASGI request. 

237 

238 Args: 

239 scope: ASGI scope dict. 

240 receive: ASGI receive callable. 

241 send: ASGI send callable. 

242 

243 """ 

244 # Only handle HTTP requests 

245 if scope.get("type") != "http": 

246 await self._app(scope, receive, send) 

247 return 

248 

249 # Rate limiting 

250 if await self._handle_rate_limit(send): 

251 return 

252 

253 # Wrap send to inject security headers 

254 send = self._wrap_send_with_security_headers(send) 

255 

256 # Call the actual application 

257 try: 

258 await self._app(scope, receive, send) 

259 except Exception: 

260 logger.exception("Unhandled exception in ASGI app") 

261 await _send_json_response( 

262 send, 

263 status=500, 

264 body={"error": "Internal server error"}, 

265 )