Coverage for src / taipanstack / utils / concurrency.py: 100%
56 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-23 14:54 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-23 14:54 +0000
1"""
2Concurrency utilities.
4Provides a bulkhead pattern concurrency limiter decorator for both
5synchronous and asynchronous functions. Uses an `OverloadError` and
6returns a `Result` type.
7"""
9import asyncio
10import functools
11import inspect
12import threading
13from collections.abc import Callable, Coroutine
14from typing import Any, ParamSpec, Protocol, TypeVar, overload
16from taipanstack.core.result import Err, Ok, Result
18__all__ = ["OverloadError", "limit_concurrency"]
20P = ParamSpec("P")
21T = TypeVar("T")
24class OverloadError(Exception):
25 """Exception raised when a concurrency limit is exceeded or timed out."""
27 def __init__(self, message: str = "Concurrency limit reached") -> None:
28 """Initialize the OverloadError.
30 Args:
31 message: The error message to display.
32 Defaults to "Concurrency limit reached".
34 """
35 super().__init__(message)
38class ConcurrencyLimitDecorator(Protocol):
39 """Protocol for the concurrency limit decorator."""
41 @overload
42 def __call__(
43 self, func: Callable[P, T]
44 ) -> Callable[P, Result[T, OverloadError]]: ... # pragma: no cover
46 @overload
47 def __call__(
48 self, func: Callable[P, Coroutine[Any, Any, T]]
49 ) -> Callable[
50 P, Coroutine[Any, Any, Result[T, OverloadError]]
51 ]: ... # pragma: no cover
54def _handle_async_concurrency(
55 func: Callable[P, Coroutine[Any, Any, T]],
56 max_tasks: int,
57 timeout: float,
58) -> Callable[P, Coroutine[Any, Any, Result[T, OverloadError]]]:
59 """Handle asynchronous concurrency limiting."""
60 async_semaphore = asyncio.Semaphore(max_tasks)
62 @functools.wraps(func)
63 async def async_wrapper(
64 *args: P.args,
65 **kwargs: P.kwargs,
66 ) -> Result[T, OverloadError]:
67 if timeout > 0.0:
68 try:
69 async with asyncio.timeout(timeout):
70 await async_semaphore.acquire()
71 except TimeoutError:
72 return Err(OverloadError())
73 else:
74 if async_semaphore.locked():
75 return Err(OverloadError())
76 await async_semaphore.acquire()
78 try:
79 return Ok(await func(*args, **kwargs))
80 finally:
81 async_semaphore.release()
83 return async_wrapper
86def _handle_sync_concurrency(
87 func: Callable[P, T],
88 max_tasks: int,
89 timeout: float,
90) -> Callable[P, Result[T, OverloadError]]:
91 """Handle synchronous concurrency limiting."""
92 sync_semaphore = threading.Semaphore(max_tasks)
94 @functools.wraps(func)
95 def wrapper(*args: P.args, **kwargs: P.kwargs) -> Result[T, OverloadError]:
96 if timeout > 0.0:
97 acquired = sync_semaphore.acquire(timeout=timeout)
98 if not acquired:
99 return Err(OverloadError())
100 else:
101 acquired = sync_semaphore.acquire(blocking=False)
102 if not acquired:
103 return Err(OverloadError())
105 try:
106 return Ok(func(*args, **kwargs))
107 finally:
108 sync_semaphore.release()
110 return wrapper
113def limit_concurrency(
114 max_tasks: int,
115 timeout: float = 0.0,
116) -> ConcurrencyLimitDecorator:
117 """Decorate a function to apply the bulkhead concurrency limit pattern.
119 If the maximum concurrent executions are reached, the wrapper will wait up
120 to `timeout` seconds to acquire a execution slot. If it fails, it returns
121 an ``Err(OverloadError)``.
123 Args:
124 max_tasks: Maximum concurrent function executions allowed.
125 timeout: Maximum time in seconds to wait for a slot if limit is reached.
127 Returns:
128 Decorated function returning a ``Result[T, OverloadError]``.
130 Example:
131 >>> @limit_concurrency(max_tasks=2, timeout=0.1)
132 ... def process_data() -> str:
133 ... return "data"
134 >>> process_data()
135 Ok('data')
137 """
138 if max_tasks <= 0:
139 raise ValueError("max_tasks must be > 0")
140 if timeout < 0.0:
141 raise ValueError("timeout must be >= 0.0")
143 def decorator(
144 func: Callable[P, T] | Callable[P, Coroutine[Any, Any, T]],
145 ) -> (
146 Callable[P, Result[T, OverloadError]]
147 | Callable[P, Coroutine[Any, Any, Result[T, OverloadError]]]
148 ):
149 if inspect.iscoroutinefunction(func):
150 return _handle_async_concurrency(
151 func,
152 max_tasks,
153 timeout,
154 )
156 return _handle_sync_concurrency(
157 func, # type: ignore[arg-type]
158 max_tasks,
159 timeout,
160 )
162 return decorator # type: ignore[return-value]