Coverage for src / taipanstack / utils / rate_limit.py: 100%

51 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-23 14:54 +0000

1""" 

2Rate limiting utilities. 

3 

4Provides an in-memory token-bucket based rate limiting decorator 

5for both synchronous and asynchronous functions. The decorator 

6returns a ``Result`` type encapsulating the original return value 

7or a ``RateLimitError`` error. 

8""" 

9 

10import functools 

11import inspect 

12import threading 

13import time 

14from collections.abc import Callable, Coroutine 

15from typing import Any, ParamSpec, Protocol, TypeVar, overload 

16 

17from taipanstack.core.result import Err, Ok, Result 

18 

19__all__ = ["RateLimitError", "RateLimiter", "rate_limit"] 

20 

21P = ParamSpec("P") 

22T = TypeVar("T") 

23 

24 

25class RateLimitError(Exception): 

26 """Exception raised when a rate limit is exceeded.""" 

27 

28 def __init__(self, message: str = "Rate limit exceeded") -> None: 

29 """Initialize the RateLimitError. 

30 

31 Args: 

32 message: The error message to display.Defaults to "Rate limit exceeded". 

33 

34 """ 

35 super().__init__(message) 

36 

37 

38class RateLimiter: 

39 """Token bucket rate limiter logic.""" 

40 

41 def __init__(self, max_calls: int, time_window: float) -> None: 

42 """Initialize the token bucket. 

43 

44 Args: 

45 max_calls: The maximum number of calls allowed in the time window. 

46 time_window: The time window in seconds. 

47 

48 """ 

49 if max_calls <= 0 or time_window <= 0: 

50 raise ValueError("max_calls and time_window must be > 0.0") 

51 self.capacity: float = float(max_calls) 

52 self.time_window: float = float(time_window) 

53 self.tokens: float = self.capacity 

54 self.last_update: float = time.monotonic() 

55 self._lock = threading.Lock() 

56 

57 def consume(self) -> bool: 

58 """Try to consume a single token. 

59 

60 Returns: 

61 True if a token was consumed (allow), False otherwise (limit exceeded). 

62 

63 """ 

64 with self._lock: 

65 now = time.monotonic() 

66 elapsed = max(0.0, now - self.last_update) 

67 self.last_update = now 

68 

69 # Add tokens for elapsed time based on fill rate 

70 self.tokens += elapsed * (self.capacity / self.time_window) 

71 self.tokens = min(self.tokens, self.capacity) 

72 

73 if self.tokens >= 1.0: 

74 self.tokens -= 1.0 

75 return True 

76 return False 

77 

78 

79class RateLimitDecorator(Protocol): 

80 """Protocol for the rate limit decorator.""" 

81 

82 @overload 

83 def __call__( 

84 self, func: Callable[P, T] 

85 ) -> Callable[P, Result[T, RateLimitError]]: ... # pragma: no cover 

86 

87 @overload 

88 def __call__( 

89 self, func: Callable[P, Coroutine[Any, Any, T]] 

90 ) -> Callable[ 

91 P, Coroutine[Any, Any, Result[T, RateLimitError]] 

92 ]: ... # pragma: no cover 

93 

94 

95def rate_limit( 

96 max_calls: int, 

97 time_window: float, 

98) -> RateLimitDecorator: 

99 """Decorate a function to apply rate limiting. 

100 

101 If the rate limit is exceeded, the wrapped function immediately returns 

102 an ``Err(RateLimitError)``. Uses an in-memory token bucket strategy. 

103 

104 Args: 

105 max_calls: Maximum function executions allowed in the defined window. 

106 time_window: Time window size in seconds. 

107 

108 Returns: 

109 Decorated function returning a ``Result[T, RateLimitError]``. 

110 

111 Example: 

112 >>> @rate_limit(max_calls=2, time_window=1.0) 

113 ... def fetch_data() -> str: 

114 ... return "data" 

115 >>> fetch_data() 

116 Ok('data') 

117 >>> fetch_data() 

118 Ok('data') 

119 >>> fetch_data() 

120 Err(RateLimitError('Rate limit exceeded')) 

121 

122 """ 

123 

124 def decorator( 

125 func: Callable[P, T] | Callable[P, Coroutine[Any, Any, T]], 

126 ) -> ( 

127 Callable[P, Result[T, RateLimitError]] 

128 | Callable[P, Coroutine[Any, Any, Result[T, RateLimitError]]] 

129 ): 

130 limiter = RateLimiter(max_calls, time_window) 

131 

132 if inspect.iscoroutinefunction(func): 

133 

134 @functools.wraps(func) 

135 async def async_wrapper( 

136 *args: P.args, 

137 **kwargs: P.kwargs, 

138 ) -> Result[T, RateLimitError]: 

139 if not limiter.consume(): 

140 return Err(RateLimitError()) 

141 return Ok(await func(*args, **kwargs)) 

142 

143 return async_wrapper 

144 

145 @functools.wraps(func) 

146 def wrapper(*args: P.args, **kwargs: P.kwargs) -> Result[T, RateLimitError]: 

147 if not limiter.consume(): 

148 return Err(RateLimitError()) 

149 return Ok(func(*args, **kwargs)) # type: ignore[arg-type] 

150 

151 return wrapper 

152 

153 return decorator # type: ignore[return-value]