Coverage for src / taipanstack / utils / rate_limit.py: 100%
89 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-12 21:18 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-12 21:18 +0000
1"""
2Rate limiting utilities.
4Provides an in-memory token-bucket based rate limiting decorator
5for both synchronous and asynchronous functions. The decorator
6returns a ``Result`` type encapsulating the original return value
7or a ``RateLimitError`` error.
8"""
10import functools
11import inspect
12import math
13import threading
14import time
15from collections.abc import Awaitable, Callable
16from typing import ParamSpec, Protocol, TypeVar, cast, overload
18from taipanstack.core.result import Err, Ok, Result
20__all__ = ["RateLimitError", "RateLimiter", "rate_limit"]
22P = ParamSpec("P")
23T = TypeVar("T")
26class RateLimitError(Exception):
27 """Exception raised when a rate limit is exceeded."""
29 def __init__(self, message: str = "Rate limit exceeded") -> None:
30 """Initialize the RateLimitError.
32 Args:
33 message: The error message to display.Defaults to "Rate limit exceeded".
35 """
36 super().__init__(message)
39class RateLimiter:
40 """Token bucket rate limiter logic."""
42 def __init__(self, max_calls: int, time_window: float) -> None:
43 """Initialize the token bucket.
45 Args:
46 max_calls: The maximum number of calls allowed in the time window.
47 time_window: The time window in seconds.
49 """
50 if not math.isfinite(max_calls) or not math.isfinite(time_window):
51 raise ValueError("max_calls and time_window must be finite numbers")
52 if max_calls <= 0 or time_window <= 0:
53 raise ValueError("max_calls and time_window must be > 0.0")
54 self.capacity: float = float(max_calls)
55 self.time_window: float = float(time_window)
56 self.tokens: float = self.capacity
57 self.last_update: float = time.monotonic()
58 self._lock = threading.Lock()
60 def _is_valid_bucket_state(self) -> bool:
61 """Check if the bucket's time window and capacity are in a valid state."""
62 try:
63 if not math.isfinite(self.time_window) or self.time_window <= 0.0:
64 return False
65 return math.isfinite(self.capacity) and self.capacity > 0.0
66 except TypeError:
67 return False
69 def _calculate_new_tokens(self, elapsed: float) -> float | None:
70 """Calculate new tokens based on elapsed time."""
71 new_tokens = elapsed * (self.capacity / self.time_window)
72 return new_tokens if math.isfinite(new_tokens) else None
74 def _apply_new_tokens(self, new_tokens: float) -> bool:
75 """Apply new tokens to the bucket."""
76 self.tokens += new_tokens
77 if not math.isfinite(self.tokens):
78 # Reset to previous state or capacity if corrupted
79 self.tokens = self.capacity
80 return False
81 self.tokens = min(self.tokens, self.capacity)
82 return True
84 def _add_tokens(self, now: float) -> bool:
85 """Calculate and add new tokens to the bucket based on elapsed time.
87 Args:
88 now: Current monotonic time.
90 Returns:
91 True if token update succeeds, False if state corruption is detected.
93 """
94 try:
95 elapsed = max(0.0, now - self.last_update)
96 self.last_update = now
98 # Prevent state corruption or infinite elapsed time
99 if not (self._is_valid_bucket_state() and math.isfinite(elapsed)):
100 return False
102 new_tokens = self._calculate_new_tokens(elapsed)
103 if new_tokens is None:
104 return False
106 return self._apply_new_tokens(new_tokens)
107 except TypeError:
108 return False
110 def _try_consume(self, tokens: float) -> bool:
111 """Attempt to consume the tokens from the bucket if available."""
112 if self.tokens >= tokens:
113 self.tokens -= tokens
114 return True
115 return False
117 def consume(self, tokens: float = 1.0) -> bool:
118 """Try to consume tokens.
120 Args:
121 tokens: Number of tokens to consume. Defaults to 1.0.
123 Returns:
124 True if tokens were consumed (allow), False otherwise (limit exceeded).
126 """
127 if tokens <= 0:
128 return True
130 with self._lock:
131 try:
132 now = time.monotonic()
134 # Prevent time corruption from poisoning the bucket state.
135 # Only try to add tokens if time is finite.
136 if math.isfinite(now) and not self._add_tokens(now):
137 return False
139 return self._try_consume(tokens)
140 except TypeError:
141 return False
144class RateLimitDecorator(Protocol):
145 """Protocol for the rate limit decorator."""
147 @overload
148 def __call__(
149 self, func: Callable[P, T]
150 ) -> Callable[P, Result[T, RateLimitError]]: ...
152 @overload
153 def __call__(
154 self, func: Callable[P, Awaitable[T]]
155 ) -> Callable[P, Awaitable[Result[T, RateLimitError]]]: ...
158def rate_limit(
159 max_calls: int,
160 time_window: float,
161) -> RateLimitDecorator:
162 """Decorate a function to apply rate limiting.
164 If the rate limit is exceeded, the wrapped function immediately returns
165 an ``Err(RateLimitError)``. Uses an in-memory token bucket strategy.
167 Args:
168 max_calls: Maximum function executions allowed in the defined window.
169 time_window: Time window size in seconds.
171 Returns:
172 Decorated function returning a ``Result[T, RateLimitError]``.
174 Example:
175 >>> @rate_limit(max_calls=2, time_window=1.0)
176 ... def fetch_data() -> str:
177 ... return "data"
178 >>> fetch_data()
179 Ok('data')
180 >>> fetch_data()
181 Ok('data')
182 >>> fetch_data()
183 Err(RateLimitError('Rate limit exceeded'))
185 """
187 def decorator(
188 func: Callable[P, T] | Callable[P, Awaitable[T]],
189 ) -> (
190 Callable[P, Result[T, RateLimitError]]
191 | Callable[P, Awaitable[Result[T, RateLimitError]]]
192 ):
193 limiter = RateLimiter(max_calls, time_window)
195 if inspect.iscoroutinefunction(func):
197 @functools.wraps(func)
198 async def async_wrapper(
199 *args: P.args,
200 **kwargs: P.kwargs,
201 ) -> Result[T, RateLimitError]:
202 if not limiter.consume():
203 return Err(RateLimitError())
204 return Ok(await func(*args, **kwargs))
206 return async_wrapper
208 @functools.wraps(func)
209 def wrapper(*args: P.args, **kwargs: P.kwargs) -> Result[T, RateLimitError]:
210 if not limiter.consume():
211 return Err(RateLimitError())
212 func_sync = cast(Callable[P, T], func)
213 return Ok(func_sync(*args, **kwargs))
215 return wrapper
217 return cast(RateLimitDecorator, decorator)