Source code for kstlib.auth.session

"""Authenticated HTTP session wrapper."""

from __future__ import annotations

from enum import Enum
from http import HTTPStatus
from typing import TYPE_CHECKING, Any

import httpx
from typing_extensions import Self

from kstlib.auth.errors import AuthError, TokenExpiredError
from kstlib.logging import TRACE_LEVEL, get_logger
from kstlib.ssl import build_ssl_context

if TYPE_CHECKING:
    import types

    from kstlib.auth.providers.base import AbstractAuthProvider

logger = get_logger(__name__)

# Default timeout for HTTP requests
DEFAULT_TIMEOUT = 30.0


[docs] class AuthSession: """HTTP session with automatic token injection and refresh. Wraps httpx.Client (sync) or httpx.AsyncClient (async) to automatically: - Inject Bearer token in Authorization header - Refresh expired tokens before making requests - Handle 401 responses by refreshing and retrying Example (sync): >>> from kstlib.auth import AuthSession, get_provider # doctest: +SKIP >>> provider = get_provider("corporate") # doctest: +SKIP >>> with AuthSession(provider) as session: # doctest: +SKIP ... response = session.get("https://api.example.com/users/me") ... print(response.json()) Example (async): >>> async with AuthSession(provider) as session: # doctest: +SKIP ... response = await session.get("https://api.example.com/users/me") ... print(response.json()) """
[docs] def __init__( self, provider: AbstractAuthProvider, *, timeout: float = DEFAULT_TIMEOUT, auto_refresh: bool = True, retry_on_401: bool = True, ssl_verify: bool | None = None, ssl_ca_bundle: str | None = None, ) -> None: """Initialize authenticated session. Args: provider: Authentication provider to use for tokens. timeout: Default request timeout in seconds. auto_refresh: Automatically refresh expired tokens before requests. retry_on_401: Retry request after token refresh on 401 response. ssl_verify: Override SSL verification (True/False). If None, uses provider's SSL config or global config. ssl_ca_bundle: Override CA bundle path. If None, uses provider's SSL config or global config. """ self.provider = provider self.timeout = timeout self.auto_refresh = auto_refresh self.retry_on_401 = retry_on_401 # Build SSL context: kwargs > provider config > global config if ssl_verify is None and ssl_ca_bundle is None and hasattr(provider, "config"): # Use provider's SSL settings if available # Check for actual bool/str values (not MagicMock from tests) provider_config = getattr(provider, "config", None) provider_ssl_verify: bool | None = None provider_ca_bundle: str | None = None if provider_config is not None: ssl_verify_attr = getattr(provider_config, "ssl_verify", None) if isinstance(ssl_verify_attr, bool): provider_ssl_verify = ssl_verify_attr ca_bundle_attr = getattr(provider_config, "ssl_ca_bundle", None) if isinstance(ca_bundle_attr, str): provider_ca_bundle = ca_bundle_attr self._ssl_context = build_ssl_context( ssl_verify=provider_ssl_verify, ssl_ca_bundle=provider_ca_bundle, ) else: # Use explicit kwargs or fall back to global config self._ssl_context = build_ssl_context( ssl_verify=ssl_verify, ssl_ca_bundle=ssl_ca_bundle, ) self._sync_client: httpx.Client | None = None self._async_client: httpx.AsyncClient | None = None
# ───────────────────────────────────────────────────────────────────────── # Context managers # ─────────────────────────────────────────────────────────────────────────
[docs] def __enter__(self) -> Self: """Enter sync context manager.""" self._sync_client = httpx.Client(timeout=self.timeout, verify=self._ssl_context) return self
[docs] def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: types.TracebackType | None, ) -> None: """Exit sync context manager.""" if self._sync_client: self._sync_client.close() self._sync_client = None
[docs] async def __aenter__(self) -> Self: """Enter async context manager.""" self._async_client = httpx.AsyncClient(timeout=self.timeout, verify=self._ssl_context) return self
[docs] async def __aexit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: types.TracebackType | None, ) -> None: """Exit async context manager.""" if self._async_client: await self._async_client.aclose() self._async_client = None
# ───────────────────────────────────────────────────────────────────────── # Token handling # ───────────────────────────────────────────────────────────────────────── def _get_auth_header(self) -> dict[str, str]: """Get Authorization header with current token. Returns: Dict with Authorization header. Raises: TokenExpiredError: If no valid token is available. """ token = self.provider.get_token(auto_refresh=self.auto_refresh) if token is None: raise TokenExpiredError("No token available - authentication required") if token.is_expired and not token.is_refreshable: raise TokenExpiredError("Token expired and cannot be refreshed") # Extract string value from TokenType enum or use as-is if already a string token_type = token.token_type.value if isinstance(token.token_type, Enum) else token.token_type return {"Authorization": f"{token_type} {token.access_token}"} def _should_retry(self, response: httpx.Response, retried: bool) -> bool: """Check if request should be retried after 401.""" return ( self.retry_on_401 and not retried and response.status_code == HTTPStatus.UNAUTHORIZED and self.provider.get_token(auto_refresh=False) is not None ) # ───────────────────────────────────────────────────────────────────────── # Sync HTTP methods # ───────────────────────────────────────────────────────────────────────── def _request( self, method: str, url: str, *, _retried: bool = False, **kwargs: Any, ) -> httpx.Response: """Make an authenticated HTTP request (sync). Args: method: HTTP method. url: Request URL. _retried: Internal flag to prevent infinite retry. **kwargs: Additional arguments for httpx. Returns: HTTP response. """ if self._sync_client is None: msg = "Session not initialized - use 'with AuthSession(...) as session:'" raise AuthError(msg) # Merge auth header with any existing headers headers = kwargs.pop("headers", {}) headers.update(self._get_auth_header()) kwargs["headers"] = headers if logger.isEnabledFor(TRACE_LEVEL): logger.log(TRACE_LEVEL, "[SESSION] %s %s", method, url) response = self._sync_client.request(method, url, **kwargs) if logger.isEnabledFor(TRACE_LEVEL): logger.log(TRACE_LEVEL, "[SESSION] Response: %d %s", response.status_code, response.reason_phrase) # Retry on 401 if configured if self._should_retry(response, _retried): logger.debug("Got 401, attempting token refresh and retry") try: self.provider.refresh() return self._request(method, url, _retried=True, **kwargs) except Exception: # pylint: disable=broad-exception-caught # Intentional catch-all for best-effort refresh logger.warning("Token refresh failed, returning original 401 response") return response
[docs] def get(self, url: str, **kwargs: Any) -> httpx.Response: """Make authenticated GET request.""" return self._request("GET", url, **kwargs)
[docs] def post(self, url: str, **kwargs: Any) -> httpx.Response: """Make authenticated POST request.""" return self._request("POST", url, **kwargs)
[docs] def put(self, url: str, **kwargs: Any) -> httpx.Response: """Make authenticated PUT request.""" return self._request("PUT", url, **kwargs)
[docs] def patch(self, url: str, **kwargs: Any) -> httpx.Response: """Make authenticated PATCH request.""" return self._request("PATCH", url, **kwargs)
[docs] def delete(self, url: str, **kwargs: Any) -> httpx.Response: """Make authenticated DELETE request.""" return self._request("DELETE", url, **kwargs)
[docs] def head(self, url: str, **kwargs: Any) -> httpx.Response: """Make authenticated HEAD request.""" return self._request("HEAD", url, **kwargs)
[docs] def options(self, url: str, **kwargs: Any) -> httpx.Response: """Make authenticated OPTIONS request.""" return self._request("OPTIONS", url, **kwargs)
# ───────────────────────────────────────────────────────────────────────── # Async HTTP methods # ───────────────────────────────────────────────────────────────────────── async def _arequest( self, method: str, url: str, *, _retried: bool = False, **kwargs: Any, ) -> httpx.Response: """Make an authenticated HTTP request (async). Args: method: HTTP method. url: Request URL. _retried: Internal flag to prevent infinite retry. **kwargs: Additional arguments for httpx. Returns: HTTP response. """ if self._async_client is None: msg = "Session not initialized - use 'async with AuthSession(...) as session:'" raise AuthError(msg) # Merge auth header with any existing headers headers = kwargs.pop("headers", {}) headers.update(self._get_auth_header()) kwargs["headers"] = headers if logger.isEnabledFor(TRACE_LEVEL): logger.log(TRACE_LEVEL, "[SESSION] %s %s (async)", method, url) response = await self._async_client.request(method, url, **kwargs) if logger.isEnabledFor(TRACE_LEVEL): logger.log(TRACE_LEVEL, "[SESSION] Response: %d %s", response.status_code, response.reason_phrase) # Retry on 401 if configured if self._should_retry(response, _retried): logger.debug("Got 401, attempting token refresh and retry") try: self.provider.refresh() return await self._arequest(method, url, _retried=True, **kwargs) except Exception: # pylint: disable=broad-exception-caught # Intentional catch-all for best-effort refresh logger.warning("Token refresh failed, returning original 401 response") return response
[docs] async def aget(self, url: str, **kwargs: Any) -> httpx.Response: """Make authenticated async GET request.""" return await self._arequest("GET", url, **kwargs)
[docs] async def apost(self, url: str, **kwargs: Any) -> httpx.Response: """Make authenticated async POST request.""" return await self._arequest("POST", url, **kwargs)
[docs] async def aput(self, url: str, **kwargs: Any) -> httpx.Response: """Make authenticated async PUT request.""" return await self._arequest("PUT", url, **kwargs)
[docs] async def apatch(self, url: str, **kwargs: Any) -> httpx.Response: """Make authenticated async PATCH request.""" return await self._arequest("PATCH", url, **kwargs)
[docs] async def adelete(self, url: str, **kwargs: Any) -> httpx.Response: """Make authenticated async DELETE request.""" return await self._arequest("DELETE", url, **kwargs)
[docs] async def ahead(self, url: str, **kwargs: Any) -> httpx.Response: """Make authenticated async HEAD request.""" return await self._arequest("HEAD", url, **kwargs)
[docs] async def aoptions(self, url: str, **kwargs: Any) -> httpx.Response: """Make authenticated async OPTIONS request.""" return await self._arequest("OPTIONS", url, **kwargs)
__all__ = ["AuthSession"]