Source code for kstlib.auth.check

"""JWT token validation with cryptographic proof.

Validates any RSA-signed JWT whose issuer exposes an OpenID Connect discovery
endpoint (`.well-known/openid-configuration`). This covers both OIDC id_tokens
and OAuth2 access_tokens issued as JWTs by OIDC-capable providers.

Example:
    >>> from kstlib.auth.check import TokenChecker  # doctest: +SKIP
    >>> checker = TokenChecker(http_client)  # doctest: +SKIP
    >>> report = checker.check("eyJhbGci...")  # doctest: +SKIP
    >>> report.valid  # doctest: +SKIP
    True

"""

from __future__ import annotations

import base64
import hashlib
import json
import logging
import time
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any

import httpx

logger = logging.getLogger(__name__)

# Clock skew tolerance for claim validation (seconds)
CLOCK_SKEW_SECONDS = 300

_STEPS = (
    "_decode_structure",
    "_discover_issuer",
    "_fetch_jwks",
    "_extract_public_key",
    "_verify_signature",
    "_validate_claims",
)


[docs] @dataclass(frozen=True, slots=True) class ValidationStep: """Result of a single validation step. Attributes: name: Step identifier (e.g. "decode_structure", "verify_signature"). passed: Whether the step succeeded. message: Human-readable result description. details: Optional extra information for verbose output. """ name: str passed: bool message: str details: dict[str, Any] = field(default_factory=dict)
[docs] @dataclass(slots=True) class TokenCheckReport: """Complete token validation report. Attributes: token_type: Whether "id_token" or "access_token" was checked. valid: Overall result (True only if ALL steps passed). steps: Ordered list of validation steps executed. header: Decoded JWT header (alg, kid, typ, ...). payload: Decoded JWT payload (claims). signature_algorithm: Algorithm from JWT header (e.g. "RS256", "RS512"). key_id: Key ID from JWT header (kid). discovery_url: OpenID Connect discovery URL used. discovery_data: Relevant fields from discovery document. jwks_uri: JWKS endpoint URL. public_key_pem: PEM-encoded public key used for verification. key_fingerprint: SHA-256 fingerprint of the public key (hex). issuer_match: Whether iss claim matches expected issuer. audience_match: Whether aud claim matches expected audience. error: Error message if validation failed. """ token_type: str = "id_token" valid: bool = False steps: list[ValidationStep] = field(default_factory=list) header: dict[str, Any] = field(default_factory=dict) payload: dict[str, Any] = field(default_factory=dict) signature_algorithm: str | None = None key_id: str | None = None discovery_url: str | None = None discovery_data: dict[str, Any] = field(default_factory=dict) jwks_uri: str | None = None jwks_data: dict[str, Any] = field(default_factory=dict) public_key_pem: str | None = None key_fingerprint: str | None = None key_type: str | None = None key_size_bits: int | None = None x509_info: dict[str, Any] = field(default_factory=dict) issuer_match: bool | None = None audience_match: bool | None = None error: str | None = None
[docs] def to_dict(self) -> dict[str, Any]: """Serialize report to dictionary for JSON output. Returns: Dictionary representation of the report. """ return { "token_type": self.token_type, "valid": self.valid, "steps": [ { "name": s.name, "passed": s.passed, "message": s.message, "details": s.details, } for s in self.steps ], "header": self.header, "payload": self.payload, "signature_algorithm": self.signature_algorithm, "key_id": self.key_id, "discovery_url": self.discovery_url, "discovery_data": self.discovery_data, "jwks_uri": self.jwks_uri, "public_key_pem": self.public_key_pem, "key_fingerprint": self.key_fingerprint, "key_type": self.key_type, "key_size_bits": self.key_size_bits, "x509_info": self.x509_info or None, "issuer_match": self.issuer_match, "audience_match": self.audience_match, "error": self.error, }
[docs] class TokenChecker: """Validates JWT tokens with full cryptographic verification. Runs a 6-step validation chain: 1. Decode JWT structure (header + payload) 2. Discover issuer (OIDC discovery document) 3. Fetch JWKS (JSON Web Key Set) 4. Extract public key (JWK to PEM, SHA-256 fingerprint) 5. Verify signature (cryptographic proof) 6. Validate claims (iss, aud, exp, iat, nbf) Args: http_client: httpx.Client for HTTP requests. expected_issuer: Expected issuer URL (for claim validation). If None, the issuer from the token payload is used. expected_audience: Expected audience (client_id). If None, audience validation is skipped. Example: >>> import httpx # doctest: +SKIP >>> client = httpx.Client(verify="/path/to/ca.pem") # doctest: +SKIP >>> checker = TokenChecker(client, expected_audience="my-app") # doctest: +SKIP >>> report = checker.check(token_string) # doctest: +SKIP """
[docs] def __init__( self, http_client: httpx.Client, expected_issuer: str | None = None, expected_audience: str | None = None, ) -> None: """Initialize the token checker with an HTTP client and optional expected claims.""" self._http = http_client self._expected_issuer = expected_issuer self._expected_audience = expected_audience self._jwks_cache: dict[str, Any] | None = None
[docs] def check( self, token_str: str, *, token_type: str = "id_token", ) -> TokenCheckReport: """Run full token validation chain. Args: token_str: Raw JWT string to validate. token_type: Label for the report ("id_token" or "access_token"). Returns: TokenCheckReport with all validation results. """ report = TokenCheckReport(token_type=token_type) self._token_str = token_str for step_name in _STEPS: step_fn = getattr(self, step_name) if not step_fn(report): return report report.valid = True return report
def _decode_structure(self, report: TokenCheckReport) -> bool: """Step 1: Split JWT and decode header/payload.""" try: parts = self._token_str.split(".") if len(parts) != 3: report.steps.append( _fail_step( "decode_structure", f"Invalid JWT format: expected 3 parts, got {len(parts)}", ) ) report.error = report.steps[-1].message return False header = json.loads(_b64url_decode(parts[0])) payload = json.loads(_b64url_decode(parts[1])) report.header = header report.payload = payload report.signature_algorithm = header.get("alg") report.key_id = header.get("kid") report.steps.append( ValidationStep( name="decode_structure", passed=True, message=f"JWT decoded: alg={header.get('alg')}, kid={header.get('kid')}", details={ "alg": header.get("alg"), "kid": header.get("kid"), "typ": header.get("typ"), "claims": list(payload.keys()), }, ) ) except Exception as exc: report.steps.append(_fail_step("decode_structure", f"Failed to decode JWT: {exc}")) report.error = report.steps[-1].message return False return True def _discover_issuer(self, report: TokenCheckReport) -> bool: """Step 2: Fetch OIDC discovery document from issuer.""" issuer = report.payload.get("iss") if not issuer: report.steps.append(_fail_step("discover_issuer", "No 'iss' claim found in token payload")) report.error = report.steps[-1].message return False discovery_url = f"{issuer.rstrip('/')}/.well-known/openid-configuration" report.discovery_url = discovery_url try: resp = self._http.get(discovery_url, timeout=10) resp.raise_for_status() discovery = resp.json() jwks_uri = discovery.get("jwks_uri") supported_algs = discovery.get("id_token_signing_alg_values_supported", []) report.discovery_data = { "issuer": discovery.get("issuer"), "jwks_uri": jwks_uri, "id_token_signing_alg_values_supported": supported_algs, "authorization_endpoint": discovery.get("authorization_endpoint"), "token_endpoint": discovery.get("token_endpoint"), } report.jwks_uri = jwks_uri report.steps.append( ValidationStep( name="discover_issuer", passed=True, message=f"Discovery OK: {discovery_url}", details={"issuer": discovery.get("issuer"), "supported_algs": supported_algs, "jwks_uri": jwks_uri}, ) ) except httpx.HTTPStatusError as exc: return _record_http_error( report, "discover_issuer", f"Discovery failed: HTTP {exc.response.status_code} from {discovery_url}" ) except httpx.RequestError as exc: return _record_http_error(report, "discover_issuer", f"Discovery failed: {exc}") return True def _fetch_jwks(self, report: TokenCheckReport) -> bool: """Step 3: Fetch JSON Web Key Set from jwks_uri.""" jwks_uri = report.jwks_uri if not jwks_uri: report.steps.append(_fail_step("fetch_jwks", "No jwks_uri found in discovery document")) report.error = report.steps[-1].message return False # Use cache if available if self._jwks_cache is not None: report.jwks_data = self._jwks_cache report.steps.append( ValidationStep( name="fetch_jwks", passed=True, message="JWKS loaded from cache", details={"key_count": len(self._jwks_cache.get("keys", []))}, ) ) return True try: resp = self._http.get(jwks_uri, timeout=10) resp.raise_for_status() jwks = resp.json() self._jwks_cache = jwks report.jwks_data = jwks key_count = len(jwks.get("keys", [])) report.steps.append( ValidationStep( name="fetch_jwks", passed=True, message=f"JWKS fetched: {key_count} key(s) from {jwks_uri}", details={"key_count": key_count, "key_ids": [k.get("kid", "no-kid") for k in jwks.get("keys", [])]}, ) ) except httpx.HTTPStatusError as exc: return _record_http_error( report, "fetch_jwks", f"JWKS fetch failed: HTTP {exc.response.status_code} from {jwks_uri}" ) except httpx.RequestError as exc: return _record_http_error(report, "fetch_jwks", f"JWKS fetch failed: {exc}") return True def _extract_public_key(self, report: TokenCheckReport) -> bool: """Step 4: Match key by kid and convert JWK to PEM.""" kid = report.key_id keys = report.jwks_data.get("keys", []) if not keys: report.steps.append(_fail_step("extract_public_key", "JWKS contains no keys")) report.error = report.steps[-1].message return False matching_key = _find_matching_key(keys, kid) if matching_key is None: available_kids = [k.get("kid", "no-kid") for k in keys] report.steps.append(_fail_step("extract_public_key", f"Key not found: kid={kid!r} not in {available_kids}")) report.error = report.steps[-1].message return False try: pem_str, fingerprint, key_size = _jwk_to_pem(matching_key) report.public_key_pem = pem_str report.key_fingerprint = fingerprint report.key_type = matching_key.get("kty", "unknown") report.key_size_bits = key_size # Extract X.509 certificate info if x5c is present x5c = matching_key.get("x5c") if x5c: report.x509_info = _parse_x509_info(x5c[0]) kty = report.key_type size_label = f" {key_size}-bit" if key_size else "" msg = f"Public key extracted: {kty}{size_label}, kid={kid}, fingerprint={fingerprint[:16]}..." report.steps.append( ValidationStep( name="extract_public_key", passed=True, message=msg, details={ "kid": kid, "kty": kty, "key_size_bits": key_size, "alg": matching_key.get("alg"), "fingerprint_sha256": fingerprint, "has_x509": bool(x5c), }, ) ) except ImportError: report.steps.append( _fail_step("extract_public_key", "authlib is required for key extraction (pip install authlib)") ) report.error = report.steps[-1].message return False except Exception as exc: report.steps.append(_fail_step("extract_public_key", f"Failed to extract public key: {exc}")) report.error = report.steps[-1].message return False return True def _verify_signature(self, report: TokenCheckReport) -> bool: """Step 5: Verify JWT signature using JWKS.""" try: from authlib.jose import jwt from authlib.jose.errors import JoseError jwt.decode( self._token_str, report.jwks_data, ) report.steps.append( ValidationStep( name="verify_signature", passed=True, message=f"Signature valid ({report.signature_algorithm})", details={"algorithm": report.signature_algorithm}, ) ) except ImportError: report.steps.append(_fail_step("verify_signature", "authlib is required for signature verification")) report.error = report.steps[-1].message return False except JoseError as exc: report.steps.append(_fail_step("verify_signature", f"Signature verification failed: {exc}")) report.error = report.steps[-1].message return False return True def _validate_claims(self, report: TokenCheckReport) -> bool: """Step 6: Validate standard JWT claims.""" payload = report.payload now = time.time() issues = _check_issuer(payload, self._expected_issuer, report) issues += _check_audience(payload, self._expected_audience, report) issues += _check_temporal_claims(payload, now) if issues: report.steps.append( ValidationStep( name="validate_claims", passed=False, message="; ".join(issues), details=_claims_details(payload, report), ) ) report.error = report.steps[-1].message return False report.steps.append( ValidationStep( name="validate_claims", passed=True, message="All claims valid", details=_claims_details(payload, report), ) ) return True
# ───────────────────────────────────────────────────────────────────────────── # Private helpers # ───────────────────────────────────────────────────────────────────────────── def _fail_step(name: str, message: str) -> ValidationStep: """Create a failing validation step.""" return ValidationStep(name=name, passed=False, message=message) def _record_http_error(report: TokenCheckReport, step_name: str, message: str) -> bool: """Record an HTTP error as a failed step and return False.""" report.steps.append(_fail_step(step_name, message)) report.error = message return False def _find_matching_key(keys: list[dict[str, Any]], kid: str | None) -> dict[str, Any] | None: """Find a JWK matching the given kid. Args: keys: List of JWK dicts from the JWKS. kid: Key ID to match, or None. Returns: Matching key dict, or None. """ if kid: for key in keys: if key.get("kid") == kid: return key return None # No kid in header, use first key if only one available if len(keys) == 1: return keys[0] return None def _jwk_to_pem(jwk_data: dict[str, Any]) -> tuple[str, str, int | None]: """Convert a JWK dict to PEM string, SHA-256 fingerprint, and key size. Args: jwk_data: JWK dict from JWKS. Returns: Tuple of (pem_string, sha256_hex_fingerprint, key_size_bits). Raises: ImportError: If authlib is not installed. """ from authlib.jose import JsonWebKey from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat jwk_obj = JsonWebKey.import_key(jwk_data) public_key = jwk_obj.get_public_key() pem_bytes = public_key.public_bytes(Encoding.PEM, PublicFormat.SubjectPublicKeyInfo) der_bytes = public_key.public_bytes(Encoding.DER, PublicFormat.SubjectPublicKeyInfo) fingerprint = hashlib.sha256(der_bytes).hexdigest() key_size: int | None = None if isinstance(public_key, RSAPublicKey): key_size = public_key.key_size return pem_bytes.decode("utf-8"), fingerprint, key_size def _parse_x509_info(cert_b64: str) -> dict[str, Any]: """Parse X.509 certificate info from base64-encoded DER (x5c field). Args: cert_b64: Base64-encoded DER certificate from x5c array. Returns: Dictionary with CN, serial, not_before, not_after. Empty dict if parsing fails. """ try: from cryptography import x509 cert_der = base64.b64decode(cert_b64) cert = x509.load_der_x509_certificate(cert_der) subject_cn = "" try: cn_attrs = cert.subject.get_attributes_for_oid(x509.oid.NameOID.COMMON_NAME) if cn_attrs: subject_cn = str(cn_attrs[0].value) except Exception: subject_cn = cert.subject.rfc4514_string() issuer_cn = "" try: cn_attrs = cert.issuer.get_attributes_for_oid(x509.oid.NameOID.COMMON_NAME) if cn_attrs: issuer_cn = str(cn_attrs[0].value) except Exception: issuer_cn = cert.issuer.rfc4514_string() return { "subject_cn": subject_cn, "issuer_cn": issuer_cn, "serial_number": format(cert.serial_number, "x"), "not_before": cert.not_valid_before_utc.isoformat(), "not_after": cert.not_valid_after_utc.isoformat(), } except Exception: logger.debug("Failed to parse x5c certificate", exc_info=True) return {} def _check_issuer( payload: dict[str, Any], expected: str | None, report: TokenCheckReport, ) -> list[str]: """Validate issuer claim.""" iss = payload.get("iss") issues: list[str] = [] if expected: report.issuer_match = iss == expected if not report.issuer_match: issues.append(f"Issuer mismatch: got {iss!r}, expected {expected!r}") else: report.issuer_match = True if iss else None return issues def _check_audience( payload: dict[str, Any], expected: str | None, report: TokenCheckReport, ) -> list[str]: """Validate audience claim.""" aud = payload.get("aud") issues: list[str] = [] if expected: report.audience_match = expected in aud if isinstance(aud, list) else aud == expected if not report.audience_match: issues.append(f"Audience mismatch: got {aud!r}, expected {expected!r}") else: report.audience_match = True if aud else None return issues def _check_temporal_claims(payload: dict[str, Any], now: float) -> list[str]: """Validate exp, iat, nbf claims with clock skew tolerance.""" issues: list[str] = [] exp = payload.get("exp") if exp is not None and now > exp + CLOCK_SKEW_SECONDS: exp_dt = datetime.fromtimestamp(exp, tz=timezone.utc).isoformat() issues.append(f"Token expired at {exp_dt}") iat = payload.get("iat") if iat is not None and iat > now + CLOCK_SKEW_SECONDS: issues.append(f"Token issued in the future: iat={iat}") nbf = payload.get("nbf") if nbf is not None and now < nbf - CLOCK_SKEW_SECONDS: issues.append(f"Token not yet valid: nbf={nbf}") return issues def _claims_details(payload: dict[str, Any], report: TokenCheckReport) -> dict[str, Any]: """Build details dict for claims validation step.""" return { "iss": payload.get("iss"), "aud": payload.get("aud"), "exp": payload.get("exp"), "iat": payload.get("iat"), "issuer_match": report.issuer_match, "audience_match": report.audience_match, } def _b64url_decode(data: str) -> bytes: """Decode base64url-encoded data with padding fix. Args: data: Base64url-encoded string. Returns: Decoded bytes. """ padding = 4 - len(data) % 4 if padding != 4: data += "=" * padding return base64.urlsafe_b64decode(data) __all__ = [ "CLOCK_SKEW_SECONDS", "TokenCheckReport", "TokenChecker", "ValidationStep", ]