Coverage for src / taipanstack / utils / concurrency.py: 100%

56 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-23 14:54 +0000

1""" 

2Concurrency utilities. 

3 

4Provides a bulkhead pattern concurrency limiter decorator for both 

5synchronous and asynchronous functions. Uses an `OverloadError` and 

6returns a `Result` type. 

7""" 

8 

9import asyncio 

10import functools 

11import inspect 

12import threading 

13from collections.abc import Callable, Coroutine 

14from typing import Any, ParamSpec, Protocol, TypeVar, overload 

15 

16from taipanstack.core.result import Err, Ok, Result 

17 

18__all__ = ["OverloadError", "limit_concurrency"] 

19 

20P = ParamSpec("P") 

21T = TypeVar("T") 

22 

23 

24class OverloadError(Exception): 

25 """Exception raised when a concurrency limit is exceeded or timed out.""" 

26 

27 def __init__(self, message: str = "Concurrency limit reached") -> None: 

28 """Initialize the OverloadError. 

29 

30 Args: 

31 message: The error message to display. 

32 Defaults to "Concurrency limit reached". 

33 

34 """ 

35 super().__init__(message) 

36 

37 

38class ConcurrencyLimitDecorator(Protocol): 

39 """Protocol for the concurrency limit decorator.""" 

40 

41 @overload 

42 def __call__( 

43 self, func: Callable[P, T] 

44 ) -> Callable[P, Result[T, OverloadError]]: ... # pragma: no cover 

45 

46 @overload 

47 def __call__( 

48 self, func: Callable[P, Coroutine[Any, Any, T]] 

49 ) -> Callable[ 

50 P, Coroutine[Any, Any, Result[T, OverloadError]] 

51 ]: ... # pragma: no cover 

52 

53 

54def _handle_async_concurrency( 

55 func: Callable[P, Coroutine[Any, Any, T]], 

56 max_tasks: int, 

57 timeout: float, 

58) -> Callable[P, Coroutine[Any, Any, Result[T, OverloadError]]]: 

59 """Handle asynchronous concurrency limiting.""" 

60 async_semaphore = asyncio.Semaphore(max_tasks) 

61 

62 @functools.wraps(func) 

63 async def async_wrapper( 

64 *args: P.args, 

65 **kwargs: P.kwargs, 

66 ) -> Result[T, OverloadError]: 

67 if timeout > 0.0: 

68 try: 

69 async with asyncio.timeout(timeout): 

70 await async_semaphore.acquire() 

71 except TimeoutError: 

72 return Err(OverloadError()) 

73 else: 

74 if async_semaphore.locked(): 

75 return Err(OverloadError()) 

76 await async_semaphore.acquire() 

77 

78 try: 

79 return Ok(await func(*args, **kwargs)) 

80 finally: 

81 async_semaphore.release() 

82 

83 return async_wrapper 

84 

85 

86def _handle_sync_concurrency( 

87 func: Callable[P, T], 

88 max_tasks: int, 

89 timeout: float, 

90) -> Callable[P, Result[T, OverloadError]]: 

91 """Handle synchronous concurrency limiting.""" 

92 sync_semaphore = threading.Semaphore(max_tasks) 

93 

94 @functools.wraps(func) 

95 def wrapper(*args: P.args, **kwargs: P.kwargs) -> Result[T, OverloadError]: 

96 if timeout > 0.0: 

97 acquired = sync_semaphore.acquire(timeout=timeout) 

98 if not acquired: 

99 return Err(OverloadError()) 

100 else: 

101 acquired = sync_semaphore.acquire(blocking=False) 

102 if not acquired: 

103 return Err(OverloadError()) 

104 

105 try: 

106 return Ok(func(*args, **kwargs)) 

107 finally: 

108 sync_semaphore.release() 

109 

110 return wrapper 

111 

112 

113def limit_concurrency( 

114 max_tasks: int, 

115 timeout: float = 0.0, 

116) -> ConcurrencyLimitDecorator: 

117 """Decorate a function to apply the bulkhead concurrency limit pattern. 

118 

119 If the maximum concurrent executions are reached, the wrapper will wait up 

120 to `timeout` seconds to acquire a execution slot. If it fails, it returns 

121 an ``Err(OverloadError)``. 

122 

123 Args: 

124 max_tasks: Maximum concurrent function executions allowed. 

125 timeout: Maximum time in seconds to wait for a slot if limit is reached. 

126 

127 Returns: 

128 Decorated function returning a ``Result[T, OverloadError]``. 

129 

130 Example: 

131 >>> @limit_concurrency(max_tasks=2, timeout=0.1) 

132 ... def process_data() -> str: 

133 ... return "data" 

134 >>> process_data() 

135 Ok('data') 

136 

137 """ 

138 if max_tasks <= 0: 

139 raise ValueError("max_tasks must be > 0") 

140 if timeout < 0.0: 

141 raise ValueError("timeout must be >= 0.0") 

142 

143 def decorator( 

144 func: Callable[P, T] | Callable[P, Coroutine[Any, Any, T]], 

145 ) -> ( 

146 Callable[P, Result[T, OverloadError]] 

147 | Callable[P, Coroutine[Any, Any, Result[T, OverloadError]]] 

148 ): 

149 if inspect.iscoroutinefunction(func): 

150 return _handle_async_concurrency( 

151 func, 

152 max_tasks, 

153 timeout, 

154 ) 

155 

156 return _handle_sync_concurrency( 

157 func, # type: ignore[arg-type] 

158 max_tasks, 

159 timeout, 

160 ) 

161 

162 return decorator # type: ignore[return-value]