Source code for kstlib.auth.models

"""Data models for the authentication module."""

from __future__ import annotations

from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from enum import Enum
from typing import Any


class AuthFlow(str, Enum):
    """OAuth2/OIDC authentication flows supported by the module.

    Attributes:
        AUTHORIZATION_CODE: Standard OAuth2 Authorization Code flow.
        AUTHORIZATION_CODE_PKCE: Authorization Code with PKCE extension (recommended).
        CLIENT_CREDENTIALS: Machine-to-machine authentication (no user interaction).
        DEVICE_CODE: For devices with limited input capabilities.
        REFRESH_TOKEN: Token refresh flow (internal use).

    """

    AUTHORIZATION_CODE = "authorization_code"
    AUTHORIZATION_CODE_PKCE = "authorization_code_pkce"
    CLIENT_CREDENTIALS = "client_credentials"
    DEVICE_CODE = "device_code"
    REFRESH_TOKEN = "refresh_token"


class TokenType(str, Enum):
    """Token type as returned by the authorization server."""

    BEARER = "Bearer"
    MAC = "MAC"
    DPOP = "DPoP"


class PreflightStatus(str, Enum):
    """Status of a preflight validation step."""

    SUCCESS = "success"
    FAILURE = "failure"
    WARNING = "warning"
    SKIPPED = "skipped"


