Coverage for src / taipanstack / utils / concurrency.py: 100%
62 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-12 21:18 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-12 21:18 +0000
1"""
2Concurrency utilities.
4Provides a bulkhead pattern concurrency limiter decorator for both
5synchronous and asynchronous functions. Uses an `OverloadError` and
6returns a `Result` type.
7"""
9import asyncio
10import functools
11import inspect
12import threading
13from collections.abc import Awaitable, Callable
14from typing import ParamSpec, Protocol, TypeVar, cast, overload
16from taipanstack.core.result import Err, Ok, Result
18__all__ = ["OverloadError", "limit_concurrency"]
20P = ParamSpec("P")
21T = TypeVar("T")
24class OverloadError(Exception):
25 """Exception raised when a concurrency limit is exceeded or timed out."""
27 def __init__(self, message: str = "Concurrency limit reached") -> None:
28 """Initialize the OverloadError.
30 Args:
31 message: The error message to display.
32 Defaults to "Concurrency limit reached".
34 """
35 super().__init__(message)
38class ConcurrencyLimitDecorator(Protocol):
39 """Protocol for the concurrency limit decorator."""
41 @overload
42 def __call__(
43 self, func: Callable[P, T]
44 ) -> Callable[P, Result[T, OverloadError]]: ...
46 @overload
47 def __call__(
48 self, func: Callable[P, Awaitable[T]]
49 ) -> Callable[P, Awaitable[Result[T, OverloadError]]]: ...
52def _handle_async_concurrency(
53 func: Callable[P, Awaitable[T]],
54 max_tasks: int,
55 timeout: float,
56) -> Callable[P, Awaitable[Result[T, OverloadError]]]:
57 """Handle asynchronous concurrency limiting."""
58 async_semaphore = asyncio.Semaphore(max_tasks)
60 @functools.wraps(func)
61 async def async_wrapper(
62 *args: P.args,
63 **kwargs: P.kwargs,
64 ) -> Result[T, OverloadError]:
65 try:
66 if timeout > 0.0:
67 try:
68 async with asyncio.timeout(timeout):
69 await async_semaphore.acquire()
70 except TimeoutError:
71 return Err(OverloadError())
72 else:
73 if async_semaphore.locked():
74 return Err(OverloadError())
75 await async_semaphore.acquire()
76 except (RuntimeError, OSError, MemoryError) as e:
77 return Err(OverloadError(f"Resource exhaustion: {e!s}"))
79 try:
80 return Ok(await func(*args, **kwargs))
81 finally:
82 async_semaphore.release()
84 return async_wrapper
87def _handle_sync_concurrency(
88 func: Callable[P, T],
89 max_tasks: int,
90 timeout: float,
91) -> Callable[P, Result[T, OverloadError]]:
92 """Handle synchronous concurrency limiting."""
93 sync_semaphore = threading.Semaphore(max_tasks)
95 @functools.wraps(func)
96 def wrapper(*args: P.args, **kwargs: P.kwargs) -> Result[T, OverloadError]:
97 try:
98 if timeout > 0.0:
99 acquired = sync_semaphore.acquire(timeout=timeout)
100 if not acquired:
101 return Err(OverloadError())
102 else:
103 acquired = sync_semaphore.acquire(blocking=False)
104 if not acquired:
105 return Err(OverloadError())
106 except (RuntimeError, OSError, MemoryError) as e:
107 return Err(OverloadError(f"Resource exhaustion: {e!s}"))
109 try:
110 return Ok(func(*args, **kwargs))
111 finally:
112 sync_semaphore.release()
114 return wrapper
117def limit_concurrency(
118 max_tasks: int,
119 timeout: float = 0.0,
120) -> ConcurrencyLimitDecorator:
121 """Decorate a function to apply the bulkhead concurrency limit pattern.
123 If the maximum concurrent executions are reached, the wrapper will wait up
124 to `timeout` seconds to acquire a execution slot. If it fails, it returns
125 an ``Err(OverloadError)``.
127 Args:
128 max_tasks: Maximum concurrent function executions allowed.
129 timeout: Maximum time in seconds to wait for a slot if limit is reached.
131 Returns:
132 Decorated function returning a ``Result[T, OverloadError]``.
134 Example:
135 >>> @limit_concurrency(max_tasks=2, timeout=0.1)
136 ... def process_data() -> str:
137 ... return "data"
138 >>> process_data()
139 Ok('data')
141 """
142 if max_tasks <= 0:
143 raise ValueError("max_tasks must be > 0")
144 if timeout < 0.0:
145 raise ValueError("timeout must be >= 0.0")
147 def decorator(
148 func: Callable[P, T] | Callable[P, Awaitable[T]],
149 ) -> (
150 Callable[P, Result[T, OverloadError]]
151 | Callable[P, Awaitable[Result[T, OverloadError]]]
152 ):
153 if inspect.iscoroutinefunction(func):
154 return _handle_async_concurrency(
155 func,
156 max_tasks,
157 timeout,
158 )
160 return _handle_sync_concurrency(
161 cast(Callable[P, T], func),
162 max_tasks,
163 timeout,
164 )
166 return cast(ConcurrencyLimitDecorator, decorator)