Source code for roost._core.cron
"""Cluster-wide cron scheduler.
A single advisory lock guarantees only one scheduler is active per database
at a time — workers can be horizontally scaled without double-firing crons.
"""
from __future__ import annotations
import asyncio
import contextlib
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any
import structlog
from croniter import croniter
try: # 3.9+ stdlib; we require 3.10+ so always present
from zoneinfo import ZoneInfo
except ImportError: # pragma: no cover
ZoneInfo = None # type: ignore[assignment,misc]
from roost._core import repo
if TYPE_CHECKING: # pragma: no cover
import asyncpg
_log = structlog.get_logger(__name__)
# A stable arbitrary 64-bit constant for the advisory lock so multiple
# Roost-using applications on the same database don't collide.
ADVISORY_LOCK_KEY = 0x52_4F_4F_53_54_43_52_4E # 'ROOSTCRN'
[docs]
@dataclass(frozen=True)
class CronEntry:
name: str
expression: str
task: str
args: dict[str, Any] = field(default_factory=dict)
queue: str = "default"
priority: int = 0
max_attempts: int = 20
timezone_name: str | None = None # IANA name; e.g. "America/Los_Angeles". None == UTC.
def _tz(self) -> Any:
if self.timezone_name is None or ZoneInfo is None:
return timezone.utc
return ZoneInfo(self.timezone_name)
def _localise(self, when: datetime) -> datetime:
"""Render ``when`` in the entry's local timezone for cron evaluation."""
if when.tzinfo is None:
when = when.replace(tzinfo=timezone.utc)
return when.astimezone(self._tz())
[docs]
def next_after(self, now: datetime) -> datetime:
itr = croniter(self.expression, self._localise(now))
nxt = itr.get_next(datetime)
if nxt.tzinfo is None:
nxt = nxt.replace(tzinfo=self._tz())
return nxt.astimezone(timezone.utc)
[docs]
def previous_or_at(self, now: datetime) -> datetime:
itr = croniter(self.expression, self._localise(now))
prev = itr.get_prev(datetime)
if prev.tzinfo is None:
prev = prev.replace(tzinfo=self._tz())
return prev.astimezone(timezone.utc)
class CronRegistry:
def __init__(self) -> None:
self._entries: dict[str, CronEntry] = {}
def register(self, entry: CronEntry) -> None:
if entry.name in self._entries:
existing = self._entries[entry.name]
if existing == entry:
return
raise ValueError(f"cron name '{entry.name}' is already registered")
self._entries[entry.name] = entry
def all(self) -> list[CronEntry]:
return list(self._entries.values())
# Module-global default registry. The decorators import this directly.
DEFAULT_REGISTRY = CronRegistry()
async def run_scheduler(
pool: asyncpg.Pool,
registry: CronRegistry | None = None,
*,
interval_seconds: float = 60.0,
stop_event: asyncio.Event | None = None,
dsn: str | None = None,
) -> None:
"""Long-running coroutine. Wakes every ``interval_seconds`` and enqueues
any cron entries whose previous run is overdue.
Holds the advisory lock on a dedicated connection (so we don't tie up a
pool slot for the lifetime of the worker), and acquires fresh pool
connections per tick for the actual work. If ``dsn`` is omitted the lock
connection is borrowed from the pool — convenient but it costs a slot.
"""
import asyncpg
reg = registry or DEFAULT_REGISTRY
stop_event = stop_event or asyncio.Event()
lock_conn: asyncpg.Connection | None = None
borrowed = False
if dsn is not None:
lock_conn = await asyncpg.connect(dsn)
await repo.init_connection(lock_conn)
else:
# Fallback: borrow from the pool. Caller pays a slot.
lock_conn = await pool.acquire()
borrowed = True
try:
if lock_conn is None:
return
if not await repo.cron_try_lock_async(lock_conn, ADVISORY_LOCK_KEY):
_log.info("cron.skip.lock_held")
return
_log.info("cron.lock_acquired")
try:
while not stop_event.is_set():
try:
async with pool.acquire() as work_conn:
await _run_once(work_conn, reg)
except Exception as exc: # pragma: no cover — defensive
_log.warning("cron.tick_failed", error=str(exc))
with contextlib.suppress(asyncio.TimeoutError):
await asyncio.wait_for(stop_event.wait(), timeout=interval_seconds)
finally:
with contextlib.suppress(Exception):
await repo.cron_unlock_async(lock_conn, ADVISORY_LOCK_KEY)
_log.info("cron.lock_released")
finally:
if lock_conn is not None:
if borrowed:
with contextlib.suppress(Exception):
await pool.release(lock_conn)
else:
with contextlib.suppress(Exception):
await lock_conn.close()
async def _run_once(conn: asyncpg.Connection, registry: CronRegistry) -> None:
now = datetime.now(tz=timezone.utc)
for entry in registry.all():
due_at = entry.previous_or_at(now)
# croniter.get_prev returns naive sometimes; coerce to UTC.
if due_at.tzinfo is None:
due_at = due_at.replace(tzinfo=timezone.utc)
if due_at > now:
continue
try:
should_enqueue = await repo.cron_should_run_async(conn, entry.name, due_at)
except Exception as exc: # pragma: no cover — defensive
_log.warning("cron.claim_failed", name=entry.name, error=str(exc))
continue
if not should_enqueue:
continue
try:
await repo.enqueue_async(
conn,
task=entry.task,
args=entry.args,
queue=entry.queue,
priority=entry.priority,
max_attempts=entry.max_attempts,
scheduled_at=now,
unique_key=f"cron:{entry.name}:{int(due_at.timestamp())}",
)
_log.info("cron.enqueued", name=entry.name, due_at=due_at.isoformat())
except Exception as exc: # pragma: no cover — defensive
_log.warning("cron.enqueue_failed", name=entry.name, error=str(exc))