"""Connection pool with retry support for async SQLite.
Provides connection pooling with:
- Configurable pool size
- Connection health checks
- Automatic retry on transient failures
- Graceful shutdown
"""
from __future__ import annotations
import asyncio
import logging
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from kstlib.db.exceptions import DatabaseConnectionError, PoolExhaustedError
from kstlib.limits import (
HARD_MAX_DB_MAX_RETRIES,
HARD_MAX_DB_RETRY_DELAY,
HARD_MAX_POOL_ACQUIRE_TIMEOUT,
HARD_MAX_POOL_MAX_SIZE,
HARD_MAX_POOL_MIN_SIZE,
HARD_MIN_DB_MAX_RETRIES,
HARD_MIN_DB_RETRY_DELAY,
HARD_MIN_POOL_ACQUIRE_TIMEOUT,
HARD_MIN_POOL_MAX_SIZE,
HARD_MIN_POOL_MIN_SIZE,
)
if TYPE_CHECKING:
from collections.abc import AsyncGenerator
import aiosqlite
log = logging.getLogger(__name__)
[docs]
@dataclass
class PoolStats:
"""Statistics for connection pool monitoring.
Attributes:
total_connections: Total connections created.
active_connections: Currently in-use connections.
idle_connections: Available connections in pool.
total_acquired: Total acquire operations.
total_released: Total release operations.
total_timeouts: Acquire operations that timed out.
total_errors: Connection errors encountered.
Examples:
>>> stats = PoolStats()
>>> stats.total_connections
0
"""
total_connections: int = 0
active_connections: int = 0
idle_connections: int = 0
total_acquired: int = 0
total_released: int = 0
total_timeouts: int = 0
total_errors: int = 0
[docs]
@dataclass
class ConnectionPool:
"""Async connection pool for SQLite/SQLCipher databases.
Manages a pool of database connections with health checks
and automatic retry on failures.
Args:
db_path: Path to database file.
min_size: Minimum connections to maintain.
max_size: Maximum connections allowed.
acquire_timeout: Timeout for acquiring connection.
max_retries: Retry attempts on failure.
retry_delay: Delay between retries.
cipher_key: Optional encryption key for SQLCipher.
on_connect: Callback after connection established.
Examples:
>>> pool = ConnectionPool(":memory:", min_size=1, max_size=5)
>>> pool.max_size
5
"""
db_path: str
min_size: int = 1
max_size: int = 10
acquire_timeout: float = 30.0
max_retries: int = 3
retry_delay: float = 0.5
cipher_key: str | None = field(default=None, repr=False)
on_connect: Any | None = None # Callable[[aiosqlite.Connection], Awaitable[None]]
_pool: asyncio.Queue[aiosqlite.Connection] = field(default_factory=lambda: asyncio.Queue(), repr=False)
_connections: set[aiosqlite.Connection] = field(default_factory=set, repr=False)
_lock: asyncio.Lock = field(default_factory=asyncio.Lock, repr=False)
_closed: bool = field(default=False, repr=False)
_stats: PoolStats = field(default_factory=PoolStats, repr=False)
[docs]
def __post_init__(self) -> None:
"""Validate and clamp configuration values to hard limits."""
# Clamp pool sizes
self.min_size = max(HARD_MIN_POOL_MIN_SIZE, min(HARD_MAX_POOL_MIN_SIZE, self.min_size))
self.max_size = max(HARD_MIN_POOL_MAX_SIZE, min(HARD_MAX_POOL_MAX_SIZE, self.max_size))
# Ensure min_size <= max_size
self.min_size = min(self.min_size, self.max_size)
# Clamp timeouts and delays
self.acquire_timeout = max(
HARD_MIN_POOL_ACQUIRE_TIMEOUT, min(HARD_MAX_POOL_ACQUIRE_TIMEOUT, self.acquire_timeout)
)
self.max_retries = max(HARD_MIN_DB_MAX_RETRIES, min(HARD_MAX_DB_MAX_RETRIES, self.max_retries))
self.retry_delay = max(HARD_MIN_DB_RETRY_DELAY, min(HARD_MAX_DB_RETRY_DELAY, self.retry_delay))
async def _create_connection(self) -> aiosqlite.Connection:
"""Create a new database connection.
Uses aiosqlcipher for encrypted connections (when cipher_key is set),
or standard aiosqlite for unencrypted connections.
"""
if self.cipher_key:
# Use SQLCipher for encrypted database
from kstlib.db.aiosqlcipher import connect as aiosqlcipher_connect
conn = await aiosqlcipher_connect(self.db_path, cipher_key=self.cipher_key)
else:
# Use standard sqlite3 for unencrypted database
import aiosqlite
conn = await aiosqlite.connect(self.db_path, isolation_level=None)
# Enable incremental auto-vacuum on new databases only.
# auto_vacuum must be set before any write operation (including journal_mode).
if self.db_path != ":memory:":
cursor = await conn.execute("PRAGMA auto_vacuum")
row = await cursor.fetchone()
if row and row[0] == 0: # NONE
cursor2 = await conn.execute("SELECT count(*) FROM sqlite_master")
row2 = await cursor2.fetchone()
if row2 and row2[0] == 0: # empty DB
await conn.execute("PRAGMA auto_vacuum=INCREMENTAL")
# Enable WAL mode for better concurrency (works with both)
await conn.execute("PRAGMA journal_mode=WAL")
await conn.execute("PRAGMA foreign_keys=ON")
# Call custom on_connect handler
if self.on_connect:
await self.on_connect(conn)
self._stats.total_connections += 1
log.debug("Created new database connection (total: %d)", self._stats.total_connections)
return conn
async def _init_pool(self) -> None:
"""Initialize the connection pool with min_size connections."""
async with self._lock:
for _ in range(self.min_size):
conn = await self._create_connection()
self._connections.add(conn)
await self._pool.put(conn)
self._stats.idle_connections = self.min_size
[docs]
async def acquire(self) -> aiosqlite.Connection:
"""Acquire a connection from the pool.
Returns:
Database connection.
Raises:
PoolExhaustedError: If no connection available within timeout.
DatabaseConnectionError: If connection creation fails after retries.
"""
if self._closed:
raise DatabaseConnectionError("Pool is closed")
# Initialize pool on first acquire
if not self._connections:
await self._init_pool()
for attempt in range(self.max_retries):
try:
# Try to get from pool with timeout
try:
conn = await asyncio.wait_for(self._pool.get(), timeout=self.acquire_timeout)
self._stats.idle_connections -= 1
except asyncio.TimeoutError:
# Pool empty, try to create new if under max
async with self._lock:
if len(self._connections) < self.max_size:
conn = await self._create_connection()
self._connections.add(conn)
else:
self._stats.total_timeouts += 1
raise PoolExhaustedError(
f"Pool exhausted (max={self.max_size}), timeout after {self.acquire_timeout}s"
) from None
# Verify connection is alive
try:
await conn.execute("SELECT 1")
except Exception:
# Connection dead, remove and retry
self._connections.discard(conn)
await conn.close()
self._stats.total_errors += 1
continue
self._stats.active_connections += 1
self._stats.total_acquired += 1
return conn
except PoolExhaustedError:
raise
except Exception as e:
self._stats.total_errors += 1
if attempt < self.max_retries - 1:
log.warning(
"Connection attempt %d failed: %s, retrying...",
attempt + 1,
e,
)
await asyncio.sleep(self.retry_delay)
else:
raise DatabaseConnectionError(
f"Failed to acquire connection after {self.max_retries} attempts"
) from e
raise DatabaseConnectionError("Failed to acquire connection")
[docs]
async def release(self, conn: aiosqlite.Connection) -> None:
"""Release a connection back to the pool.
Args:
conn: Connection to release.
"""
if self._closed:
await conn.close()
return
self._stats.active_connections -= 1
self._stats.total_released += 1
# Return to pool
await self._pool.put(conn)
self._stats.idle_connections += 1
[docs]
@asynccontextmanager
async def connection(self) -> AsyncGenerator[aiosqlite.Connection, None]:
"""Context manager for acquiring and releasing connections.
Yields:
Database connection.
Examples:
>>> async with pool.connection() as conn: # doctest: +SKIP
... await conn.execute("SELECT 1")
"""
conn = await self.acquire()
try:
yield conn
finally:
await self.release(conn)
[docs]
async def close(self) -> None:
"""Close all connections and shutdown the pool."""
self._closed = True
async with self._lock:
# Close all connections
for conn in self._connections:
try: # reason: per-connection best-effort close on pool teardown
await conn.execute("PRAGMA optimize")
await conn.execute("PRAGMA wal_checkpoint(TRUNCATE)")
await conn.close()
except Exception:
# Connection may be already closed - intentional silent catch
log.debug("Failed to close connection (may be already closed)", exc_info=True)
self._connections.clear()
# Empty the queue
while not self._pool.empty():
try: # reason: race-safe queue drain (QueueEmpty IS the loop terminator)
self._pool.get_nowait()
except asyncio.QueueEmpty:
break
self._stats.active_connections = 0
self._stats.idle_connections = 0
# Clear cipher key reference to reduce exposure window
self.cipher_key = None
log.debug("Connection pool closed")
@property
def stats(self) -> PoolStats:
"""Get pool statistics."""
return self._stats
@property
def size(self) -> int:
"""Current number of connections in pool."""
return len(self._connections)
@property
def is_closed(self) -> bool:
"""Whether the pool is closed."""
return self._closed
__all__ = ["ConnectionPool", "PoolStats"]