"""Heartbeat mechanism for process liveness signaling."""
from __future__ import annotations
import asyncio
import contextlib
import json
import logging
import os
import socket
import threading
from collections.abc import Awaitable, Callable, Mapping
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
from typing_extensions import Self
from kstlib.limits import (
HARD_MAX_HEARTBEAT_INTERVAL,
HARD_MIN_HEARTBEAT_INTERVAL,
clamp_with_limits,
get_resilience_limits,
)
from kstlib.resilience.exceptions import HeartbeatError
log = logging.getLogger(__name__)
if TYPE_CHECKING:
import types
# Type aliases for callbacks
OnAlertCallback = Callable[[str, str, Mapping[str, Any]], Awaitable[None] | None]
@runtime_checkable
class HeartbeatTarget(Protocol):
"""Protocol for objects that can be monitored by Heartbeat.
Any object implementing `is_dead` property can be used as a target.
This allows Heartbeat to detect when a monitored component has failed.
Examples:
>>> class MyWebSocket: # doctest: +SKIP
... @property
... def is_dead(self) -> bool:
... return not self.connected
"""
@property
def is_dead(self) -> bool:
"""Check if the target is dead and needs restart."""
...
[docs]
@dataclass(frozen=True, slots=True)
class HeartbeatState:
"""Represents the state written to the heartbeat file.
Attributes:
timestamp: Last heartbeat time (ISO 8601 UTC).
pid: Process ID.
hostname: Machine hostname.
metadata: Optional application-specific data.
Examples:
>>> state = HeartbeatState(
... timestamp="2026-01-12T10:00:00+00:00",
... pid=1234,
... hostname="myhost",
... )
>>> state.pid
1234
"""
timestamp: str
pid: int
hostname: str
metadata: dict[str, Any] = field(default_factory=dict)
[docs]
def to_dict(self) -> dict[str, Any]:
"""Serialize to JSON-compatible dictionary.
Returns:
Dictionary representation of the heartbeat state.
"""
return {
"timestamp": self.timestamp,
"pid": self.pid,
"hostname": self.hostname,
"metadata": self.metadata,
}
[docs]
@classmethod
def from_dict(cls, data: dict[str, Any]) -> HeartbeatState:
"""Deserialize from dictionary.
Args:
data: Dictionary with heartbeat state fields.
Returns:
HeartbeatState instance.
Raises:
KeyError: If required fields are missing.
"""
return cls(
timestamp=data["timestamp"],
pid=data["pid"],
hostname=data["hostname"],
metadata=data.get("metadata", {}),
)
[docs]
class Heartbeat:
"""Periodic signal to indicate the process is alive.
Writes timestamp to a JSON state file at configurable intervals.
Supports both sync and async context managers.
Args:
state_file: Path to the heartbeat state file. If None, no file is written
(useful when using on_beat callback for state management).
interval: Seconds between heartbeats (default from config or 10s).
on_missed_beat: Callback invoked when a beat write fails.
on_alert: Callback for alerting (channel, message, context).
target: Optional object with `is_dead` property to monitor.
on_target_dead: Callback invoked when target is detected as dead.
on_beat: Callback invoked after each successful beat. Can be sync or async.
Use this to delegate state writing to an external component.
metadata: Optional dict included in each heartbeat.
Examples:
Sync context manager:
>>> with Heartbeat("/tmp/bot.heartbeat") as hb: # doctest: +SKIP
... do_work()
Async context manager:
>>> async with Heartbeat("/tmp/bot.heartbeat") as hb: # doctest: +SKIP
... await do_async_work()
Check if a process is alive:
>>> Heartbeat.is_alive("/tmp/bot.heartbeat", max_age_seconds=30) # doctest: +SKIP
True
Monitor a WebSocket:
>>> hb = Heartbeat( # doctest: +SKIP
... "/tmp/bot.heartbeat",
... target=ws_manager,
... on_target_dead=lambda: restart_ws(),
... )
"""
[docs]
def __init__(
self,
state_file: str | Path | None = None,
*,
interval: float | None = None,
on_missed_beat: Callable[[Exception], None] | None = None,
on_alert: OnAlertCallback | None = None,
target: HeartbeatTarget | None = None,
on_target_dead: Callable[[], Awaitable[None] | None] | None = None,
on_beat: Callable[[], Awaitable[None] | None] | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
"""Initialize heartbeat.
Args:
state_file: Path to the heartbeat state file. If None, no file is written.
interval: Seconds between heartbeats. Uses config default if None.
on_missed_beat: Callback invoked when a beat write fails.
on_alert: Callback for alerting (channel, message, context).
target: Optional object with `is_dead` property to monitor.
on_target_dead: Callback invoked when target is detected as dead.
on_beat: Callback invoked after each successful beat.
metadata: Optional dict included in each heartbeat.
"""
self._state_file = Path(state_file) if state_file else None
self._on_missed_beat = on_missed_beat
self._on_alert = on_alert
self._target = target
self._on_target_dead = on_target_dead
self._on_beat = on_beat
self._metadata = metadata or {}
# Load interval from config if not provided, or clamp user value
limits = get_resilience_limits()
self._interval = (
limits.heartbeat_interval
if interval is None
else clamp_with_limits(interval, HARD_MIN_HEARTBEAT_INTERVAL, HARD_MAX_HEARTBEAT_INTERVAL)
)
# Threading state
self._running = False
self._thread: threading.Thread | None = None
self._lock = threading.Lock()
self._stop_event = threading.Event()
self._shutdown_requested = False
# Async state
self._async_task: asyncio.Task[None] | None = None
@property
def interval(self) -> float:
"""Return the heartbeat interval in seconds."""
return self._interval
@property
def state_file(self) -> Path | None:
"""Return the path to the state file, or None if not configured."""
return self._state_file
@property
def is_shutdown(self) -> bool:
"""Check if shutdown has been requested."""
return self._shutdown_requested
@property
def target(self) -> HeartbeatTarget | None:
"""Return the monitored target, if any."""
return self._target
[docs]
def shutdown(self) -> None:
"""Signal shutdown and stop gracefully.
Sets the shutdown flag which can be checked by external code
to know that we're shutting down intentionally.
"""
log.info("Heartbeat shutdown requested")
self._shutdown_requested = True
self.stop()
[docs]
async def ashutdown(self) -> None:
"""Signal shutdown and stop gracefully (async version)."""
log.info("Heartbeat shutdown requested")
self._shutdown_requested = True
await self.astop()
[docs]
def start(self) -> None:
"""Start the heartbeat background thread.
Raises:
HeartbeatError: If heartbeat is already running.
"""
with self._lock:
if self._running:
raise HeartbeatError("Heartbeat is already running")
self._running = True
self._stop_event.clear()
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._thread.start()
[docs]
def stop(self) -> None:
"""Stop the heartbeat and clean up.
Safe to call multiple times or if not started.
"""
with self._lock:
if not self._running:
return
self._running = False
self._stop_event.set()
if self._thread is not None:
self._thread.join(timeout=self._interval + 1.0)
self._thread = None
[docs]
def beat(self) -> None:
"""Write a heartbeat immediately (manual trigger).
If state_file is configured, writes to file.
If on_beat callback is configured, it will be invoked by the loop (not here).
Raises:
HeartbeatError: If state file is configured and cannot be written.
"""
# Skip file write if no state_file configured
if self._state_file is None:
return
state = HeartbeatState(
timestamp=datetime.now(timezone.utc).isoformat(),
pid=os.getpid(),
hostname=socket.gethostname(),
metadata=self._metadata,
)
try:
# Ensure parent directory exists with proper permissions
self._state_file.parent.mkdir(parents=True, exist_ok=True, mode=0o755)
# Write atomically using temp file
temp_file = self._state_file.with_suffix(".tmp")
temp_file.write_text(json.dumps(state.to_dict(), indent=2))
temp_file.replace(self._state_file)
except OSError as exc:
raise HeartbeatError(f"Failed to write heartbeat: {exc}") from exc
def _run_loop(self) -> None:
"""Background thread loop that writes heartbeats and checks target."""
while not self._stop_event.wait(timeout=self._interval):
if self._shutdown_requested:
break
try:
self.beat()
# Invoke on_beat callback after successful beat
if self._on_beat is not None:
with contextlib.suppress(Exception):
result = self._on_beat()
# Note: Cannot await in sync thread, result is ignored if coroutine
if asyncio.iscoroutine(result):
result.close()
except Exception as exc: # pylint: disable=broad-exception-caught
if self._on_missed_beat is not None:
with contextlib.suppress(Exception):
self._on_missed_beat(exc)
# Check target if provided (sync version cannot use async callbacks)
if self._target is not None and self._target.is_dead and self._on_target_dead is not None:
with contextlib.suppress(Exception):
result = self._on_target_dead()
# Note: Cannot await in sync thread, result is ignored if coroutine
if asyncio.iscoroutine(result):
# Close the coroutine to avoid warning
result.close()
[docs]
async def astart(self) -> None:
"""Start the heartbeat using asyncio (async version).
Raises:
HeartbeatError: If heartbeat is already running.
"""
with self._lock:
if self._running:
raise HeartbeatError("Heartbeat is already running")
self._running = True
self._async_task = asyncio.create_task(self._async_loop())
[docs]
async def astop(self) -> None:
"""Stop the heartbeat (async version).
Safe to call multiple times or if not started.
"""
with self._lock:
if not self._running:
return
self._running = False
if self._async_task is not None:
self._async_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._async_task
self._async_task = None
async def _invoke_callback_async(
self,
callback: Callable[[], Awaitable[None] | None] | None,
) -> None:
"""Invoke a callback that may be sync or async."""
if callback is not None:
try:
result = callback()
if asyncio.iscoroutine(result):
await result
except Exception as exc:
log.warning("Callback failed: %s", exc)
async def _check_target_async(self) -> None:
"""Check target health and invoke callbacks if dead."""
if self._target is None or not self._target.is_dead:
return
# Send alert if callback provided
if self._on_alert is not None:
with contextlib.suppress(Exception):
alert_result = self._on_alert(
"heartbeat",
"Target is dead, triggering recovery",
{"target": str(type(self._target).__name__)},
)
if asyncio.iscoroutine(alert_result):
await alert_result
# Invoke on_target_dead callback
await self._invoke_callback_async(self._on_target_dead)
async def _async_loop(self) -> None:
"""Async loop that writes heartbeats and monitors target."""
log.debug("Heartbeat async loop started (interval=%.1fs)", self._interval)
while self._running and not self._shutdown_requested:
try:
# Run beat in executor to avoid blocking
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, self.beat)
# Invoke on_beat callback after successful beat
if self._on_beat is not None:
log.debug("Invoking on_beat callback")
await self._invoke_callback_async(self._on_beat)
except Exception as exc: # pylint: disable=broad-exception-caught
log.warning("Heartbeat beat failed: %s", exc)
if self._on_missed_beat is not None:
with contextlib.suppress(Exception):
self._on_missed_beat(exc)
await self._check_target_async()
await asyncio.sleep(self._interval)
[docs]
@staticmethod
def read_state(state_file: str | Path) -> HeartbeatState | None:
"""Read and parse an existing heartbeat state file.
Args:
state_file: Path to heartbeat file.
Returns:
HeartbeatState if file exists and is valid, None otherwise.
Examples:
>>> state = Heartbeat.read_state("/tmp/bot.heartbeat") # doctest: +SKIP
>>> if state: # doctest: +SKIP
... print(f"Last beat: {state.timestamp}")
"""
path = Path(state_file)
if not path.exists():
return None
try:
data = json.loads(path.read_text())
return HeartbeatState.from_dict(data)
except (json.JSONDecodeError, KeyError, OSError):
return None
[docs]
@staticmethod
def is_alive(state_file: str | Path, max_age_seconds: float = 30.0) -> bool:
"""Check if a process is alive based on its heartbeat.
Args:
state_file: Path to heartbeat file.
max_age_seconds: Maximum age before considering process dead.
Returns:
True if heartbeat exists and is recent enough.
Examples:
>>> Heartbeat.is_alive("/tmp/bot.heartbeat", max_age_seconds=30) # doctest: +SKIP
True
"""
state = Heartbeat.read_state(state_file)
if state is None:
return False
try:
beat_time = datetime.fromisoformat(state.timestamp)
age = (datetime.now(timezone.utc) - beat_time).total_seconds()
return age <= max_age_seconds
except (ValueError, TypeError):
return False
[docs]
def __enter__(self) -> Self:
"""Enter sync context manager."""
self.start()
return self
[docs]
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: types.TracebackType | None,
) -> None:
"""Exit sync context manager."""
self.stop()
[docs]
async def __aenter__(self) -> Self:
"""Enter async context manager."""
await self.astart()
return self
[docs]
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: types.TracebackType | None,
) -> None:
"""Exit async context manager."""
await self.astop()
__all__ = ["Heartbeat", "HeartbeatState", "HeartbeatTarget", "OnAlertCallback"]