Source code for kstlib.transform.chain

"""Transform chain engine for kstlib.transform.

Chains primitives (decode, decompress, parse, patch, serialize,
compress, encode) declared in YAML configuration. Supports preset
inheritance, replace/callable patching with blob/outer/all scopes,
composed patches with filters, and lossless round-trip via stored
envelopes.
"""

from __future__ import annotations

import fnmatch
import importlib
import logging
import threading
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from xml.etree.ElementTree import Element

if TYPE_CHECKING:
    from collections.abc import Callable, Mapping

from kstlib.transform.config import (
    ComposedPatchConfig,
    FilterConfig,
    PatchConfig,
    PrimitiveConfig,
    TransformChainConfig,
    TransformConfig,
    load_transform_config,
)
from kstlib.transform.exceptions import (
    CallableError,
    CallableImportError,
    PatchError,
    TransformChainError,
    TransformConfigError,
)
from kstlib.transform.primitives import (
    BACKWARD_DISPATCH,
    FORWARD_DISPATCH,
    json_serialize,
    xml_parse,
    xml_serialize,
)
from kstlib.transform.validators import (
    CALLABLE_TIMEOUT,
    MAX_VARIABLE_REFS,
    VARIABLE_PATTERN,
)

#: Cached PrimitiveConfig for internal XML serialize/parse operations.
#: Avoids re-instantiation on every call.
_DEFAULT_XML_PRIM: PrimitiveConfig = PrimitiveConfig(name="xml")

log = logging.getLogger(__name__)


def _log_trace(msg: str, *args: object) -> None:
    """Log at TRACE level (custom level 5, below DEBUG)."""
    from kstlib.logging import TRACE_LEVEL

    log.log(TRACE_LEVEL, msg, *args)


#: Modules that are ALWAYS rejected by callable patches, regardless of
#: the user-configured ``allowed_callable_modules`` whitelist.
#: These modules provide unrestricted OS or interpreter access and must
#: never be importable via a YAML-driven patch config.
DANGEROUS_MODULES: frozenset[str] = frozenset(
    {
        "os",
        "sys",
        "subprocess",
        "builtins",
        "ctypes",
        "posix",
        "nt",
        "shutil",
        "importlib",
        "pickle",
        "marshal",
        "code",
        "compile",
        "__main__",
    }
)

# Pattern used to auto-coerce bytes->str between zlib and json
_AUTO_COERCE_PAIRS: frozenset[tuple[str, str]] = frozenset(
    {
        ("zlib", "json"),
        ("bytes", "json"),
        ("zlib", "xml"),
        ("bytes", "xml"),
    }
)

#: Default protected paths for :func:`replace_outer_uris`.
#:
#: Each entry is a dotted-path string with ``[*]`` matching any list
#: index. The default blacklist protects ``connectors[*].hints.xpath``
#: because SAS Viya stores BIRD XPath pointers there: patching them
#: would silently corrupt the wrapper-to-content coherence.
#:
#: The set is exposed as a public constant so callers can extend it
#: (or replace it) by passing a custom ``protected_paths`` to
#: :func:`replace_outer_uris`.
PROTECTED_OUTER_PATHS: frozenset[str] = frozenset(
    {
        "connectors[*].hints.xpath",
    }
)


def _parse_path_pattern(pattern: str) -> tuple[str, ...]:
    """Parse a dotted path with ``[*]`` markers into a tuple of segments.

    The wildcard ``"*"`` is emitted in place of every ``[*]`` marker, so
    list indices in the walked path are matched against ``"*"`` for
    structural equality.

    Examples:
        >>> _parse_path_pattern("connectors[*].hints.xpath")
        ('connectors', '*', 'hints', 'xpath')
        >>> _parse_path_pattern("a.b.c")
        ('a', 'b', 'c')
        >>> _parse_path_pattern("items[*]")
        ('items', '*')

    """
    parts: list[str] = []
    for segment in pattern.split("."):
        remaining = segment
        while "[*]" in remaining:
            head, _, remaining = remaining.partition("[*]")
            if head:
                parts.append(head)
            parts.append("*")
        if remaining:
            parts.append(remaining)
    return tuple(parts)


