Coverage for src / taipanstack / utils / concurrency.py: 100%

62 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-12 21:18 +0000

1""" 

2Concurrency utilities. 

3 

4Provides a bulkhead pattern concurrency limiter decorator for both 

5synchronous and asynchronous functions. Uses an `OverloadError` and 

6returns a `Result` type. 

7""" 

8 

9import asyncio 

10import functools 

11import inspect 

12import threading 

13from collections.abc import Awaitable, Callable 

14from typing import ParamSpec, Protocol, TypeVar, cast, overload 

15 

16from taipanstack.core.result import Err, Ok, Result 

17 

18__all__ = ["OverloadError", "limit_concurrency"] 

19 

20P = ParamSpec("P") 

21T = TypeVar("T") 

22 

23 

24class OverloadError(Exception): 

25 """Exception raised when a concurrency limit is exceeded or timed out.""" 

26 

27 def __init__(self, message: str = "Concurrency limit reached") -> None: 

28 """Initialize the OverloadError. 

29 

30 Args: 

31 message: The error message to display. 

32 Defaults to "Concurrency limit reached". 

33 

34 """ 

35 super().__init__(message) 

36 

37 

38class ConcurrencyLimitDecorator(Protocol): 

39 """Protocol for the concurrency limit decorator.""" 

40 

41 @overload 

42 def __call__( 

43 self, func: Callable[P, T] 

44 ) -> Callable[P, Result[T, OverloadError]]: ... 

45 

46 @overload 

47 def __call__( 

48 self, func: Callable[P, Awaitable[T]] 

49 ) -> Callable[P, Awaitable[Result[T, OverloadError]]]: ... 

50 

51 

52def _handle_async_concurrency( 

53 func: Callable[P, Awaitable[T]], 

54 max_tasks: int, 

55 timeout: float, 

56) -> Callable[P, Awaitable[Result[T, OverloadError]]]: 

57 """Handle asynchronous concurrency limiting.""" 

58 async_semaphore = asyncio.Semaphore(max_tasks) 

59 

60 @functools.wraps(func) 

61 async def async_wrapper( 

62 *args: P.args, 

63 **kwargs: P.kwargs, 

64 ) -> Result[T, OverloadError]: 

65 try: 

66 if timeout > 0.0: 

67 try: 

68 async with asyncio.timeout(timeout): 

69 await async_semaphore.acquire() 

70 except TimeoutError: 

71 return Err(OverloadError()) 

72 else: 

73 if async_semaphore.locked(): 

74 return Err(OverloadError()) 

75 await async_semaphore.acquire() 

76 except (RuntimeError, OSError, MemoryError) as e: 

77 return Err(OverloadError(f"Resource exhaustion: {e!s}")) 

78 

79 try: 

80 return Ok(await func(*args, **kwargs)) 

81 finally: 

82 async_semaphore.release() 

83 

84 return async_wrapper 

85 

86 

87def _handle_sync_concurrency( 

88 func: Callable[P, T], 

89 max_tasks: int, 

90 timeout: float, 

91) -> Callable[P, Result[T, OverloadError]]: 

92 """Handle synchronous concurrency limiting.""" 

93 sync_semaphore = threading.Semaphore(max_tasks) 

94 

95 @functools.wraps(func) 

96 def wrapper(*args: P.args, **kwargs: P.kwargs) -> Result[T, OverloadError]: 

97 try: 

98 if timeout > 0.0: 

99 acquired = sync_semaphore.acquire(timeout=timeout) 

100 if not acquired: 

101 return Err(OverloadError()) 

102 else: 

103 acquired = sync_semaphore.acquire(blocking=False) 

104 if not acquired: 

105 return Err(OverloadError()) 

106 except (RuntimeError, OSError, MemoryError) as e: 

107 return Err(OverloadError(f"Resource exhaustion: {e!s}")) 

108 

109 try: 

110 return Ok(func(*args, **kwargs)) 

111 finally: 

112 sync_semaphore.release() 

113 

114 return wrapper 

115 

116 

117def limit_concurrency( 

118 max_tasks: int, 

119 timeout: float = 0.0, 

120) -> ConcurrencyLimitDecorator: 

121 """Decorate a function to apply the bulkhead concurrency limit pattern. 

122 

123 If the maximum concurrent executions are reached, the wrapper will wait up 

124 to `timeout` seconds to acquire a execution slot. If it fails, it returns 

125 an ``Err(OverloadError)``. 

126 

127 Args: 

128 max_tasks: Maximum concurrent function executions allowed. 

129 timeout: Maximum time in seconds to wait for a slot if limit is reached. 

130 

131 Returns: 

132 Decorated function returning a ``Result[T, OverloadError]``. 

133 

134 Example: 

135 >>> @limit_concurrency(max_tasks=2, timeout=0.1) 

136 ... def process_data() -> str: 

137 ... return "data" 

138 >>> process_data() 

139 Ok('data') 

140 

141 """ 

142 if max_tasks <= 0: 

143 raise ValueError("max_tasks must be > 0") 

144 if timeout < 0.0: 

145 raise ValueError("timeout must be >= 0.0") 

146 

147 def decorator( 

148 func: Callable[P, T] | Callable[P, Awaitable[T]], 

149 ) -> ( 

150 Callable[P, Result[T, OverloadError]] 

151 | Callable[P, Awaitable[Result[T, OverloadError]]] 

152 ): 

153 if inspect.iscoroutinefunction(func): 

154 return _handle_async_concurrency( 

155 func, 

156 max_tasks, 

157 timeout, 

158 ) 

159 

160 return _handle_sync_concurrency( 

161 cast(Callable[P, T], func), 

162 max_tasks, 

163 timeout, 

164 ) 

165 

166 return cast(ConcurrencyLimitDecorator, decorator)