Coverage for src / taipanstack / utils / rate_limit.py: 100%

89 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-12 21:18 +0000

1""" 

2Rate limiting utilities. 

3 

4Provides an in-memory token-bucket based rate limiting decorator 

5for both synchronous and asynchronous functions. The decorator 

6returns a ``Result`` type encapsulating the original return value 

7or a ``RateLimitError`` error. 

8""" 

9 

10import functools 

11import inspect 

12import math 

13import threading 

14import time 

15from collections.abc import Awaitable, Callable 

16from typing import ParamSpec, Protocol, TypeVar, cast, overload 

17 

18from taipanstack.core.result import Err, Ok, Result 

19 

20__all__ = ["RateLimitError", "RateLimiter", "rate_limit"] 

21 

22P = ParamSpec("P") 

23T = TypeVar("T") 

24 

25 

26class RateLimitError(Exception): 

27 """Exception raised when a rate limit is exceeded.""" 

28 

29 def __init__(self, message: str = "Rate limit exceeded") -> None: 

30 """Initialize the RateLimitError. 

31 

32 Args: 

33 message: The error message to display.Defaults to "Rate limit exceeded". 

34 

35 """ 

36 super().__init__(message) 

37 

38 

39class RateLimiter: 

40 """Token bucket rate limiter logic.""" 

41 

42 def __init__(self, max_calls: int, time_window: float) -> None: 

43 """Initialize the token bucket. 

44 

45 Args: 

46 max_calls: The maximum number of calls allowed in the time window. 

47 time_window: The time window in seconds. 

48 

49 """ 

50 if not math.isfinite(max_calls) or not math.isfinite(time_window): 

51 raise ValueError("max_calls and time_window must be finite numbers") 

52 if max_calls <= 0 or time_window <= 0: 

53 raise ValueError("max_calls and time_window must be > 0.0") 

54 self.capacity: float = float(max_calls) 

55 self.time_window: float = float(time_window) 

56 self.tokens: float = self.capacity 

57 self.last_update: float = time.monotonic() 

58 self._lock = threading.Lock() 

59 

60 def _is_valid_bucket_state(self) -> bool: 

61 """Check if the bucket's time window and capacity are in a valid state.""" 

62 try: 

63 if not math.isfinite(self.time_window) or self.time_window <= 0.0: 

64 return False 

65 return math.isfinite(self.capacity) and self.capacity > 0.0 

66 except TypeError: 

67 return False 

68 

69 def _calculate_new_tokens(self, elapsed: float) -> float | None: 

70 """Calculate new tokens based on elapsed time.""" 

71 new_tokens = elapsed * (self.capacity / self.time_window) 

72 return new_tokens if math.isfinite(new_tokens) else None 

73 

74 def _apply_new_tokens(self, new_tokens: float) -> bool: 

75 """Apply new tokens to the bucket.""" 

76 self.tokens += new_tokens 

77 if not math.isfinite(self.tokens): 

78 # Reset to previous state or capacity if corrupted 

79 self.tokens = self.capacity 

80 return False 

81 self.tokens = min(self.tokens, self.capacity) 

82 return True 

83 

84 def _add_tokens(self, now: float) -> bool: 

85 """Calculate and add new tokens to the bucket based on elapsed time. 

86 

87 Args: 

88 now: Current monotonic time. 

89 

90 Returns: 

91 True if token update succeeds, False if state corruption is detected. 

92 

93 """ 

94 try: 

95 elapsed = max(0.0, now - self.last_update) 

96 self.last_update = now 

97 

98 # Prevent state corruption or infinite elapsed time 

99 if not (self._is_valid_bucket_state() and math.isfinite(elapsed)): 

100 return False 

101 

102 new_tokens = self._calculate_new_tokens(elapsed) 

103 if new_tokens is None: 

104 return False 

105 

106 return self._apply_new_tokens(new_tokens) 

107 except TypeError: 

108 return False 

109 

110 def _try_consume(self, tokens: float) -> bool: 

111 """Attempt to consume the tokens from the bucket if available.""" 

112 if self.tokens >= tokens: 

113 self.tokens -= tokens 

114 return True 

115 return False 

116 

117 def consume(self, tokens: float = 1.0) -> bool: 

118 """Try to consume tokens. 

119 

120 Args: 

121 tokens: Number of tokens to consume. Defaults to 1.0. 

122 

123 Returns: 

124 True if tokens were consumed (allow), False otherwise (limit exceeded). 

125 

126 """ 

127 if tokens <= 0: 

128 return True 

129 

130 with self._lock: 

131 try: 

132 now = time.monotonic() 

133 

134 # Prevent time corruption from poisoning the bucket state. 

135 # Only try to add tokens if time is finite. 

136 if math.isfinite(now) and not self._add_tokens(now): 

137 return False 

138 

139 return self._try_consume(tokens) 

140 except TypeError: 

141 return False 

142 

143 

144class RateLimitDecorator(Protocol): 

145 """Protocol for the rate limit decorator.""" 

146 

147 @overload 

148 def __call__( 

149 self, func: Callable[P, T] 

150 ) -> Callable[P, Result[T, RateLimitError]]: ... 

151 

152 @overload 

153 def __call__( 

154 self, func: Callable[P, Awaitable[T]] 

155 ) -> Callable[P, Awaitable[Result[T, RateLimitError]]]: ... 

156 

157 

158def rate_limit( 

159 max_calls: int, 

160 time_window: float, 

161) -> RateLimitDecorator: 

162 """Decorate a function to apply rate limiting. 

163 

164 If the rate limit is exceeded, the wrapped function immediately returns 

165 an ``Err(RateLimitError)``. Uses an in-memory token bucket strategy. 

166 

167 Args: 

168 max_calls: Maximum function executions allowed in the defined window. 

169 time_window: Time window size in seconds. 

170 

171 Returns: 

172 Decorated function returning a ``Result[T, RateLimitError]``. 

173 

174 Example: 

175 >>> @rate_limit(max_calls=2, time_window=1.0) 

176 ... def fetch_data() -> str: 

177 ... return "data" 

178 >>> fetch_data() 

179 Ok('data') 

180 >>> fetch_data() 

181 Ok('data') 

182 >>> fetch_data() 

183 Err(RateLimitError('Rate limit exceeded')) 

184 

185 """ 

186 

187 def decorator( 

188 func: Callable[P, T] | Callable[P, Awaitable[T]], 

189 ) -> ( 

190 Callable[P, Result[T, RateLimitError]] 

191 | Callable[P, Awaitable[Result[T, RateLimitError]]] 

192 ): 

193 limiter = RateLimiter(max_calls, time_window) 

194 

195 if inspect.iscoroutinefunction(func): 

196 

197 @functools.wraps(func) 

198 async def async_wrapper( 

199 *args: P.args, 

200 **kwargs: P.kwargs, 

201 ) -> Result[T, RateLimitError]: 

202 if not limiter.consume(): 

203 return Err(RateLimitError()) 

204 return Ok(await func(*args, **kwargs)) 

205 

206 return async_wrapper 

207 

208 @functools.wraps(func) 

209 def wrapper(*args: P.args, **kwargs: P.kwargs) -> Result[T, RateLimitError]: 

210 if not limiter.consume(): 

211 return Err(RateLimitError()) 

212 func_sync = cast(Callable[P, T], func) 

213 return Ok(func_sync(*args, **kwargs)) 

214 

215 return wrapper 

216 

217 return cast(RateLimitDecorator, decorator)