Source code for kstlib.mail.transports.smtp

"""SMTP transport backend with TRACE-level debugging.

When the logger level is set to TRACE, this transport logs detailed
information about the SMTP session including:
- Connection and EHLO exchange
- STARTTLS negotiation and SSL/TLS cipher details
- Authentication flow (credentials redacted)
- Message envelope (MAIL FROM, RCPT TO)

Enable trace logging via configuration:
    logger:
      preset: trace_mail  # Or set level: TRACE directly
"""

from __future__ import annotations

# pylint: disable=too-many-instance-attributes,too-many-arguments,too-few-public-methods
import io
import logging
import smtplib
import ssl
import sys
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

from kstlib.logging import TRACE_LEVEL
from kstlib.mail.exceptions import MailTransportError
from kstlib.mail.transport import MailTransport

if TYPE_CHECKING:
    from collections.abc import Iterator
    from email.message import EmailMessage

# Module logger - uses kstlib hierarchy for config-driven trace
log = logging.getLogger(__name__)


@contextmanager
def _capture_smtp_debug() -> Iterator[io.StringIO]:
    """Capture smtplib debug output to a StringIO buffer.

    smtplib.set_debuglevel() writes to stderr. This context manager
    temporarily redirects stderr to capture the debug output.

    Yields:
        StringIO buffer containing captured debug output.

    """
    buffer = io.StringIO()
    old_stderr = sys.stderr
    try:
        sys.stderr = buffer
        yield buffer
    finally:
        sys.stderr = old_stderr


def _extract_cn_from_cert_field(field: Any) -> str | None:
    """Extract commonName from a certificate subject or issuer field.

    Certificate fields (subject, issuer) are nested tuples of RDNs (Relative
    Distinguished Names), each containing attribute tuples like ('commonName', 'value').

    Args:
        field: Nested tuple structure from peer_cert['subject'] or ['issuer'].

    Returns:
        The commonName value if found, None otherwise.

    """
    if not field or not isinstance(field, tuple):
        return None
    for rdn in field:
        if not isinstance(rdn, tuple):
            continue
        for attr in rdn:
            if isinstance(attr, tuple) and len(attr) >= 2 and attr[0] == "commonName":
                return str(attr[1])
    return None


def _extract_cipher_info(sock: ssl.SSLSocket) -> dict[str, Any]:
    """Extract cipher information from SSL socket.

    Args:
        sock: The SSL socket after handshake.

    Returns:
        Dictionary with cipher_name, cipher_protocol, cipher_bits.

    """
    info: dict[str, Any] = {}
    try:
        cipher = sock.cipher()
        if cipher:
            info["cipher_name"] = cipher[0]
            info["cipher_protocol"] = cipher[1]
            info["cipher_bits"] = cipher[2]
    except Exception:  # pylint: disable=broad-exception-caught
        pass
    return info


def _extract_cert_info(sock: ssl.SSLSocket) -> dict[str, Any]:
    """Extract peer certificate information from SSL socket.

    Args:
        sock: The SSL socket after handshake.

    Returns:
        Dictionary with peer_cn, issuer_cn, valid_from, valid_until.

    """
    info: dict[str, Any] = {}
    try:
        peer_cert = sock.getpeercert()
        if peer_cert:
            peer_cn = _extract_cn_from_cert_field(peer_cert.get("subject"))
            if peer_cn:
                info["peer_cn"] = peer_cn
            issuer_cn = _extract_cn_from_cert_field(peer_cert.get("issuer"))
            if issuer_cn:
                info["issuer_cn"] = issuer_cn
            if "notBefore" in peer_cert:
                info["valid_from"] = peer_cert["notBefore"]
            if "notAfter" in peer_cert:
                info["valid_until"] = peer_cert["notAfter"]
    except Exception:  # pylint: disable=broad-exception-caught
        pass
    return info


def _extract_ssl_info(sock: ssl.SSLSocket | None) -> dict[str, Any]:
    """Extract SSL/TLS session information for trace logging.

    Args:
        sock: The SSL socket after handshake.

    Returns:
        Dictionary with SSL session details (version, cipher, peer cert).

    """
    if sock is None:
        return {}

    info: dict[str, Any] = {}

    try:
        info["version"] = sock.version()
    except Exception:  # pylint: disable=broad-exception-caught
        info["version"] = "unknown"

    info.update(_extract_cipher_info(sock))
    info.update(_extract_cert_info(sock))

    return info


def _log_ssl_info(client: smtplib.SMTP | smtplib.SMTP_SSL, protocol_label: str) -> None:
    """Log SSL/TLS session information at TRACE level.

    Args:
        client: The SMTP client with an SSL socket.
        protocol_label: Label for the protocol (e.g., "SSL" or "TLS").

    """
    ssl_sock = getattr(client, "sock", None)
    if ssl_sock is None or not hasattr(ssl_sock, "version"):
        return

    ssl_info = _extract_ssl_info(ssl_sock)
    if not ssl_info:
        return

    log.log(
        TRACE_LEVEL,
        "[SMTP] %s: %s, cipher=%s (%d bits)",
        protocol_label,
        ssl_info.get("version", "unknown"),
        ssl_info.get("cipher_name", "unknown"),
        ssl_info.get("cipher_bits", 0),
    )
    if "peer_cn" in ssl_info:
        log.log(
            TRACE_LEVEL,
            "[SMTP] %s peer: CN=%s, issuer=%s",
            protocol_label,
            ssl_info.get("peer_cn", "unknown"),
            ssl_info.get("issuer_cn", "unknown"),
        )