# Pre-parsed default protected paths for replace_outer_uris.
# Avoids re-parsing on every call when using the defaults.
_PROTECTED_OUTER_PATHS_PARSED: frozenset[tuple[str, ...]] = frozenset(
    _parse_path_pattern(p) for p in PROTECTED_OUTER_PATHS
)


def _replace_in_string(value: str, replace_map: Mapping[str, str]) -> str:
    """Apply every ``(old, new)`` pair from ``replace_map`` to ``value``."""
    new = value
    for old, replacement in replace_map.items():
        if old in new:
            new = new.replace(old, replacement)
    return new


#: Maximum recursion depth for :func:`_walk_node`. Protects against
#: pathological nesting in attacker-controlled metadata.
_MAX_WALK_DEPTH: int = 64


def _walk_node(
    node: Any,
    path: tuple[str, ...],
    replace_map: Mapping[str, str],
    protected: frozenset[tuple[str, ...]],
    state: dict[str, int],
    *,
    depth: int = 0,
) -> Any:
    """Recursively walk a JSON-like node and apply replacements in place.

    Args:
        node: Current node in the JSON tree.
        path: Tuple of keys/indices leading to this node.
        replace_map: Substring replacements to apply.
        protected: Set of parsed path patterns to skip.
        state: Mutable dict with ``"counter"`` tracking modified strings.
        depth: Current recursion depth (for stack overflow protection).

    Returns:
        The (possibly modified) node.

    Raises:
        PatchError: If recursion depth exceeds ``_MAX_WALK_DEPTH``.

    """
    if depth > _MAX_WALK_DEPTH:
        raise PatchError(
            f"replace_outer_uris recursion depth exceeds {_MAX_WALK_DEPTH}",
            primitive_name="patch",
        )
    if path in protected:
        return node
    if isinstance(node, dict):
        for k, v in node.items():
            node[k] = _walk_node(v, (*path, str(k)), replace_map, protected, state, depth=depth + 1)
        return node
    if isinstance(node, list):
        new_path = (*path, "*")
        for i, v in enumerate(node):
            node[i] = _walk_node(v, new_path, replace_map, protected, state, depth=depth + 1)
        return node
    if isinstance(node, str):
        new = _replace_in_string(node, replace_map)
        if new != node:
            state["counter"] += 1
        return new
    return node


