Source code for roost._core.wait

"""``wait_for_async`` — block until a job reaches a terminal state.

Backed by ``LISTEN roost_updated`` so it reacts as fast as the trigger
fires; falls back to polling at ``poll_interval`` for resilience.
"""

from __future__ import annotations

import asyncio
import contextlib
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, ClassVar

from roost._core.notify import CHANNEL_UPDATED
from roost._core.repo import init_connection
from roost._core.states import TERMINAL_STATES

if TYPE_CHECKING:  # pragma: no cover
    import asyncpg


[docs] class JobTimeoutError(TimeoutError): """Raised when ``wait_for_async`` exceeds its timeout.""" code: ClassVar[str] = "roost.job-timeout"
[docs] class JobFailed(RuntimeError): """Raised when the awaited job ended in ``discarded`` or ``cancelled``.""" code: ClassVar[str] = "roost.job-failed" def __init__(self, job_id: int, state: str, errors: list[dict[str, Any]] | None = None): self.job_id = job_id self.state = state self.errors = errors or [] last = self.errors[-1]["error"] if self.errors else "(no error recorded)" super().__init__(f"job {job_id} ended in state {state!r}: {last}")
[docs] @dataclass(frozen=True) class JobOutcome: id: int state: str result: Any | None errors: list[dict[str, Any]]
async def wait_for_async( dsn: str, job_id: int, *, timeout: float | None = 30.0, poll_interval: float = 1.0, raise_on_failure: bool = True, ) -> JobOutcome: """Wait until ``job_id`` reaches a terminal state and return its row. Pass ``raise_on_failure=False`` to receive the :class:`JobOutcome` even when the job ended in ``discarded`` / ``cancelled`` (default: raise :class:`JobFailed`). Pass ``timeout=None`` to wait indefinitely. """ import asyncpg deadline = None if timeout is None else asyncio.get_running_loop().time() + timeout wakeup = asyncio.Event() def _on_update(_conn: object, _pid: int, _channel: str, payload: str) -> None: try: updated = int(payload) except (TypeError, ValueError): return if updated == job_id: wakeup.set() listen_conn = await asyncpg.connect(dsn) poll_conn = await asyncpg.connect(dsn) await init_connection(listen_conn) await init_connection(poll_conn) try: await listen_conn.add_listener(CHANNEL_UPDATED, _on_update) # Initial check — the job may already be terminal. outcome = await _check_terminal(poll_conn, job_id) if outcome is not None: return _maybe_raise(outcome, raise_on_failure) while True: remaining = None if deadline is None else max(0.0, deadline - asyncio.get_running_loop().time()) wait_for = poll_interval if remaining is None else min(remaining, poll_interval) with contextlib.suppress(asyncio.TimeoutError): await asyncio.wait_for(wakeup.wait(), timeout=wait_for) wakeup.clear() outcome = await _check_terminal(poll_conn, job_id) if outcome is not None: return _maybe_raise(outcome, raise_on_failure) if deadline is not None and asyncio.get_running_loop().time() >= deadline: raise JobTimeoutError(f"job {job_id} did not finish within {timeout}s") finally: with contextlib.suppress(Exception): await listen_conn.remove_listener(CHANNEL_UPDATED, _on_update) with contextlib.suppress(Exception): await listen_conn.close() with contextlib.suppress(Exception): await poll_conn.close() async def _check_terminal(conn: asyncpg.Connection, job_id: int) -> JobOutcome | None: row = await conn.fetchrow("SELECT id, state, result, errors FROM roost.jobs WHERE id = $1", job_id) if row is None: return None state = row["state"] if state not in TERMINAL_STATES: return None return JobOutcome( id=int(row["id"]), state=str(state), result=row["result"], errors=list(row["errors"] or []), ) def _maybe_raise(outcome: JobOutcome, raise_on_failure: bool) -> JobOutcome: if raise_on_failure and outcome.state in {"discarded", "cancelled"}: raise JobFailed(outcome.id, outcome.state, outcome.errors) return outcome __all__ = ["JobFailed", "JobOutcome", "JobTimeoutError", "wait_for_async"]