"""``@job`` and ``@cron`` decorators backed by module-level registries."""
from __future__ import annotations
import functools
import inspect
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any, TypeVar, cast
from pydantic import BaseModel
from roost._core.cron import DEFAULT_REGISTRY, CronEntry
F = TypeVar("F", bound=Callable[..., Any])
[docs]
@dataclass(frozen=True)
class TaskDefaults:
"""Per-task enqueue defaults declared on ``@job(...)``.
The first five (``queue`` … ``timeout_seconds``) are merged into
``enqueue`` calls — explicit kwargs always win.
The throttling fields (``rate_per_minute``, ``max_concurrency``) are
enforced at fetch time by the worker. Workers read the registry and
pass the limits into the fetch SQL.
"""
queue: str | None = None
priority: int | None = None
max_attempts: int | None = None
tags: tuple[str, ...] | None = None
timeout_seconds: int | None = None
rate_per_minute: int | None = None
max_concurrency: int | None = None
@dataclass(frozen=True)
class HandlerSpec:
name: str
func: Callable[..., Any]
is_async: bool
args_model: type[BaseModel] | None = None
defaults: TaskDefaults = field(default_factory=TaskDefaults)
[docs]
class HandlerRegistry:
def __init__(self) -> None:
self._handlers: dict[str, HandlerSpec] = {}
[docs]
def register(
self,
name: str,
func: Callable[..., Any],
*,
args_model: type[BaseModel] | None = None,
defaults: TaskDefaults | None = None,
) -> None:
if name in self._handlers and self._handlers[name].func is not func:
raise ValueError(f"task '{name}' is already registered to a different function")
self._handlers[name] = HandlerSpec(
name=name,
func=func,
is_async=inspect.iscoroutinefunction(func),
args_model=args_model,
defaults=defaults or TaskDefaults(),
)
[docs]
def get(self, name: str) -> HandlerSpec | None:
return self._handlers.get(name)
[docs]
def specs(self) -> list[HandlerSpec]:
"""Return every registered :class:`HandlerSpec`, sorted by name.
Useful for building admin UIs or generating manifests::
for spec in roost_handlers.specs():
print(spec.name, spec.defaults.queue, spec.args_model)
"""
return [self._handlers[name] for name in sorted(self._handlers)]
[docs]
def names(self) -> list[str]:
return sorted(self._handlers)
[docs]
def clear(self) -> None:
self._handlers.clear()
DEFAULT_HANDLERS = HandlerRegistry()
[docs]
def job(
name: str,
*,
args_model: type[BaseModel] | None = None,
queue: str | None = None,
priority: int | None = None,
max_attempts: int | None = None,
tags: list[str] | tuple[str, ...] | None = None,
timeout_seconds: int | None = None,
rate_per_minute: int | None = None,
max_concurrency: int | None = None,
registry: HandlerRegistry | None = None,
) -> Callable[[F], F]:
"""Register ``func`` as the handler for the task ``name``.
Per-task defaults (``queue``, ``priority``, ``max_attempts``, ``tags``,
``timeout_seconds``) are applied to every enqueue of this task unless
the caller passes an explicit kwarg.
Pass ``args_model=`` (a Pydantic model) to validate enqueued args at
handler-call time.
The decorated function is returned untouched — it can still be called
directly in tests.
"""
target = registry or DEFAULT_HANDLERS
defaults = TaskDefaults(
queue=queue,
priority=priority,
max_attempts=max_attempts,
tags=tuple(tags) if tags is not None else None,
timeout_seconds=timeout_seconds,
rate_per_minute=rate_per_minute,
max_concurrency=max_concurrency,
)
def _decorate(func: F) -> F:
if args_model is not None:
wrapped = _wrap_with_validation(func, args_model)
target.register(name, wrapped, args_model=args_model, defaults=defaults)
wrapped.__roost_task_name__ = name # type: ignore[attr-defined]
return cast(F, wrapped)
target.register(name, func, defaults=defaults)
func.__roost_task_name__ = name # type: ignore[attr-defined]
return func
return _decorate
def _wrap_with_validation(func: F, model: type[BaseModel]) -> Callable[..., Any]:
"""Validate inbound kwargs against ``model`` before calling ``func``."""
if inspect.iscoroutinefunction(func):
@functools.wraps(func)
async def _async_wrapper(**kwargs: Any) -> Any:
validated = model(**kwargs)
return await func(**validated.model_dump())
return _async_wrapper
@functools.wraps(func)
def _sync_wrapper(**kwargs: Any) -> Any:
validated = model(**kwargs)
return func(**validated.model_dump())
return _sync_wrapper
[docs]
def cron(
expression: str,
*,
name: str | None = None,
queue: str = "default",
args: dict[str, Any] | None = None,
priority: int = 0,
max_attempts: int = 20,
timezone: str | None = None,
handler_registry: HandlerRegistry | None = None,
) -> Callable[[F], F]:
"""Register a function as a cron handler under ``expression``.
``timezone`` accepts an IANA name (``"America/Los_Angeles"``,
``"Europe/Berlin"``). Defaults to UTC. The cron expression is then
interpreted in that local timezone, including DST.
"""
handler_target = handler_registry or DEFAULT_HANDLERS
def _decorate(func: F) -> F:
task_name = name or getattr(func, "__roost_task_name__", None) or func.__name__
handler_target.register(task_name, func)
func.__roost_task_name__ = task_name # type: ignore[attr-defined]
DEFAULT_REGISTRY.register(
CronEntry(
name=task_name,
expression=expression,
task=task_name,
args=dict(args or {}),
queue=queue,
priority=priority,
max_attempts=max_attempts,
timezone_name=timezone,
)
)
return func
return _decorate
def task_name(func: Callable[..., Any]) -> str:
"""Resolve the registered task name for ``func`` or raise."""
name = getattr(func, "__roost_task_name__", None)
if name is None:
raise ValueError(f"function {func!r} is not a registered Roost task — did you forget @job?")
return cast(str, name)