[docs] @dataclass(slots=True) class Token: # pylint: disable=too-many-instance-attributes """Represents an OAuth2/OIDC token set. Attributes: access_token: The access token issued by the authorization server. token_type: Token type (usually "Bearer"). expires_at: Absolute expiration time (UTC). None if unknown. refresh_token: Optional refresh token for obtaining new access tokens. scope: List of granted scopes. id_token: OIDC ID token (JWT) containing user claims. None for pure OAuth2. issued_at: When the token was issued (UTC). metadata: Additional provider-specific data. Example: >>> from datetime import datetime, timezone >>> token = Token( ... access_token="eyJhbGc...", ... expires_at=datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc), ... refresh_token="dGhpcyBpcyBh...", ... scope=["openid", "profile"], ... ) >>> token.is_expired True >>> token.is_refreshable True """ access_token: str = field(repr=False) token_type: TokenType | str = TokenType.BEARER expires_at: datetime | None = None refresh_token: str | None = field(default=None, repr=False) scope: list[str] = field(default_factory=list) id_token: str | None = field(default=None, repr=False) issued_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) metadata: dict[str, Any] = field(default_factory=dict, repr=False) @property def is_expired(self) -> bool: """Check if the access token has expired. Returns: True if expired or expiration is unknown and token is old (>1h). """ if self.expires_at is None: # Conservative: assume expired after 1 hour if no expiry info return datetime.now(timezone.utc) > self.issued_at + timedelta(hours=1) return datetime.now(timezone.utc) >= self.expires_at @property def is_refreshable(self) -> bool: """Check if the token can be refreshed. Returns: True if a refresh_token is available. """ return self.refresh_token is not None @property def expires_in(self) -> int | None: """Seconds until expiration. None if unknown, negative if expired.""" if self.expires_at is None: return None delta = self.expires_at - datetime.now(timezone.utc) return int(delta.total_seconds()) @property def should_refresh(self) -> bool: """Check if the token should be proactively refreshed. Returns: True if token expires within 60 seconds or is already expired. """ if self.expires_at is None: return self.is_expired # Refresh 60 seconds before actual expiry buffer = timedelta(seconds=60) return datetime.now(timezone.utc) >= (self.expires_at - buffer)
[docs] @classmethod def from_response(cls, data: dict[str, Any]) -> Token: """Create a Token from an OAuth2 token response. Args: data: Raw token response from the authorization server. Returns: Token instance populated from the response. Example: >>> response = { ... "access_token": "eyJhbGc...", ... "token_type": "Bearer", ... "expires_in": 3600, ... "refresh_token": "dGhpcyBpcyBh...", ... "scope": "openid profile", ... "id_token": "eyJhbGc...", ... } >>> token = Token.from_response(response) >>> token.scope ['openid', 'profile'] """ now = datetime.now(timezone.utc) # Parse expires_at from expires_in expires_at = None if "expires_in" in data: expires_at = now + timedelta(seconds=int(data["expires_in"])) elif "expires_at" in data: # Some servers return absolute timestamp expires_at = datetime.fromtimestamp(data["expires_at"], tz=timezone.utc) # Parse scope (can be string or list) scope_raw = data.get("scope", []) scope = (scope_raw.split() if scope_raw else []) if isinstance(scope_raw, str) else list(scope_raw) # Extract known fields, rest goes to metadata known_fields = { "access_token", "token_type", "expires_in", "expires_at", "refresh_token", "scope", "id_token", } metadata = {k: v for k, v in data.items() if k not in known_fields} return cls( access_token=data["access_token"], token_type=data.get("token_type", TokenType.BEARER), expires_at=expires_at, refresh_token=data.get("refresh_token"), scope=scope, id_token=data.get("id_token"), issued_at=now, metadata=metadata, )
[docs] def to_dict(self) -> dict[str, Any]: """Serialize token to dictionary for storage. Returns: Dictionary representation suitable for JSON serialization. """ return { "access_token": self.access_token, "token_type": str(self.token_type.value if isinstance(self.token_type, TokenType) else self.token_type), "expires_at": self.expires_at.isoformat() if self.expires_at else None, "refresh_token": self.refresh_token, "scope": self.scope, "id_token": self.id_token, "issued_at": self.issued_at.isoformat(), "metadata": self.metadata, }
[docs] @classmethod def from_dict(cls, data: dict[str, Any]) -> Token: """Deserialize token from dictionary (storage retrieval). Args: data: Dictionary from to_dict() or storage. Returns: Token instance. """ expires_at = None if data.get("expires_at"): expires_at = datetime.fromisoformat(data["expires_at"]) issued_at = datetime.now(timezone.utc) if data.get("issued_at"): issued_at = datetime.fromisoformat(data["issued_at"]) return cls( access_token=data["access_token"], token_type=data.get("token_type", TokenType.BEARER), expires_at=expires_at, refresh_token=data.get("refresh_token"), scope=data.get("scope", []), id_token=data.get("id_token"), issued_at=issued_at, metadata=data.get("metadata", {}), )
@dataclass(slots=True) class PreflightResult: """Result of a single preflight validation step. Attributes: step: Name/identifier of the validation step. status: Outcome of the step (success, failure, warning, skipped). message: Human-readable description of the result. details: Optional additional information (URLs checked, errors, etc.). duration_ms: Time taken for this step in milliseconds. Example: >>> result = PreflightResult( ... step="discovery", ... status=PreflightStatus.SUCCESS, ... message="Discovery document fetched successfully", ... details={"issuer": "https://idp.example.com", "endpoints": 5}, ... duration_ms=234, ... ) >>> result.success True """ step: str status: PreflightStatus message: str details: dict[str, Any] = field(default_factory=dict) duration_ms: int | None = None @property def success(self) -> bool: """Check if step passed (success or warning).""" return self.status in (PreflightStatus.SUCCESS, PreflightStatus.WARNING) @property def failed(self) -> bool: """Check if step failed.""" return self.status == PreflightStatus.FAILURE @dataclass(slots=True) class PreflightReport: """Aggregated results from a complete preflight check. Attributes: provider_name: Name of the provider being validated. results: List of individual step results. started_at: When the preflight started. completed_at: When the preflight finished. """ provider_name: str results: list[PreflightResult] = field(default_factory=list) started_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) completed_at: datetime | None = None @property def success(self) -> bool: """Check if all steps passed (no failures).""" return all(not r.failed for r in self.results) @property def total_duration_ms(self) -> int: """Total time for all steps in milliseconds.""" return sum(r.duration_ms or 0 for r in self.results) @property def failed_steps(self) -> list[PreflightResult]: """List of failed steps.""" return [r for r in self.results if r.failed] @property def warnings(self) -> list[PreflightResult]: """List of steps with warnings.""" return [r for r in self.results if r.status == PreflightStatus.WARNING] __all__ = [ "AuthFlow", "PreflightReport", "PreflightResult", "PreflightStatus", "Token", "TokenType", ]