[docs] def replace_outer_uris( obj: Any, replace_map: Mapping[str, str], *, protected_paths: frozenset[str] = PROTECTED_OUTER_PATHS, additional_protected_paths: frozenset[str] | None = None, ) -> int: """Recursively patch string values in a JSON-like object, in place. Walks the object tree and applies ``str.replace(old, new)`` for every entry of ``replace_map`` to every string value, skipping any path that matches ``protected_paths``. The object is mutated in place. Path syntax: each protected path is a dotted string. ``[*]`` matches any list index. Dict keys are matched literally. For example, ``"connectors[*].hints.xpath"`` matches ``obj["connectors"][i]["hints"]["xpath"]`` for every ``i``. .. note:: Keys containing a literal ``"."`` cannot be expressed in the dotted-path syntax and are therefore not protectable via ``protected_paths``. This is a known limitation. This helper is meant to be called from caller code that knows about the wrapper structure (e.g. SAS Viya transfer packages where the BIRD XML lives inside an encoded blob but ``connectors[].uri`` and ``connectors[].hints.orig-uri`` live in the outer JSON wrapper). Use it together with ``PatchConfig(scope='outer')`` or ``scope='all'``. Args: obj: JSON-like nested structure (dict, list, str, int, ...). Mutated in place. Non-string scalars are returned unchanged. replace_map: Mapping of ``old -> new`` substring replacements. Replacements are applied in iteration order. protected_paths: Dotted-path patterns that must NOT be patched. Defaults to :data:`PROTECTED_OUTER_PATHS`. additional_protected_paths: Extra patterns merged with ``protected_paths``. Provides an additive API so callers can extend the defaults without accidentally wiping them. Returns: Total number of string values that were modified. Examples: >>> wrapper = {"connectors": [{"uri": "library=CASUSER", "hints": {"xpath": "/foo/CASUSER"}}]} >>> replace_outer_uris(wrapper, {"CASUSER": "PUBLIC"}) 1 >>> wrapper["connectors"][0]["uri"] 'library=PUBLIC' >>> wrapper["connectors"][0]["hints"]["xpath"] '/foo/CASUSER' """ if not replace_map: return 0 # Use pre-parsed cache when defaults are unchanged if protected_paths is PROTECTED_OUTER_PATHS and not additional_protected_paths: parsed = _PROTECTED_OUTER_PATHS_PARSED else: all_paths = protected_paths if additional_protected_paths: all_paths = protected_paths | additional_protected_paths parsed = frozenset(_parse_path_pattern(p) for p in all_paths) state = {"counter": 0} _walk_node(obj, (), replace_map, parsed, state) return state["counter"]
@dataclass(slots=True) class _ChainContext: """Internal state kept during a forward->backward round-trip. Stores JSON envelopes extracted during forward for lossless restoration during backward. """ json_envelopes: dict[str, Any] = field(default_factory=dict) def _auto_reverse(forward: tuple[PrimitiveConfig, ...]) -> tuple[PrimitiveConfig, ...]: """Generate backward chain by reversing forward primitives. Args: forward: Forward primitive chain. Returns: Reversed backward chain with swapped options. Raises: TransformConfigError: If a primitive cannot be auto-reversed. """ backward: list[PrimitiveConfig] = [] for prim in reversed(forward): match prim.name: case "base64": backward.append(PrimitiveConfig(name="base64")) case "bytes": backward.append(PrimitiveConfig(name="bytes", options=dict(prim.options))) case "xml": backward.append(PrimitiveConfig(name="xml", options=dict(prim.options))) case "zlib": if prim.options.get("skip_bytes"): raise TransformConfigError( "zlib with skip_bytes requires explicit backward with prepend_bytes. " "Auto-reverse cannot recover the skipped header bytes." ) backward.append(PrimitiveConfig(name="zlib")) case "json": opts: dict[str, Any] = {} if extract := prim.options.get("extract"): opts["wrap"] = extract backward.append(PrimitiveConfig(name="json", options=opts)) case _: raise TransformConfigError(f"Cannot auto-reverse primitive: {prim.name!r}") return tuple(backward) def _resolve_preset( chain_config: TransformChainConfig, all_chains: dict[str, TransformChainConfig], ) -> TransformChainConfig: """Resolve preset inheritance, returning a standalone chain config. Args: chain_config: Chain config with preset reference. all_chains: All available chain configs. Returns: New TransformChainConfig with forward/backward from preset and patch/composed_patch from child. """ preset = all_chains[chain_config.preset] # type: ignore[index] return TransformChainConfig( name=chain_config.name, forward=preset.forward, backward=preset.backward, patch=chain_config.patch, composed_patch=chain_config.composed_patch, preset=None, ) def _matches_filter(metadata: Mapping[str, Any], filter_config: FilterConfig) -> bool: """Check if object metadata matches a FilterConfig. All fields are ANDed: an object matches only if every non-wildcard field of the filter matches the corresponding metadata key. Args: metadata: Object metadata dict (typically with ``content_type`` and ``name`` keys). Missing keys default to empty string. filter_config: Filter to evaluate. Returns: True if metadata matches the filter, False otherwise. Examples: >>> _matches_filter({"content_type": "report", "name": "R220_X"}, ... FilterConfig(content_type="report", name="R220_*")) True >>> _matches_filter({"content_type": "folder", "name": "R220_X"}, ... FilterConfig(content_type="report")) False """ content_type = str(metadata.get("content_type", "")) name = str(metadata.get("name", "")) if filter_config.content_type not in ("*", content_type): return False return filter_config.name == "*" or fnmatch.fnmatchcase(name, filter_config.name) def _resolve_variables( args: dict[str, Any], context: Mapping[str, Any], ) -> dict[str, Any]: """Replace {{variable}} patterns in args values from context. Args: args: Arguments dict with possible {{variable}} references. context: Variable store for resolution. Returns: New dict with variables resolved. Raises: TransformChainError: If variable not found in context. """ resolved: dict[str, Any] = {} total_refs = 0 for key, raw_value in args.items(): result_value = raw_value if isinstance(result_value, str) and "{{" in result_value: matches = VARIABLE_PATTERN.findall(result_value) total_refs += len(matches) if total_refs > MAX_VARIABLE_REFS: raise TransformChainError(f"Too many variable references ({total_refs} > {MAX_VARIABLE_REFS})") # Validate all variables exist before substituting for var_name in matches: if var_name not in context: raise TransformChainError(f"Variable '{{{{{var_name}}}}}' not found in context") # Single-pass substitution via re.sub callback def _replace_var(m: Any) -> str: return str(context[m.group(1)]) result_value = VARIABLE_PATTERN.sub(_replace_var, result_value) resolved[key] = result_value return resolved
[docs] class TransformChain: """Execute a chain of transform primitives with optional patching. .. warning:: ``TransformChain`` instances are **not reentrant**. Do not call ``transform()`` concurrently from multiple threads on the same instance. Each call resets internal state (``_chain_context``). Create one instance per thread if concurrent execution is needed. Args: config: Resolved chain configuration (no preset references). context: Optional external context for {{variable}} resolution. Examples: >>> from kstlib.transform.config import PrimitiveConfig, TransformChainConfig >>> chain = TransformChain(TransformChainConfig( ... name="test", ... forward=(PrimitiveConfig(name="base64"),), ... )) >>> chain.forward("SGVsbG8=") b'Hello' """
[docs] def __init__( self, config: TransformChainConfig, *, context: Mapping[str, Any] | None = None, transform_config: TransformConfig | None = None, allowed_modules: frozenset[str] | None = None, ) -> None: """Initialize TransformChain. Args: config: Resolved chain configuration. context: External context for variable resolution. transform_config: Top-level config used to resolve chain references in ``composed_patch``. Required when ``config.composed_patch`` is set. allowed_modules: Whitelist of allowed callable module prefixes. When ``None`` (direct construction without ``from_config``), any callable patch is rejected (fail-closed). Pass an explicit frozenset to allow specific modules. Raises: TransformConfigError: If ``composed_patch`` is set but ``transform_config`` was not provided. """ self._config = config self._context = context or {} self._transform_config = transform_config self._allowed_modules = allowed_modules self._chain_context = _ChainContext() self._callable_lock = threading.Lock() if config.composed_patch is not None and transform_config is None: raise TransformConfigError( f"Chain '{config.name}': composed_patch requires a transform_config " f"to resolve chain references. Use TransformChain.from_config()." ) # Resolve backward if not explicit if config.backward is None: self._backward = _auto_reverse(config.forward) else: self._backward = config.backward log.debug( "TransformChain '%s': %d forward, %d backward, patch=%s, composed=%s", config.name, len(config.forward), len(self._backward), "yes" if config.patch else "no", "yes" if config.composed_patch else "no", )
[docs] @classmethod def from_config( cls, name: str, transform_config: TransformConfig, *, context: Mapping[str, Any] | None = None, ) -> TransformChain: """Create a TransformChain from a named config entry. Resolves presets and returns a ready-to-use chain. Args: name: Chain name to look up in transform_config. transform_config: Top-level TransformConfig. context: External context for variable resolution. Returns: Configured TransformChain. Raises: TransformChainError: If chain name not found. """ if name not in transform_config.chains: available = sorted(transform_config.chains) raise TransformChainError( f"Chain '{name}' not found. Available: {available}", chain_name=name, ) chain_config = transform_config.chains[name] # Resolve preset if needed if chain_config.preset is not None: chain_config = _resolve_preset(chain_config, transform_config.chains) return cls( chain_config, context=context, transform_config=transform_config, allowed_modules=transform_config.allowed_callable_modules, )
[docs] def forward(self, data: Any) -> Any: """Apply forward primitives in order. Args: data: Input data (typically base64 string). Returns: Decoded/parsed data ready for patching. Raises: TransformChainError: If any primitive fails. """ self._chain_context = _ChainContext() current = data for i, prim in enumerate(self._config.forward): # Per-primitive trace (firehose-level detail). Sizes / timings # are intentionally not measured here : data may be a non-sized # object (Element tree) and computing len() defensively across # types would add complexity for little diagnostic value. _log_trace( "Chain '%s' forward[%d]: %s", self._config.name, i, prim.name, ) # Auto-coerce bytes -> str if needed if i > 0 and isinstance(current, bytes): prev_name = self._config.forward[i - 1].name if (prev_name, prim.name) in _AUTO_COERCE_PAIRS: current = current.decode("utf-8") func = FORWARD_DISPATCH[prim.name] if prim.name == "json": # json_parse returns (value, envelope) value, envelope = func(current, prim) extract_path = prim.options.get("extract") if extract_path and envelope is not None: self._chain_context.json_envelopes[extract_path] = envelope current = value else: current = func(current, prim) return current
[docs] def backward(self, data: Any) -> Any: """Apply backward primitives in order. Uses stored envelopes from forward for lossless JSON restoration. Args: data: Data to re-encode (typically patched XML string or Element). Returns: Re-encoded data (same format as original input). Raises: TransformChainError: If any primitive fails. """ current = data for i, prim in enumerate(self._backward): _log_trace( "Chain '%s' backward[%d]: %s", self._config.name, i, prim.name, ) func = BACKWARD_DISPATCH[prim.name] if prim.name == "json": wrap_path = prim.options.get("wrap") envelope = self._chain_context.json_envelopes.get(wrap_path or "") current = json_serialize(current, prim, envelope=envelope) elif prim.name in ("zlib", "base64", "bytes"): # Auto-coerce str -> bytes if needed if isinstance(current, str) and prim.name in ("zlib", "bytes"): current = current.encode("utf-8") current = func(current, prim) else: current = func(current, prim) return current
[docs] def patch( self, data: Any, *, metadata: Mapping[str, Any] | None = None, ) -> Any: """Apply patch to decoded data. If the chain has an inline ``patch``, applies it directly. If the chain has a ``composed_patch``, applies the global patches then the targeted patches whose filter matches ``metadata`` (in declaration order, last applied wins). Args: data: Decoded data from forward. metadata: Object metadata. Used for filter matching in composed patches (typical keys: ``content_type``, ``name``). May also carry ``"outer"`` referencing the JSON wrapper to mutate when the patch declares ``scope: outer`` or ``scope: all``. Returns: Patched data. Raises: PatchError: If patching fails. CallableError: If a callable patch raises. CallableImportError: If a callable cannot be imported. """ if self._config.composed_patch is not None: return self._apply_composed_patch(data, metadata or {}) if self._config.patch is None: return data return self._apply_patch_config(data, self._config.patch, metadata)
[docs] def transform( self, data: Any, *, metadata: Mapping[str, Any] | None = None, ) -> Any: """Full round-trip: forward -> patch -> backward. This is the main entry point for most use cases. Args: data: Raw input data. metadata: Object metadata used for filter matching in composed patches. Ignored for inline ``patch``. Returns: Transformed data (same format as input). Raises: TransformChainError: If any stage fails. """ import time log.debug("Chain '%s': starting transform", self._config.name) start = time.monotonic() decoded = self.forward(data) patched = self.patch(decoded, metadata=metadata) result = self.backward(patched) log.debug( "Chain '%s': transform complete (took=%.3fs)", self._config.name, time.monotonic() - start, ) return result
def _apply_patch_config( self, data: Any, patch_config: PatchConfig, metadata: Mapping[str, Any] | None = None, ) -> Any: """Dispatch a single PatchConfig to replace or callable. Args: data: Data to patch. patch_config: Patch configuration to apply. metadata: Object metadata. Required (with key ``"outer"``) when ``patch_config.scope`` is ``"outer"`` or ``"all"``. Returns: Patched data (or unchanged if patch is a no-op). """ if patch_config.replace is not None: return self._apply_replace_with_scope(data, patch_config, metadata) if patch_config.callable is not None: return self._apply_callable(data, patch_config) return data def _apply_composed_patch( self, data: Any, metadata: Mapping[str, Any], ) -> Any: """Apply global_patches then matching targeted_patches in order. When the decoded data is an XML ``Element``, the method serializes it to a string once before the patch loop and re-parses once after all patches have been applied. This avoids K full XML round-trips when K patches match (performance fix). Args: data: Decoded data to patch. metadata: Object metadata used for filter evaluation. Also threaded to ``_apply_patch_config`` for ``scope: outer`` and ``scope: all`` patches that need ``metadata['outer']``. Returns: Patched data after the full cascade. """ assert self._config.composed_patch is not None assert self._transform_config is not None # checked in __init__ composed: ComposedPatchConfig = self._config.composed_patch all_chains = self._transform_config.chains # Optimization: serialize Element to string once before the # patch loop so that N replace patches only do string ops. # Re-parse to Element once at the end. was_element = isinstance(data, Element) result: Any = xml_serialize(data, _DEFAULT_XML_PRIM) if was_element else data # 1. global_patches - always applied for ref_name in composed.global_patches: ref_chain = all_chains[ref_name] if ref_chain.patch is not None: log.debug( "Chain '%s': applying global patch '%s'", self._config.name, ref_name, ) result = self._apply_patch_config(result, ref_chain.patch, metadata) # 2. targeted_patches - applied if filter matches for targeted in composed.targeted_patches: if not _matches_filter(metadata, targeted.filter): continue for ref_name in targeted.patches: ref_chain = all_chains[ref_name] if ref_chain.patch is not None: log.debug( "Chain '%s': applying targeted patch '%s' (matched filter)", self._config.name, ref_name, ) result = self._apply_patch_config(result, ref_chain.patch, metadata) # Re-parse to Element if the original data was Element and # the result is still a string (callable patches may return # a different type). if was_element and isinstance(result, str): result = xml_parse(result, _DEFAULT_XML_PRIM) return result def _apply_replace_with_scope( self, data: Any, patch: PatchConfig, metadata: Mapping[str, Any] | None, ) -> Any: """Apply ``patch.replace`` according to ``patch.scope``. - ``scope='blob'``: patch the decoded data only (str or Element). - ``scope='outer'``: patch ``metadata['outer']`` in place via :func:`replace_outer_uris`. Data flows through unchanged. - ``scope='all'``: do both, outer first then blob. Args: data: Decoded data from forward (only relevant for blob/all). patch: PatchConfig with a non-None ``replace`` mapping. metadata: Object metadata. Must contain ``"outer"`` when scope is ``outer`` or ``all``. Returns: Patched data (unchanged if scope is ``outer`` only). Raises: PatchError: If scope requires ``metadata['outer']`` but it is missing, or if data type is unsupported for ``blob`` / ``all`` scopes. """ assert patch.replace is not None scope = patch.scope if scope in ("outer", "all"): outer = (metadata or {}).get("outer") if outer is None: raise PatchError( f"PatchConfig with scope={scope!r} requires " f"metadata['outer'] to be set on chain.transform() " f"or chain.patch()", primitive_name="patch", chain_name=self._config.name, ) replace_outer_uris(outer, patch.replace) if scope in ("blob", "all"): return self._apply_replace_to_data(data, patch.replace) # scope == "outer": data is unchanged, side effect on outer only. return data def _apply_replace_to_data( self, data: Any, replace_map: Mapping[str, str], ) -> Any: """Apply string replacement to decoded blob data (str or Element).""" if isinstance(data, str): return _replace_in_string(data, replace_map) if isinstance(data, Element): xml_str = xml_serialize(data, _DEFAULT_XML_PRIM) xml_str = _replace_in_string(xml_str, replace_map) return xml_parse(xml_str, _DEFAULT_XML_PRIM) raise PatchError( f"Cannot apply replace to {type(data).__name__}; expected str or Element", primitive_name="patch", chain_name=self._config.name, ) def _apply_callable(self, data: Any, patch: PatchConfig) -> Any: """Import and call a patch function with a hard wall-clock timeout. Security enforcement (defense-in-depth): 1. If ``allowed_modules`` was not provided at construction time (direct ``TransformChain(...)`` without ``from_config``), ALL callable patches are rejected (fail-closed). 2. The module is checked against :data:`DANGEROUS_MODULES` regardless of the whitelist. 3. The module must match the ``allowed_modules`` whitelist. """ assert patch.callable is not None target = patch.callable module_path, _, func_name = target.rpartition(":") if not module_path or not func_name: raise CallableImportError(target, chain_name=self._config.name) # Fail-closed: no whitelist means no callable allowed if self._allowed_modules is None: raise CallableError( target, "Callable patches require allowed_callable_modules config. " "Use TransformChain.from_config() or pass allowed_modules explicitly.", chain_name=self._config.name, ) # Hardcoded blacklist: always rejected root_module = module_path.split(".")[0] if root_module in DANGEROUS_MODULES: raise CallableError( target, f"Module '{module_path}' is in DANGEROUS_MODULES blacklist and cannot be used as a callable patch", chain_name=self._config.name, ) # Whitelist check (same logic as validate_callable_module) if not any( module_path == allowed or module_path.startswith(f"{allowed}.") for allowed in self._allowed_modules ): raise CallableError( target, f"Module '{module_path}' is not in allowed_callable_modules", chain_name=self._config.name, ) try: module = importlib.import_module(module_path) func = getattr(module, func_name) except (ImportError, AttributeError) as exc: raise CallableImportError(target, chain_name=self._config.name) from exc # Post-import check: verify the resolved module matches the # expected path. Protects against sys.modules tampering where # a malicious entry could redirect an import to a different module. actual_name = getattr(module, "__name__", "") if actual_name != module_path: raise CallableError( target, f"Module identity mismatch: expected '{module_path}', " f"got '{actual_name}' (possible sys.modules tampering)", chain_name=self._config.name, ) resolved_args = _resolve_variables(patch.args, self._context) return self._run_callable_with_timeout(func, data, resolved_args, target) def _run_callable_with_timeout( self, fn: Callable[..., Any], data: Any, resolved_args: dict[str, Any], target_label: str, ) -> Any: """Run a callable in a daemon thread with a hard wall-clock timeout. The timeout enforcement uses ``threading.Thread(daemon=True)`` plus ``Thread.join(timeout)``. This is the only practical way to add wall-clock cancellation to arbitrary Python callables without relying on ``signal`` (main-thread only) or ``multiprocessing`` (too heavy and incompatible with shared in-memory state). The timeout value is :data:`kstlib.transform.validators.CALLABLE_TIMEOUT` (30 seconds by default). Tests can patch the module-level constant with ``monkeypatch`` to use a smaller value. Limitations (honest disclosure): 1. **Orphaned threads on timeout** keep running until natural completion or process exit. Python has no public API to kill a thread cleanly. The worker is marked ``daemon=True`` so it does not block process exit, but it still consumes resources (CPU, memory, file handles) until it finishes on its own. 2. **Blocking syscalls cannot be interrupted** mid-flight. If the callable is stuck in ``time.sleep``, ``socket.recv``, a long CPython bytecode loop without GIL release, or any blocking OS call, the timeout fires (the join returns) but the worker keeps running. The timeout is wall-clock only; it does not cancel the work itself. 3. **Resources held by a timed-out callable may leak** until the process exits. Open files, network connections, and locks held by the orphaned thread will not be released early. 4. **Mutable shared state can race**: a timed-out callable that kept writing to shared structures may continue mutating them after the timeout error has been raised in the caller. Use callables that are short, idempotent, side-effect-free, and do not hold external resources. For long-running or risky operations, prefer running them out-of-process via a pipeline shell step with its own timeout supervisor. Args: fn: Resolved callable (already imported). data: First positional argument passed to ``fn``. resolved_args: Keyword arguments passed to ``fn`` after ``{{variable}}`` resolution. target_label: ``module:function`` string used in error messages and the worker thread name. Returns: The callable's return value. Raises: CallableError: If the callable raises an exception or if it does not complete within ``CALLABLE_TIMEOUT`` seconds. """ # Reject reentrant calls: if a previous callable timed out and # its orphan thread is still accessing _chain_context, running # another callable risks data races. if not self._callable_lock.acquire(blocking=False): raise CallableError( target_label, "Reentrant callable call rejected (previous call still in progress)", chain_name=self._config.name, ) try: result_box: list[Any] = [] error_box: list[BaseException] = [] def _worker() -> None: try: result_box.append(fn(data, **resolved_args)) except BaseException as exc: error_box.append(exc) worker = threading.Thread( target=_worker, name=f"transform-callable-{target_label}", daemon=True, ) worker.start() worker.join(CALLABLE_TIMEOUT) if worker.is_alive(): # Timed out: the worker will keep running in the background # until it terminates on its own (or until process exit, since # the thread is marked daemon=True). log.warning( "Callable '%s' timed out after %.1fs (orphaned thread continues)", target_label, CALLABLE_TIMEOUT, ) raise CallableError( target_label, f"timed out after {CALLABLE_TIMEOUT}s", chain_name=self._config.name, ) if error_box: raise CallableError( target_label, str(error_box[0]), chain_name=self._config.name, ) from error_box[0] finally: self._callable_lock.release() return result_box[0]
# ============================================================================ # Module-level convenience function # ============================================================================
[docs] def transform( data: Any, chain_name: str, config: TransformConfig | None = None, context: dict[str, Any] | None = None, *, metadata: Mapping[str, Any] | None = None, ) -> Any: """Apply a named transform chain to data. Convenience function for use in CallableStep pipelines. Loads config from kstlib.conf.yml if not provided. Args: data: Raw input data. chain_name: Name of the transform chain to apply. config: Transform config (loads from kstlib.conf.yml if None). context: Variables for {{variable}} resolution in callable args. metadata: Object metadata for filter matching in composed patches (typical keys: ``content_type``, ``name``). Returns: Transformed data. Examples: >>> transform("SGVsbG8=", "my_chain") # doctest: +SKIP """ if config is None: config = load_transform_config() chain = TransformChain.from_config(chain_name, config, context=context) return chain.transform(data, metadata=metadata)
__all__ = [ "DANGEROUS_MODULES", "PROTECTED_OUTER_PATHS", "TransformChain", "replace_outer_uris", "transform", ]