"""Configuration management for RAPI module.
This module handles loading and resolving endpoint configurations
from kstlib.conf.yml or external ``*.rapi.yml`` files.
Supports:
- Loading from kstlib.conf.yml (default)
- Loading from external YAML files (``*.rapi.yml``)
- Auto-discovery of ``*.rapi.yml`` files in current directory
- Include patterns in kstlib.conf.yml
"""
from __future__ import annotations
import logging
import re
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any
from kstlib.logging import TRACE_LEVEL
from kstlib.rapi.exceptions import (
EndpointAmbiguousError,
EndpointCollisionError,
EndpointNotFoundError,
EnvVarError,
SafeguardMissingError,
ServerNotFoundError,
)
if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
log = logging.getLogger(__name__)
def _log_trace(msg: str, *args: object) -> None:
"""Log at TRACE level (custom level 5, below DEBUG)."""
log.log(TRACE_LEVEL, msg, *args)
# Pattern for path parameters: {param} or {0}, {1}
_PATH_PARAM_PATTERN = re.compile(r"\{([a-zA-Z_][a-zA-Z0-9_]*|\d+)\}")
# Patterns that indicate path traversal or injection in URL path parameters
_DANGEROUS_PATH_SEGMENTS = ("../", "..\\", "..")
def _validate_path_param(name: str, value: str) -> None:
"""Validate a path parameter value for injection attacks.
Args:
name: Parameter name (for error messages).
value: Substituted value to check.
Raises:
ValueError: If the value contains path traversal or null bytes.
"""
if "\x00" in value:
raise ValueError(f"Null bytes not allowed in path parameter '{name}'")
if any(seg in value for seg in _DANGEROUS_PATH_SEGMENTS):
raise ValueError(f"Path traversal not allowed in parameter '{name}': {value!r}")
# Deep defense: allowed values for HMAC config (hardcoded limits)
_ALLOWED_HMAC_ALGORITHMS = frozenset({"sha256", "sha512"})
_ALLOWED_SIGNATURE_FORMATS = frozenset({"hex", "base64"})
_MAX_FIELD_NAME_LENGTH = 64 # Max length for field names (timestamp_field, etc.)
_MAX_HEADER_NAME_LENGTH = 128 # Max length for header names
# Deep defense: safeguard validation
_MAX_SAFEGUARD_LENGTH = 128
_SAFEGUARD_PATTERN = re.compile(r"^[A-Za-z0-9_\-\s\{\}/]+$")
# Default HTTP methods that require safeguard
_DEFAULT_SAFEGUARD_METHODS = frozenset({"DELETE", "PUT"})
# Deep defense: server profiles
_MAX_SERVERS = 20 # Max named server profiles
_MAX_SERVER_NAME_LENGTH = 64
_SERVER_NAME_PATTERN = re.compile(r"^[a-zA-Z][a-zA-Z0-9_-]*$")
_ALLOWED_SERVER_KEYS = frozenset({"base_url", "credentials", "auth", "headers"})
_ALLOWED_AUTH_TYPES = frozenset({"bearer", "basic", "api_key", "hmac"})
_ALLOWED_URL_SCHEMES = frozenset({"http", "https"})
# Pattern for environment variable substitution: ${VAR} or ${VAR:-default}
_ENV_VAR_PATTERN = re.compile(r"\$\{([a-zA-Z_][a-zA-Z0-9_]*)(?::-([^}]*))?\}")
def _expand_env_vars(value: str, source: str | None = None) -> str:
"""Expand environment variables in a string value.
Supports two syntaxes:
- ``${VAR}`` - required variable, raises EnvVarError if not set
- ``${VAR:-default}`` - optional variable with default value
Args:
value: String potentially containing ${VAR} patterns.
source: Source file for error messages.
Returns:
String with environment variables expanded.
Raises:
EnvVarError: If required variable is not set.
Examples:
>>> import os
>>> os.environ["TEST_VAR"] = "hello"
>>> _expand_env_vars("${TEST_VAR} world")
'hello world'
>>> _expand_env_vars("${MISSING:-default}")
'default'
"""
import os
def replacer(match: re.Match[str]) -> str:
var_name = match.group(1)
default_value = match.group(2)
env_value = os.environ.get(var_name)
if env_value is not None:
return env_value
if default_value is not None:
return default_value
raise EnvVarError(var_name, source)
return _ENV_VAR_PATTERN.sub(replacer, value)
def _expand_env_vars_recursive(data: Any, source: str | None = None) -> Any:
"""Recursively expand environment variables in config data.
Applies ``_expand_env_vars`` to all string values in dicts and lists.
Args:
data: Configuration data (dict, list, or scalar).
source: Source file for error messages.
Returns:
Data with all environment variables expanded.
Examples:
>>> import os
>>> os.environ["HOST"] = "example.com"
>>> _expand_env_vars_recursive({"url": "https://${HOST}"})
{'url': 'https://example.com'}
"""
if isinstance(data, dict):
return {k: _expand_env_vars_recursive(v, source) for k, v in data.items()}
if isinstance(data, list):
return [_expand_env_vars_recursive(item, source) for item in data]
if isinstance(data, str):
return _expand_env_vars(data, source)
return data
@dataclass(frozen=True, slots=True)
class HmacConfig:
"""HMAC signing configuration.
Supports various exchange APIs like Binance (SHA256) and Kraken (SHA512).
Attributes:
algorithm: Hash algorithm (sha256, sha512).
timestamp_field: Query param name for timestamp.
nonce_field: Query param name for nonce (alternative to timestamp).
signature_field: Query param name for signature.
signature_format: Output format (hex, base64).
key_header: Header name for API key.
sign_body: If True, sign request body instead of query string.
Examples:
>>> config = HmacConfig(algorithm="sha512", signature_format="base64")
>>> config.algorithm
'sha512'
"""
algorithm: str = "sha256"
timestamp_field: str = "timestamp"
nonce_field: str | None = None
signature_field: str = "signature"
signature_format: str = "hex"
key_header: str | None = None
sign_body: bool = False
def __post_init__(self) -> None:
"""Validate HMAC config values (deep defense)."""
# Validate algorithm
if self.algorithm not in _ALLOWED_HMAC_ALGORITHMS:
raise ValueError(f"Invalid HMAC algorithm: {self.algorithm!r}. Allowed: {sorted(_ALLOWED_HMAC_ALGORITHMS)}")
# Validate signature format
if self.signature_format not in _ALLOWED_SIGNATURE_FORMATS:
raise ValueError(
f"Invalid signature format: {self.signature_format!r}. Allowed: {sorted(_ALLOWED_SIGNATURE_FORMATS)}"
)
# Validate field name lengths
if len(self.timestamp_field) > _MAX_FIELD_NAME_LENGTH:
raise ValueError(f"timestamp_field too long: {len(self.timestamp_field)} > {_MAX_FIELD_NAME_LENGTH}")
if len(self.signature_field) > _MAX_FIELD_NAME_LENGTH:
raise ValueError(f"signature_field too long: {len(self.signature_field)} > {_MAX_FIELD_NAME_LENGTH}")
if self.nonce_field and len(self.nonce_field) > _MAX_FIELD_NAME_LENGTH:
raise ValueError(f"nonce_field too long: {len(self.nonce_field)} > {_MAX_FIELD_NAME_LENGTH}")
if self.key_header and len(self.key_header) > _MAX_HEADER_NAME_LENGTH:
raise ValueError(f"key_header too long: {len(self.key_header)} > {_MAX_HEADER_NAME_LENGTH}")
[docs]
@dataclass(frozen=True, slots=True)
class MultipartConfig:
"""Multipart upload configuration for an endpoint.
Attributes:
field_name: Form field name for the file (default: "file").
content_type: Override MIME type (auto-detected from filename if None).
Examples:
>>> config = MultipartConfig(field_name="dataFile")
>>> config.field_name
'dataFile'
"""
field_name: str = "file"
content_type: str | None = None
[docs]
def __post_init__(self) -> None:
"""Validate multipart config values."""
if not self.field_name or len(self.field_name) > _MAX_FIELD_NAME_LENGTH:
raise ValueError(f"field_name must be 1-{_MAX_FIELD_NAME_LENGTH} chars, got: {self.field_name!r}")
@dataclass(frozen=True, slots=True)
class SafeguardConfig:
"""Global safeguard configuration for dangerous HTTP methods.
Configures which HTTP methods require a safeguard (confirmation string)
to be defined on endpoints. This is a safety mechanism to prevent
accidental calls to destructive endpoints.
Attributes:
required_methods: HTTP methods that must have a safeguard configured.
Examples:
>>> config = SafeguardConfig()
>>> "DELETE" in config.required_methods
True
>>> config = SafeguardConfig(required_methods=frozenset({"DELETE"}))
>>> "PUT" in config.required_methods
False
"""
required_methods: frozenset[str] = field(default_factory=lambda: _DEFAULT_SAFEGUARD_METHODS)
def _extract_credentials_from_rapi(
data: dict[str, Any],
api_name: str,
file_path: Path,
) -> tuple[str | None, dict[str, Any]]:
"""Extract credentials configuration from RAPI file data.
Args:
data: Parsed YAML data.
api_name: Name of the API.
file_path: Path to the file (for resolving relative paths).
Returns:
Tuple of (credentials_ref, credentials_config).
"""
credentials_config: dict[str, Any] = {}
credentials_ref: str | None = None
if "credentials" not in data:
return None, {}
cred_data = data["credentials"]
if isinstance(cred_data, dict):
# Inline credentials definition
credentials_ref = f"_rapi_{api_name}_cred"
# Resolve relative paths in credentials (expand ~ first)
if "path" in cred_data:
cred_path = Path(cred_data["path"]).expanduser()
if cred_path.is_absolute():
# Already absolute (or was ~ expanded to absolute)
cred_data["path"] = str(cred_path)
else:
# Relative path: resolve against file location
cred_data["path"] = str(file_path.parent / cred_data["path"])
credentials_config[credentials_ref] = cred_data
elif isinstance(cred_data, str):
# Reference to existing credential
credentials_ref = cred_data
return credentials_ref, credentials_config
def _extract_auth_config(
data: dict[str, Any],
) -> tuple[str | None, HmacConfig | None]:
"""Extract auth type and HMAC config from RAPI file data.
Args:
data: Parsed YAML data.
Returns:
Tuple of (auth_type, HmacConfig or None).
"""
if "auth" not in data:
return None, None
auth_data = data["auth"]
if isinstance(auth_data, str):
return auth_data, None
if not isinstance(auth_data, dict):
return None, None
auth_type = auth_data.get("type")
# Parse HMAC config if auth type is hmac
hmac_config: HmacConfig | None = None
if auth_type == "hmac":
hmac_config = HmacConfig(
algorithm=auth_data.get("algorithm", "sha256"),
timestamp_field=auth_data.get("timestamp_field", "timestamp"),
nonce_field=auth_data.get("nonce_field"),
signature_field=auth_data.get("signature_field", "signature"),
signature_format=auth_data.get("signature_format", "hex"),
key_header=auth_data.get("key_header"),
sign_body=auth_data.get("sign_body", False),
)
return auth_type, hmac_config
def _merge_with_defaults(data: dict[str, Any], defaults: dict[str, Any] | None) -> dict[str, Any]:
"""Merge file data with defaults (file wins on conflict).
Args:
data: Configuration data from the file.
defaults: Default values to apply.
Returns:
Merged configuration with file values taking precedence.
"""
if not defaults:
return data
# Start with defaults, then overlay file data
merged = dict(defaults)
for key, value in data.items():
if key == "headers" and isinstance(value, dict) and isinstance(merged.get("headers"), dict):
# Merge headers dicts (file headers override default headers)
merged["headers"] = {**merged["headers"], **value}
elif key == "credentials" and isinstance(value, dict) and isinstance(merged.get("credentials"), dict):
# Merge credentials dicts (file credentials override default credentials)
merged["credentials"] = {**merged["credentials"], **value}
else:
# File value takes precedence
merged[key] = value
return merged
def _parse_rapi_file(
path: Path,
defaults: dict[str, Any] | None = None,
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Parse a ``*.rapi.yml`` file into internal config format.
Converts the simplified format:
```yaml
name: github
base_url: "https://api.github.com"
credentials:
type: sops
path: "./tokens/github.sops.json"
auth:
type: bearer
endpoints:
user:
path: "/user"
```
Into the internal format:
```python
{
"api": {
"github": {
"base_url": "...",
"credentials": "_github_cred",
"auth_type": "bearer",
"endpoints": {...}
}
}
}
```
With defaults support, a minimal file can inherit from kstlib.conf.yml:
```yaml
name: github
endpoints:
user:
path: "/user"
```
Args:
path: Path to the ``*.rapi.yml`` file.
defaults: Default values inherited from kstlib.conf.yml rapi.defaults section.
Returns:
Tuple of (api_config, credentials_config).
Raises:
TypeError: If file format is invalid.
ValueError: If required fields are missing.
"""
import yaml
content = path.read_text(encoding="utf-8")
data = yaml.safe_load(content)
if not isinstance(data, dict):
raise TypeError(f"Invalid RAPI config format in {path}: expected dict")
# Merge with defaults first (file wins on conflict)
data = _merge_with_defaults(data, defaults)
# Expand environment variables in all string values (after merge so defaults can use env vars too)
data = _expand_env_vars_recursive(data, source=str(path))
# Extract API name (or derive from filename)
api_name = data.get("name")
if not api_name:
api_name = path.stem.replace(".rapi", "")
_log_trace("API name not specified, derived from filename: %s", api_name)
# Validate required fields
base_url = data.get("base_url")
if not base_url:
raise ValueError(f"Missing 'base_url' in {path}")
# Extract credentials and auth
credentials_ref, credentials_config = _extract_credentials_from_rapi(data, api_name, path)
auth_type, hmac_config = _extract_auth_config(data)
# Build API config
api_config: dict[str, Any] = {
"api": {
api_name: {
"base_url": base_url,
"credentials": credentials_ref,
"auth_type": auth_type,
"hmac_config": hmac_config,
"headers": data.get("headers", {}),
"endpoints": data.get("endpoints", {}),
"server": data.get("server"),
}
}
}
_log_trace(
"Parsed %s: api=%s, %d endpoints, credentials=%s",
path.name,
api_name,
len(data.get("endpoints", {})),
"inline" if credentials_ref and credentials_ref.startswith("_rapi_") else credentials_ref,
)
# Handle nested includes (relative to this file)
include_patterns = data.get("include")
if include_patterns:
included_endpoints, included_creds = _resolve_rapi_includes(include_patterns, path.parent, defaults)
# Merge included endpoints into this API
api_config["api"][api_name]["endpoints"].update(included_endpoints)
credentials_config.update(included_creds)
_log_trace(
"Merged %d endpoints from includes into %s",
len(included_endpoints),
api_name,
)
return api_config, credentials_config
def _resolve_rapi_includes(
patterns: list[str] | str,
base_dir: Path,
defaults: dict[str, Any] | None,
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Resolve include patterns relative to a rapi file.
Args:
patterns: Include pattern(s) relative to base_dir.
base_dir: Directory containing the parent rapi file.
defaults: Default values to pass to included files.
Returns:
Tuple of (merged_endpoints, merged_credentials).
"""
if isinstance(patterns, str):
patterns = [patterns]
merged_endpoints: dict[str, Any] = {}
merged_credentials: dict[str, Any] = {}
for pattern in patterns:
# Resolve relative path (remove leading ./)
clean_pattern = pattern.removeprefix("./")
resolved_path = base_dir / clean_pattern
# Support glob patterns or single file
matches = (
list(base_dir.glob(clean_pattern))
if "*" in clean_pattern
else ([resolved_path] if resolved_path.exists() else [])
)
for file_path in matches:
if not file_path.exists():
log.warning("Include file not found: %s", file_path)
continue
_log_trace("Including nested file: %s", file_path.name)
api_config, creds = _parse_rapi_file(file_path, defaults=defaults)
# Extract endpoints from the included file (ignore API name)
for api_data in api_config.get("api", {}).values():
endpoints = api_data.get("endpoints", {})
merged_endpoints.update(endpoints)
merged_credentials.update(creds)
return merged_endpoints, merged_credentials
def _check_endpoint_collisions(
api_name: str,
existing_api: dict[str, Any],
api_data: dict[str, Any],
path: Path,
ctx: tuple[dict[str, list[str]], bool],
) -> None:
"""Check for endpoint collisions and warn or raise.
Args:
api_name: Name of the API.
existing_api: Existing API data from previous files.
api_data: New API data from current file.
path: Current file path.
ctx: Tuple of (endpoint_sources dict, strict flag).
"""
endpoint_sources, strict = ctx
existing_endpoints = set(existing_api.get("endpoints", {}).keys())
new_endpoints = set(api_data.get("endpoints", {}).keys())
collisions = existing_endpoints & new_endpoints
if collisions:
log.warning(
"API '%s': %d endpoint collision(s) in '%s' (already loaded from previous files). "
"Tip: check for overlapping includes or duplicate API names. "
"Use rapi.strict: true to raise an error instead.",
api_name,
len(collisions),
path.name,
)
for ep_name in collisions:
full_ref = f"{api_name}.{ep_name}"
first_source = (
endpoint_sources.get(full_ref, ["unknown"])[0] if full_ref in endpoint_sources else "earlier file"
)
if full_ref not in endpoint_sources:
endpoint_sources[full_ref] = []
endpoint_sources[full_ref].append(str(path))
if strict:
raise EndpointCollisionError(full_ref, endpoint_sources[full_ref])
log.warning(
" - %s (first: %s, now: %s, overwriting)",
full_ref,
first_source,
path.name,
)
def _merge_api_endpoints(
api_name: str,
existing_api: dict[str, Any],
api_data: dict[str, Any],
path: Path,
) -> None:
"""Merge endpoints from existing API with new API data.
Args:
api_name: Name of the API.
existing_api: Existing API data from previous files.
api_data: New API data from current file (modified in place).
path: Current file path (for logging).
"""
existing_count = len(existing_api.get("endpoints", {}))
new_count = len(api_data.get("endpoints", {}))
merged_endpoints = {**existing_api.get("endpoints", {}), **api_data.get("endpoints", {})}
api_data["endpoints"] = merged_endpoints
log.warning(
"API '%s' redefined in '%s': merging %d existing + %d new = %d endpoints. "
"This usually means two files declare name: '%s'. "
"Consider using include: in a single root file instead.",
api_name,
path.name,
existing_count,
new_count,
len(merged_endpoints),
api_name,
)
def _track_endpoint_sources(
api_name: str,
api_data: dict[str, Any],
path: Path,
endpoint_sources: dict[str, list[str]],
) -> None:
"""Track endpoint sources for debugging.
Args:
api_name: Name of the API.
api_data: API data from current file.
path: Current file path.
endpoint_sources: Tracking dict for endpoint sources.
"""
for ep_name in api_data.get("endpoints", {}):
full_ref = f"{api_name}.{ep_name}"
if full_ref not in endpoint_sources:
endpoint_sources[full_ref] = []
if str(path) not in endpoint_sources[full_ref]:
endpoint_sources[full_ref].append(str(path))
[docs]
@dataclass(frozen=True, slots=True)
class EndpointConfig:
"""Configuration for a single API endpoint.
Attributes:
name: Endpoint name (e.g., "get_ip").
api_name: Parent API name (e.g., "httpbin").
path: URL path template (e.g., "/delay/{seconds}").
method: HTTP method (GET, POST, PUT, DELETE, PATCH).
query: Default query parameters.
headers: Endpoint-level headers (merged with service headers).
body_template: Default body template for POST/PUT.
auth: Whether to apply API-level authentication to this endpoint.
Set to False for public endpoints that don't require auth.
description: Human-readable description of the endpoint.
server: Optional named server profile this endpoint should use.
Resolved against ``rapi.servers`` at request time. Overrides
any ``server:`` directive set at the file level on the parent
``ApiConfig``. Validated at config-load time when
``rapi.servers`` is present.
Examples:
>>> config = EndpointConfig(
... name="get_ip",
... api_name="httpbin",
... path="/ip",
... method="GET",
... )
>>> config.full_ref
'httpbin.get_ip'
"""
name: str
api_name: str
path: str
method: str = "GET"
query: dict[str, str | None] = field(default_factory=dict)
headers: dict[str, str] = field(default_factory=dict)
body_template: dict[str, Any] | None = None
auth: bool = True
safeguard: str | None = None
description: str | None = None
multipart: MultipartConfig | None = None
server: str | None = None
@property
def is_multipart(self) -> bool:
"""Check if endpoint is configured for multipart/form-data upload."""
ct = self.headers.get("Content-Type", "")
return "multipart/form-data" in ct.lower()
[docs]
def __post_init__(self) -> None:
"""Validate safeguard field (deep defense)."""
if self.safeguard is not None:
if len(self.safeguard) > _MAX_SAFEGUARD_LENGTH:
raise ValueError(f"safeguard too long: {len(self.safeguard)} > {_MAX_SAFEGUARD_LENGTH}")
if not _SAFEGUARD_PATTERN.match(self.safeguard):
raise ValueError(
f"safeguard contains invalid characters: {self.safeguard!r}. "
f"Allowed: A-Z, a-z, 0-9, _, -, space, {'{}'}, /"
)
@property
def full_ref(self) -> str:
"""Return full reference: api_name.endpoint_name."""
return f"{self.api_name}.{self.name}"
[docs]
def build_path(self, *args: Any, **kwargs: Any) -> str:
"""Build path with positional and keyword arguments.
Args:
*args: Positional arguments for {0}, {1}, etc.
**kwargs: Keyword arguments for {name} placeholders.
Returns:
Formatted path string.
Raises:
ValueError: If required parameters are missing.
Examples:
>>> config = EndpointConfig(
... name="delay",
... api_name="httpbin",
... path="/delay/{seconds}",
... )
>>> config.build_path(seconds=5)
'/delay/5'
>>> config.build_path(5)
'/delay/5'
"""
path = self.path
# Find all placeholders
placeholders = _PATH_PARAM_PATTERN.findall(path)
for placeholder in placeholders:
if placeholder.isdigit():
# Positional: {0}, {1}
idx = int(placeholder)
if idx < len(args):
value = str(args[idx])
_validate_path_param(placeholder, value)
path = path.replace(f"{{{placeholder}}}", value)
else:
raise ValueError(f"Missing positional argument {idx} for path {self.path}")
elif placeholder in kwargs:
# Named: {name}
value = str(kwargs[placeholder])
_validate_path_param(placeholder, value)
path = path.replace(f"{{{placeholder}}}", value)
elif len(args) > 0:
# Try to use first positional arg for first named placeholder
value = str(args[0])
_validate_path_param(placeholder, value)
path = path.replace(f"{{{placeholder}}}", value)
args = args[1:]
else:
raise ValueError(f"Missing parameter '{placeholder}' for path {self.path}")
return path
[docs]
def build_safeguard(self, *args: Any, **kwargs: Any) -> str | None:
"""Build safeguard string with variable substitution.
Substitutes ``{param}`` placeholders in the safeguard string with
provided arguments, similar to ``build_path``.
Args:
*args: Positional arguments for {0}, {1}, etc.
**kwargs: Keyword arguments for {name} placeholders.
Returns:
Substituted safeguard string, or None if no safeguard configured.
Examples:
>>> config = EndpointConfig(
... name="delete",
... api_name="test",
... path="/users/{userId}",
... method="DELETE",
... safeguard="DELETE USER {userId}",
... )
>>> config.build_safeguard(userId="abc123")
'DELETE USER abc123'
"""
if self.safeguard is None:
return None
result = self.safeguard
placeholders = _PATH_PARAM_PATTERN.findall(result)
for placeholder in placeholders:
if placeholder.isdigit():
idx = int(placeholder)
if idx < len(args):
result = result.replace(f"{{{placeholder}}}", str(args[idx]))
elif placeholder in kwargs:
result = result.replace(f"{{{placeholder}}}", str(kwargs[placeholder]))
elif len(args) > 0:
result = result.replace(f"{{{placeholder}}}", str(args[0]))
args = args[1:]
return result
[docs]
@dataclass(frozen=True, slots=True)
class ApiConfig:
"""Configuration for an API service.
Attributes:
name: API service name (e.g., "httpbin").
base_url: Base URL for the API.
credentials: Name of credential config to use.
auth_type: Authentication type (bearer, basic, api_key, hmac).
hmac_config: HMAC signing configuration (required when auth_type is hmac).
headers: Service-level headers (applied to all endpoints).
endpoints: Dictionary of endpoint configurations.
server: Optional named server profile this API file should use by
default. Resolved against ``rapi.servers`` at request time.
Acts as a fallback when individual endpoints do not declare
their own ``server:`` directive. Validated at config-load
time when ``rapi.servers`` is present.
Examples:
>>> api = ApiConfig(
... name="httpbin",
... base_url="https://httpbin.org",
... endpoints={},
... )
"""
name: str
base_url: str
credentials: str | None = None
auth_type: str | None = None
hmac_config: HmacConfig | None = None
headers: dict[str, str] = field(default_factory=dict)
endpoints: dict[str, EndpointConfig] = field(default_factory=dict)
server: str | None = None
def _validate_server_name(name: str) -> None:
"""Validate a server profile name.
Args:
name: Server profile name to validate.
Raises:
ValueError: If name is invalid.
"""
if not name:
raise ValueError("Server profile name must not be empty")
if len(name) > _MAX_SERVER_NAME_LENGTH:
raise ValueError(f"Server profile name too long: {len(name)} > {_MAX_SERVER_NAME_LENGTH}")
if not _SERVER_NAME_PATTERN.match(name):
raise ValueError(f"Invalid server profile name: {name!r}. Must match {_SERVER_NAME_PATTERN.pattern}")
def _validate_server_profile(name: str, profile: dict[str, Any]) -> None:
"""Validate a server profile dict for allowed keys and value types.
Args:
name: Server profile name (for error messages).
profile: Server profile configuration dict.
Raises:
ValueError: If profile contains invalid keys or values.
"""
unknown_keys = set(profile) - _ALLOWED_SERVER_KEYS
if unknown_keys:
raise ValueError(
f"Server profile '{name}' has unknown keys: {sorted(unknown_keys)}. Allowed: {sorted(_ALLOWED_SERVER_KEYS)}"
)
base_url = profile.get("base_url")
if base_url is not None:
_validate_base_url(base_url, context=f"servers.{name}")
auth = profile.get("auth")
if auth is not None and auth not in _ALLOWED_AUTH_TYPES:
raise ValueError(
f"Server profile '{name}' has invalid auth type: {auth!r}. Allowed: {sorted(_ALLOWED_AUTH_TYPES)}"
)
headers = profile.get("headers")
if headers is not None:
if not isinstance(headers, dict):
raise ValueError(f"Server profile '{name}' headers must be a dict")
for hdr_name in headers:
if len(str(hdr_name)) > _MAX_HEADER_NAME_LENGTH:
raise ValueError(
f"Server profile '{name}' header name too long: {len(str(hdr_name))} > {_MAX_HEADER_NAME_LENGTH}"
)
def _validate_base_url(url: str, *, context: str = "config") -> None:
"""Validate a base URL for allowed schemes.
Args:
url: URL string to validate.
context: Context label for error messages.
Raises:
ValueError: If URL scheme is not http or https.
"""
from urllib.parse import urlparse
if not url or "${" in url:
return
parsed = urlparse(url)
if parsed.scheme and parsed.scheme not in _ALLOWED_URL_SCHEMES:
raise ValueError(f"Invalid URL scheme in {context}: {parsed.scheme!r}. Allowed: {sorted(_ALLOWED_URL_SCHEMES)}")
[docs]
@dataclass(frozen=True, slots=True)
class ServerConfig:
"""Resolved server profile (defaults merged with server overrides).
Created by :meth:`RapiConfigManager.resolve_server` after merging
``rapi.defaults`` with a named ``rapi.servers.<name>`` profile.
Attributes:
name: Server profile name (or ``"defaults"`` for the fallback).
base_url: Base URL for the server.
credentials: Credentials configuration dict.
auth: Authentication type string (bearer, basic, api_key, hmac).
headers: Merged headers dict.
Examples:
>>> cfg = ServerConfig(
... name="source",
... base_url="https://viya-source.example.com",
... credentials={"type": "file", "path": "~/.sas/creds.json"},
... auth="bearer",
... headers={"Accept": "application/json"},
... )
"""
name: str
base_url: str
credentials: dict[str, Any] = field(default_factory=dict)
auth: str | None = None
headers: dict[str, str] = field(default_factory=dict)
[docs]
class RapiConfigManager:
"""Manage RAPI configuration and endpoint resolution.
Loads API and endpoint configurations from kstlib.conf.yml and provides
resolution methods supporting both full references (api.endpoint) and
short references (endpoint only, auto-resolved if unique).
Supports loading from:
- kstlib.conf.yml (default)
- External ``*.rapi.yml`` files (via from_file/from_files)
- Auto-discovery of ``*.rapi.yml`` in current directory (via discover)
Args:
rapi_config: The 'rapi' section from configuration.
credentials_config: Inline credentials extracted from ``*.rapi.yml`` files.
Examples:
>>> manager = RapiConfigManager({"api": {"httpbin": {"base_url": "..."}}})
>>> endpoint = manager.resolve("httpbin.get_ip") # doctest: +SKIP
>>> manager = RapiConfigManager.from_file("github.rapi.yml") # doctest: +SKIP
>>> manager = RapiConfigManager.discover() # doctest: +SKIP
"""
[docs]
def __init__(
self,
rapi_config: Mapping[str, Any] | None = None,
credentials_config: Mapping[str, Any] | None = None,
safeguard_config: SafeguardConfig | None = None,
strict: bool = False,
) -> None:
"""Initialize RapiConfigManager.
Args:
rapi_config: The 'rapi' section from configuration.
credentials_config: Inline credentials from ``*.rapi.yml`` files.
safeguard_config: Safeguard configuration (default: DELETE and PUT require safeguard).
strict: If True, raise error on endpoint collisions. If False, warn and overwrite.
"""
self._config = rapi_config or {}
self._credentials_config = dict(credentials_config) if credentials_config else {}
self._safeguard_config = safeguard_config or SafeguardConfig()
self._strict = strict
self._apis: dict[str, ApiConfig] = {}
self._endpoint_index: dict[str, list[str]] = {} # endpoint_name -> [api_names]
self._endpoint_sources: dict[str, str] = {} # full_ref -> source file
self._source_files: list[Path] = [] # Track loaded files for debugging
self._defaults: dict[str, Any] = {} # rapi.defaults section
self._servers: dict[str, dict[str, Any]] = {} # rapi.servers.* profiles
self._load_apis()
[docs]
@classmethod
def from_file(
cls,
path: str | Path,
base_dir: Path | None = None,
safeguard_config: SafeguardConfig | None = None,
defaults: dict[str, Any] | None = None,
strict: bool = False,
) -> RapiConfigManager:
"""Load configuration from a single ``*.rapi.yml`` file.
The file format is simplified compared to kstlib.conf.yml,
with top-level keys: name, base_url, credentials, auth, headers, endpoints.
Args:
path: Path to the ``*.rapi.yml`` file.
base_dir: Base directory for resolving relative paths in credentials.
safeguard_config: Safeguard configuration (default: DELETE and PUT require safeguard).
defaults: Default values inherited from kstlib.conf.yml rapi.defaults section.
strict: If True, raise error on endpoint collisions. If False, warn and overwrite.
Returns:
Configured RapiConfigManager instance.
Raises:
FileNotFoundError: If file does not exist.
ValueError: If file format is invalid.
Examples:
>>> manager = RapiConfigManager.from_file("github.rapi.yml") # doctest: +SKIP
"""
return cls.from_files(
[path], base_dir=base_dir, safeguard_config=safeguard_config, defaults=defaults, strict=strict
)
[docs]
@classmethod
def from_files(
cls,
paths: Sequence[str | Path],
base_dir: Path | None = None,
safeguard_config: SafeguardConfig | None = None,
defaults: dict[str, Any] | None = None,
strict: bool = False,
) -> RapiConfigManager:
"""Load configuration from multiple ``*.rapi.yml`` files.
Args:
paths: List of paths to ``*.rapi.yml`` files.
base_dir: Base directory for resolving relative paths.
safeguard_config: Safeguard configuration (default: DELETE and PUT require safeguard).
defaults: Default values inherited from kstlib.conf.yml rapi.defaults section.
Supports: base_url, credentials, auth, headers.
strict: If True, raise error on endpoint collisions. If False, warn and overwrite.
Returns:
Configured RapiConfigManager instance with merged configs.
Raises:
FileNotFoundError: If any file does not exist.
ValueError: If any file format is invalid.
EndpointCollisionError: If strict=True and endpoints collide.
Examples:
>>> manager = RapiConfigManager.from_files([
... "github.rapi.yml",
... "slack.rapi.yml",
... ]) # doctest: +SKIP
"""
merged_api_config: dict[str, Any] = {"api": {}}
merged_credentials: dict[str, Any] = {}
source_files: list[Path] = []
# Track endpoint sources: full_ref -> [source_files]
endpoint_sources: dict[str, list[str]] = {}
for file_path in paths:
path = Path(file_path)
if not path.is_absolute() and base_dir:
path = base_dir / path
if not path.exists():
raise FileNotFoundError(f"RAPI config file not found: {path}")
_log_trace("Loading RAPI config from: %s", path)
api_config, credentials = _parse_rapi_file(path, defaults=defaults)
# Merge API config with collision detection
collision_ctx = (endpoint_sources, strict)
for api_name, api_data in api_config.get("api", {}).items():
existing_api = merged_api_config["api"].get(api_name)
if existing_api:
_check_endpoint_collisions(api_name, existing_api, api_data, path, collision_ctx)
_merge_api_endpoints(api_name, existing_api, api_data, path)
_track_endpoint_sources(api_name, api_data, path, endpoint_sources)
merged_api_config["api"][api_name] = api_data
# Merge credentials
merged_credentials.update(credentials)
source_files.append(path)
manager = cls(merged_api_config, merged_credentials, safeguard_config, strict=strict)
manager._source_files = source_files
# Store endpoint sources for debugging
for full_ref, sources in endpoint_sources.items():
if sources:
manager._endpoint_sources[full_ref] = sources[0]
# Validate any server: directives in the loaded files. With no
# rapi.servers context here, references produce warnings only.
# When called via load_rapi_config, the parent manager re-validates
# against the merged servers section.
manager._validate_server_references()
return manager
[docs]
@classmethod
def discover(
cls,
directory: str | Path | None = None,
pattern: str = "*.rapi.yml",
) -> RapiConfigManager:
"""Auto-discover and load ``*.rapi.yml`` files from a directory.
Searches for files matching the pattern in the specified directory
(defaults to current working directory).
Args:
directory: Directory to search in (default: current directory).
pattern: Glob pattern for files (default: ``*.rapi.yml``).
Returns:
Configured RapiConfigManager instance.
Raises:
FileNotFoundError: If no matching files found.
Examples:
>>> manager = RapiConfigManager.discover() # doctest: +SKIP
>>> manager = RapiConfigManager.discover("./apis/") # doctest: +SKIP
"""
search_dir = Path(directory) if directory else Path.cwd()
if not search_dir.exists():
raise FileNotFoundError(f"Directory not found: {search_dir}")
# Find all matching files
files = list(search_dir.glob(pattern))
if not files:
raise FileNotFoundError(f"No RAPI config files found matching '{pattern}' in {search_dir}")
log.info("Discovered %d RAPI config file(s) in %s", len(files), search_dir)
for f in files:
_log_trace(" - %s", f.name)
return cls.from_files(files, base_dir=search_dir)
@property
def credentials_config(self) -> dict[str, Any]:
"""Get inline credentials config extracted from ``*.rapi.yml`` files.
Returns:
Dictionary of credentials configurations.
"""
return self._credentials_config
@property
def source_files(self) -> list[Path]:
"""Get list of source files loaded.
Returns:
List of Path objects for loaded files.
"""
return self._source_files
@property
def safeguard_config(self) -> SafeguardConfig:
"""Get safeguard configuration.
Returns:
SafeguardConfig instance.
"""
return self._safeguard_config
[docs]
def resolve_server(self, server_name: str | None = None) -> ServerConfig:
"""Resolve a named server profile, merged with defaults.
If ``server_name`` is None, returns the defaults as a ServerConfig.
If ``server_name`` is given, merges ``rapi.servers.<name>`` on top
of ``rapi.defaults`` using deep merge (server values win).
Args:
server_name: Named server profile, or None for defaults.
Returns:
Resolved ServerConfig with merged values.
Raises:
ServerNotFoundError: If server_name is not in rapi.servers.
Examples:
>>> manager = load_rapi_config() # doctest: +SKIP
>>> server = manager.resolve_server("source") # doctest: +SKIP
>>> server.base_url # doctest: +SKIP
'https://viya-source.example.com'
"""
import copy
from kstlib.utils.dict import deep_merge
base = copy.deepcopy(self._defaults)
if server_name is None:
return ServerConfig(
name="defaults",
base_url=base.get("base_url", ""),
credentials=base.get("credentials", {}),
auth=base.get("auth"),
headers=base.get("headers", {}),
)
if server_name not in self._servers:
raise ServerNotFoundError(server_name, available=list(self._servers))
overrides = copy.deepcopy(self._servers[server_name])
merged = deep_merge(base, overrides)
return ServerConfig(
name=server_name,
base_url=merged.get("base_url", ""),
credentials=merged.get("credentials", {}),
auth=merged.get("auth"),
headers=merged.get("headers", {}),
)
@property
def server_names(self) -> list[str]:
"""Get list of configured server profile names.
Returns:
List of server names from rapi.servers section.
"""
return list(self._servers)
[docs]
def resolve_effective_server(
self,
api_config: ApiConfig,
endpoint_config: EndpointConfig,
runtime_server: str | None = None,
) -> ServerConfig | None:
"""Resolve the effective server profile for a given request.
Cascade priority (highest to lowest):
1. ``runtime_server`` (e.g. CLI ``--server`` flag or ``call(server=...)``)
2. ``endpoint_config.server`` (endpoint-level ``server:`` directive)
3. ``api_config.server`` (file-level ``server:`` directive)
4. None (caller falls back to the static ``api_config``)
Args:
api_config: API configuration for the called endpoint.
endpoint_config: Endpoint configuration for the called endpoint.
runtime_server: Optional runtime override (CLI flag or kwarg).
Returns:
Resolved ServerConfig if any cascade level provided a name,
otherwise None (caller should use the static ApiConfig).
Raises:
ServerNotFoundError: If the resolved name does not exist in
``rapi.servers``.
"""
name = runtime_server or endpoint_config.server or api_config.server
if name is None:
return None
return self.resolve_server(name)
def _validate_server_references(self) -> None:
"""Validate ``server:`` references in loaded ApiConfig/EndpointConfig.
Walks every loaded API and endpoint, collects every non-None
``server:`` value, and validates it against ``self._servers``:
- If ``self._servers`` is empty: log a warning per reference (the
user may add a ``servers:`` section later, this is not fatal).
- If ``self._servers`` is non-empty but the name is unknown:
raise :class:`ValueError` (strict, fail at load time).
Called explicitly from :func:`load_rapi_config` after both inline
APIs and included files are merged. Not called automatically by
:meth:`from_files` standalone use (which has no ``rapi.servers``
context); for that case, validation happens at request time via
:meth:`_resolve_effective_server`.
Raises:
ValueError: If a ``server:`` directive references an unknown
server profile while ``rapi.servers`` is configured.
"""
from kstlib.rapi.exceptions import ServerNotFoundError
servers_present = bool(self._servers)
known = set(self._servers)
def _check(label: str, server_name: str) -> None:
if servers_present:
if server_name not in known:
# Strict error: servers section is present but name is unknown
raise ServerNotFoundError(server_name, available=sorted(known))
else:
# Permissive warning: no servers section, may be added later
log.warning(
"%s declares server: %r but rapi.servers section is absent. "
"Add a servers: block to kstlib.conf.yml or remove the directive.",
label,
server_name,
)
for api_name, api_config in self._apis.items():
if api_config.server is not None:
_check(f"API '{api_name}' (file-level)", api_config.server)
for ep_name, endpoint in api_config.endpoints.items():
if endpoint.server is not None:
_check(
f"Endpoint '{api_name}.{ep_name}' (endpoint-level)",
endpoint.server,
)
def _load_apis(self) -> None:
"""Load API configurations from config."""
api_section = self._config.get("api", {})
for api_name, api_data in api_section.items():
if not isinstance(api_data, dict):
log.warning("Skipping invalid API config: %s", api_name)
continue
base_url = api_data.get("base_url", "")
if not base_url:
log.warning("API '%s' missing base_url, skipping", api_name)
continue
# Parse endpoints
endpoints: dict[str, EndpointConfig] = {}
endpoints_data = api_data.get("endpoints", {})
for ep_name, ep_data in endpoints_data.items():
if not isinstance(ep_data, dict):
log.warning("Skipping invalid endpoint: %s.%s", api_name, ep_name)
continue
method = ep_data.get("method", "GET").upper()
safeguard = ep_data.get("safeguard")
# Parse optional multipart config
multipart_data = ep_data.get("multipart")
multipart_config: MultipartConfig | None = None
if isinstance(multipart_data, dict):
multipart_config = MultipartConfig(
field_name=multipart_data.get("field_name", "file"),
content_type=multipart_data.get("content_type"),
)
elif multipart_data is True:
multipart_config = MultipartConfig()
endpoint = EndpointConfig(
name=ep_name,
api_name=api_name,
path=ep_data.get("path", f"/{ep_name}"),
method=method,
query=dict(ep_data.get("query", {})),
headers=dict(ep_data.get("headers", {})),
body_template=ep_data.get("body"),
auth=ep_data.get("auth", True),
safeguard=safeguard,
description=ep_data.get("description"),
multipart=multipart_config,
server=ep_data.get("server"),
)
# Validate safeguard requirement
if method in self._safeguard_config.required_methods and safeguard is None:
raise SafeguardMissingError(endpoint.full_ref, method)
endpoints[ep_name] = endpoint
# Index for short reference lookup
if ep_name not in self._endpoint_index:
self._endpoint_index[ep_name] = []
self._endpoint_index[ep_name].append(api_name)
_log_trace("Loaded endpoint: %s.%s", api_name, ep_name)
# Create API config
api_config = ApiConfig(
name=api_name,
base_url=base_url.rstrip("/"),
credentials=api_data.get("credentials"),
auth_type=api_data.get("auth_type"),
hmac_config=api_data.get("hmac_config"),
headers=dict(api_data.get("headers", {})),
endpoints=endpoints,
server=api_data.get("server"),
)
self._apis[api_name] = api_config
log.debug("Loaded API: %s (%d endpoints)", api_name, len(endpoints))
def _merge_apis(
self,
other: RapiConfigManager,
*,
overwrite: bool = False,
) -> None:
"""Merge APIs from another manager into this one.
Args:
other: Source manager to merge from.
overwrite: If True, overwrite existing APIs. If False, skip conflicts.
"""
for api_name, api_config in other.apis.items():
if api_name in self._apis and not overwrite:
self._handle_api_conflict(api_name, api_config, other)
continue
self._apis[api_name] = api_config
self._update_endpoint_index(api_name, api_config)
self._copy_endpoint_sources(api_name, api_config, other)
# Merge credentials
for cred_name, cred_config in other.credentials_config.items():
if cred_name not in self._credentials_config:
self._credentials_config[cred_name] = cred_config
def _handle_api_conflict(
self,
api_name: str,
api_config: ApiConfig,
other: RapiConfigManager,
) -> None:
"""Handle API name conflict during merge.
Args:
api_name: Name of the conflicting API.
api_config: The incoming API config.
other: Source manager to merge from.
"""
existing_endpoints = set(self._apis[api_name].endpoints.keys())
new_endpoints = set(api_config.endpoints.keys())
collisions = existing_endpoints & new_endpoints
if collisions:
log.warning(
"API '%s': %d endpoint(s) from included files conflict with inline config (keeping inline). "
"Colliding: %s",
api_name,
len(collisions),
", ".join(sorted(f"{api_name}.{ep}" for ep in collisions)),
)
for ep_name in collisions:
full_ref = f"{api_name}.{ep_name}"
sources = ["inline config"]
if full_ref in other._endpoint_sources:
sources.append(other._endpoint_sources[full_ref])
if self._strict:
raise EndpointCollisionError(full_ref, sources)
if not collisions:
log.warning(
"API '%s' defined in both inline config and included files (no endpoint overlap, keeping inline). "
"Consider removing one definition to avoid confusion.",
api_name,
)
def _update_endpoint_index(self, api_name: str, api_config: ApiConfig) -> None:
"""Update endpoint index for an API.
Args:
api_name: Name of the API.
api_config: The API configuration.
"""
for ep_name in api_config.endpoints:
if ep_name not in self._endpoint_index:
self._endpoint_index[ep_name] = []
if api_name not in self._endpoint_index[ep_name]:
self._endpoint_index[ep_name].append(api_name)
def _copy_endpoint_sources(
self,
api_name: str,
api_config: ApiConfig,
other: RapiConfigManager,
) -> None:
"""Copy endpoint source tracking from another manager.
Args:
api_name: Name of the API.
api_config: The API configuration.
other: Source manager to copy from.
"""
for ep_name in api_config.endpoints:
full_ref = f"{api_name}.{ep_name}"
if full_ref in other._endpoint_sources:
self._endpoint_sources[full_ref] = other._endpoint_sources[full_ref]
[docs]
def resolve(self, endpoint_ref: str) -> tuple[ApiConfig, EndpointConfig]:
"""Resolve endpoint reference to configuration.
Supports both full references (api.endpoint) and short references
(endpoint only). Short references are auto-resolved if the endpoint
name is unique across all APIs.
Args:
endpoint_ref: Full reference (api.endpoint) or short (endpoint).
Returns:
Tuple of (ApiConfig, EndpointConfig).
Raises:
EndpointNotFoundError: If endpoint cannot be found.
EndpointAmbiguousError: If short reference matches multiple APIs.
Examples:
>>> manager = RapiConfigManager({...}) # doctest: +SKIP
>>> api, endpoint = manager.resolve("httpbin.get_ip") # doctest: +SKIP
>>> api, endpoint = manager.resolve("get_ip") # doctest: +SKIP
"""
_log_trace("Resolving endpoint reference: %s", endpoint_ref)
if "." in endpoint_ref:
# Full reference: api.endpoint
return self._resolve_full(endpoint_ref)
# Short reference: endpoint only
return self._resolve_short(endpoint_ref)
def _resolve_full(self, endpoint_ref: str) -> tuple[ApiConfig, EndpointConfig]:
"""Resolve full reference (api.endpoint)."""
parts = endpoint_ref.split(".", 1)
if len(parts) != 2:
raise EndpointNotFoundError(endpoint_ref, list(self._apis))
api_name, endpoint_name = parts
if api_name not in self._apis:
raise EndpointNotFoundError(
endpoint_ref,
list(self._apis),
)
api_config = self._apis[api_name]
if endpoint_name not in api_config.endpoints:
raise EndpointNotFoundError(
endpoint_ref,
[api_name],
)
endpoint_config = api_config.endpoints[endpoint_name]
_log_trace("Resolved full reference: %s", endpoint_config.full_ref)
return api_config, endpoint_config
def _resolve_short(self, endpoint_name: str) -> tuple[ApiConfig, EndpointConfig]:
"""Resolve short reference (endpoint only, auto-resolve if unique)."""
if endpoint_name not in self._endpoint_index:
raise EndpointNotFoundError(endpoint_name, list(self._apis))
matching_apis = self._endpoint_index[endpoint_name]
if len(matching_apis) > 1:
raise EndpointAmbiguousError(endpoint_name, matching_apis)
api_name = matching_apis[0]
api_config = self._apis[api_name]
endpoint_config = api_config.endpoints[endpoint_name]
_log_trace(
"Resolved short reference '%s' to '%s'",
endpoint_name,
endpoint_config.full_ref,
)
return api_config, endpoint_config
[docs]
def get_api(self, api_name: str) -> ApiConfig | None:
"""Get API configuration by name.
Args:
api_name: API service name.
Returns:
ApiConfig or None if not found.
"""
return self._apis.get(api_name)
[docs]
def list_apis(self) -> list[str]:
"""List all configured API names.
Returns:
List of API names.
"""
return list(self._apis)
@property
def apis(self) -> dict[str, ApiConfig]:
"""Get all configured APIs.
Returns:
Dictionary mapping API names to ApiConfig objects.
"""
return self._apis
[docs]
def list_endpoints(self, api_name: str | None = None) -> list[str]:
"""List endpoint references.
Args:
api_name: Filter by API name (optional).
Returns:
List of full endpoint references.
"""
if api_name:
api = self._apis.get(api_name)
if not api:
return []
return [f"{api_name}.{ep}" for ep in api.endpoints]
# All endpoints
result: list[str] = []
for api in self._apis.values():
result.extend(f"{api.name}.{ep}" for ep in api.endpoints)
return result
def _parse_safeguard_config(rapi_section: dict[str, Any]) -> SafeguardConfig:
"""Parse safeguard configuration from rapi section.
Args:
rapi_section: The 'rapi' section from configuration.
Returns:
SafeguardConfig instance.
"""
safeguard_data = rapi_section.get("safeguard", {})
if not safeguard_data:
return SafeguardConfig()
required_methods = safeguard_data.get("required_methods")
if required_methods is None:
return SafeguardConfig()
# Convert list to frozenset, uppercase all methods
methods = frozenset(m.upper() for m in required_methods)
return SafeguardConfig(required_methods=methods)
def _validate_defaults_section(defaults: dict[str, Any]) -> None:
"""Validate the rapi.defaults section.
Args:
defaults: The defaults dict from rapi config.
Raises:
ValueError: If defaults contain invalid values.
"""
defaults_url = defaults.get("base_url")
if defaults_url and isinstance(defaults_url, str):
_validate_base_url(defaults_url, context="rapi.defaults")
defaults_auth = defaults.get("auth")
if defaults_auth and defaults_auth not in _ALLOWED_AUTH_TYPES:
raise ValueError(
f"Invalid auth type in rapi.defaults: {defaults_auth!r}. Allowed: {sorted(_ALLOWED_AUTH_TYPES)}"
)
def _parse_servers_section(
servers_raw: dict[str, Any] | None,
) -> dict[str, dict[str, Any]]:
"""Parse and validate the rapi.servers section.
Args:
servers_raw: Raw servers dict from config (may be None).
Returns:
Validated and env-expanded server profiles dict.
Raises:
ValueError: If servers section exceeds limits or contains invalid data.
"""
if not servers_raw or not isinstance(servers_raw, dict):
return {}
if len(servers_raw) > _MAX_SERVERS:
raise ValueError(f"Too many server profiles: {len(servers_raw)} > {_MAX_SERVERS}")
servers: dict[str, dict[str, Any]] = {}
for name, profile in servers_raw.items():
if not isinstance(profile, dict):
log.warning("Skipping invalid server profile: %s", name)
continue
_validate_server_name(name)
expanded = _expand_env_vars_recursive(profile)
_validate_server_profile(name, expanded)
servers[name] = expanded
if servers:
log.debug("Found %d server profile(s): %s", len(servers), list(servers))
return servers
[docs]
def load_rapi_config() -> RapiConfigManager:
"""Load RAPI configuration from kstlib.conf.yml with include support.
Supports including external ``*.rapi.yml`` files via glob patterns,
and a ``defaults`` section that is inherited by included files:
.. code-block:: yaml
rapi:
# Strict mode: error on endpoint collisions (default: false = warn only)
strict: true
# Defaults inherited by all included *.rapi.yml files
defaults:
base_url: "https://${VIYA_HOST}"
credentials:
type: file
path: ~/.sas/credentials.json
token_path: ".Default['access-token']"
auth: bearer
headers:
Accept: application/json
include:
- "./apis/*.rapi.yml"
- "~/.config/kstlib/*.rapi.yml"
safeguard:
required_methods:
- DELETE
api:
httpbin:
base_url: "https://httpbin.org"
# ...
With defaults, included files can be minimal:
.. code-block:: yaml
# annotations.rapi.yml
name: annotations
headers:
Accept: application/vnd.sas.annotation+json
endpoints:
root:
path: /annotations/
method: GET
Returns:
Configured RapiConfigManager instance with merged configs.
Examples:
>>> manager = load_rapi_config() # doctest: +SKIP
"""
from kstlib.config import get_config
config = get_config()
rapi_section = dict(config.get("rapi", {})) # type: ignore[no-untyped-call]
log.debug("Loading RAPI config from kstlib.conf.yml")
# Extract strict mode (default: False = warn on collisions)
strict = rapi_section.pop("strict", False)
if strict:
log.debug("Strict mode enabled: endpoint collisions will raise errors")
# Extract defaults for included files
defaults = rapi_section.pop("defaults", None)
if defaults:
# Sanitization invariant: only log the KEYS of rapi.defaults, never
# dict(defaults) directly. Values may carry resolved env-var content
# (e.g. base_url with credentials inline, headers with secrets).
log.debug("Found rapi.defaults section with keys: %s", list(defaults.keys()))
_validate_defaults_section(defaults)
# Extract named server profiles (optional)
servers = _parse_servers_section(rapi_section.pop("servers", None))
# Process includes if present
include_patterns = rapi_section.pop("include", None)
# Parse safeguard config
safeguard_config = _parse_safeguard_config(rapi_section)
# Create manager for inline config first
manager = RapiConfigManager(rapi_section, safeguard_config=safeguard_config, strict=strict)
# Store defaults and servers for resolve_server()
if defaults:
manager._defaults = _expand_env_vars_recursive(defaults)
if servers:
manager._servers = servers
# Merge included files if any
n_includes = 0
if include_patterns:
included_files = _resolve_include_patterns(include_patterns)
n_includes = len(included_files)
if included_files:
log.info("Including %d external RAPI config file(s)", n_includes)
included_manager = RapiConfigManager.from_files(
included_files,
safeguard_config=safeguard_config,
defaults=defaults,
strict=strict,
)
# Merge included APIs (inline config takes precedence)
manager._merge_apis(included_manager, overwrite=False)
# Validate any server: directives now that defaults + servers + includes
# are all in place. Strict error if servers: is configured but a name is
# unknown; permissive warning if servers: section is absent.
manager._validate_server_references()
# Final synthesis : single user-facing line that recaps the whole load
# so operators can see the effective endpoint count without scanning
# the per-endpoint TRACE stream. Replaces the old "scan 1300+ DEBUG
# to know how many endpoints loaded" workflow.
total_endpoints = sum(len(api.endpoints) for api in manager._apis.values())
log.info(
"Loaded %d endpoint(s) across %d API(s) from kstlib.conf.yml + %d include file(s)",
total_endpoints,
len(manager._apis),
n_includes,
)
return manager
def _resolve_include_patterns(patterns: list[str] | str) -> list[Path]:
"""Resolve include patterns to file paths.
Args:
patterns: Glob pattern or list of patterns.
Returns:
List of resolved file paths.
"""
if isinstance(patterns, str):
patterns = [patterns]
files: list[Path] = []
for pattern in patterns:
expanded = Path(pattern).expanduser()
if expanded.is_absolute():
matches = list(expanded.parent.glob(expanded.name))
else:
matches = list(Path.cwd().glob(pattern))
files.extend(matches)
return files
__all__ = [
"ApiConfig",
"EndpointConfig",
"HmacConfig",
"MultipartConfig",
"RapiConfigManager",
"SafeguardConfig",
"ServerConfig",
"load_rapi_config",
]