"""HTTP client for RAPI module.
This module provides the RapiClient class for making REST API calls
with config-driven endpoints, multi-source credentials, and detailed logging.
"""
from __future__ import annotations
import base64
import hashlib
import hmac
import json
import time
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from urllib.parse import urlencode
import httpx
from kstlib.auth import AuthExpiredError
from kstlib.limits import get_rapi_limits
from kstlib.rapi.config import (
ApiConfig,
EndpointConfig,
HmacConfig,
MultipartConfig,
RapiConfigManager,
ServerConfig,
load_rapi_config,
)
from kstlib.rapi.credentials import CredentialRecord, CredentialResolver
from kstlib.rapi.exceptions import (
ConfirmationRequiredError,
RequestError,
ResponseTooLargeError,
)
from kstlib.ssl import build_ssl_context
if TYPE_CHECKING:
from collections.abc import Callable, Mapping
from kstlib.logging import TRACE_LEVEL, get_logger
log = get_logger(__name__)
# Heuristic keywords used to detect HTTP 401 responses that signal
# access token expiration vs other 401 causes (permissions, audience
# mismatch, etc.). Match is case-insensitive against the response body.
_AUTH_EXPIRED_BODY_KEYWORDS: tuple[str, ...] = (
"expired",
"invalid_token",
"token expired",
)
# WWW-Authenticate parameter value that signals token expiration per
# RFC 6750 Section 3 (Bearer error="invalid_token"). Match is
# case-insensitive against the header value.
_AUTH_EXPIRED_HEADER_MARKER: str = "invalid_token"
def _source_for_env_credential(cred_config: Mapping[str, Any]) -> str | None:
"""Derive a sanitized token source label for ``env`` credentials."""
var = cred_config.get("var") or cred_config.get("var_key") or (cred_config.get("fields") or {}).get("key")
return f"env:{var}" if var else None
def _source_for_file_credential(cred_config: Mapping[str, Any]) -> str | None:
"""Derive a sanitized token source label for ``file`` credentials."""
path = cred_config.get("path")
return str(path) if path else None
def _source_for_sops_credential(cred_config: Mapping[str, Any]) -> str | None:
"""Derive a sanitized token source label for ``sops`` credentials."""
path = cred_config.get("path")
return f"sops:{path}" if path else None
def _source_for_provider_credential(cred_config: Mapping[str, Any]) -> str | None:
"""Derive a sanitized token source label for ``provider`` credentials."""
provider = cred_config.get("provider")
return f"provider:{provider}" if provider else None
# Dispatch table for AuthExpiredError.token_source labels keyed by
# CredentialRecord.source. Each builder takes the raw credential
# config dict (possibly empty) and returns a sanitized label or
# ``None`` when the configured source is missing the expected key.
_TOKEN_SOURCE_BUILDERS: dict[str, Callable[[Mapping[str, Any]], str | None]] = {
"env": _source_for_env_credential,
"file": _source_for_file_credential,
"sops": _source_for_sops_credential,
"provider": _source_for_provider_credential,
}
def _hint_for_file_credential(cred_config: Mapping[str, Any]) -> str:
"""Build a re-authentication hint for ``file`` credentials."""
path = str(cred_config.get("path", ""))
if path and "credentials.json" in path:
return "Re-authenticate with: sas-admin --profile $VIYA_HOST -k auth login -u <user>"
if path:
return f"Re-authenticate via your credentials file: {path}"
return "Re-authenticate via your credentials file."
def _hint_for_env_credential(cred_config: Mapping[str, Any]) -> str:
"""Build a re-authentication hint for ``env`` credentials."""
var = cred_config.get("var") or cred_config.get("var_key") or (cred_config.get("fields") or {}).get("key")
if var:
return f"Refresh and re-export env var: ${var}"
return "Refresh and re-export your env-based credential."
def _hint_for_sops_credential(cred_config: Mapping[str, Any]) -> str:
"""Build a re-authentication hint for ``sops`` credentials."""
path = cred_config.get("path", "")
if path:
return f"Update SOPS-encrypted file: {path}"
return "Update your SOPS-encrypted credential file."
def _hint_for_provider_credential(cred_config: Mapping[str, Any]) -> str:
"""Build a re-authentication hint for ``provider`` credentials."""
provider = cred_config.get("provider", "")
if provider:
return f"Re-authenticate via your provider: kstlib auth login --provider {provider}"
return "Re-authenticate via your provider."
# Dispatch table for AuthExpiredError.suggested_action hints keyed
# by CredentialRecord.source. Each builder takes the raw credential
# config dict (possibly empty) and returns a sanitized hint string
# (never includes secret material, only source labels such as path
# / env var / provider name).
_AUTH_RENEW_HINT_BUILDERS: dict[str, Callable[[Mapping[str, Any]], str]] = {
"file": _hint_for_file_credential,
"env": _hint_for_env_credential,
"sops": _hint_for_sops_credential,
"provider": _hint_for_provider_credential,
}
def _log_trace(msg: str, *args: Any) -> None:
"""Log at TRACE level."""
log.log(TRACE_LEVEL, msg, *args)
[docs]
@dataclass
class FilePayload:
"""Carrier for file upload data (multipart mode).
Use this to upload file content programmatically without a file on disk.
Attributes:
filename: Original filename (used in Content-Disposition).
data: Raw file bytes.
content_type: MIME type of the file.
field_name: Form field name for the file part.
Examples:
>>> payload = FilePayload(
... filename="report.csv",
... data=b"col1,col2",
... content_type="text/csv",
... )
>>> payload.field_name
'file'
"""
filename: str
data: bytes
content_type: str
field_name: str = "file"
def _validate_safeguard(
endpoint_config: EndpointConfig,
args: tuple[Any, ...],
kwargs: dict[str, Any],
confirm: str | None,
) -> None:
"""Validate safeguard confirmation for dangerous endpoints.
Args:
endpoint_config: Endpoint configuration.
args: Positional arguments for path/safeguard substitution.
kwargs: Keyword arguments for path/safeguard substitution.
confirm: Confirmation string provided by caller.
Raises:
ConfirmationRequiredError: If safeguard is required but confirm is missing or wrong.
"""
if endpoint_config.safeguard is None:
return
expected = endpoint_config.build_safeguard(*args, **kwargs)
if expected is None:
return
if confirm is None:
raise ConfirmationRequiredError(endpoint_config.full_ref, expected=expected)
if confirm != expected:
raise ConfirmationRequiredError(endpoint_config.full_ref, expected=expected, actual=confirm)
[docs]
@dataclass
class RapiResponse:
"""Response from an API call.
Attributes:
status_code: HTTP status code.
headers: Response headers.
data: Parsed JSON response (or None if not JSON).
text: Raw response text.
elapsed: Request duration in seconds.
endpoint_ref: Full endpoint reference used.
Examples:
>>> response = RapiResponse(status_code=200, data={"ip": "1.2.3.4"})
>>> response.ok
True
>>> response.data["ip"]
'1.2.3.4'
"""
status_code: int
headers: dict[str, str] = field(default_factory=dict)
data: Any = None
text: str = ""
elapsed: float = 0.0
endpoint_ref: str = ""
@property
def ok(self) -> bool:
"""Return True if status code indicates success (2xx)."""
return 200 <= self.status_code < 300
[docs]
class RapiClient:
"""Config-driven REST API client.
Makes HTTP requests to configured API endpoints with automatic
credential resolution, header merging, and detailed logging.
Supports loading configuration from:
- kstlib.conf.yml (default)
- External ``*.rapi.yml`` files (via from_file)
- Auto-discovery of ``*.rapi.yml`` in current directory (via discover)
Args:
config_manager: Optional RapiConfigManager (loads from config if None).
credentials_config: Optional credentials configuration.
Examples:
>>> client = RapiClient() # doctest: +SKIP
>>> response = client.call("httpbin.get_ip") # doctest: +SKIP
>>> response.data # doctest: +SKIP
{'origin': '...'}
>>> client = RapiClient.from_file("github.rapi.yml") # doctest: +SKIP
>>> client = RapiClient.discover() # doctest: +SKIP
"""
[docs]
def __init__(
self,
config_manager: RapiConfigManager | None = None,
credentials_config: Mapping[str, Any] | None = None,
*,
ssl_verify: bool | None = None,
ssl_ca_bundle: str | None = None,
) -> None:
"""Initialize RapiClient.
Args:
config_manager: Optional RapiConfigManager instance.
credentials_config: Optional credentials configuration.
ssl_verify: Override SSL verification (True/False).
If None, uses global config from kstlib.conf.yml.
ssl_ca_bundle: Override CA bundle path.
If None, uses global config from kstlib.conf.yml.
"""
self._config_manager = config_manager or load_rapi_config()
# Merge credentials: inline from config_manager + explicit credentials_config
merged_credentials: dict[str, Any] = {}
if hasattr(self._config_manager, "credentials_config"):
merged_credentials.update(self._config_manager.credentials_config)
if credentials_config:
merged_credentials.update(credentials_config)
self._credential_resolver = CredentialResolver(merged_credentials or None)
self._limits = get_rapi_limits()
# Build SSL context (cascade: kwargs > global config > default)
self._ssl_context = build_ssl_context(
ssl_verify=ssl_verify,
ssl_ca_bundle=ssl_ca_bundle,
)
log.debug(
"RapiClient initialized (timeout=%.1fs, max_retries=%d)",
self._limits.timeout,
self._limits.max_retries,
)
[docs]
@classmethod
def from_file(
cls,
path: str,
credentials_config: Mapping[str, Any] | None = None,
) -> RapiClient:
"""Create client from a ``*.rapi.yml`` file.
Loads API configuration from an external YAML file with simplified format.
Args:
path: Path to the ``*.rapi.yml`` file.
credentials_config: Additional credentials (merged with inline).
Returns:
Configured RapiClient instance.
Raises:
FileNotFoundError: If file does not exist.
ValueError: If file format is invalid.
Examples:
>>> client = RapiClient.from_file("github.rapi.yml") # doctest: +SKIP
>>> response = client.call("github.user") # doctest: +SKIP
"""
config_manager = RapiConfigManager.from_file(path)
return cls(config_manager, credentials_config)
[docs]
@classmethod
def discover(
cls,
directory: str | None = None,
pattern: str = "*.rapi.yml",
credentials_config: Mapping[str, Any] | None = None,
) -> RapiClient:
"""Create client by auto-discovering ``*.rapi.yml`` files.
Searches for files matching the pattern in the specified directory
(defaults to current working directory) and loads all found configs.
Args:
directory: Directory to search in (default: current directory).
pattern: Glob pattern for files (default: ``*.rapi.yml``).
credentials_config: Additional credentials (merged with inline).
Returns:
Configured RapiClient instance.
Raises:
FileNotFoundError: If no matching files found.
Examples:
>>> client = RapiClient.discover() # doctest: +SKIP
>>> client = RapiClient.discover("./apis/") # doctest: +SKIP
"""
config_manager = RapiConfigManager.discover(directory, pattern)
return cls(config_manager, credentials_config)
@property
def config_manager(self) -> RapiConfigManager:
"""Get the configuration manager.
Returns:
RapiConfigManager instance.
"""
return self._config_manager
[docs]
def list_apis(self) -> list[str]:
"""List all configured API names.
Returns:
List of API names.
"""
return self._config_manager.list_apis()
[docs]
def list_endpoints(self, api_name: str | None = None) -> list[str]:
"""List endpoint references.
Args:
api_name: Filter by API name (optional).
Returns:
List of full endpoint references (api.endpoint).
"""
return self._config_manager.list_endpoints(api_name)
[docs]
def call(
self,
endpoint_ref: str,
*args: Any,
body: Any = None,
headers: Mapping[str, str] | None = None,
timeout: float | None = None,
confirm: str | None = None,
server: str | None = None,
**kwargs: Any,
) -> RapiResponse:
"""Make a synchronous API call.
Args:
endpoint_ref: Endpoint reference (full: api.endpoint or short: endpoint).
*args: Positional arguments for path parameters.
body: Request body (dict for JSON, str for raw).
headers: Runtime headers (override service/endpoint headers).
timeout: Request timeout (uses config default if None).
confirm: Confirmation string for dangerous endpoints with safeguard.
server: Optional named server profile from ``rapi.servers`` to
use for this call. Overrides any ``server:`` directive set
at the endpoint or file level in the YAML config. Cascade:
runtime ``server`` > endpoint ``server:`` > file ``server:``
> static ``ApiConfig`` (no server).
**kwargs: Keyword arguments for path parameters and query params.
Returns:
RapiResponse with parsed data.
Raises:
AuthExpiredError: If the response signals access token
expiration (HTTP 401 with body keyword or
``WWW-Authenticate`` ``invalid_token`` marker).
Terminal, not retried.
ConfirmationRequiredError: If safeguard requires confirmation.
RequestError: If request fails after retries.
ResponseTooLargeError: If response exceeds max size.
ServerNotFoundError: If ``server`` (or any cascading directive)
does not exist in ``rapi.servers``.
Examples:
>>> client = RapiClient() # doctest: +SKIP
>>> client.call("httpbin.get_ip") # doctest: +SKIP
>>> client.call("httpbin.delayed", 5) # doctest: +SKIP
>>> client.call("httpbin.post_data", body={"key": "value"}) # doctest: +SKIP
>>> client.call("admin.delete_user", userId="123", confirm="DELETE USER 123") # doctest: +SKIP
>>> client.call("github.repos-list", server="github") # doctest: +SKIP
"""
log.debug("Calling endpoint: %s", endpoint_ref)
# Resolve endpoint
api_config, endpoint_config = self._config_manager.resolve(endpoint_ref)
_log_trace("Resolved to: %s", endpoint_config.full_ref)
# Resolve effective server profile (cascade: runtime > endpoint > file > None)
effective_server = self._config_manager.resolve_effective_server(
api_config,
endpoint_config,
runtime_server=server,
)
if effective_server is not None:
_log_trace("Effective server profile: %s", effective_server.name)
# Validate safeguard before proceeding
_validate_safeguard(endpoint_config, args, kwargs, confirm)
# Build request
request = self._build_request(
api_config,
endpoint_config,
args,
kwargs,
body,
headers,
effective_server=effective_server,
)
# Execute with retries
effective_timeout = timeout if timeout is not None else self._limits.timeout
return self._execute_with_retry(
request,
endpoint_config,
effective_timeout,
api_config=api_config,
effective_server=effective_server,
)
[docs]
async def call_async(
self,
endpoint_ref: str,
*args: Any,
body: Any = None,
headers: Mapping[str, str] | None = None,
timeout: float | None = None,
confirm: str | None = None,
server: str | None = None,
**kwargs: Any,
) -> RapiResponse:
"""Make an asynchronous API call.
Args:
endpoint_ref: Endpoint reference (full: api.endpoint or short: endpoint).
*args: Positional arguments for path parameters.
body: Request body (dict for JSON, str for raw).
headers: Runtime headers (override service/endpoint headers).
timeout: Request timeout (uses config default if None).
confirm: Confirmation string for dangerous endpoints with safeguard.
server: Optional named server profile from ``rapi.servers`` to
use for this call. See :meth:`call` for cascade rules.
**kwargs: Keyword arguments for path parameters and query params.
Returns:
RapiResponse with parsed data.
Raises:
AuthExpiredError: If the response signals access token
expiration (HTTP 401 with body keyword or
``WWW-Authenticate`` ``invalid_token`` marker).
Terminal, not retried.
ConfirmationRequiredError: If safeguard requires confirmation.
RequestError: If request fails after retries.
ResponseTooLargeError: If response exceeds max size.
ServerNotFoundError: If ``server`` (or any cascading directive)
does not exist in ``rapi.servers``.
"""
log.debug("Calling endpoint (async): %s", endpoint_ref)
# Resolve endpoint
api_config, endpoint_config = self._config_manager.resolve(endpoint_ref)
_log_trace("Resolved to: %s", endpoint_config.full_ref)
# Resolve effective server profile (cascade: runtime > endpoint > file > None)
effective_server = self._config_manager.resolve_effective_server(
api_config,
endpoint_config,
runtime_server=server,
)
if effective_server is not None:
_log_trace("Effective server profile: %s", effective_server.name)
# Validate safeguard before proceeding
_validate_safeguard(endpoint_config, args, kwargs, confirm)
# Build request
request = self._build_request(
api_config,
endpoint_config,
args,
kwargs,
body,
headers,
effective_server=effective_server,
)
# Execute with retries
effective_timeout = timeout if timeout is not None else self._limits.timeout
return await self._execute_with_retry_async(
request,
endpoint_config,
effective_timeout,
api_config=api_config,
effective_server=effective_server,
)
def _extract_query_params(
self,
endpoint_config: EndpointConfig,
kwargs: dict[str, Any],
) -> dict[str, str]:
"""Extract query parameters from kwargs (excluding path params)."""
import re
# Start with non-None defaults; None means "available but not sent"
query_params = {k: v for k, v in endpoint_config.query.items() if v is not None}
path_params: set[str] = set()
for match in re.finditer(r"\{([a-zA-Z_][a-zA-Z0-9_]*|\d+)\}", endpoint_config.path):
param = match.group(1)
if not param.isdigit():
path_params.add(param)
for key, value in kwargs.items():
if key not in path_params:
if value is None:
query_params.pop(key, None)
else:
query_params[key] = str(value)
return query_params
def _prepare_body(
self,
body: Any,
headers: dict[str, str],
) -> bytes | None:
"""Prepare request body and set Content-Type header if needed."""
if body is None:
return None
content: bytes | None = None
if isinstance(body, dict):
content = json.dumps(body).encode("utf-8")
if "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
elif isinstance(body, str):
content = body.encode("utf-8")
elif isinstance(body, bytes):
content = body
if content:
_log_trace("Request body size: %d bytes", len(content))
# Log body content (truncate if too large)
body_preview = content.decode("utf-8", errors="replace")
if len(body_preview) > 1000:
_log_trace(">>> Body: %s... [truncated]", body_preview[:1000])
else:
_log_trace(">>> Body: %s", body_preview)
return content
def _prepare_multipart(
self,
body: Any,
endpoint_config: EndpointConfig,
) -> list[tuple[str, tuple[str, bytes, str]]]:
"""Prepare multipart file upload from body.
Args:
body: Body value (expected: "@filepath" string or FilePayload).
endpoint_config: Endpoint configuration with optional multipart config.
Returns:
httpx files parameter: list of (field_name, (filename, data, content_type)).
Raises:
RequestError: If body is not a valid file reference for multipart.
"""
import mimetypes
from pathlib import Path
mp_config = endpoint_config.multipart or MultipartConfig()
if isinstance(body, FilePayload):
_log_trace(
"Multipart upload (FilePayload): field=%s, file=%s, type=%s, size=%d",
body.field_name,
body.filename,
body.content_type,
len(body.data),
)
return [(body.field_name, (body.filename, body.data, body.content_type))]
if not isinstance(body, str) or not body.startswith("@"):
raise RequestError(
f"Multipart endpoint '{endpoint_config.full_ref}' requires a file body "
f"(@/path/to/file) or FilePayload, got: {type(body).__name__}",
retryable=False,
)
filepath = Path(body[1:])
if not filepath.exists():
raise RequestError(
f"File not found for multipart upload: {filepath}",
retryable=False,
)
file_data = filepath.read_bytes()
# Validate file size
limits = get_rapi_limits()
if len(file_data) > limits.max_upload_size:
raise RequestError(
f"File too large for upload: {len(file_data)} bytes (max: {limits.max_upload_size_display})",
retryable=False,
)
filename = filepath.name
content_type = mp_config.content_type
if content_type is None:
guessed, _ = mimetypes.guess_type(filename)
content_type = guessed or "application/octet-stream"
field_name = mp_config.field_name
_log_trace(
"Multipart upload: field=%s, file=%s, type=%s, size=%d bytes",
field_name,
filename,
content_type,
len(file_data),
)
return [(field_name, (filename, file_data, content_type))]
def _build_request(
self,
api_config: ApiConfig,
endpoint_config: EndpointConfig,
args: tuple[Any, ...],
kwargs: dict[str, Any],
body: Any,
runtime_headers: Mapping[str, str] | None,
*,
effective_server: ServerConfig | None = None,
) -> httpx.Request:
"""Build HTTP request from configuration.
Args:
api_config: API service configuration.
endpoint_config: Endpoint configuration.
args: Positional path parameters.
kwargs: Keyword parameters (path + query).
body: Request body.
runtime_headers: Runtime header overrides.
effective_server: Optional resolved server profile that
overrides ``api_config.base_url`` and contributes
``credentials`` + ``headers`` to the request build.
When None, the static ApiConfig is used (backward
compatible behavior).
Returns:
Prepared httpx.Request.
"""
# Build URL with path parameter substitution
_log_trace("Path template: %s", endpoint_config.path)
if args:
_log_trace("Path args (positional): %s", args)
if kwargs:
_log_trace("Path/query kwargs: %s", kwargs)
path = endpoint_config.build_path(*args, **kwargs)
# Server profile (when present) overrides the static api_config base_url
base_url = effective_server.base_url if effective_server else api_config.base_url
url = f"{base_url}{path}"
# Security: reject null bytes in URL to prevent injection
if "\x00" in url:
raise RequestError("Null bytes not allowed in URL", retryable=False)
# Security: validate URL scheme to prevent SSRF via config injection
if not url.lower().startswith(("http://", "https://")):
raise RequestError(f"Invalid URL scheme (only http/https allowed): {url}", retryable=False)
# Extract query params from kwargs
query_params = self._extract_query_params(endpoint_config, kwargs)
_log_trace("Final URL: %s", url)
if query_params:
_log_trace("Query params: %s", query_params)
# Merge headers - cascade: service < server < endpoint < runtime
merged_headers = self._merge_headers(
api_config.headers,
endpoint_config.headers,
dict(runtime_headers) if runtime_headers else {},
server_headers=effective_server.headers if effective_server else None,
)
# Detect multipart mode from merged Content-Type header
is_multipart = "multipart/form-data" in merged_headers.get("Content-Type", "").lower()
if is_multipart:
_log_trace("Multipart mode detected from Content-Type header")
# Multipart file upload: use httpx files= parameter
files_param = self._prepare_multipart(body, endpoint_config)
# Remove Content-Type: httpx generates it with the boundary
merged_headers.pop("Content-Type", None)
merged_headers.pop("content-type", None)
# Apply auth (no body content for HMAC signing in multipart mode).
# Auth is applied if either the static ApiConfig or the effective
# server profile provides credentials, and the endpoint allows it.
if endpoint_config.auth and ((effective_server and effective_server.credentials) or api_config.credentials):
self._apply_auth(
merged_headers,
api_config,
query_params,
None,
effective_server=effective_server,
)
request = httpx.Request(
method=endpoint_config.method,
url=url,
params=query_params if query_params else None,
headers=merged_headers,
files=files_param,
)
# Log multipart request structure at TRACE level
generated_ct = request.headers.get("content-type", "")
_log_trace("Multipart Content-Type (generated): %s", generated_ct)
for field_name, (filename, data, ct) in files_param:
_log_trace(
"Multipart part: field=%s, filename=%s, content_type=%s, size=%d bytes",
field_name,
filename,
ct,
len(data),
)
else:
# Normal body processing
content = self._prepare_body(body, merged_headers)
# Apply authentication (may modify headers and query_params for HMAC).
# Skip auth if endpoint explicitly disables it (auth: false).
# Credentials may come from either the static ApiConfig or the
# effective server profile.
if endpoint_config.auth and ((effective_server and effective_server.credentials) or api_config.credentials):
self._apply_auth(
merged_headers,
api_config,
query_params,
content,
effective_server=effective_server,
)
request = httpx.Request(
method=endpoint_config.method,
url=url,
params=query_params if query_params else None,
headers=merged_headers,
content=content,
)
self._log_request(request)
return request
def _merge_headers(
self,
service_headers: dict[str, str],
endpoint_headers: dict[str, str],
runtime_headers: dict[str, str],
*,
server_headers: dict[str, str] | None = None,
) -> dict[str, str]:
"""Merge headers from up to four levels.
Cascade order (later overrides earlier):
``service < server < endpoint < runtime``
The ``server`` level is inserted between service and endpoint
because the server profile describes "what the API server itself
expects" (e.g. ``Accept: application/vnd.github+json``), while
the endpoint and runtime levels carry per-call overrides.
Args:
service_headers: Service-level headers (lowest priority).
endpoint_headers: Endpoint-level headers.
runtime_headers: Runtime headers (highest priority).
server_headers: Optional headers from the effective server
profile, layered between service and endpoint headers.
Returns:
Merged headers dictionary.
"""
merged: dict[str, str] = {}
merged.update(service_headers)
if server_headers:
merged.update(server_headers)
merged.update(endpoint_headers)
merged.update(runtime_headers)
_log_trace(
"Headers merged: service=%d, server=%d, endpoint=%d, runtime=%d -> total=%d",
len(service_headers),
len(server_headers or {}),
len(endpoint_headers),
len(runtime_headers),
len(merged),
)
return merged
def _apply_auth(
self,
headers: dict[str, str],
api_config: ApiConfig,
query_params: dict[str, str] | None = None,
body_content: bytes | None = None,
*,
effective_server: ServerConfig | None = None,
) -> None:
"""Apply authentication to headers and query params.
When ``effective_server`` is provided and carries inline
credentials, those are resolved via
:meth:`CredentialResolver.resolve_inline` and used instead of
the static ``api_config.credentials`` reference. The auth type
also follows the server profile when set.
Args:
headers: Headers dict to modify.
api_config: API config with credentials reference (fallback).
query_params: Query params dict to modify (for HMAC signing).
body_content: Request body content (for HMAC signing).
effective_server: Optional resolved server profile providing
inline credentials and/or auth type override.
"""
# Resolve credentials: server profile takes precedence over static
# api_config.credentials, and uses inline dict resolution.
cred: CredentialRecord | None = None
if effective_server and effective_server.credentials:
try:
cred = self._credential_resolver.resolve_inline(
effective_server.credentials,
name_hint=f"server.{effective_server.name}",
)
except Exception as e:
log.warning(
"Failed to resolve inline credentials for server '%s': %s",
effective_server.name,
e,
)
return
elif api_config.credentials:
try:
cred = self._credential_resolver.resolve(api_config.credentials)
except Exception as e:
log.warning("Failed to resolve credential '%s': %s", api_config.credentials, e)
return
else:
return
# Auth type cascade: server profile > api_config > default "bearer"
if effective_server and effective_server.auth:
auth_type = effective_server.auth
else:
auth_type = api_config.auth_type or "bearer"
if auth_type == "bearer":
headers["Authorization"] = f"Bearer {cred.value}"
_log_trace("Applied Bearer auth")
elif auth_type == "basic":
auth_str = f"{cred.value}:{cred.secret}" if cred.secret else f"{cred.value}:"
encoded = base64.b64encode(auth_str.encode()).decode()
headers["Authorization"] = f"Basic {encoded}"
_log_trace("Applied Basic auth")
elif auth_type == "api_key":
headers["X-API-Key"] = cred.value
_log_trace("Applied API Key auth")
elif auth_type == "hmac":
self._apply_hmac_auth(
headers,
api_config,
cred,
query_params if query_params is not None else {},
body_content,
)
else:
log.warning("Unknown auth_type: %s", auth_type)
def _apply_hmac_auth(
self,
headers: dict[str, str],
api_config: ApiConfig,
cred: CredentialRecord,
query_params: dict[str, str],
body_content: bytes | None,
) -> None:
"""Apply HMAC authentication.
Supports various exchange APIs like Binance (SHA256, hex) and
Kraken (SHA512, base64).
Args:
headers: Headers dict to modify.
api_config: API config with HMAC configuration.
cred: Resolved credential with API key and secret.
query_params: Query params dict to modify (timestamp/signature added).
body_content: Request body content (for signing if sign_body=True).
Raises:
ValueError: If secret is not available in credentials.
"""
if not cred.secret:
raise ValueError("HMAC auth requires secret_key in credentials")
hmac_cfg = api_config.hmac_config or HmacConfig()
# 1. Generate timestamp or nonce
ts_value = str(int(time.time() * 1000))
ts_field = hmac_cfg.nonce_field or hmac_cfg.timestamp_field
# Add timestamp/nonce to query params
query_params[ts_field] = ts_value
# 2. Build payload to sign
if hmac_cfg.sign_body and body_content:
payload = body_content.decode("utf-8", errors="replace")
else:
# Query string with timestamp (same order as httpx will send)
payload = urlencode(query_params)
# 3. Generate signature
hash_func = hashlib.sha512 if hmac_cfg.algorithm == "sha512" else hashlib.sha256
signature = hmac.new(
cred.secret.encode("utf-8"),
payload.encode("utf-8"),
hash_func,
)
if hmac_cfg.signature_format == "base64":
sig_value = base64.b64encode(signature.digest()).decode("utf-8")
else:
sig_value = signature.hexdigest()
# 4. Add signature to query params
query_params[hmac_cfg.signature_field] = sig_value
# 5. Set API key header if configured
if hmac_cfg.key_header:
headers[hmac_cfg.key_header] = cred.value
_log_trace(
"Applied HMAC auth (algorithm=%s, format=%s)",
hmac_cfg.algorithm,
hmac_cfg.signature_format,
)
def _log_request(self, request: httpx.Request) -> None:
"""Log request details at TRACE level."""
_log_trace(">>> %s %s", request.method, request.url)
# Log headers (redact sensitive ones)
for name, value in request.headers.items():
if name.lower() in ("authorization", "x-api-key", "cookie"):
_log_trace(">>> %s: [REDACTED]", name)
else:
_log_trace(">>> %s: %s", name, value)
def _log_response(self, response: httpx.Response, elapsed: float) -> None:
"""Log response details at TRACE level."""
_log_trace("<<< %d %s (%.3fs)", response.status_code, response.reason_phrase, elapsed)
_log_trace("<<< Content-Type: %s", response.headers.get("content-type", "unknown"))
_log_trace("<<< Content-Length: %s", response.headers.get("content-length", "unknown"))
# Log response body (truncate if too large)
try:
body_text = response.text
if len(body_text) > 2000:
_log_trace("<<< Body: %s... [truncated, %d bytes total]", body_text[:2000], len(body_text))
else:
_log_trace("<<< Body: %s", body_text)
except Exception:
_log_trace("<<< Body: [unable to decode]")
def _check_response_size(self, response: httpx.Response) -> None:
"""Validate response size against configured limits.
Args:
response: HTTP response to check.
Raises:
ResponseTooLargeError: If response exceeds max size.
"""
content_length = response.headers.get("content-length")
if content_length and int(content_length) > self._limits.max_response_size:
raise ResponseTooLargeError(
int(content_length),
self._limits.max_response_size,
)
actual_size = len(response.content)
if actual_size > self._limits.max_response_size:
raise ResponseTooLargeError(actual_size, self._limits.max_response_size)
def _handle_retry_error(
self,
exc: Exception,
attempt: int,
endpoint_config: EndpointConfig,
) -> RapiResponse | None:
"""Handle a retryable error, returning a response for 4xx or None to continue.
Args:
exc: The caught exception.
attempt: Current attempt number (1-based).
endpoint_config: Endpoint configuration for parsing 4xx responses.
Returns:
RapiResponse for 4xx client errors, None to continue retrying.
"""
if isinstance(exc, httpx.TimeoutException):
log.warning("Request timeout (attempt %d): %s", attempt, exc)
elif isinstance(exc, httpx.NetworkError):
log.warning("Network error (attempt %d): %s", attempt, exc)
elif isinstance(exc, httpx.HTTPStatusError):
if 400 <= exc.response.status_code < 500:
return self._parse_response(exc.response, endpoint_config, 0.0)
log.warning("HTTP error (attempt %d): %s", attempt, exc)
return None
def _execute_with_retry(
self,
request: httpx.Request,
endpoint_config: EndpointConfig,
timeout: float,
*,
api_config: ApiConfig,
effective_server: ServerConfig | None = None,
) -> RapiResponse:
"""Execute request with retry logic.
Args:
request: Prepared HTTP request.
endpoint_config: Endpoint configuration.
timeout: Request timeout in seconds.
api_config: Resolved API configuration, forwarded to
:meth:`_check_auth_expired` for credential-context
lookup on HTTP 401 detection.
effective_server: Optional resolved server profile,
forwarded to :meth:`_check_auth_expired`.
Returns:
RapiResponse.
Raises:
AuthExpiredError: If the response signals access token
expiration (HTTP 401 with expiration markers).
Terminal, not retried.
RequestError: If all retries fail.
ResponseTooLargeError: If response is too large.
"""
last_error: Exception | None = None
delay = self._limits.retry_delay
for attempt in range(self._limits.max_retries + 1):
if attempt > 0:
log.debug("Retry %d/%d after %.1fs", attempt, self._limits.max_retries, delay)
_log_trace("Waiting %.1fs before retry...", delay)
time.sleep(delay)
delay *= self._limits.retry_backoff
_log_trace("Attempt %d/%d", attempt + 1, self._limits.max_retries + 1)
try:
start_time = time.monotonic()
with httpx.Client(timeout=timeout, verify=self._ssl_context, follow_redirects=False) as client:
response = client.send(request)
elapsed = time.monotonic() - start_time
self._log_response(response, elapsed)
self._check_response_size(response)
parsed = self._parse_response(response, endpoint_config, elapsed)
self._check_auth_expired(parsed, api_config, effective_server)
return parsed
except ResponseTooLargeError:
raise
except (httpx.TimeoutException, httpx.NetworkError, httpx.HTTPStatusError) as e:
result = self._handle_retry_error(e, attempt + 1, endpoint_config)
if result is not None:
self._check_auth_expired(result, api_config, effective_server)
return result
last_error = e
raise RequestError(
f"Request failed after {self._limits.max_retries + 1} attempts: {last_error}",
retryable=False,
)
async def _execute_with_retry_async(
self,
request: httpx.Request,
endpoint_config: EndpointConfig,
timeout: float,
*,
api_config: ApiConfig,
effective_server: ServerConfig | None = None,
) -> RapiResponse:
"""Execute async request with retry logic.
Args:
request: Prepared HTTP request.
endpoint_config: Endpoint configuration.
timeout: Request timeout in seconds.
api_config: Resolved API configuration, forwarded to
:meth:`_check_auth_expired` for credential-context
lookup on HTTP 401 detection.
effective_server: Optional resolved server profile,
forwarded to :meth:`_check_auth_expired`.
Returns:
RapiResponse.
Raises:
AuthExpiredError: If the response signals access token
expiration (HTTP 401 with expiration markers).
Terminal, not retried.
RequestError: If all retries fail.
ResponseTooLargeError: If response is too large.
"""
import asyncio
last_error: Exception | None = None
delay = self._limits.retry_delay
for attempt in range(self._limits.max_retries + 1):
if attempt > 0:
log.debug("Retry %d/%d after %.1fs", attempt, self._limits.max_retries, delay)
_log_trace("Waiting %.1fs before retry...", delay)
await asyncio.sleep(delay)
delay *= self._limits.retry_backoff
_log_trace("Attempt %d/%d", attempt + 1, self._limits.max_retries + 1)
try:
start_time = time.monotonic()
async with httpx.AsyncClient(
timeout=timeout, verify=self._ssl_context, follow_redirects=False
) as client:
response = await client.send(request)
elapsed = time.monotonic() - start_time
self._log_response(response, elapsed)
self._check_response_size(response)
parsed = self._parse_response(response, endpoint_config, elapsed)
self._check_auth_expired(parsed, api_config, effective_server)
return parsed
except ResponseTooLargeError:
raise
except (httpx.TimeoutException, httpx.NetworkError, httpx.HTTPStatusError) as e:
result = self._handle_retry_error(e, attempt + 1, endpoint_config)
if result is not None:
self._check_auth_expired(result, api_config, effective_server)
return result
last_error = e
raise RequestError(
f"Request failed after {self._limits.max_retries + 1} attempts: {last_error}",
retryable=False,
)
def _parse_response(
self,
response: httpx.Response,
endpoint_config: EndpointConfig,
elapsed: float,
) -> RapiResponse:
"""Parse HTTP response into RapiResponse.
Args:
response: Raw HTTP response.
endpoint_config: Endpoint configuration.
elapsed: Request duration in seconds.
Returns:
Parsed RapiResponse.
"""
text = response.text
data: Any = None
# Try to parse as JSON (only when content-type confirms it)
content_type = response.headers.get("content-type", "")
if "application/json" in content_type or ("application/vnd." in content_type and "+json" in content_type):
try:
data = response.json()
except json.JSONDecodeError:
log.debug("Response is not valid JSON despite content-type")
return RapiResponse(
status_code=response.status_code,
headers=dict(response.headers),
data=data,
text=text,
elapsed=elapsed,
endpoint_ref=endpoint_config.full_ref,
)
def _check_auth_expired(
self,
response: RapiResponse,
api_config: ApiConfig,
effective_server: ServerConfig | None,
) -> None:
"""Detect HTTP 401 indicating access token expiration and raise.
Inspects an already-parsed response and, when the heuristic
matches, raises :class:`AuthExpiredError` with a contextual
hint derived from the credentials currently active for this
request. Non-401 responses and 401 responses without
expiration markers are left untouched (backward compat
preserved : the caller receives a ``RapiResponse(ok=False)``
for non-expiration 401s).
Heuristic (any match triggers detection) :
- response body (case-insensitive) contains one of the
keywords listed in ``_AUTH_EXPIRED_BODY_KEYWORDS``
- response ``WWW-Authenticate`` header (case-insensitive)
contains ``invalid_token`` (RFC 6750 Bearer
error="invalid_token")
Args:
response: Parsed ``RapiResponse`` to inspect.
api_config: Resolved API configuration used to send the
request.
effective_server: Optional resolved server profile
providing the effective credentials for this call.
``None`` when falling back to the static
``api_config``.
Raises:
AuthExpiredError: If the response signals access token
expiration via body keywords or ``WWW-Authenticate``
header. The exception carries a sanitized
``token_source`` label and a contextual
``suggested_action`` hint.
"""
if response.status_code != 401:
return
body_lower = (response.text or "").lower()
has_body_keyword = any(kw in body_lower for kw in _AUTH_EXPIRED_BODY_KEYWORDS)
www_auth_value = response.headers.get("www-authenticate") or response.headers.get(
"WWW-Authenticate",
"",
)
has_header_marker = _AUTH_EXPIRED_HEADER_MARKER in www_auth_value.lower()
if not (has_body_keyword or has_header_marker):
# Backward compat: 401 without expiration markers returns
# a RapiResponse(ok=False) as before, not AuthExpiredError.
return
cred, credential_name = self._resolve_active_credential(api_config, effective_server)
token_source = self._derive_token_source(cred, credential_name, api_config, effective_server)
suggested_action = self._auth_renew_hint(cred, credential_name, api_config, effective_server)
# SECURITY (rules/code-rules.md Section 10) : tagged WARNING
# before raise, sanitized payload only (status code, content
# type, credential source label). NEVER log token value,
# response body, or Authorization header.
log.warning(
"[SECURITY] AuthExpired detected on endpoint %s (status=%d, content-type=%s, cred.source=%s)",
response.endpoint_ref,
response.status_code,
response.headers.get("content-type", "unknown"),
cred.source if cred else "unknown",
)
# User-facing ERROR (visible to apps consuming kstlib as a
# library) with the actionable re-authentication hint. The
# [SECURITY] WARNING above stays for audit/observability;
# this log targets human operators who need a clear next
# step. Sanitization rules apply equally: hint references
# source labels only (file path / env var name / SOPS path
# / provider name), never the token value.
log.error(
"Access token expired on endpoint '%s'. %s",
response.endpoint_ref,
suggested_action or "Please re-authenticate via your usual channel.",
)
raise AuthExpiredError(
f"Access token expired or invalidated (HTTP 401) on endpoint '{response.endpoint_ref}'.",
token_source=token_source,
suggested_action=suggested_action,
)
def _resolve_active_credential(
self,
api_config: ApiConfig,
effective_server: ServerConfig | None,
) -> tuple[CredentialRecord | None, str | None]:
"""Re-derive the credential active for the current request.
Used by :meth:`_check_auth_expired` to source a contextual
hint without mutating the request flow. Resolution errors
are swallowed so 401 detection always raises a meaningful
``AuthExpiredError`` rather than being masked by a stale
credential resolution failure.
Returns:
Tuple ``(cred, credential_name)`` where ``cred`` is the
resolved record (or ``None`` on failure) and
``credential_name`` is the registered name (or a
``server.<profile>`` hint for inline server
credentials).
"""
try:
if effective_server and effective_server.credentials:
cred = self._credential_resolver.resolve_inline(
effective_server.credentials,
name_hint=f"server.{effective_server.name}",
)
return cred, f"server.{effective_server.name}"
if api_config.credentials:
cred = self._credential_resolver.resolve(api_config.credentials)
return cred, api_config.credentials
except Exception:
# Credential resolution may fail between send time and
# 401 detection (env var unset, file removed, etc.). Fall
# back to a generic hint rather than mask the expiration.
log.debug("Credential re-resolution failed during AuthExpired detection", exc_info=True)
return None, None
def _credential_config(
self,
credential_name: str | None,
effective_server: ServerConfig | None,
) -> Mapping[str, Any] | None:
"""Return the raw credential configuration dict for hint derivation.
Args:
credential_name: Registered name of the credential, or
``server.<profile>`` when the credential is declared
inline on a server profile.
effective_server: Optional resolved server profile.
Returns:
The credential config dict, or ``None`` when no source
is available (caller falls back to a generic hint).
"""
if effective_server and effective_server.credentials:
return effective_server.credentials
if credential_name:
cfg = self._credential_resolver._config.get(credential_name) # noqa: SLF001
if isinstance(cfg, dict):
return cfg
return None
def _derive_token_source(
self,
cred: CredentialRecord | None,
credential_name: str | None,
api_config: ApiConfig,
effective_server: ServerConfig | None,
) -> str | None:
"""Build a sanitized label identifying where the access token came from.
Used as the ``token_source`` attribute of
:class:`AuthExpiredError`. Returns ``None`` when no
credential context is available so callers do not surface a
misleading default.
Examples of generated labels :
- ``"~/.sas/credentials.json"`` (file-based, SAS Viya)
- ``"env:KSTLIB_TOKEN"`` (environment variable)
- ``"sops:secrets/viya.sops.json"`` (SOPS-encrypted)
- ``"provider:corporate"`` (kstlib.auth OIDC provider)
"""
del api_config # reserved for future per-API token source resolution
if cred is None:
return None
cred_config = self._credential_config(credential_name, effective_server)
fallback = f"{cred.source}:{credential_name}" if credential_name else None
if cred_config is None:
return fallback
builder = _TOKEN_SOURCE_BUILDERS.get(cred.source)
if builder is None:
return fallback
return builder(cred_config) or fallback
def _auth_renew_hint(
self,
cred: CredentialRecord | None,
credential_name: str | None,
api_config: ApiConfig,
effective_server: ServerConfig | None,
) -> str:
"""Build a contextual re-authentication hint for the user.
Used as the ``suggested_action`` attribute of
:class:`AuthExpiredError`. The hint is best-effort and
NEVER includes secret material : it only references
credential sources (file path, env var name, SOPS file
path, provider name).
"""
del api_config # reserved for future per-API hint customization
if cred is None:
return "Re-authenticate via your usual channel."
cred_config = self._credential_config(credential_name, effective_server) or {}
builder = _AUTH_RENEW_HINT_BUILDERS.get(cred.source)
if builder is None:
return "Re-authenticate via your usual channel."
return builder(cred_config)
[docs]
def call(
endpoint_ref: str,
*args: Any,
body: Any = None,
headers: Mapping[str, str] | None = None,
confirm: str | None = None,
server: str | None = None,
**kwargs: Any,
) -> RapiResponse:
"""Make a quick synchronous API call using a temporary RapiClient.
Creates a temporary RapiClient and makes the call.
Args:
endpoint_ref: Endpoint reference.
*args: Positional path parameters.
body: Request body.
headers: Runtime headers.
confirm: Confirmation string for dangerous endpoints with safeguard.
server: Optional named server profile from ``rapi.servers``
(see :meth:`RapiClient.call` for cascade rules).
**kwargs: Keyword parameters.
Returns:
RapiResponse.
Examples:
>>> from kstlib.rapi import call # doctest: +SKIP
>>> response = call("httpbin.get_ip") # doctest: +SKIP
>>> response = call("github.repos-list", server="github") # doctest: +SKIP
"""
client = RapiClient()
return client.call(
endpoint_ref,
*args,
body=body,
headers=headers,
confirm=confirm,
server=server,
**kwargs,
)
[docs]
async def call_async(
endpoint_ref: str,
*args: Any,
body: Any = None,
headers: Mapping[str, str] | None = None,
confirm: str | None = None,
server: str | None = None,
**kwargs: Any,
) -> RapiResponse:
"""Make a quick asynchronous API call using a temporary RapiClient.
Creates a temporary RapiClient and makes the async call.
Args:
endpoint_ref: Endpoint reference.
*args: Positional path parameters.
body: Request body.
headers: Runtime headers.
confirm: Confirmation string for dangerous endpoints with safeguard.
server: Optional named server profile from ``rapi.servers``
(see :meth:`RapiClient.call` for cascade rules).
**kwargs: Keyword parameters.
Returns:
RapiResponse.
Examples:
>>> from kstlib.rapi import call_async # doctest: +SKIP
>>> response = await call_async("httpbin.get_ip") # doctest: +SKIP
"""
client = RapiClient()
return await client.call_async(
endpoint_ref,
*args,
body=body,
headers=headers,
confirm=confirm,
server=server,
**kwargs,
)
__all__ = [
"RapiClient",
"RapiResponse",
"call",
"call_async",
]