"""Worker loop: fetch ready jobs, dispatch handlers, retry on failure.
This module is the most critical part of the runtime. The design goals:
* No double-processing. ``FOR UPDATE SKIP LOCKED`` makes the row claim
atomic; concurrency control lives in SQL, not the application.
* No silent loss. Jobs only leave ``executing`` via ``completed``,
``retryable``, ``discarded``, ``cancelled``, or via the orphan reaper.
* Crash-tolerant. SIGKILL'd workers leave their rows in ``executing``;
the reaper running in any peer worker drains them after a configurable
staleness window.
* Resilient to transient Postgres failures. Network blips and listener
drops are caught, logged, and retried with backoff — they never kill
the worker process.
"""
from __future__ import annotations
import asyncio
import contextlib
import inspect
import os
import signal
import socket
import traceback
import uuid
from collections.abc import Callable, Iterable
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any
import structlog
from roost import observability
from roost._core import repo
from roost._core.cron import run_scheduler
from roost._core.notify import CHANNEL_CANCEL_REQUESTED, CHANNEL_INSERTED
from roost._core.retry import BackoffStrategy, resolve
from roost.decorators import DEFAULT_HANDLERS, HandlerRegistry, HandlerSpec
from roost.exceptions import SnoozeJob, UnknownTaskError
from roost.hooks import Hooks, call_after, call_before
if TYPE_CHECKING: # pragma: no cover
import asyncpg
_log = structlog.get_logger(__name__)
[docs]
class Worker:
"""A single-process Roost worker.
Multiple workers can run against the same database — concurrency control
is enforced by ``FOR UPDATE SKIP LOCKED`` at the SQL level.
"""
def __init__(
self,
dsn: str,
*,
queues: Iterable[str] = ("default",),
concurrency: int = 4,
prefetch: int | None = None,
poll_interval: float = 1.0,
retry_strategy: BackoffStrategy | None = None,
registry: HandlerRegistry | None = None,
run_cron: bool = True,
heartbeat_interval: float = 15.0,
orphan_reaper_interval: float = 30.0,
orphan_stale_after: float = 5 * 60.0,
shutdown_timeout: float = 30.0,
listen_reconnect_delay: float = 1.0,
error_cap: int = 20,
archive_after_seconds: float | None = None,
archive_interval: float = 60.0,
result_ttl_seconds: float | None = None,
startup_max_retries: int = 30,
startup_retry_delay: float = 1.0,
hooks: Hooks | None = None,
) -> None:
self.dsn = dsn
self.queues = list(queues)
if not self.queues:
raise ValueError("queues must not be empty")
if concurrency < 1:
raise ValueError("concurrency must be >= 1")
self.concurrency = concurrency
self.prefetch = prefetch if prefetch is not None else concurrency
self.poll_interval = poll_interval
self.retry_strategy = resolve(retry_strategy)
self.registry = registry or DEFAULT_HANDLERS
self.run_cron = run_cron
self.heartbeat_interval = heartbeat_interval
self.orphan_reaper_interval = orphan_reaper_interval
self.orphan_stale_after = orphan_stale_after
self.shutdown_timeout = shutdown_timeout
self.listen_reconnect_delay = listen_reconnect_delay
self.error_cap = max(1, int(error_cap))
self.archive_after_seconds = archive_after_seconds
self.archive_interval = archive_interval
self.result_ttl_seconds = result_ttl_seconds
self.startup_max_retries = max(1, int(startup_max_retries))
self.startup_retry_delay = startup_retry_delay
self.hooks = hooks
self.id = f"{socket.gethostname()}-{os.getpid()}-{uuid.uuid4().hex[:8]}"
self._stop = asyncio.Event()
self._wakeup = asyncio.Event()
self._inflight: set[asyncio.Task[None]] = set()
# Map job_id -> running asyncio Task so cancel-requests can find it.
self._running: dict[int, asyncio.Task[None]] = {}
# ------------------------------------------------------------------
# public API
# ------------------------------------------------------------------
[docs]
async def run(self) -> None:
pool = await self._open_pool_with_retry()
background: list[asyncio.Task[None]] = []
try:
background.append(asyncio.create_task(self._listen_loop(), name="roost-listener"))
background.append(asyncio.create_task(self._cancel_listen_loop(), name="roost-cancel"))
background.append(asyncio.create_task(self._heartbeat_loop(pool), name="roost-heartbeat"))
background.append(asyncio.create_task(self._reaper_loop(pool), name="roost-reaper"))
if self.archive_after_seconds is not None or self.result_ttl_seconds is not None:
background.append(asyncio.create_task(self._archive_loop(pool), name="roost-archive"))
if self.run_cron:
background.append(
asyncio.create_task(
run_scheduler(pool, stop_event=self._stop, dsn=self.dsn),
name="roost-cron",
)
)
await self._main_loop(pool)
finally:
self._stop.set()
for task in background:
task.cancel()
for task in background:
with contextlib.suppress(asyncio.CancelledError, Exception):
await task
await self._drain_inflight()
with contextlib.suppress(Exception):
async with pool.acquire() as conn:
await repo.deregister_worker_async(conn, self.id)
await pool.close()
_log.info("worker.stopped", id=self.id)
[docs]
async def run_once(self) -> int:
"""Drain every currently-available job and exit. Returns count processed.
Skips the cron scheduler, archive loop, and orphan reaper — this is
intended for one-shot invocations (CI smokes, cron-style runs of
``roost run --once``, programmatic test helpers).
Heartbeats and the listen connection still come up briefly; the
worker still claims rows via ``FOR UPDATE SKIP LOCKED`` so it's
safe to run alongside other workers.
"""
pool = await self._open_pool_with_retry()
processed = 0
try:
async with pool.acquire() as conn:
await repo.heartbeat_async(
conn,
worker_id=self.id,
hostname=socket.gethostname(),
pid=os.getpid(),
queues=self.queues,
metadata={"mode": "once", "concurrency": self.concurrency},
)
while True:
await self._promote_retryable(pool)
picked = await self._fetch_batch(pool)
if picked == 0:
break
processed += picked
# Wait for the inflight tasks from this batch before checking again.
if self._inflight:
await asyncio.gather(*self._inflight, return_exceptions=True)
finally:
with contextlib.suppress(Exception):
async with pool.acquire() as conn:
await repo.deregister_worker_async(conn, self.id)
await pool.close()
_log.info("worker.run_once.done", processed=processed, id=self.id)
return processed
[docs]
def request_stop(self) -> None:
self._stop.set()
self._wakeup.set()
[docs]
def install_signal_handlers(self, loop: asyncio.AbstractEventLoop) -> None:
for sig in (signal.SIGTERM, signal.SIGINT):
try:
loop.add_signal_handler(sig, self.request_stop)
except NotImplementedError: # pragma: no cover — Windows
signal.signal(sig, lambda *_: self.request_stop())
# ------------------------------------------------------------------
# main loop
# ------------------------------------------------------------------
async def _main_loop(self, pool: asyncpg.Pool) -> None:
backoff = 0.0
while not self._stop.is_set():
try:
await self._promote_retryable(pool)
picked = await self._fetch_batch(pool)
backoff = 0.0
if picked == 0:
await self._sleep_or_wakeup(self.poll_interval)
except Exception as exc: # pragma: no cover — defensive
backoff = min(max(backoff * 2, 1.0), 30.0)
_log.warning("worker.cycle_failed", error=str(exc), backoff=backoff)
await self._sleep_or_wakeup(backoff)
async def _promote_retryable(self, pool: asyncpg.Pool) -> None:
async with pool.acquire() as conn:
await repo.promote_retryable_async(conn)
def _task_limits(self) -> dict[str, tuple[int | None, int | None]]:
"""Snapshot per-task throttling from the registered handler defaults."""
out: dict[str, tuple[int | None, int | None]] = {}
for name in self.registry.names():
spec = self.registry.get(name)
if spec is None:
continue
rate = spec.defaults.rate_per_minute
conc = spec.defaults.max_concurrency
if rate is not None or conc is not None:
out[spec.name] = (rate, conc)
return out
async def _fetch_batch(self, pool: asyncpg.Pool) -> int:
free_slots = self.concurrency - len(self._inflight)
if free_slots <= 0:
await asyncio.sleep(0.05)
return 0
batch_size = min(self.prefetch, free_slots)
task_limits = self._task_limits()
async with pool.acquire() as conn, conn.transaction():
jobs = await repo.fetch_available_async(
conn,
self.queues,
batch_size,
task_limits=task_limits,
worker_id=self.id,
)
for job in jobs:
task = asyncio.create_task(self._dispatch(pool, job), name=f"roost-job-{job.id}")
self._inflight.add(task)
self._running[job.id] = task
def _cleanup(t: asyncio.Task[None], jid: int = job.id) -> None:
self._inflight.discard(t)
self._running.pop(jid, None)
task.add_done_callback(_cleanup)
return len(jobs)
async def _drain_inflight(self) -> None:
if not self._inflight:
return
_log.info(
"worker.draining",
inflight=len(self._inflight),
timeout=self.shutdown_timeout,
)
try:
await asyncio.wait_for(
asyncio.gather(*self._inflight, return_exceptions=True),
timeout=self.shutdown_timeout,
)
except asyncio.TimeoutError:
still_running = [t for t in self._inflight if not t.done()]
_log.warning(
"worker.drain_timeout",
cancelled=len(still_running),
timeout=self.shutdown_timeout,
)
for t in still_running:
t.cancel()
with contextlib.suppress(asyncio.CancelledError, Exception):
await asyncio.gather(*still_running, return_exceptions=True)
# ------------------------------------------------------------------
# dispatch
# ------------------------------------------------------------------
async def _dispatch(self, pool: asyncpg.Pool, job: Any) -> None:
spec = self.registry.get(job.task)
labels = {"queue": job.queue, "task": job.task}
started = asyncio.get_running_loop().time()
ctx: dict[str, Any] = {}
result: Any = None
error: BaseException | None = None
try:
await self._safely_call_before(job, ctx)
if spec is None:
raise UnknownTaskError(f"no handler registered for task '{job.task}'")
result = await self._invoke(spec, job)
async with pool.acquire() as conn:
await repo.mark_completed_async(conn, job.id, result=result)
duration = asyncio.get_running_loop().time() - started
observability.JOBS_COMPLETED.labels(**labels).inc()
observability.JOB_DURATION.labels(**labels).observe(duration)
_log.info(
"job.completed",
id=job.id,
task=job.task,
attempt=job.attempt,
duration=round(duration, 4),
)
except SnoozeJob as snooze:
error = snooze
when = datetime.now(tz=timezone.utc) + timedelta(seconds=snooze.seconds)
async with pool.acquire() as conn:
await repo.snooze_async(conn, job.id, when)
_log.info("job.snoozed", id=job.id, task=job.task, seconds=snooze.seconds)
except BaseException as exc: # noqa: BLE001 — surfaced into errors[]
error = exc
await self._handle_failure(pool, job, exc)
finally:
await self._safely_call_after(job, result=result, error=error, ctx=ctx)
async def _safely_call_before(self, job: Any, ctx: dict[str, Any]) -> None:
if self.hooks is None:
return
try:
await call_before(self.hooks, job, ctx)
except Exception as exc:
_log.warning("hooks.before_failed", id=job.id, error=str(exc))
async def _safely_call_after(
self, job: Any, *, result: Any, error: BaseException | None, ctx: dict[str, Any]
) -> None:
if self.hooks is None:
return
try:
await call_after(self.hooks, job, result=result, error=error, ctx=ctx)
except Exception as exc:
_log.warning("hooks.after_failed", id=job.id, error=str(exc))
@staticmethod
async def _invoke(spec: HandlerSpec, job: Any) -> Any:
# Strip the trace carrier off args so it's not forwarded to the handler.
args = dict(job.args or {})
_, carrier = observability.strip_trace_context(args)
timeout = job.timeout_seconds
async def _run() -> Any:
with observability.job_span(
f"job:{job.task}",
{
"roost.job.id": job.id,
"roost.job.queue": job.queue,
"roost.job.task": job.task,
"roost.job.attempt": job.attempt,
},
carrier,
):
if spec.is_async:
return await spec.func(**args)
return await asyncio.to_thread(_call_sync_handler, spec.func, args)
if timeout and timeout > 0:
return await asyncio.wait_for(_run(), timeout=float(timeout))
return await _run()
async def _handle_failure(self, pool: asyncpg.Pool, job: Any, exc: BaseException) -> None:
next_attempt = job.attempt
error_payload = {
"attempt": next_attempt,
"at": datetime.now(tz=timezone.utc).isoformat(),
"error": f"{type(exc).__name__}: {exc}",
"trace": "".join(traceback.format_exception(exc)).strip(),
}
try:
async with pool.acquire() as conn:
if isinstance(exc, asyncio.CancelledError):
# Re-fetch cancel_requested in case the row updated after dispatch.
requested = await conn.fetchval(
"SELECT cancel_requested FROM roost.jobs WHERE id = $1", job.id
)
if requested:
await repo.finalize_cancel_async(conn, job.id)
observability.JOBS_FAILED.labels(
queue=job.queue, task=job.task, outcome="cancelled"
).inc()
_log.info("job.cancelled", id=job.id, task=job.task)
return
if next_attempt >= job.max_attempts:
await repo.mark_discarded_async(conn, job.id, error_payload, error_cap=self.error_cap)
observability.JOBS_FAILED.labels(
queue=job.queue, task=job.task, outcome="discarded"
).inc()
_log.warning(
"job.discarded",
id=job.id,
task=job.task,
attempt=next_attempt,
error=error_payload["error"],
)
else:
delay = float(self.retry_strategy(next_attempt))
when = datetime.now(tz=timezone.utc) + timedelta(seconds=delay)
await repo.mark_retryable_async(
conn, job.id, when, error_payload, error_cap=self.error_cap
)
observability.JOBS_FAILED.labels(
queue=job.queue, task=job.task, outcome="retryable"
).inc()
_log.info(
"job.retry_scheduled",
id=job.id,
task=job.task,
attempt=next_attempt,
delay=delay,
error=error_payload["error"],
)
except Exception as inner: # pragma: no cover — defensive
_log.error(
"job.failure_record_failed",
id=job.id,
task=job.task,
error=str(inner),
)
# ------------------------------------------------------------------
# background loops
# ------------------------------------------------------------------
async def _listen_loop(self) -> None:
"""Maintain a LISTEN connection. Reconnects on drop."""
import asyncpg
def _handler(_conn: object, _pid: int, _channel: str, payload: str) -> None:
if payload in self.queues:
self._wakeup.set()
while not self._stop.is_set():
conn: asyncpg.Connection | None = None
try:
conn = await asyncpg.connect(self.dsn)
await repo.init_connection(conn)
await conn.add_listener(CHANNEL_INSERTED, _handler)
_log.info("listener.connected")
# Hold the connection until shutdown or it dies.
while not self._stop.is_set():
if conn.is_closed():
raise ConnectionError("listen connection closed")
await asyncio.sleep(self.listen_reconnect_delay)
except asyncio.CancelledError:
raise
except Exception as exc:
_log.warning("listener.error", error=str(exc))
await self._sleep_or_wakeup(min(self.listen_reconnect_delay * 5, 30.0))
finally:
if conn is not None:
with contextlib.suppress(Exception):
await conn.remove_listener(CHANNEL_INSERTED, _handler)
with contextlib.suppress(Exception):
await conn.close()
async def _cancel_listen_loop(self) -> None:
"""Cancel in-flight handlers when ``cancel_requested`` flips to true."""
import asyncpg
loop = asyncio.get_running_loop()
def _handler(_conn: object, _pid: int, _channel: str, payload: str) -> None:
try:
jid = int(payload)
except (TypeError, ValueError):
return
task = self._running.get(jid)
if task is not None and not task.done():
_log.info("job.cancel_signaled", id=jid)
loop.call_soon_threadsafe(task.cancel)
while not self._stop.is_set():
conn: asyncpg.Connection | None = None
try:
conn = await asyncpg.connect(self.dsn)
await repo.init_connection(conn)
await conn.add_listener(CHANNEL_CANCEL_REQUESTED, _handler)
while not self._stop.is_set():
if conn.is_closed():
raise ConnectionError("cancel-listen connection closed")
await asyncio.sleep(self.listen_reconnect_delay)
except asyncio.CancelledError:
raise
except Exception as exc:
_log.warning("cancel_listener.error", error=str(exc))
await self._sleep_or_wakeup(min(self.listen_reconnect_delay * 5, 30.0))
finally:
if conn is not None:
with contextlib.suppress(Exception):
await conn.remove_listener(CHANNEL_CANCEL_REQUESTED, _handler)
with contextlib.suppress(Exception):
await conn.close()
async def _heartbeat_loop(self, pool: asyncpg.Pool) -> None:
hostname = socket.gethostname()
pid = os.getpid()
while not self._stop.is_set():
try:
async with pool.acquire() as conn:
await repo.heartbeat_async(
conn,
worker_id=self.id,
hostname=hostname,
pid=pid,
queues=self.queues,
metadata={
"concurrency": self.concurrency,
"inflight": len(self._inflight),
},
)
except asyncio.CancelledError:
raise
except Exception as exc:
_log.warning("worker.heartbeat_failed", error=str(exc))
await self._sleep_or_wakeup(self.heartbeat_interval)
async def _open_pool_with_retry(self) -> asyncpg.Pool:
"""Open the pool, retrying with backoff if Postgres isn't ready yet.
Critical in containerized deploys where the worker may boot before
the Postgres container's healthcheck flips green.
"""
import asyncpg
delay = max(0.0, self.startup_retry_delay)
last_exc: BaseException | None = None
for attempt in range(1, self.startup_max_retries + 1):
try:
return await asyncpg.create_pool(
self.dsn,
min_size=1,
max_size=self.concurrency + 4,
init=repo.init_connection,
)
except Exception as exc:
last_exc = exc
_log.warning(
"worker.startup_pool_failed",
attempt=attempt,
max_retries=self.startup_max_retries,
error=str(exc),
)
if attempt == self.startup_max_retries:
break
await asyncio.sleep(min(delay, 30.0))
delay = min(delay * 1.5 + 0.5, 30.0)
assert last_exc is not None
raise last_exc
async def _archive_loop(self, pool: asyncpg.Pool) -> None:
"""Periodic terminal-job archive plus optional result-TTL clear."""
if self.archive_after_seconds is None and self.result_ttl_seconds is None:
return
while not self._stop.is_set():
try:
async with pool.acquire() as conn:
if self.archive_after_seconds is not None:
moved = await repo.archive_terminal_async(
conn, older_than_seconds=self.archive_after_seconds
)
if moved:
_log.info("worker.archived", count=moved)
if self.result_ttl_seconds is not None:
cleared = await repo.clear_old_results_async(
conn, older_than_seconds=self.result_ttl_seconds
)
if cleared:
_log.info("worker.results_cleared", count=cleared)
except asyncio.CancelledError:
raise
except Exception as exc:
_log.warning("worker.archive_failed", error=str(exc))
await self._sleep_or_wakeup(self.archive_interval)
async def _reaper_loop(self, pool: asyncpg.Pool) -> None:
while not self._stop.is_set():
try:
async with pool.acquire() as conn:
reaped = await repo.reap_orphans_async(conn, stale_after_seconds=self.orphan_stale_after)
gced = await repo.gc_workers_async(
conn, stale_after_seconds=max(self.heartbeat_interval * 4, 60.0)
)
blocked = await repo.cancel_blocked_dependents_async(conn)
if reaped:
_log.warning(
"worker.reaped_orphans",
count=len(reaped),
ids=[i for i, _ in reaped],
)
if gced:
_log.info("worker.gc_workers", count=gced)
if blocked:
_log.info("worker.blocked_dependents_cancelled", count=len(blocked), ids=blocked)
except asyncio.CancelledError:
raise
except Exception as exc:
_log.warning("worker.reaper_failed", error=str(exc))
await self._sleep_or_wakeup(self.orphan_reaper_interval)
async def _sleep_or_wakeup(self, seconds: float) -> None:
if seconds <= 0:
return
try:
await asyncio.wait_for(self._wakeup.wait(), timeout=seconds)
except asyncio.TimeoutError:
pass
finally:
self._wakeup.clear()
def _call_sync_handler(func: Callable[..., Any], args: dict[str, Any]) -> Any:
result = func(**args)
if inspect.isawaitable(result):
raise TypeError(
f"sync handler {func.__qualname__} returned an awaitable — use `async def` for async handlers"
)
return result
__all__ = ["Worker"]