Coverage for src / taipanstack / utils / rate_limit.py: 100%
51 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-23 14:54 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-23 14:54 +0000
1"""
2Rate limiting utilities.
4Provides an in-memory token-bucket based rate limiting decorator
5for both synchronous and asynchronous functions. The decorator
6returns a ``Result`` type encapsulating the original return value
7or a ``RateLimitError`` error.
8"""
10import functools
11import inspect
12import threading
13import time
14from collections.abc import Callable, Coroutine
15from typing import Any, ParamSpec, Protocol, TypeVar, overload
17from taipanstack.core.result import Err, Ok, Result
19__all__ = ["RateLimitError", "RateLimiter", "rate_limit"]
21P = ParamSpec("P")
22T = TypeVar("T")
25class RateLimitError(Exception):
26 """Exception raised when a rate limit is exceeded."""
28 def __init__(self, message: str = "Rate limit exceeded") -> None:
29 """Initialize the RateLimitError.
31 Args:
32 message: The error message to display.Defaults to "Rate limit exceeded".
34 """
35 super().__init__(message)
38class RateLimiter:
39 """Token bucket rate limiter logic."""
41 def __init__(self, max_calls: int, time_window: float) -> None:
42 """Initialize the token bucket.
44 Args:
45 max_calls: The maximum number of calls allowed in the time window.
46 time_window: The time window in seconds.
48 """
49 if max_calls <= 0 or time_window <= 0:
50 raise ValueError("max_calls and time_window must be > 0.0")
51 self.capacity: float = float(max_calls)
52 self.time_window: float = float(time_window)
53 self.tokens: float = self.capacity
54 self.last_update: float = time.monotonic()
55 self._lock = threading.Lock()
57 def consume(self) -> bool:
58 """Try to consume a single token.
60 Returns:
61 True if a token was consumed (allow), False otherwise (limit exceeded).
63 """
64 with self._lock:
65 now = time.monotonic()
66 elapsed = max(0.0, now - self.last_update)
67 self.last_update = now
69 # Add tokens for elapsed time based on fill rate
70 self.tokens += elapsed * (self.capacity / self.time_window)
71 self.tokens = min(self.tokens, self.capacity)
73 if self.tokens >= 1.0:
74 self.tokens -= 1.0
75 return True
76 return False
79class RateLimitDecorator(Protocol):
80 """Protocol for the rate limit decorator."""
82 @overload
83 def __call__(
84 self, func: Callable[P, T]
85 ) -> Callable[P, Result[T, RateLimitError]]: ... # pragma: no cover
87 @overload
88 def __call__(
89 self, func: Callable[P, Coroutine[Any, Any, T]]
90 ) -> Callable[
91 P, Coroutine[Any, Any, Result[T, RateLimitError]]
92 ]: ... # pragma: no cover
95def rate_limit(
96 max_calls: int,
97 time_window: float,
98) -> RateLimitDecorator:
99 """Decorate a function to apply rate limiting.
101 If the rate limit is exceeded, the wrapped function immediately returns
102 an ``Err(RateLimitError)``. Uses an in-memory token bucket strategy.
104 Args:
105 max_calls: Maximum function executions allowed in the defined window.
106 time_window: Time window size in seconds.
108 Returns:
109 Decorated function returning a ``Result[T, RateLimitError]``.
111 Example:
112 >>> @rate_limit(max_calls=2, time_window=1.0)
113 ... def fetch_data() -> str:
114 ... return "data"
115 >>> fetch_data()
116 Ok('data')
117 >>> fetch_data()
118 Ok('data')
119 >>> fetch_data()
120 Err(RateLimitError('Rate limit exceeded'))
122 """
124 def decorator(
125 func: Callable[P, T] | Callable[P, Coroutine[Any, Any, T]],
126 ) -> (
127 Callable[P, Result[T, RateLimitError]]
128 | Callable[P, Coroutine[Any, Any, Result[T, RateLimitError]]]
129 ):
130 limiter = RateLimiter(max_calls, time_window)
132 if inspect.iscoroutinefunction(func):
134 @functools.wraps(func)
135 async def async_wrapper(
136 *args: P.args,
137 **kwargs: P.kwargs,
138 ) -> Result[T, RateLimitError]:
139 if not limiter.consume():
140 return Err(RateLimitError())
141 return Ok(await func(*args, **kwargs))
143 return async_wrapper
145 @functools.wraps(func)
146 def wrapper(*args: P.args, **kwargs: P.kwargs) -> Result[T, RateLimitError]:
147 if not limiter.consume():
148 return Err(RateLimitError())
149 return Ok(func(*args, **kwargs)) # type: ignore[arg-type]
151 return wrapper
153 return decorator # type: ignore[return-value]