def _log_smtp_debug_output(buffer: io.StringIO) -> None:
    """Parse and log captured smtplib debug output at TRACE level.

    Args:
        buffer: StringIO containing captured debug output.

    """
    if not log.isEnabledFor(TRACE_LEVEL):
        return

    content = buffer.getvalue()
    if not content:
        return

    # Patterns that may contain credentials in SMTP debug output
    _sensitive_patterns = ("AUTH ", "AUTH=", "ATRN ", "PASS ")

    for line in content.strip().split("\n"):
        line = line.strip()
        if not line:
            continue
        # smtplib debug format: "send: 'EHLO ...'" or "reply: retcode (...);"
        if line.startswith("send:"):
            payload = line[5:].strip().strip("'\"")
            # Redact AUTH commands that contain base64-encoded credentials
            upper = payload.upper()
            if any(upper.startswith(p) for p in _sensitive_patterns):
                log.log(TRACE_LEVEL, "[SMTP] >>> %s ***REDACTED***", payload.split()[0])
            else:
                log.log(TRACE_LEVEL, "[SMTP] >>> %s", payload)
        elif line.startswith("reply:"):
            log.log(TRACE_LEVEL, "[SMTP] <<< %s", line[6:].strip())
        else:
            log.log(TRACE_LEVEL, "[SMTP] %s", line)


[docs] @dataclass(frozen=True, slots=True) class SMTPCredentials: """SMTP authentication bundle.""" username: str password: str | None = None
[docs] def __repr__(self) -> str: """Redact password from repr output.""" return f"SMTPCredentials(username={self.username!r}, password={'***' if self.password else None!r})"
@dataclass(frozen=True, slots=True) class SMTPSecurity: """SMTP security preferences.""" use_ssl: bool = False use_starttls: bool = True ssl_context: ssl.SSLContext | None = None
[docs] class SMTPTransport(MailTransport): """Deliver messages using the standard SMTP protocol."""
[docs] def __init__( self, host: str, port: int = 587, *, credentials: SMTPCredentials | None = None, security: SMTPSecurity | None = None, timeout: float | None = None, ) -> None: """Configure connection parameters for the SMTP backend.""" self._host = host self._port = port self._timeout = timeout self._username = credentials.username if credentials else None self._password = credentials.password if credentials else None effective_security = security or SMTPSecurity() self._use_ssl = effective_security.use_ssl self._use_starttls = effective_security.use_starttls if not effective_security.use_ssl else False self._ssl_context = effective_security.ssl_context or ssl.create_default_context()
def _upgrade_to_tls(self, client: smtplib.SMTP, trace: bool) -> None: """Upgrade connection to TLS via STARTTLS if supported. Args: client: Active SMTP client. trace: Whether TRACE logging is enabled. """ if not (self._use_starttls and client.has_extn("STARTTLS")): return if trace: log.log(TRACE_LEVEL, "[SMTP] Upgrading to TLS via STARTTLS") client.starttls(context=self._ssl_context) client.ehlo() if trace: _log_ssl_info(client, "TLS") def _authenticate(self, client: smtplib.SMTP, trace: bool) -> None: """Authenticate with the SMTP server if credentials are configured. Args: client: Active SMTP client. trace: Whether TRACE logging is enabled. """ if not self._username: return if trace: log.log(TRACE_LEVEL, "[SMTP] Authenticating as: %s", self._username) client.login(self._username, self._password or "") if trace: log.log(TRACE_LEVEL, "[SMTP] Authentication successful") def _trace_envelope(self, message: EmailMessage) -> None: """Log message envelope details at TRACE level. Args: message: Email message to log. """ log.log(TRACE_LEVEL, "[SMTP] MAIL FROM: %s", message.get("From", "unknown")) log.log(TRACE_LEVEL, "[SMTP] RCPT TO: %s", message.get("To", "unknown")) log.log(TRACE_LEVEL, "[SMTP] Subject: %s", message.get("Subject", "(no subject)"))
[docs] def send(self, message: EmailMessage) -> None: """Send *message* through the configured SMTP server. When TRACE logging is enabled, detailed session information is logged including SMTP commands, SSL/TLS handshake details, and message envelope. """ trace = log.isEnabledFor(TRACE_LEVEL) client_kwargs: dict[str, Any] = {"host": self._host, "port": self._port, "timeout": self._timeout} client_cls = smtplib.SMTP_SSL if self._use_ssl else smtplib.SMTP protocol = "SMTP_SSL" if self._use_ssl else "SMTP" if trace: log.log(TRACE_LEVEL, "[SMTP] Connecting to %s:%d (%s)", self._host, self._port, protocol) try: with _capture_smtp_debug() as debug_buffer, client_cls(**client_kwargs) as client: if trace: client.set_debuglevel(2) client.ehlo() if trace and self._use_ssl: _log_ssl_info(client, "SSL") self._upgrade_to_tls(client, trace) self._authenticate(client, trace) if trace: self._trace_envelope(message) client.send_message(message) if trace: log.log(TRACE_LEVEL, "[SMTP] Message sent successfully") if trace: _log_smtp_debug_output(debug_buffer) except smtplib.SMTPException as exc: # pragma: no cover - network dependent if trace: log.log(TRACE_LEVEL, "[SMTP] Error: %s", exc) raise MailTransportError(str(exc)) from exc
__all__ = ["SMTPCredentials", "SMTPSecurity", "SMTPTransport"]