"""Atomic transform primitives for kstlib.transform.
Pure functions with no state. Each takes data + PrimitiveConfig
and returns transformed data. Input validation is performed before
any allocation or processing (fail fast).
"""
from __future__ import annotations
import base64 as b64_module
import binascii
import copy
import json as json_module
import logging
import re
import zlib as zlib_module
from typing import Any
from xml.etree.ElementTree import Element
from defusedxml import ElementTree as ET
from defusedxml.ElementTree import fromstring as _safe_xml_fromstring
from kstlib.transform.config import PrimitiveConfig
from kstlib.transform.exceptions import (
CompressError,
DecodeError,
DecompressError,
EncodeError,
ParseError,
SerializeError,
)
from kstlib.transform.validators import (
MAX_DECOMPRESSED_SIZE,
MAX_DECOMPRESSION_RATIO,
MAX_INPUT_SIZE,
MAX_JSON_SIZE,
MAX_XML_SIZE,
)
log = logging.getLogger(__name__)
#: Pre-compiled pattern to strip non-base64-alphabet characters in lenient mode.
#: Allowed alphabet: A-Z, a-z, 0-9, +, /, =. Anything else is removed.
#: Compiled at module level (perf: avoid recompilation in hot path).
_NON_BASE64_PATTERN: re.Pattern[str] = re.compile(r"[^A-Za-z0-9+/=]")
# ============================================================================
# base64
# ============================================================================
[docs]
def base64_decode(data: str, config: PrimitiveConfig) -> bytes:
r"""Decode base64 string to bytes.
Supports SAS Viya wire formats and other proprietary base64 variants
via two opt-in options:
- ``strip_prefix``: a literal string that, if present at the start
of the input, is removed before decoding. Useful for SAS Viya
report blobs which begin with ``"TRUE###"`` (the ``TRUE`` part
decodes to the 3-byte SAS proprietary header and ``###`` is a
separator that lenient base64 decoders skip).
- ``strict``: when ``True`` (default), the underlying decoder runs
with ``validate=True`` and rejects any character outside the
base64 alphabet. When ``False``, non-alphabet characters are
stripped before decoding (matches the de facto behavior of
Python's stdlib ``base64.b64decode`` and most other tools).
Args:
data: Base64-encoded string. May include a configurable prefix
and (in lenient mode) embedded whitespace or separators.
config: Primitive config. Recognized options:
* ``strip_prefix`` (str | None): literal prefix to remove
before decoding. Default ``None``. Max 32 chars. If the
input does not start with this prefix, the option is a
no-op (does NOT raise) so the same chain can handle
mixed blobs that sometimes carry the prefix.
* ``strict`` (bool): when ``True`` (default) reject any
non-alphabet character; when ``False`` strip them
silently before decoding.
Returns:
Decoded bytes.
Raises:
DecodeError: If data is not a string, exceeds the input size
limit, or fails to decode after the configured pre-processing.
Examples:
>>> base64_decode("SGVsbG8=", PrimitiveConfig(name="base64"))
b'Hello'
>>> # SAS Viya pattern: strip the proprietary "TRUE###" prefix
>>> cfg = PrimitiveConfig(name="base64",
... options={"strip_prefix": "TRUE###", "strict": False})
>>> base64_decode("TRUE###SGVsbG8=", cfg)
b'Hello'
>>> # strip_prefix is a no-op when the input does not start with it
>>> cfg2 = PrimitiveConfig(name="base64",
... options={"strip_prefix": "TRUE###"})
>>> base64_decode("SGVsbG8=", cfg2)
b'Hello'
>>> # Lenient mode strips embedded non-alphabet noise
>>> cfg3 = PrimitiveConfig(name="base64", options={"strict": False})
>>> base64_decode("SGVs###bG8=", cfg3)
b'Hello'
"""
if not isinstance(data, str):
raise DecodeError(
f"Expected str, got {type(data).__name__}",
primitive_name="base64",
)
if len(data) > MAX_INPUT_SIZE:
raise DecodeError(
f"Input exceeds limit ({len(data):,} > {MAX_INPUT_SIZE:,})",
primitive_name="base64",
)
# Track whether the original input was non-empty so we can detect
# the case where pre-processing strips everything (e.g. all-noise input).
# Empty input is still allowed (legacy behavior: returns b"").
original_was_non_empty = bool(data)
# Step 1: optional prefix strip (no-op if data does not start with it)
strip_prefix = config.options.get("strip_prefix")
if strip_prefix and data.startswith(strip_prefix):
data = data[len(strip_prefix) :]
# Step 2: optional lenient cleanup (strip non-alphabet chars)
strict = config.options.get("strict", True)
if not strict:
data = _NON_BASE64_PATTERN.sub("", data)
# Detect "all noise was stripped to nothing" only when caller passed
# actual data (preserves the legacy b64decode("") -> b"" behavior).
if not data and original_was_non_empty:
raise DecodeError(
"Empty base64 data after pre-processing",
primitive_name="base64",
)
try:
return b64_module.b64decode(data, validate=True)
except (binascii.Error, ValueError) as exc:
raise DecodeError(
f"Invalid base64 data (length: {len(data)})",
primitive_name="base64",
) from exc
[docs]
def base64_encode(data: bytes, config: PrimitiveConfig) -> str:
"""Encode bytes to base64 string with an optional literal prefix.
The ``prefix`` option allows reattaching a proprietary marker after
encoding, mirroring ``base64_decode``'s ``strip_prefix`` on the
forward path. The typical SAS Viya use case is ``"TRUE###"``: the
forward chain strips it before decoding, the backward chain
re-prepends it after encoding so the wire format is preserved
bit-for-bit.
Args:
data: Raw bytes to encode.
config: Primitive config. Recognized options:
* ``prefix`` (str | None): literal string prepended to the
base64 result. Default ``None``. Max 32 chars.
Returns:
Base64-encoded string, optionally prefixed.
Raises:
EncodeError: If data is not bytes.
Examples:
>>> base64_encode(b"Hello", PrimitiveConfig(name="base64"))
'SGVsbG8='
>>> # Reattach the SAS Viya proprietary prefix
>>> cfg = PrimitiveConfig(name="base64", options={"prefix": "TRUE###"})
>>> base64_encode(b"Hello", cfg)
'TRUE###SGVsbG8='
"""
if not isinstance(data, bytes):
raise EncodeError(
f"Expected bytes, got {type(data).__name__}",
primitive_name="base64",
)
encoded = b64_module.b64encode(data).decode("ascii")
prefix: str | None = config.options.get("prefix")
if prefix:
return f"{prefix}{encoded}"
return encoded
# ============================================================================
# zlib
# ============================================================================
[docs]
def zlib_decompress(data: bytes, config: PrimitiveConfig) -> bytes:
"""Decompress zlib data with optional header skip.
Args:
data: Compressed bytes (possibly with prefix header).
config: Options: skip_bytes (int) strips N leading bytes.
Returns:
Decompressed bytes.
Raises:
DecompressError: If decompression fails or input invalid.
Examples:
>>> import zlib
>>> compressed = zlib.compress(b"Hello")
>>> zlib_decompress(compressed, PrimitiveConfig(name="zlib"))
b'Hello'
"""
if not isinstance(data, bytes):
raise DecompressError(
f"Expected bytes, got {type(data).__name__}",
primitive_name="zlib",
)
skip_bytes: int = config.options.get("skip_bytes", 0)
if skip_bytes > len(data):
raise DecompressError(
f"skip_bytes ({skip_bytes}) exceeds data length ({len(data)})",
primitive_name="zlib",
)
payload = data[skip_bytes:]
try:
result = zlib_module.decompress(payload)
except zlib_module.error as exc:
raise DecompressError(
f"zlib decompression failed (input: {len(payload):,} bytes)",
primitive_name="zlib",
) from exc
# Zlib bomb protection
if len(result) > MAX_DECOMPRESSED_SIZE:
raise DecompressError(
f"Decompressed size ({len(result):,}) exceeds limit ({MAX_DECOMPRESSED_SIZE:,})",
primitive_name="zlib",
)
compressed_size = max(len(payload), 1)
ratio = len(result) / compressed_size
if ratio > MAX_DECOMPRESSION_RATIO:
raise DecompressError(
f"Decompression ratio {ratio:.1f} exceeds limit ({MAX_DECOMPRESSION_RATIO})",
primitive_name="zlib",
)
return result
[docs]
def zlib_compress(data: bytes, config: PrimitiveConfig) -> bytes:
"""Compress data with zlib, optionally prepending a header.
Args:
data: Raw bytes to compress.
config: Primitive config. Recognized options:
* ``prepend_bytes`` (str | None): hex string prepended
before the compressed bytes. Default ``None``.
* ``level`` (int): compression level passed to
``zlib.compress``. Range -1 to 9, where -1 means "use
the Python zlib default level" (typically 6), 0 means
no compression, and 9 means maximum compression.
Default -1. Higher values produce smaller output but
are slower.
Returns:
Compressed bytes with optional header prefix.
Raises:
CompressError: If compression fails or prepend_bytes hex is invalid.
Examples:
>>> result = zlib_compress(b"Hello", PrimitiveConfig(name="zlib"))
>>> import zlib
>>> zlib.decompress(result)
b'Hello'
>>> # Maximum compression level
>>> cfg = PrimitiveConfig(name="zlib", options={"level": 9})
>>> result9 = zlib_compress(b"Hello world " * 100, cfg)
>>> zlib.decompress(result9) == b"Hello world " * 100
True
"""
if not isinstance(data, bytes):
raise CompressError(
f"Expected bytes, got {type(data).__name__}",
primitive_name="zlib",
)
prepend_hex: str = config.options.get("prepend_bytes", "")
level: int = config.options.get("level", -1)
try:
compressed = zlib_module.compress(data, level)
except zlib_module.error as exc:
raise CompressError(
f"zlib compression failed (input: {len(data):,} bytes)",
primitive_name="zlib",
) from exc
if prepend_hex:
try:
header = bytes.fromhex(prepend_hex)
except ValueError as exc:
raise CompressError(
f"Invalid prepend_bytes hex: {prepend_hex!r}",
primitive_name="zlib",
) from exc
return header + compressed
return compressed
# ============================================================================
# json
# ============================================================================
def _extract_by_path(data: Any, path: str) -> Any:
"""Extract a value from nested dicts using dot-notation path.
Args:
data: Parsed JSON (dict).
path: Dot-separated path (e.g. "transferableContent.content").
Returns:
Extracted value.
Raises:
ParseError: If path not found.
"""
current = data
for segment in path.split("."):
if not isinstance(current, dict) or segment not in current:
raise ParseError(
f"Path '{path}' not found in parsed data",
primitive_name="json",
)
current = current[segment]
return current
def _wrap_by_path(value: Any, path: str, envelope: dict[str, Any] | None) -> dict[str, Any]:
"""Wrap a value into a nested dict at the given dot-notation path.
If envelope is provided, replaces only the path in the envelope
(lossless round-trip). Otherwise builds a minimal dict.
Args:
value: Value to wrap.
path: Dot-separated path.
envelope: Original envelope dict for lossless restoration.
Returns:
Dict with value placed at path.
"""
if envelope is not None:
restored = copy.deepcopy(envelope)
segments = path.split(".")
current: dict[str, Any] = restored
for segment in segments[:-1]:
current = current[segment]
current[segments[-1]] = value
return restored
# Build minimal path
segments = path.split(".")
result: dict[str, Any] = {}
current = result
for segment in segments[:-1]:
child: dict[str, Any] = {}
current[segment] = child
current = child
current[segments[-1]] = value
return result
[docs]
def json_parse(
data: str | bytes,
config: PrimitiveConfig,
) -> tuple[Any, dict[str, Any] | None]:
"""Parse JSON string, optionally extracting a nested field.
Returns a tuple of (value, envelope). If extract is used, envelope
contains the original parsed dict for lossless backward restoration.
If no extract, envelope is None.
Args:
data: JSON string or bytes.
config: Options: extract (dot path).
Returns:
Tuple of (extracted_or_full_value, original_envelope_or_None).
Raises:
ParseError: If JSON parsing fails or extract path not found.
Examples:
>>> val, env = json_parse('{"a": 1}', PrimitiveConfig(name="json"))
>>> val
{'a': 1}
"""
if not isinstance(data, str | bytes):
raise ParseError(
f"Expected str or bytes, got {type(data).__name__}",
primitive_name="json",
)
size = len(data)
if size > MAX_JSON_SIZE:
raise ParseError(
f"JSON input exceeds limit ({size:,} > {MAX_JSON_SIZE:,})",
primitive_name="json",
)
try:
parsed = json_module.loads(data)
except (json_module.JSONDecodeError, UnicodeDecodeError) as exc:
raise ParseError(
f"Invalid JSON (input length: {size:,})",
primitive_name="json",
) from exc
extract_path = config.options.get("extract")
if extract_path:
extracted = _extract_by_path(parsed, extract_path)
return extracted, parsed
return parsed, None
[docs]
def json_serialize(
data: Any,
config: PrimitiveConfig,
*,
envelope: dict[str, Any] | None = None,
) -> str:
r"""Serialize Python object to JSON string.
If wrap path and envelope are provided, restores the value into
the original envelope structure (lossless round-trip).
Args:
data: Python object to serialize.
config: Primitive config. Recognized options:
* ``wrap`` (str | None): dot-notation path used together
with ``envelope`` to restore the value inside its
original envelope structure. Default ``None``.
* ``minify`` (bool): when ``True``, output uses compact
``separators=(",", ":")`` (no whitespace). When ``False``
(default), uses Python's default separators
``(", ", ": ")``. Useful before zlib compression
(denser input compresses better).
* ``ensure_ascii`` (bool): when ``True``, escape every
non-ASCII character as ``\\uXXXX``. When ``False``
(default, **diverges from Python stdlib which is True**),
non-ASCII characters are emitted verbatim. The kstlib
default is ``False`` to preserve Unicode content
(French, Japanese, etc.) without bloating the output.
envelope: Original envelope for lossless restoration when
``wrap`` is set.
Returns:
JSON string.
Raises:
SerializeError: If serialization fails.
Examples:
>>> json_serialize({"a": 1}, PrimitiveConfig(name="json"))
'{"a": 1}'
>>> # Minified output (no spaces after , and :)
>>> cfg = PrimitiveConfig(name="json", options={"minify": True})
>>> json_serialize({"a": 1, "b": 2}, cfg)
'{"a":1,"b":2}'
>>> # Preserve Unicode content (default behavior)
>>> json_serialize({"k": "café"}, PrimitiveConfig(name="json"))
'{"k": "café"}'
>>> # Force ASCII escapes
>>> cfg2 = PrimitiveConfig(name="json", options={"ensure_ascii": True})
>>> json_serialize({"k": "café"}, cfg2)
'{"k": "caf\\u00e9"}'
"""
wrap_path = config.options.get("wrap")
minify = config.options.get("minify", False)
ensure_ascii = config.options.get("ensure_ascii", False)
separators = (",", ":") if minify else None
try:
target = _wrap_by_path(data, wrap_path, envelope) if wrap_path else data
return json_module.dumps(
target,
ensure_ascii=ensure_ascii,
separators=separators,
)
except (TypeError, ValueError) as exc:
raise SerializeError(
f"JSON serialization failed: {type(data).__name__}",
primitive_name="json",
) from exc
# ============================================================================
# xml
# ============================================================================
[docs]
def xml_parse(data: str, config: PrimitiveConfig) -> Element:
"""Parse XML string to ElementTree Element.
Uses defusedxml.ElementTree.fromstring for XXE protection. defusedxml
raises EntitiesForbidden, DTDForbidden, or ExternalReferenceForbidden
on malicious payloads; all are wrapped in a ParseError here.
Args:
data: XML string.
config: Primitive config.
Returns:
ElementTree root Element.
Raises:
ParseError: If XML parsing fails or input is unsafe.
Examples:
>>> root = xml_parse("<root><a>1</a></root>", PrimitiveConfig(name="xml"))
>>> root.tag
'root'
"""
if not isinstance(data, str):
raise ParseError(
f"Expected str, got {type(data).__name__}",
primitive_name="xml",
)
if len(data) > MAX_XML_SIZE:
raise ParseError(
f"XML input exceeds limit ({len(data):,} > {MAX_XML_SIZE:,})",
primitive_name="xml",
)
try:
result: Element = _safe_xml_fromstring(data)
return result
except Exception as exc:
raise ParseError(
f"Invalid XML (input length: {len(data):,})",
primitive_name="xml",
) from exc
[docs]
def xml_serialize(data: Element, config: PrimitiveConfig) -> str:
"""Serialize ElementTree Element to XML string.
Args:
data: ElementTree root Element.
config: Primitive config.
Returns:
XML string.
Raises:
SerializeError: If serialization fails.
Examples:
>>> from xml.etree.ElementTree import Element
>>> root = Element("root")
>>> xml_serialize(root, PrimitiveConfig(name="xml"))
'<root />'
"""
if not isinstance(data, Element):
raise SerializeError(
f"Expected Element, got {type(data).__name__}",
primitive_name="xml",
)
encoding = config.options.get("encoding", "unicode")
try:
result = ET.tostring(data, encoding=encoding)
if not isinstance(result, str): # pragma: no cover - non-text encoding yields bytes
result = str(result, "utf-8")
return result
except (TypeError, LookupError) as exc:
raise SerializeError(
"XML serialization failed",
primitive_name="xml",
) from exc
# ============================================================================
# bytes
# ============================================================================
[docs]
def bytes_decode(data: bytes, config: PrimitiveConfig) -> str:
"""Decode bytes to string.
Args:
data: Raw bytes.
config: Options: encoding (default utf-8).
Returns:
Decoded string.
Raises:
DecodeError: If decoding fails.
Examples:
>>> bytes_decode(b"Hello", PrimitiveConfig(name="bytes"))
'Hello'
"""
if not isinstance(data, bytes):
raise DecodeError(
f"Expected bytes, got {type(data).__name__}",
primitive_name="bytes",
)
encoding = config.options.get("encoding", "utf-8")
try:
return data.decode(encoding)
except (UnicodeDecodeError, LookupError) as exc:
raise DecodeError(
f"bytes decode failed with encoding '{encoding}' (input: {len(data):,} bytes)",
primitive_name="bytes",
) from exc
[docs]
def bytes_encode(data: str, config: PrimitiveConfig) -> bytes:
"""Encode string to bytes.
Args:
data: String to encode.
config: Options: encoding (default utf-8).
Returns:
Encoded bytes.
Raises:
EncodeError: If encoding fails.
Examples:
>>> bytes_encode("Hello", PrimitiveConfig(name="bytes"))
b'Hello'
"""
if not isinstance(data, str):
raise EncodeError(
f"Expected str, got {type(data).__name__}",
primitive_name="bytes",
)
encoding = config.options.get("encoding", "utf-8")
try:
return data.encode(encoding)
except (UnicodeEncodeError, LookupError) as exc:
raise EncodeError(
f"bytes encode failed with encoding '{encoding}'",
primitive_name="bytes",
) from exc
# ============================================================================
# Registry
# ============================================================================
#: Forward primitive dispatch table.
FORWARD_DISPATCH: dict[str, Any] = {
"base64": base64_decode,
"zlib": zlib_decompress,
"json": json_parse,
"xml": xml_parse,
"bytes": bytes_decode,
}
#: Backward primitive dispatch table.
BACKWARD_DISPATCH: dict[str, Any] = {
"base64": base64_encode,
"zlib": zlib_compress,
"json": json_serialize,
"xml": xml_serialize,
"bytes": bytes_encode,
}
__all__ = [
"BACKWARD_DISPATCH",
"FORWARD_DISPATCH",
"base64_decode",
"base64_encode",
"bytes_decode",
"bytes_encode",
"json_parse",
"json_serialize",
"xml_parse",
"xml_serialize",
"zlib_compress",
"zlib_decompress",
]