Source code for kstlib.auth.providers.base

"""Abstract base class for authentication providers."""

from __future__ import annotations

import threading
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse

from typing_extensions import Self

from kstlib.logging import TRACE_LEVEL, get_logger
from kstlib.ssl import validate_ca_bundle_path, validate_ssl_verify

if TYPE_CHECKING:
    import types

    from kstlib.auth.models import AuthFlow, PreflightReport, Token
    from kstlib.auth.token import AbstractTokenStorage

logger = get_logger(__name__)


@dataclass
class AuthProviderConfig:  # pylint: disable=too-many-instance-attributes
    """Configuration for an authentication provider.

    Supports three modes:

    1. **Auto discovery** (OIDC): Only ``issuer`` provided, endpoints discovered via
       ``.well-known/openid-configuration``.

    2. **Hybrid mode** (OIDC): ``issuer`` + some explicit endpoints. Discovery fills
       missing endpoints, explicit ones take precedence.

    3. **Full manual** (OAuth2/OIDC): No ``issuer``, all endpoints explicit.
       No discovery attempted.

    Attributes:
        client_id: OAuth2 client identifier.
        client_secret: Optional client secret (not needed for public clients with PKCE).
        authorize_url: Authorization endpoint URL.
        token_url: Token endpoint URL.
        revoke_url: Optional token revocation endpoint.
        userinfo_url: Optional UserInfo endpoint URL.
        jwks_uri: JWKS endpoint for ID token signature verification.
        end_session_endpoint: Logout/end session endpoint.
        issuer: OIDC issuer URL (enables discovery).
        scopes: List of OAuth2 scopes to request.
        redirect_uri: Callback URI for authorization code flow.
        pkce: Enable PKCE extension (default True for OIDC).
        discovery_ttl: Cache TTL for OIDC discovery document (seconds).
        headers: Custom HTTP headers to send with all IDP requests.
        ssl_verify: Enable SSL certificate verification (default True).
            Set to False only for development with self-signed certificates.
        ssl_ca_bundle: Path to custom CA bundle file for corporate PKI.
            If provided, ssl_verify is implicitly True.
        extra: Additional provider-specific configuration.

    Example:
        Auto discovery (Keycloak, Auth0, etc.)::

            AuthProviderConfig(
                client_id="my-app",
                issuer="http://localhost:8080/realms/test",
            )

        Hybrid mode (discovery + override)::

            AuthProviderConfig(
                client_id="my-app",
                issuer="https://idp.corp.local",
                end_session_endpoint="https://idp.corp.local/custom/logout",  # Override
            )

        Full manual (legacy IDP without discovery)::

            AuthProviderConfig(
                client_id="my-app",
                authorize_url="https://old-idp.local/auth",
                token_url="https://old-idp.local/token",
                jwks_uri="https://old-idp.local/certs",
            )

    """

    client_id: str
    client_secret: str | None = None
    authorize_url: str | None = None
    token_url: str | None = None
    revoke_url: str | None = None
    userinfo_url: str | None = None
    jwks_uri: str | None = None
    end_session_endpoint: str | None = None
    issuer: str | None = None
    scopes: list[str] = field(default_factory=lambda: ["openid"])
    redirect_uri: str = "http://127.0.0.1:8400/callback"
    pkce: bool = True
    discovery_ttl: int = 3600
    headers: dict[str, str] = field(default_factory=dict)
    ssl_verify: bool = True
    ssl_ca_bundle: str | None = None
    extra: dict[str, Any] = field(default_factory=dict)

    def __post_init__(self) -> None:
        """Validate configuration."""
        if not self.issuer and not (self.authorize_url and self.token_url):
            msg = "Either 'issuer' (OIDC with discovery) or both 'authorize_url' and 'token_url' (manual) required"
            raise ValueError(msg)

        # Validate redirect_uri scheme and host
        parsed_redirect = urlparse(self.redirect_uri)
        if parsed_redirect.scheme not in ("http", "https"):
            msg = f"redirect_uri must use http or https scheme, got: {parsed_redirect.scheme!r}"
            raise ValueError(msg)
        if parsed_redirect.hostname not in ("127.0.0.1", "localhost", "::1"):
            logger.warning(
                "[SECURITY] redirect_uri host '%s' is not localhost - ensure this is intentional",
                parsed_redirect.hostname,
            )

        # SSL/TLS validation (delegated to kstlib.ssl for DRY)
        validate_ssl_verify(self.ssl_verify)

        if self.ssl_ca_bundle is not None:
            validated_path = validate_ca_bundle_path(self.ssl_ca_bundle)
            object.__setattr__(self, "ssl_ca_bundle", validated_path)

    @property
    def has_explicit_endpoints(self) -> bool:
        """Check if any endpoints are explicitly configured."""
        return any(
            [
                self.authorize_url,
                self.token_url,
                self.userinfo_url,
                self.jwks_uri,
                self.end_session_endpoint,
                self.revoke_url,
            ]
        )


