Coverage for src / taipanstack / utils / cache.py: 100%

99 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-12 21:18 +0000

1""" 

2Intelligent Cache decorator. 

3 

4Provides in-memory caching that respects the Result monad and TTL, 

5ignoring caching for Err() results. 

6""" 

7 

8import asyncio 

9import functools 

10import inspect 

11import time 

12from collections.abc import Awaitable, Callable 

13from typing import ParamSpec, Protocol, TypeAlias, TypeVar, cast, overload 

14 

15from taipanstack.core.result import Err, Ok, Result 

16 

17P = ParamSpec("P") 

18T = TypeVar("T") 

19E = TypeVar("E", bound=Exception) 

20 

21CacheKey: TypeAlias = tuple[object, ...] 

22CacheValue: TypeAlias = tuple[float, object] 

23CacheDict: TypeAlias = dict[CacheKey, CacheValue] 

24 

25 

26class CacheDecorator(Protocol): 

27 """Protocol for the cache decorator.""" 

28 

29 @overload 

30 def __call__( 

31 self, func: Callable[P, Result[T, E]] 

32 ) -> Callable[P, Result[T, E]]: ... 

33 

34 @overload 

35 def __call__( 

36 self, func: Callable[P, Awaitable[Result[T, E]]] 

37 ) -> Callable[P, Awaitable[Result[T, E]]]: ... 

38 

39 

40def cached(ttl: float, max_size: int = 1024) -> CacheDecorator: # noqa: PLR0915 

41 """Cache the Ok() results of a function for a given TTL. 

42 

43 Err() results are not cached. Supports both async and sync functions. 

44 Implements LRU (Least Recently Used) eviction when max_size is reached. 

45 

46 Args: 

47 ttl: Time to live in seconds. 

48 max_size: Maximum number of elements to store in the cache. 

49 

50 Returns: 

51 Decorator function. 

52 

53 """ 

54 if not isinstance(max_size, int) or isinstance(max_size, bool) or max_size <= 0: 

55 raise ValueError("max_size must be a positive integer") 

56 

57 _cache: CacheDict = {} 

58 _locks: dict[CacheKey, asyncio.Lock] = {} 

59 _lock_waiters: dict[CacheKey, int] = {} 

60 

61 def get_cache_key( 

62 func_name: str, args: tuple[object, ...], kwargs: dict[str, object] 

63 ) -> CacheKey: 

64 def _make_hashable(val: object) -> object: 

65 match val: 

66 case tuple() | list(): 

67 return tuple(_make_hashable(item) for item in val) 

68 case dict(): 

69 return tuple(sorted((k, _make_hashable(v)) for k, v in val.items())) 

70 case set(): 

71 return frozenset(_make_hashable(item) for item in val) 

72 case _: 

73 hash(val) 

74 return val 

75 

76 hashable_args = tuple(_make_hashable(arg) for arg in args) 

77 hashable_kwargs = tuple( 

78 sorted((k, _make_hashable(v)) for k, v in kwargs.items()) 

79 ) 

80 return (func_name, hashable_args, hashable_kwargs) 

81 

82 def decorator( # noqa: PLR0915 

83 func: Callable[P, Result[T, E]] | Callable[P, Awaitable[Result[T, E]]], 

84 ) -> Callable[P, Result[T, E]] | Callable[P, Awaitable[Result[T, E]]]: 

85 if inspect.iscoroutinefunction(func): 

86 

87 @functools.wraps(func) 

88 async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> Result[T, E]: 

89 cache_key = get_cache_key( 

90 func.__name__, 

91 cast(tuple[object, ...], args), 

92 cast(dict[str, object], kwargs), 

93 ) 

94 

95 # Check cache before acquiring lock 

96 now = time.monotonic() 

97 if cache_key in _cache: 

98 expiry, value = _cache[cache_key] 

99 if now < expiry: 

100 # Move to end to mark as recently used 

101 _cache[cache_key] = _cache.pop(cache_key) 

102 return Ok(cast(T, value)) 

103 

104 if cache_key not in _locks: 

105 _locks[cache_key] = asyncio.Lock() 

106 _lock_waiters[cache_key] = 0 

107 

108 _lock_waiters[cache_key] += 1 

109 lock = _locks[cache_key] 

110 

111 try: 

112 async with lock: 

113 # Double-check cache after acquiring lock 

114 now = time.monotonic() 

115 if cache_key in _cache: 

116 expiry, value = _cache[cache_key] 

117 if now < expiry: 

118 # Move to end to mark as recently used 

119 _cache[cache_key] = _cache.pop(cache_key) 

120 return Ok(cast(T, value)) 

121 del _cache[cache_key] 

122 

123 func_coro = cast(Callable[P, Awaitable[Result[T, E]]], func) 

124 result = await func_coro(*args, **kwargs) 

125 

126 match result: 

127 case Ok(value): 

128 if len(_cache) >= max_size: 

129 # Evict least recently used (first item) 

130 lru_key = next(iter(_cache)) 

131 del _cache[lru_key] 

132 _cache[cache_key] = (now + ttl, value) 

133 case Err(_): 

134 pass 

135 

136 return result 

137 finally: 

138 _lock_waiters[cache_key] -= 1 

139 if _lock_waiters[cache_key] == 0: 

140 _locks.pop(cache_key, None) 

141 _lock_waiters.pop(cache_key, None) 

142 

143 return async_wrapper 

144 

145 @functools.wraps(func) 

146 def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> Result[T, E]: 

147 cache_key = get_cache_key( 

148 func.__name__, 

149 cast(tuple[object, ...], args), 

150 cast(dict[str, object], kwargs), 

151 ) 

152 now = time.monotonic() 

153 

154 if cache_key in _cache: 

155 expiry, value = _cache[cache_key] 

156 if now < expiry: 

157 # Move to end to mark as recently used 

158 _cache[cache_key] = _cache.pop(cache_key) 

159 return Ok(cast(T, value)) 

160 del _cache[cache_key] 

161 

162 func_sync = cast(Callable[P, Result[T, E]], func) 

163 result = func_sync(*args, **kwargs) 

164 

165 match result: 

166 case Ok(value): 

167 if len(_cache) >= max_size: 

168 # Evict least recently used (first item) 

169 lru_key = next(iter(_cache)) 

170 del _cache[lru_key] 

171 _cache[cache_key] = (now + ttl, value) 

172 case Err(_): 

173 pass 

174 

175 return result 

176 

177 return sync_wrapper 

178 

179 return cast(CacheDecorator, decorator)