Source code for kstlib.resilience.rate_limiter

"""Rate limiter using the Token Bucket algorithm.

Provides configurable rate limiting for protecting against request floods
and respecting external API rate limits.

Examples:
    As a decorator:

    >>> @rate_limiter(rate=10, per=1.0)  # 10 requests per second
    ... def call_api():  # doctest: +SKIP
    ...     return requests.get("http://api.example.com")

    Direct usage:

    >>> limiter = RateLimiter(rate=5, per=1.0)
    >>> limiter.acquire()  # Blocks until token available
    True
    >>> limiter.try_acquire()  # Non-blocking, returns immediately
    True

    As context manager:

    >>> with RateLimiter(rate=100, per=60.0):  # doctest: +SKIP
    ...     call_api()

"""

from __future__ import annotations

import asyncio
import functools
import inspect
import threading
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, ParamSpec, TypeVar, overload

from typing_extensions import Self

from kstlib.resilience.exceptions import RateLimitExceededError

if TYPE_CHECKING:
    from collections.abc import Callable

P = ParamSpec("P")
R = TypeVar("R")


[docs] @dataclass class RateLimiterStats: """Statistics for rate limiter monitoring. Attributes: total_acquired: Total number of tokens successfully acquired. total_rejected: Total number of acquire attempts that were rejected. total_waited: Total time spent waiting for tokens (seconds). Examples: >>> stats = RateLimiterStats() >>> stats.record_acquired() >>> stats.record_rejected() >>> stats.record_wait(0.5) >>> (stats.total_acquired, stats.total_rejected) (1, 1) >>> stats.total_waited 0.5 """ total_acquired: int = 0 total_rejected: int = 0 total_waited: float = 0.0
[docs] def record_acquired(self) -> None: """Record a successful token acquisition.""" self.total_acquired += 1
[docs] def record_rejected(self) -> None: """Record a rejected acquisition attempt.""" self.total_rejected += 1
[docs] def record_wait(self, seconds: float) -> None: """Record time spent waiting for a token.""" self.total_waited += seconds
[docs] class RateLimiter: """Token bucket rate limiter for controlling request throughput. Implements the token bucket algorithm where tokens are added at a fixed rate and each request consumes one token. Allows bursts up to the bucket capacity. Args: rate: Maximum number of tokens (requests) allowed per period. per: Time period in seconds (default 1.0 = per second). burst: Initial tokens available. If None, starts full (burst = rate). name: Optional name for logging and monitoring. Examples: Basic usage - 10 requests per second: >>> limiter = RateLimiter(rate=10, per=1.0) >>> int(limiter.tokens) # Starts full 10 With custom burst capacity: >>> limiter = RateLimiter(rate=10, per=1.0, burst=5) >>> int(limiter.tokens) # Starts with 5 tokens 5 Rate limiting API calls: >>> limiter = RateLimiter(rate=100, per=60.0) # 100 per minute >>> for _ in range(5): ... if limiter.try_acquire(): ... pass # call_api() >>> limiter.stats.total_acquired 5 """
[docs] def __init__( self, rate: float, per: float = 1.0, *, burst: float | None = None, name: str | None = None, ) -> None: """Initialize rate limiter. Args: rate: Maximum tokens (requests) per period. per: Period duration in seconds. burst: Initial token count. Defaults to rate (full bucket). name: Optional name for identification. Raises: ValueError: If rate or per is not positive. """ if rate <= 0: raise ValueError("rate must be positive") if per <= 0: raise ValueError("per must be positive") self._rate = float(rate) self._per = float(per) self._tokens = float(burst) if burst is not None else self._rate self._max_tokens = self._rate self._refill_rate = self._rate / self._per # tokens per second self._last_refill = time.monotonic() self._lock = threading.Lock() self._name = name self._stats = RateLimiterStats()
@property def rate(self) -> float: """Maximum tokens per period.""" return self._rate @property def per(self) -> float: """Period duration in seconds.""" return self._per @property def tokens(self) -> float: """Current available tokens (after refill).""" with self._lock: self._refill() return self._tokens @property def stats(self) -> RateLimiterStats: """Statistics for this rate limiter.""" return self._stats @property def name(self) -> str | None: """Name of this rate limiter.""" return self._name def _refill(self) -> None: """Refill tokens based on elapsed time. Must hold lock.""" now = time.monotonic() elapsed = now - self._last_refill self._tokens = min(self._max_tokens, self._tokens + elapsed * self._refill_rate) self._last_refill = now def _time_until_token(self) -> float: """Calculate time until at least 1 token is available. Must hold lock.""" if self._tokens >= 1.0: return 0.0 needed = 1.0 - self._tokens return needed / self._refill_rate
[docs] def time_until_token(self) -> float: """Calculate time until at least 1 token will be available. Returns: Seconds until a token is available. Returns 0.0 if token available now. Examples: >>> limiter = RateLimiter(rate=10, per=1.0) >>> limiter.time_until_token() # Tokens available 0.0 """ with self._lock: self._refill() return self._time_until_token()
[docs] def acquire(self, *, blocking: bool = True, timeout: float | None = None) -> bool: """Acquire a token from the bucket. Args: blocking: If True, wait until a token is available. timeout: Maximum time to wait in seconds (None = wait forever). Returns: True if token was acquired, False if non-blocking and no token. Raises: RateLimitExceededError: If timeout exceeded while waiting. Examples: >>> limiter = RateLimiter(rate=10, per=1.0) >>> limiter.acquire() # Blocks if needed True >>> limiter.acquire(blocking=False) # Returns immediately True """ start_time = time.monotonic() deadline = start_time + timeout if timeout is not None else None while True: with self._lock: self._refill() if self._tokens >= 1.0: self._tokens -= 1.0 wait_time = time.monotonic() - start_time if wait_time > 0.001: # Only record significant waits self._stats.record_wait(wait_time) self._stats.record_acquired() return True if not blocking: self._stats.record_rejected() return False # Calculate wait time wait_time = self._time_until_token() # Check timeout if deadline is not None: remaining = deadline - time.monotonic() if remaining <= 0: self._stats.record_rejected() raise RateLimitExceededError( f"Rate limit timeout after {timeout}s", retry_after=wait_time, ) wait_time = min(wait_time, remaining) # Wait outside the lock time.sleep(min(wait_time, 0.1)) # Cap sleep to allow interruption
[docs] def try_acquire(self) -> bool: """Try to acquire a token without blocking. Returns: True if token was acquired, False otherwise. Examples: >>> limiter = RateLimiter(rate=2, per=1.0) >>> limiter.try_acquire() True >>> limiter.try_acquire() True >>> limiter.try_acquire() # No tokens left False """ return self.acquire(blocking=False)
[docs] async def acquire_async(self, *, timeout: float | None = None) -> bool: """Acquire a token asynchronously. Args: timeout: Maximum time to wait in seconds. Returns: True when token is acquired. Raises: RateLimitExceededError: If timeout exceeded. Examples: >>> import asyncio >>> limiter = RateLimiter(rate=10, per=1.0) >>> asyncio.run(limiter.acquire_async()) True """ start_time = time.monotonic() deadline = start_time + timeout if timeout is not None else None while True: with self._lock: self._refill() if self._tokens >= 1.0: self._tokens -= 1.0 wait_time = time.monotonic() - start_time if wait_time > 0.001: self._stats.record_wait(wait_time) self._stats.record_acquired() return True wait_time = self._time_until_token() if deadline is not None: remaining = deadline - time.monotonic() if remaining <= 0: self._stats.record_rejected() raise RateLimitExceededError( f"Rate limit timeout after {timeout}s", retry_after=wait_time, ) wait_time = min(wait_time, remaining) # Async sleep outside the lock await asyncio.sleep(min(wait_time, 0.1))
[docs] def reset(self) -> None: """Reset the rate limiter to full capacity. Examples: >>> limiter = RateLimiter(rate=5, per=1.0) >>> for _ in range(5): ... limiter.try_acquire() True True True True True >>> limiter.try_acquire() False >>> limiter.reset() >>> limiter.try_acquire() True """ with self._lock: self._tokens = self._max_tokens self._last_refill = time.monotonic()
[docs] def __enter__(self) -> Self: """Enter context manager, acquiring a token.""" self.acquire() return self
[docs] def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object, ) -> None: """Exit context manager.""" pass
[docs] async def __aenter__(self) -> Self: """Enter async context manager, acquiring a token.""" await self.acquire_async() return self
[docs] async def __aexit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object, ) -> None: """Exit async context manager.""" pass
[docs] def __repr__(self) -> str: """Return string representation.""" name_part = f", name={self._name!r}" if self._name else "" return f"RateLimiter(rate={self._rate}, per={self._per}{name_part})"
# Type overloads for the decorator @overload def rate_limiter(fn: Callable[P, R]) -> Callable[P, R]: ... @overload def rate_limiter( fn: None = None, *, rate: float = 10.0, per: float = 1.0, burst: float | None = None, blocking: bool = True, timeout: float | None = None, name: str | None = None, ) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
[docs] def rate_limiter( fn: Callable[P, R] | None = None, *, rate: float = 10.0, per: float = 1.0, burst: float | None = None, blocking: bool = True, timeout: float | None = None, name: str | None = None, ) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]: """Rate limit calls to the decorated function. Can be used with or without arguments: - ``@rate_limiter`` - Uses defaults (10 requests/second) - ``@rate_limiter(rate=5, per=1.0)`` - 5 requests per second Args: fn: Function to decorate (when used without parentheses). rate: Maximum calls per period (default 10). per: Period in seconds (default 1.0). burst: Initial capacity (default = rate). blocking: If True, wait for token. If False, raise on limit. timeout: Maximum wait time in seconds. name: Name for the rate limiter. Returns: Decorated function that respects rate limits. Raises: RateLimitExceededError: If blocking=False and rate limit exceeded. Examples: Default rate limiting (10/sec): >>> @rate_limiter ... def call_api(): # doctest: +SKIP ... pass Custom rate: >>> @rate_limiter(rate=100, per=60.0) # 100 per minute ... def call_api(): # doctest: +SKIP ... pass Non-blocking mode: >>> @rate_limiter(rate=5, blocking=False) ... def fast_api(): # doctest: +SKIP ... pass # Raises RateLimitExceededError if limit hit """ # Create the limiter instance (shared across all calls) limiter = RateLimiter(rate=rate, per=per, burst=burst, name=name) def decorator(fn: Callable[P, R]) -> Callable[P, R]: # Check if function is async is_async = inspect.iscoroutinefunction(fn) if is_async: @functools.wraps(fn) async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R: if blocking: await limiter.acquire_async(timeout=timeout) elif not limiter.try_acquire(): raise RateLimitExceededError( f"Rate limit exceeded for {fn.__name__}", retry_after=limiter.time_until_token(), ) return await fn(*args, **kwargs) # type: ignore[no-any-return, misc] # Attach limiter for inspection async_wrapper._rate_limiter = limiter # type: ignore[attr-defined] return async_wrapper # type: ignore[return-value] @functools.wraps(fn) def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R: if blocking: limiter.acquire(timeout=timeout) elif not limiter.try_acquire(): raise RateLimitExceededError( f"Rate limit exceeded for {fn.__name__}", retry_after=limiter.time_until_token(), ) return fn(*args, **kwargs) # Attach limiter for inspection sync_wrapper._rate_limiter = limiter # type: ignore[attr-defined] return sync_wrapper # Handle @rate_limiter vs @rate_limiter() if fn is not None: return decorator(fn) return decorator
__all__ = [ "RateLimiter", "RateLimiterStats", "rate_limiter", ]