[docs] class AbstractAuthProvider(ABC): """Abstract base class for OAuth2/OIDC authentication providers. Subclasses must implement the abstract methods to handle the specific authentication flow (OAuth2, OIDC, etc.). Attributes: name: Provider identifier (matches config key). config: Provider configuration. token_storage: Storage backend for tokens. """
[docs] def __init__( self, name: str, config: AuthProviderConfig, token_storage: AbstractTokenStorage, ) -> None: """Initialize the provider. Args: name: Provider identifier. config: Provider configuration. token_storage: Token storage backend. """ self.name = name self.config = config self.token_storage = token_storage self._current_token: Token | None = None self._refresh_lock = threading.Lock()
# ───────────────────────────────────────────────────────────────────────── # Properties # ───────────────────────────────────────────────────────────────────────── @property def is_authenticated(self) -> bool: """Check if a valid (non-expired) token is available.""" token = self.get_token(auto_refresh=False) return token is not None and not token.is_expired @property @abstractmethod def flow(self) -> AuthFlow: """Return the OAuth2 flow used by this provider.""" # ───────────────────────────────────────────────────────────────────────── # Authorization flow (abstract) # ─────────────────────────────────────────────────────────────────────────
[docs] @abstractmethod def get_authorization_url(self, state: str | None = None) -> tuple[str, str]: """Generate the authorization URL for the user to visit. Args: state: Optional state parameter. Generated if not provided. Returns: Tuple of (authorization_url, state). """
[docs] @abstractmethod def exchange_code( self, code: str, state: str, *, code_verifier: str | None = None, ) -> Token: """Exchange an authorization code for tokens. Args: code: Authorization code from callback. state: State parameter for CSRF validation. code_verifier: PKCE code verifier (required if PKCE was used). Returns: Token object with access_token, refresh_token, etc. Raises: TokenExchangeError: If the exchange fails. """
[docs] @abstractmethod def refresh(self, token: Token | None = None) -> Token: """Refresh an expired token. Args: token: Token to refresh. Uses stored token if not provided. Returns: New Token object. Raises: TokenRefreshError: If refresh fails or no refresh_token available. """
[docs] @abstractmethod def revoke(self, token: Token | None = None) -> bool: """Revoke a token at the authorization server. Args: token: Token to revoke. Uses stored token if not provided. Returns: True if revoked successfully, False if revocation not supported. """
# ───────────────────────────────────────────────────────────────────────── # Token management # ─────────────────────────────────────────────────────────────────────────
[docs] def get_token(self, *, auto_refresh: bool = True) -> Token | None: """Get the current token, optionally refreshing if expired. Args: auto_refresh: If True and token is expired, attempt refresh. Returns: Token if available, None otherwise. """ if self._current_token is None: if logger.isEnabledFor(TRACE_LEVEL): logger.log(TRACE_LEVEL, "[AUTH] Loading token from storage for '%s'", self.name) self._current_token = self.token_storage.load(self.name) if self._current_token is None: if logger.isEnabledFor(TRACE_LEVEL): logger.log(TRACE_LEVEL, "[AUTH] No token found for '%s'", self.name) return None if self._current_token.should_refresh and auto_refresh: if self._current_token.is_refreshable: self._try_refresh_token() else: logger.debug("Token expired and not refreshable for provider '%s'", self.name) return self._current_token
def _try_refresh_token(self) -> None: """Attempt to refresh the current token with thread-safety.""" # Lock to prevent concurrent refresh attempts (e.g. multi-threaded bots) with self._refresh_lock: # Re-check after acquiring lock (another thread may have refreshed) if not self._current_token or not self._current_token.should_refresh: return if logger.isEnabledFor(TRACE_LEVEL): logger.log(TRACE_LEVEL, "[AUTH] Token needs refresh for '%s'", self.name) try: self._current_token = self.refresh(self._current_token) self.token_storage.save(self.name, self._current_token) if logger.isEnabledFor(TRACE_LEVEL): logger.log(TRACE_LEVEL, "[AUTH] Token refreshed successfully for '%s'", self.name) except Exception as e: # pylint: disable=broad-exception-caught # Option C : keep WARNING free of the server's error_description # which may carry user-enumeration hints ("User <email> not # found"), token state ("Token expired for principal CN=...") # or other diagnostic detail the OIDC server returns. Emit a # short generic WARNING and stash the redacted detail in # TRACE for explicit opt-in. from kstlib._shared.redaction import redact_sensitive logger.warning( "Token refresh failed for '%s' (see TRACE for details). Using cached token.", self.name, ) if logger.isEnabledFor(TRACE_LEVEL): logger.log(TRACE_LEVEL, "Token refresh error detail: %s", redact_sensitive(str(e))) logger.debug("Token refresh traceback:", exc_info=True)
[docs] def save_token(self, token: Token) -> None: """Save a token to storage. Args: token: Token to save. """ self._current_token = token self.token_storage.save(self.name, token)
[docs] def clear_token(self) -> None: """Clear the current token from memory and storage.""" self._current_token = None self.token_storage.delete(self.name)
# ───────────────────────────────────────────────────────────────────────── # Preflight validation # ─────────────────────────────────────────────────────────────────────────
[docs] @abstractmethod def preflight(self) -> PreflightReport: """Run preflight validation checks. Returns: PreflightReport with results for each validation step. """
# ───────────────────────────────────────────────────────────────────────── # Context manager support # ─────────────────────────────────────────────────────────────────────────
[docs] def __enter__(self) -> Self: """Enter context manager.""" return self
[docs] def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: types.TracebackType | None, ) -> None: """Exit context manager - clear sensitive data from memory.""" self._current_token = None
# ───────────────────────────────────────────────────────────────────────────── # Helper for from_config factory pattern # ───────────────────────────────────────────────────────────────────────────── def load_provider_from_config( provider_name: str, allowed_types: tuple[str, ...], type_label: str, config: dict[str, Any] | None = None, **overrides: Any, ) -> tuple[AuthProviderConfig, AbstractTokenStorage]: """Load provider configuration and token storage from config file. This helper factorizes the common logic for OAuth2Provider.from_config() and OIDCProvider.from_config(). Args: provider_name: Name of the provider in config. allowed_types: Tuple of allowed provider type strings (e.g., ("oidc", "openid")). type_label: Human-readable type label for error messages (e.g., "oidc"). config: Optional explicit config dict. **overrides: Direct overrides for provider config. Returns: Tuple of (AuthProviderConfig, AbstractTokenStorage). Raises: ConfigurationError: If provider not found or type mismatch. """ from kstlib.auth.config import ( build_provider_config, get_provider_config, get_token_storage_from_config, ) from kstlib.auth.errors import ConfigurationError # Validate provider exists provider_cfg = get_provider_config(provider_name, config=config) if provider_cfg is None: msg = f"Provider '{provider_name}' not found in auth.providers config" raise ConfigurationError(msg) # Verify provider type matches provider_type = provider_cfg.get("type", allowed_types[0]).lower() if provider_type not in allowed_types: msg = f"Provider '{provider_name}' has type '{provider_type}', expected '{type_label}'" raise ConfigurationError(msg) # Build AuthProviderConfig auth_config = build_provider_config(provider_name, config=config, **overrides) # Get token storage token_storage = get_token_storage_from_config( provider_name=provider_name, config=config, ) return auth_config, token_storage __all__ = [ "AbstractAuthProvider", "AuthProviderConfig", "load_provider_from_config", ]