"""Graceful shutdown handler with prioritized cleanup callbacks."""
from __future__ import annotations
import asyncio
import contextlib
import inspect
import logging
import signal
import sys
import threading
from collections.abc import Callable, Coroutine
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, cast
from typing_extensions import Self
from kstlib.limits import (
HARD_MAX_SHUTDOWN_TIMEOUT,
HARD_MIN_SHUTDOWN_TIMEOUT,
get_resilience_limits,
)
from kstlib.resilience.exceptions import ShutdownError
if TYPE_CHECKING:
import types
log = logging.getLogger(__name__)
# Type aliases for callbacks
SyncCallback = Callable[[], None]
AsyncCallback = Callable[[], Coroutine[Any, Any, None]]
Callback = SyncCallback | AsyncCallback
[docs]
@dataclass(frozen=True, slots=True)
class CleanupCallback:
"""Registered cleanup callback with metadata.
Attributes:
name: Unique identifier for the callback.
callback: The cleanup function (sync or async).
priority: Execution order (lower runs first, default 100).
timeout: Per-callback timeout in seconds (None = use global).
is_async: Whether the callback is async.
Examples:
>>> cb = CleanupCallback(
... name="db_close",
... callback=lambda: None,
... priority=50,
... timeout=5.0,
... is_async=False,
... )
>>> (cb.name, cb.priority, cb.timeout)
('db_close', 50, 5.0)
"""
name: str
callback: Callback
priority: int = 100
timeout: float | None = None
is_async: bool = False
[docs]
class GracefulShutdown:
"""Graceful shutdown handler with prioritized cleanup callbacks.
Manages orderly shutdown on SIGTERM/SIGINT with timeout enforcement.
Callbacks execute in priority order (lower = first).
Args:
timeout: Total timeout for all callbacks (default from config).
signals: Signals to handle (default: SIGTERM, SIGINT).
force_exit_code: Exit code when timeout exceeded (default: 1).
Examples:
Register callbacks with priority ordering:
>>> shutdown = GracefulShutdown(timeout=30)
>>> shutdown.register("cache", lambda: None, priority=20)
>>> shutdown.register("db", lambda: None, priority=10)
>>> [cb.name for cb in shutdown._get_sorted_callbacks()]
['db', 'cache']
Context manager usage (with signals):
>>> with GracefulShutdown() as shutdown: # doctest: +SKIP
... shutdown.register("cleanup", close_resources)
... run_application()
"""
# Signals not available on Windows
_UNIX_SIGNALS: tuple[signal.Signals, ...] = (signal.SIGTERM, signal.SIGINT)
_WINDOWS_SIGNALS: tuple[signal.Signals, ...] = (signal.SIGINT,)
[docs]
def __init__(
self,
*,
timeout: float | None = None,
signals: tuple[signal.Signals, ...] | None = None,
force_exit_code: int = 1,
) -> None:
"""Initialize graceful shutdown handler.
Args:
timeout: Total timeout for all callbacks. Uses config if None.
signals: Signals to handle. Auto-detects platform if None.
force_exit_code: Exit code when timeout exceeded.
"""
# Load timeout from config if not provided
if timeout is None:
limits = get_resilience_limits()
self._timeout = limits.shutdown_timeout
else:
# Clamp to hard limits
self._timeout = max(
HARD_MIN_SHUTDOWN_TIMEOUT,
min(timeout, HARD_MAX_SHUTDOWN_TIMEOUT),
)
# Auto-detect signals based on platform
if signals is None:
self._signals = self._WINDOWS_SIGNALS if sys.platform == "win32" else self._UNIX_SIGNALS
else:
self._signals = signals
self._force_exit_code = force_exit_code
# Callback registry
self._callbacks: dict[str, CleanupCallback] = {}
self._lock = threading.Lock()
# Shutdown state
self._shutting_down = False
self._shutdown_event = threading.Event()
# Original signal handlers (for restoration)
self._original_handlers: dict[signal.Signals, Any] = {}
self._installed = False
@property
def timeout(self) -> float:
"""Return the total shutdown timeout in seconds."""
return self._timeout
@property
def is_shutting_down(self) -> bool:
"""Return True if shutdown is in progress."""
return self._shutting_down
@property
def is_installed(self) -> bool:
"""Return True if signal handlers are installed."""
return self._installed
[docs]
def register(
self,
name: str,
callback: Callback,
*,
priority: int = 100,
timeout: float | None = None,
) -> None:
"""Register a cleanup callback.
Args:
name: Unique identifier for the callback.
callback: Cleanup function (sync or async).
priority: Execution order (lower runs first, default 100).
timeout: Per-callback timeout (None = use global).
Raises:
ShutdownError: If name already registered or shutdown in progress.
Examples:
>>> shutdown = GracefulShutdown()
>>> shutdown.register("db", lambda: print("closing db"), priority=10)
>>> "db" in [cb.name for cb in shutdown._callbacks.values()]
True
"""
with self._lock:
if self._shutting_down:
raise ShutdownError("Cannot register callback during shutdown")
if name in self._callbacks:
raise ShutdownError(f"Callback '{name}' already registered")
is_async = inspect.iscoroutinefunction(callback)
self._callbacks[name] = CleanupCallback(
name=name,
callback=callback,
priority=priority,
timeout=timeout,
is_async=is_async,
)
[docs]
def unregister(self, name: str) -> bool:
"""Unregister a cleanup callback.
Args:
name: Identifier of callback to remove.
Returns:
True if callback was removed, False if not found.
Examples:
>>> shutdown = GracefulShutdown()
>>> shutdown.register("test", lambda: None)
>>> shutdown.unregister("test")
True
>>> shutdown.unregister("nonexistent")
False
"""
with self._lock:
if name in self._callbacks:
del self._callbacks[name]
return True
return False
[docs]
def install(self) -> None:
"""Install signal handlers.
Raises:
ShutdownError: If handlers already installed.
"""
with self._lock:
if self._installed:
raise ShutdownError("Signal handlers already installed")
for sig in self._signals:
with contextlib.suppress(OSError, ValueError):
self._original_handlers[sig] = signal.signal(sig, self._signal_handler)
self._installed = True
[docs]
def uninstall(self) -> None:
"""Restore original signal handlers."""
with self._lock:
if not self._installed:
return
for sig, handler in self._original_handlers.items():
with contextlib.suppress(OSError, ValueError):
signal.signal(sig, handler)
self._original_handlers.clear()
self._installed = False
def _signal_handler(self, _signum: int, _frame: types.FrameType | None) -> None:
"""Handle incoming signal."""
self.trigger()
[docs]
def trigger(self) -> None:
"""Trigger shutdown programmatically.
Useful for testing or triggering shutdown from code.
Runs callbacks synchronously in priority order.
"""
with self._lock:
if self._shutting_down:
return
self._shutting_down = True
log.info("Shutdown requested")
self._shutdown_event.set()
self._run_callbacks_sync()
[docs]
async def atrigger(self) -> None:
"""Trigger shutdown programmatically (async version).
Runs callbacks asynchronously in priority order.
"""
with self._lock:
if self._shutting_down:
return
self._shutting_down = True
log.info("Shutdown requested")
self._shutdown_event.set()
await self._run_callbacks_async()
def _get_sorted_callbacks(self) -> list[CleanupCallback]:
"""Return callbacks sorted by priority (ascending)."""
with self._lock:
return sorted(self._callbacks.values(), key=lambda cb: cb.priority)
def _run_callbacks_sync(self) -> None:
"""Run all callbacks synchronously with timeout."""
callbacks = self._get_sorted_callbacks()
for cb in callbacks:
cb_timeout = cb.timeout if cb.timeout is not None else self._timeout
if cb.is_async:
# Run async callback in new event loop
async_callback = cast("AsyncCallback", cb.callback)
try:
loop = asyncio.new_event_loop()
try:
loop.run_until_complete(asyncio.wait_for(async_callback(), timeout=cb_timeout))
finally:
loop.close()
except asyncio.TimeoutError:
log.warning("Shutdown callback '%s' timed out", cb.name)
except Exception: # pylint: disable=broad-exception-caught
# Intentional: shutdown must continue even if callback fails
log.warning("Shutdown callback '%s' failed", cb.name, exc_info=True)
else:
# Run sync callback with timeout via thread
# Wrap in suppress to prevent unhandled thread exceptions
def safe_callback(fn: SyncCallback) -> None:
with contextlib.suppress(Exception):
fn()
sync_callback = cast("SyncCallback", cb.callback)
thread = threading.Thread(target=safe_callback, args=(sync_callback,))
thread.start()
thread.join(timeout=cb_timeout)
# If thread still running after timeout, we continue anyway
async def _run_callbacks_async(self) -> None:
"""Run all callbacks asynchronously with timeout."""
callbacks = self._get_sorted_callbacks()
for cb in callbacks:
cb_timeout = cb.timeout if cb.timeout is not None else self._timeout
try:
if cb.is_async:
async_cb = cast("AsyncCallback", cb.callback)
await asyncio.wait_for(async_cb(), timeout=cb_timeout)
else:
# Run sync callback in executor
loop = asyncio.get_running_loop()
await asyncio.wait_for(
loop.run_in_executor(None, cb.callback),
timeout=cb_timeout,
)
except asyncio.TimeoutError:
log.warning("Shutdown callback '%s' timed out", cb.name)
except Exception: # pylint: disable=broad-exception-caught
# Intentional: shutdown must continue even if callback fails
log.warning("Shutdown callback '%s' failed", cb.name, exc_info=True)
[docs]
def wait(self, timeout: float | None = None) -> bool:
"""Wait for shutdown signal.
Args:
timeout: Maximum time to wait (None = wait forever).
Returns:
True if shutdown was triggered, False if timeout.
"""
return self._shutdown_event.wait(timeout=timeout)
[docs]
async def await_shutdown(self, timeout: float | None = None) -> bool:
"""Wait for shutdown signal (async version).
Args:
timeout: Maximum time to wait (None = wait forever).
Returns:
True if shutdown was triggered, False if timeout.
"""
# Use polling to avoid executor thread cleanup issues
# Note: We poll a threading.Event, not asyncio.Event, hence the loop
# The threading.Event is used to support both sync and async contexts
poll_interval = 0.05 # 50ms polling
if timeout is None:
# Wait forever (poll threading.Event from async context)
while not self._shutdown_event.is_set():
await asyncio.sleep(poll_interval)
return True
loop = asyncio.get_running_loop()
deadline = loop.time() + timeout
while not self._shutdown_event.is_set():
remaining = deadline - loop.time()
if remaining <= 0:
return False
await asyncio.sleep(min(poll_interval, remaining))
return True
[docs]
def __enter__(self) -> Self:
"""Enter sync context manager."""
self.install()
return self
[docs]
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: types.TracebackType | None,
) -> None:
"""Exit sync context manager."""
self.uninstall()
if not self._shutting_down:
self.trigger()
[docs]
async def __aenter__(self) -> Self:
"""Enter async context manager."""
self.install()
return self
[docs]
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: types.TracebackType | None,
) -> None:
"""Exit async context manager."""
self.uninstall()
if not self._shutting_down:
await self.atrigger()
__all__ = ["CleanupCallback", "GracefulShutdown"]