"""Async database wrapper for SQLite/SQLCipher.
Provides a high-level async interface for database operations with:
- Connection pooling
- Automatic retry on transient failures
- SQLCipher encryption support
- Transaction management
- Query helpers
"""
from __future__ import annotations
import logging
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, cast
from typing_extensions import Self
from kstlib.db.exceptions import TransactionError
from kstlib.db.pool import ConnectionPool, PoolStats
if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Sequence
from pathlib import Path
import aiosqlite
log = logging.getLogger(__name__)
# Maximum SQL length surfaced in TRACE logs. Long queries are truncated
# to avoid noise without losing the leading clause that tells you which
# table / operation is involved. Parameters are NEVER logged regardless.
_SQL_TRACE_TRUNCATE = 200
def _log_trace(msg: str, *args: object) -> None:
"""Log at TRACE level (custom level 5, below DEBUG)."""
from kstlib.logging import TRACE_LEVEL
log.log(TRACE_LEVEL, msg, *args)
def _truncate_sql(sql: str, limit: int = _SQL_TRACE_TRUNCATE) -> str:
"""Return ``sql`` truncated to ``limit`` characters, single-line."""
flat = " ".join(sql.split())
if len(flat) <= limit:
return flat
return flat[: limit - 3] + "..."
[docs]
@dataclass
class AsyncDatabase:
"""Async database interface for SQLite/SQLCipher.
Provides connection pooling, encryption, and query helpers
for async database operations.
Args:
path: Path to database file (or ":memory:" for in-memory).
cipher_key: Direct encryption key for SQLCipher.
cipher_env: Environment variable containing cipher key.
cipher_sops: Path to SOPS file containing cipher key.
cipher_sops_key: Key name in SOPS file (default: "db_key").
pool_min: Minimum pool connections.
pool_max: Maximum pool connections.
pool_timeout: Acquire timeout in seconds.
max_retries: Retry attempts on failure.
retry_delay: Delay between retries.
Examples:
Basic usage:
>>> db = AsyncDatabase(":memory:")
>>> db.path
':memory:'
With encryption:
>>> db = AsyncDatabase("app.db", cipher_key="secret") # doctest: +SKIP
With SOPS:
>>> db = AsyncDatabase("app.db", cipher_sops="secrets.yml") # doctest: +SKIP
"""
path: str | Path
cipher_key: str | None = field(default=None, repr=False)
cipher_env: str | None = None
cipher_sops: str | Path | None = None
cipher_sops_key: str = "db_key"
pool_min: int = 1
pool_max: int = 10
pool_timeout: float = 30.0
max_retries: int = 3
retry_delay: float = 0.5
_pool: ConnectionPool | None = field(default=None, repr=False)
_resolved_key: str | None = field(default=None, repr=False)
[docs]
def __post_init__(self) -> None:
"""Resolve cipher key and apply config defaults with hard limits."""
from kstlib.limits import get_db_limits
self.path = str(self.path)
# Load config defaults for any unset pool/retry params
limits = get_db_limits()
# Apply config defaults if using dataclass defaults (sentinel check)
# Use object.__setattr__ since dataclass fields are set
if self.pool_min == 1:
object.__setattr__(self, "pool_min", limits.pool_min_size)
if self.pool_max == 10:
object.__setattr__(self, "pool_max", limits.pool_max_size)
if self.pool_timeout == 30.0:
object.__setattr__(self, "pool_timeout", limits.pool_acquire_timeout)
if self.max_retries == 3:
object.__setattr__(self, "max_retries", limits.max_retries)
if self.retry_delay == 0.5:
object.__setattr__(self, "retry_delay", limits.retry_delay)
# Resolve encryption key if any source provided
if self.cipher_key or self.cipher_env or self.cipher_sops:
from kstlib.db.cipher import resolve_cipher_key
self._resolved_key = resolve_cipher_key(
passphrase=self.cipher_key,
env_var=self.cipher_env,
sops_path=self.cipher_sops,
sops_key=self.cipher_sops_key,
)
def _ensure_pool(self) -> ConnectionPool:
"""Ensure connection pool is initialized."""
if self._pool is None:
# path is already converted to str in __post_init__
db_path = self.path if isinstance(self.path, str) else str(self.path)
self._pool = ConnectionPool(
db_path=db_path,
min_size=self.pool_min,
max_size=self.pool_max,
acquire_timeout=self.pool_timeout,
max_retries=self.max_retries,
retry_delay=self.retry_delay,
cipher_key=self._resolved_key,
)
return self._pool
[docs]
async def connect(self) -> None:
"""Initialize the connection pool.
Called automatically on first operation, but can be
called explicitly for eager initialization.
"""
pool = self._ensure_pool()
await pool._init_pool()
log.info("Database connected: %s", self.path)
[docs]
async def close(self) -> None:
"""Close all connections and shutdown the pool."""
if self._pool:
await self._pool.close()
self._pool = None
# Scrub resolved key from memory
was_encrypted = self._resolved_key is not None
self._resolved_key = None
log.info("Database closed: %s (encrypted=%s)", self.path, was_encrypted)
[docs]
async def __aenter__(self) -> Self:
"""Async context manager entry."""
await self.connect()
return self
[docs]
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: object,
) -> None:
"""Async context manager exit."""
await self.close()
[docs]
@asynccontextmanager
async def connection(self) -> AsyncGenerator[aiosqlite.Connection, None]:
"""Get a connection from the pool.
Yields:
Database connection.
Examples:
>>> async with db.connection() as conn: # doctest: +SKIP
... await conn.execute("SELECT 1")
"""
pool = self._ensure_pool()
async with pool.connection() as conn:
yield conn
[docs]
@asynccontextmanager
async def transaction(self) -> AsyncGenerator[aiosqlite.Connection, None]:
"""Execute operations within a transaction.
Automatically commits on success, rolls back on error.
Yields:
Database connection within transaction.
Raises:
TransactionError: If transaction fails.
Examples:
>>> async with db.transaction() as conn: # doctest: +SKIP
... await conn.execute("INSERT INTO users VALUES (?)", ("alice",))
... await conn.execute("INSERT INTO users VALUES (?)", ("bob",))
"""
pool = self._ensure_pool()
conn = await pool.acquire()
try:
await conn.execute("BEGIN")
yield conn
await conn.commit()
except Exception as e:
try:
await conn.rollback()
except Exception:
# Rollback may fail on closed connection - intentional silent catch
log.debug("Rollback failed (connection may be closed)", exc_info=True)
raise TransactionError(f"Transaction failed: {e}") from e
finally:
await pool.release(conn)
[docs]
async def execute(
self,
sql: str,
parameters: Sequence[Any] | None = None,
) -> aiosqlite.Cursor:
"""Execute a single SQL statement.
Args:
sql: SQL statement to execute.
parameters: Query parameters.
Returns:
Cursor with results.
Examples:
>>> await db.execute("CREATE TABLE test (id INTEGER)") # doctest: +SKIP
"""
# Sanitization invariant : log the truncated SQL only, NEVER the
# parameters tuple (would leak PII / credentials in WHERE/INSERT
# values). Truncation prevents giant DDL from drowning the trace.
_log_trace("[DB] Execute: %s (params=%s)", _truncate_sql(sql), "yes" if parameters else "no")
pool = self._ensure_pool()
async with pool.connection() as conn:
if parameters:
return await conn.execute(sql, parameters)
return await conn.execute(sql)
[docs]
async def executemany(
self,
sql: str,
parameters: Sequence[Sequence[Any]],
) -> aiosqlite.Cursor:
"""Execute SQL statement for multiple parameter sets.
Args:
sql: SQL statement to execute.
parameters: Sequence of parameter tuples.
Returns:
Cursor with results.
Examples:
>>> await db.executemany( # doctest: +SKIP
... "INSERT INTO test VALUES (?)",
... [(1,), (2,), (3,)]
... )
"""
pool = self._ensure_pool()
async with pool.connection() as conn:
return await conn.executemany(sql, parameters)
[docs]
async def fetch_one(
self,
sql: str,
parameters: Sequence[Any] | None = None,
) -> tuple[Any, ...] | None:
"""Fetch a single row.
Args:
sql: SQL query.
parameters: Query parameters.
Returns:
Row tuple or None if no results.
Examples:
>>> row = await db.fetch_one("SELECT * FROM test WHERE id=?", (1,)) # doctest: +SKIP
"""
pool = self._ensure_pool()
async with pool.connection() as conn:
if parameters:
cursor = await conn.execute(sql, parameters)
else:
cursor = await conn.execute(sql)
row = await cursor.fetchone()
return cast("tuple[Any, ...] | None", row)
[docs]
async def fetch_all(
self,
sql: str,
parameters: Sequence[Any] | None = None,
) -> list[tuple[Any, ...]]:
"""Fetch all rows.
Args:
sql: SQL query.
parameters: Query parameters.
Returns:
List of row tuples.
Examples:
>>> rows = await db.fetch_all("SELECT * FROM test") # doctest: +SKIP
"""
pool = self._ensure_pool()
async with pool.connection() as conn:
if parameters:
cursor = await conn.execute(sql, parameters)
else:
cursor = await conn.execute(sql)
rows = await cursor.fetchall()
return cast("list[tuple[Any, ...]]", rows)
[docs]
async def fetch_value(
self,
sql: str,
parameters: Sequence[Any] | None = None,
) -> Any:
"""Fetch a single value (first column of first row).
Args:
sql: SQL query.
parameters: Query parameters.
Returns:
Single value or None.
Examples:
>>> count = await db.fetch_value("SELECT count(*) FROM test") # doctest: +SKIP
"""
row = await self.fetch_one(sql, parameters)
return row[0] if row else None
[docs]
async def table_exists(self, table_name: str) -> bool:
"""Check if a table exists.
Args:
table_name: Name of the table.
Returns:
True if table exists.
Examples:
>>> await db.table_exists("users") # doctest: +SKIP
False
"""
sql = "SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?"
count = await self.fetch_value(sql, (table_name,))
return bool(count)
@property
def stats(self) -> PoolStats:
"""Get connection pool statistics."""
if self._pool:
return self._pool.stats
return PoolStats()
@property
def is_encrypted(self) -> bool:
"""Whether database is configured for encryption."""
return self._resolved_key is not None
@property
def pool_size(self) -> int:
"""Current number of connections in pool."""
if self._pool:
return self._pool.size
return 0
__all__ = ["AsyncDatabase"]