From 0bd0f1709e94a30e29c9812129581ca407de2981 Mon Sep 17 00:00:00 2001 From: "sergei.romanchuk" Date: Sun, 17 May 2026 00:12:57 +0200 Subject: [PATCH 1/5] experiment --- docs/examples/router/multiple_brokers.py | 65 +++++ docs/examples/router/shared_task_package.py | 48 ++++ taskiq/__init__.py | 11 + taskiq/abc/broker.py | 63 ++++- taskiq/brokers/inmemory_broker.py | 12 +- taskiq/context.py | 2 +- taskiq/flow.py | 58 +++++ taskiq/kicker.py | 142 +++++++++-- taskiq/router.py | 268 ++++++++++++++++++++ taskiq/task_builder.py | 107 ++++++++ tests/test_router.py | 164 ++++++++++++ 11 files changed, 910 insertions(+), 30 deletions(-) create mode 100644 docs/examples/router/multiple_brokers.py create mode 100644 docs/examples/router/shared_task_package.py create mode 100644 taskiq/flow.py create mode 100644 taskiq/router.py create mode 100644 taskiq/task_builder.py create mode 100644 tests/test_router.py diff --git a/docs/examples/router/multiple_brokers.py b/docs/examples/router/multiple_brokers.py new file mode 100644 index 00000000..c1d5c1aa --- /dev/null +++ b/docs/examples/router/multiple_brokers.py @@ -0,0 +1,65 @@ +"""Route one task through several brokers with a shared router.""" + +import asyncio + +from taskiq import Flow, InMemoryBroker, TaskiqRouter + +router = TaskiqRouter() + +default_email_flow = Flow.queue("emails.default") +priority_email_flow = Flow.queue("emails.priority") +bulk_email_flow = Flow.queue("emails.bulk") + +default_broker = InMemoryBroker( + router=router, + broker_name="default", + default_flow=default_email_flow, + await_inplace=True, +) +priority_broker = InMemoryBroker( + router=router, + broker_name="priority", + default_flow=priority_email_flow, + await_inplace=True, +) + + +@default_broker.task(task_name="examples.send_email", domain="notifications") +async def send_email(user_id: int, template: str) -> str: + """Pretend to render and send an email.""" + return f"{template} email sent to user {user_id}" + + +router.route_task( + send_email.task_name, + broker="priority", + flow=priority_email_flow, +) + + +async def _main() -> None: + await default_broker.startup() + await priority_broker.startup() + try: + direct_result = await send_email(7, "welcome") + + routed_task = await send_email.kiq(7, "welcome") + routed_result = await routed_task.wait_result(timeout=2) + + bulk_task = await send_email.kicker().with_route( + "default", + bulk_email_flow, + ).kiq(8, "digest") + bulk_result = await bulk_task.wait_result(timeout=2) + + print(f"Direct call: {direct_result}") + print(f"Default route: {router.resolve_route(send_email.task_name)}") + print(f"Routed call: {routed_result.return_value}") + print(f"Route override: {bulk_result.return_value}") + finally: + await priority_broker.shutdown() + await default_broker.shutdown() + + +if __name__ == "__main__": + asyncio.run(_main()) diff --git a/docs/examples/router/shared_task_package.py b/docs/examples/router/shared_task_package.py new file mode 100644 index 00000000..0614679e --- /dev/null +++ b/docs/examples/router/shared_task_package.py @@ -0,0 +1,48 @@ +"""Declare shared task definitions and bind them in the final application.""" + +import asyncio + +from taskiq import Flow, InMemoryBroker, TaskiqRouter, task_builder + + +@task_builder("billing.calculate_total", domain="billing") +async def calculate_total(price: int, quantity: int) -> int: + """Package-level task definition that is not bound to any broker.""" + return price * quantity + + +router = TaskiqRouter() +billing_flow = Flow.queue("billing.tasks") +priority_billing_flow = Flow.queue("billing.priority") + +billing_broker = InMemoryBroker( + router=router, + broker_name="billing", + default_flow=billing_flow, + await_inplace=True, +) + +registered_calculate_total = billing_broker.register_task(calculate_total) + + +async def _main() -> None: + await billing_broker.startup() + try: + direct_result = await calculate_total.call(19, 3) + + prepared_task = registered_calculate_total.kicker().with_flow( + priority_billing_flow, + ).prepare(19, 3) + + queued_task = await prepared_task.kiq() + queued_result = await queued_task.wait_result(timeout=2) + + print(f"Shared task direct call: {direct_result}") + print(f"Prepared message: {prepared_task.message.task_name}") + print(f"Registered queued call: {queued_result.return_value}") + finally: + await billing_broker.shutdown() + + +if __name__ == "__main__": + asyncio.run(_main()) diff --git a/taskiq/__init__.py b/taskiq/__init__.py index 2414754f..856beb3c 100644 --- a/taskiq/__init__.py +++ b/taskiq/__init__.py @@ -24,7 +24,9 @@ TaskiqError, TaskiqResultTimeoutError, ) +from taskiq.flow import Flow, FlowKind from taskiq.funcs import gather +from taskiq.kicker import PreparedKiq from taskiq.message import BrokerMessage, TaskiqMessage from taskiq.middlewares import ( PrometheusMiddleware, @@ -32,10 +34,12 @@ SmartRetryMiddleware, ) from taskiq.result import TaskiqResult +from taskiq.router import TaskiqRoute, TaskiqRouter from taskiq.scheduler.scheduled_task import ScheduledTask from taskiq.scheduler.scheduler import TaskiqScheduler from taskiq.state import TaskiqState from taskiq.task import AsyncTaskiqTask +from taskiq.task_builder import TaskDefinition, task_builder __version__ = version("taskiq") @@ -47,8 +51,11 @@ "AsyncTaskiqTask", "BrokerMessage", "Context", + "Flow", + "FlowKind", "InMemoryBroker", "NoResultError", + "PreparedKiq", "PrometheusMiddleware", "ResultGetError", "ResultIsReadyError", @@ -58,6 +65,7 @@ "SendTaskError", "SimpleRetryMiddleware", "SmartRetryMiddleware", + "TaskDefinition", "TaskiqDepends", "TaskiqError", "TaskiqEvents", @@ -66,10 +74,13 @@ "TaskiqMiddleware", "TaskiqResult", "TaskiqResultTimeoutError", + "TaskiqRoute", + "TaskiqRouter", "TaskiqScheduler", "TaskiqState", "ZeroMQBroker", "__version__", "async_shared_broker", "gather", + "task_builder", ] diff --git a/taskiq/abc/broker.py b/taskiq/abc/broker.py index ea2e86c0..71cd1c79 100644 --- a/taskiq/abc/broker.py +++ b/taskiq/abc/broker.py @@ -24,11 +24,14 @@ from taskiq.decor import AsyncTaskiqDecoratedTask from taskiq.events import TaskiqEvents from taskiq.exceptions import TaskBrokerMismatchError +from taskiq.flow import Flow from taskiq.formatters.proxy_formatter import ProxyFormatter from taskiq.message import BrokerMessage from taskiq.result_backends.dummy import DummyResultBackend +from taskiq.router import TaskiqRouter from taskiq.serializers.json_serializer import JSONSerializer from taskiq.state import TaskiqState +from taskiq.task_builder import TaskDefinition from taskiq.utils import maybe_awaitable from taskiq.warnings import TaskiqDeprecationWarning @@ -78,6 +81,10 @@ def __init__( self, result_backend: "AsyncResultBackend[_T] | None" = None, task_id_generator: Callable[[], str] | None = None, + *, + router: TaskiqRouter | None = None, + broker_name: str | None = None, + default_flow: Flow | None = None, ) -> None: if result_backend is None: result_backend = DummyResultBackend() @@ -103,6 +110,13 @@ def __init__( self.serializer: TaskiqSerializer = JSONSerializer() self.formatter: TaskiqFormatter = ProxyFormatter(self) self.id_generator = task_id_generator + self.router = router or TaskiqRouter() + self.default_flow = default_flow + self.broker_name = self.router.set_broker( + self, + name=broker_name, + default_flow=default_flow, + ) self.local_task_registry: dict[str, AsyncTaskiqDecoratedTask[Any, Any]] = {} # Every event has a list of handlers. # Every handler is a function which takes state as a first argument. @@ -133,10 +147,14 @@ def find_task(self, task_name: str) -> AsyncTaskiqDecoratedTask[Any, Any] | None :param task_name: name of a task. :returns: found task or None. """ - return self.local_task_registry.get( - task_name, - ) or self.global_task_registry.get( - task_name, + return ( + self.local_task_registry.get( + task_name, + ) + or self.router.find_task(task_name) + or self.global_task_registry.get( + task_name, + ) ) def get_all_tasks(self) -> dict[str, AsyncTaskiqDecoratedTask[Any, Any]]: @@ -152,7 +170,11 @@ def get_all_tasks(self) -> dict[str, AsyncTaskiqDecoratedTask[Any, Any]]: :return: dict of all tasks. Keys are task names, values are tasks. """ - return {**self.global_task_registry, **self.local_task_registry} + return { + **self.global_task_registry, + **self.router.get_all_tasks(), + **self.local_task_registry, + } def add_dependency_context(self, new_ctx: dict[Any, Any]) -> None: """ @@ -237,6 +259,23 @@ async def kick( :param message: name of a task. """ + async def kick_to_flow( + self, + message: BrokerMessage, + flow: Flow | None = None, + ) -> None: + """ + Send message to a flow-aware broker. + + Existing brokers can keep implementing only `kick`. New brokers may + override this method and use `flow` to route to a concrete queue, topic, + stream or any other transport address. + + :param message: message to send. + :param flow: optional transport-neutral flow. + """ + await self.kick(message) + @abstractmethod def listen(self) -> AsyncGenerator[bytes | AckableMessage, None]: """ @@ -362,7 +401,8 @@ def inner( def register_task( self, - func: Callable[_FuncParams, _ReturnType], + func: Callable[_FuncParams, _ReturnType] + | TaskDefinition[_FuncParams, _ReturnType], task_name: str | None = None, **labels: Any, ) -> AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]: @@ -380,6 +420,12 @@ def register_task( :returns: registered task. """ + if isinstance(func, TaskDefinition): + return self.router.register_task( + func, + broker=self, + flow=self.default_flow, + ) return self.task(task_name=task_name, **labels)(func) def on_event(self, *events: TaskiqEvents) -> Callable[[EventHandler], EventHandler]: @@ -533,6 +579,11 @@ def _register_task( if task.broker != self: raise TaskBrokerMismatchError(broker=task.broker) self.local_task_registry[task_name] = task + self.router.register_task( + task, + broker=self, + flow=self.default_flow, + ) async def __aenter__(self) -> None: """Starts the broker as ctx manager.""" diff --git a/taskiq/brokers/inmemory_broker.py b/taskiq/brokers/inmemory_broker.py index 0a7cc98e..f341d4dc 100644 --- a/taskiq/brokers/inmemory_broker.py +++ b/taskiq/brokers/inmemory_broker.py @@ -9,8 +9,10 @@ from taskiq.depends.progress_tracker import TaskProgress from taskiq.events import TaskiqEvents from taskiq.exceptions import UnknownTaskError +from taskiq.flow import Flow from taskiq.message import BrokerMessage from taskiq.receiver import Receiver +from taskiq.router import TaskiqRouter from taskiq.utils import maybe_awaitable _ReturnType = TypeVar("_ReturnType") @@ -130,8 +132,16 @@ def __init__( max_async_tasks_jitter: int = 0, propagate_exceptions: bool = True, await_inplace: bool = False, + *, + router: TaskiqRouter | None = None, + broker_name: str | None = None, + default_flow: Flow | None = None, ) -> None: - super().__init__() + super().__init__( + router=router, + broker_name=broker_name, + default_flow=default_flow, + ) self.result_backend: InmemoryResultBackend[Any] = InmemoryResultBackend( max_stored_results=max_stored_results, ) diff --git a/taskiq/context.py b/taskiq/context.py index f9f1d0ee..d29b8cb3 100644 --- a/taskiq/context.py +++ b/taskiq/context.py @@ -30,7 +30,7 @@ async def requeue(self) -> None: requeue_count = int(self.message.labels.get("X-Taskiq-requeue", 0)) requeue_count += 1 self.message.labels["X-Taskiq-requeue"] = str(requeue_count) - await self.broker.kick(self.broker.formatter.dumps(self.message)) + await self.broker.router.requeue(self.message, broker=self.broker) raise NoResultError def reject(self) -> None: diff --git a/taskiq/flow.py b/taskiq/flow.py new file mode 100644 index 00000000..5fde2b46 --- /dev/null +++ b/taskiq/flow.py @@ -0,0 +1,58 @@ +import enum +from dataclasses import dataclass, field, replace +from typing import Any + +__all__ = ("Flow", "FlowKind") + + +@enum.unique +class FlowKind(str, enum.Enum): + """Transport-neutral flow shape.""" + + QUEUE = "queue" + TOPIC = "topic" + STREAM = "stream" + + +@dataclass(frozen=True, slots=True) +class Flow: + """Transport-neutral publish or subscribe address. + + Plain flows are intentionally generic. Every broker may interpret a flow + using its own defaults: queue name, topic, stream, channel, list key, or any + other transport address. + + Broker packages can subclass this value object to expose transport-specific + details while still accepting plain Flow instances. + """ + + name: str + kind: FlowKind = FlowKind.QUEUE + options: dict[str, Any] = field( + default_factory=dict, + compare=False, + hash=False, + ) + + @classmethod + def queue(cls, name: str, **options: Any) -> "Flow": + """Create a queue-like flow.""" + return cls(name=name, kind=FlowKind.QUEUE, options=options) + + @classmethod + def topic(cls, name: str, **options: Any) -> "Flow": + """Create a topic-like flow.""" + return cls(name=name, kind=FlowKind.TOPIC, options=options) + + @classmethod + def stream(cls, name: str, **options: Any) -> "Flow": + """Create a stream-like flow.""" + return cls(name=name, kind=FlowKind.STREAM, options=options) + + def with_options(self, **options: Any) -> "Flow": + """Return the same flow with additional generic options.""" + return replace(self, options={**self.options, **options}) + + def broker_options(self, broker_name: str) -> dict[str, Any]: + """Return transport options for broker-specific implementations.""" + return dict(self.options) diff --git a/taskiq/kicker.py b/taskiq/kicker.py index dc113a7e..96ee9f88 100644 --- a/taskiq/kicker.py +++ b/taskiq/kicker.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections.abc import Coroutine from dataclasses import asdict, is_dataclass from datetime import datetime, timedelta @@ -9,17 +11,20 @@ Generic, ParamSpec, TypeVar, - Union, + cast, overload, ) from pydantic import BaseModel from taskiq.abc.middleware import TaskiqMiddleware +from taskiq.abc.result_backend import AsyncResultBackend from taskiq.compat import model_dump from taskiq.exceptions import SendTaskError +from taskiq.flow import Flow from taskiq.labels import prepare_label from taskiq.message import TaskiqMessage +from taskiq.router import TaskiqRouter from taskiq.scheduler.created_schedule import CreatedSchedule from taskiq.scheduler.scheduled_task import CronSpec, ScheduledTask from taskiq.task import AsyncTaskiqTask @@ -36,13 +41,29 @@ logger = getLogger("taskiq") +class PreparedKiq(Generic[_ReturnType]): + """Prepared task invocation that can be sent later.""" + + def __init__( + self, + kicker: AsyncKicker[..., _ReturnType], + message: TaskiqMessage, + ) -> None: + self.kicker = kicker + self.message = message + + async def kiq(self) -> AsyncTaskiqTask[_ReturnType]: + """Send prepared invocation.""" + return await self.kicker.kiq_message(self.message) + + class AsyncKicker(Generic[_FuncParams, _ReturnType]): """Class that used to modify data before sending it to broker.""" def __init__( self, task_name: str, - broker: "AsyncBroker", + broker: AsyncBroker, labels: dict[str, Any], return_type: type[_ReturnType] | None = None, ) -> None: @@ -52,11 +73,13 @@ def __init__( self.custom_task_id: str | None = None self.custom_schedule_id: str | None = None self.return_type = return_type + self.route_broker: AsyncBroker | str | None = None + self.route_flow: Flow | None = None def with_labels( self, **labels: str | float, - ) -> "AsyncKicker[_FuncParams, _ReturnType]": + ) -> AsyncKicker[_FuncParams, _ReturnType]: """ Update function's labels before sending. @@ -69,7 +92,7 @@ def with_labels( def with_task_id( self, task_id: str | None, - ) -> "AsyncKicker[_FuncParams, _ReturnType]": + ) -> AsyncKicker[_FuncParams, _ReturnType]: """ Set task_id for current execution. @@ -85,7 +108,7 @@ def with_task_id( def with_schedule_id( self, schedule_id: str, - ) -> "AsyncKicker[_FuncParams, _ReturnType]": + ) -> AsyncKicker[_FuncParams, _ReturnType]: """ Set schedule_id for current execution. @@ -97,8 +120,8 @@ def with_schedule_id( def with_broker( self, - broker: "AsyncBroker", - ) -> "AsyncKicker[_FuncParams, _ReturnType]": + broker: AsyncBroker, + ) -> AsyncKicker[_FuncParams, _ReturnType]: """ Replace broker for the function. @@ -109,11 +132,55 @@ def with_broker( :return: Kicker with new broker. """ self.broker = broker + self.route_broker = broker + return self + + def with_flow( + self, + flow: Flow | None, + ) -> AsyncKicker[_FuncParams, _ReturnType]: + """ + Replace flow for the current invocation. + + :param flow: flow to send message to. + :return: Kicker with a route flow override. + """ + self.route_flow = flow return self + def with_route( + self, + broker: AsyncBroker | str, + flow: Flow | None, + ) -> AsyncKicker[_FuncParams, _ReturnType]: + """ + Replace broker and flow for the current invocation. + + :param broker: broker instance or broker name. + :param flow: flow to send message to. + :return: Kicker with a route override. + """ + self.route_broker = broker + self.route_flow = flow + return self + + def prepare( + self, + *args: _FuncParams.args, + **kwargs: _FuncParams.kwargs, + ) -> PreparedKiq[_ReturnType]: + """ + Prepare a task invocation without sending it. + + :param args: function's arguments. + :param kwargs: function's key word arguments. + :return: prepared task invocation. + """ + return PreparedKiq(self, self._prepare_message(*args, **kwargs)) + @overload async def kiq( - self: "AsyncKicker[_FuncParams, CoroutineType[Any, Any, _T]]", + self: AsyncKicker[_FuncParams, CoroutineType[Any, Any, _T]], *args: _FuncParams.args, **kwargs: _FuncParams.kwargs, ) -> AsyncTaskiqTask[_T]: # pragma: no cover @@ -121,7 +188,7 @@ async def kiq( @overload async def kiq( - self: "AsyncKicker[_FuncParams, Coroutine[Any, Any, _T]]", + self: AsyncKicker[_FuncParams, Coroutine[Any, Any, _T]], *args: _FuncParams.args, **kwargs: _FuncParams.kwargs, ) -> AsyncTaskiqTask[_T]: # pragma: no cover @@ -129,7 +196,7 @@ async def kiq( @overload async def kiq( - self: "AsyncKicker[_FuncParams, _ReturnType]", + self: AsyncKicker[_FuncParams, _ReturnType], *args: _FuncParams.args, **kwargs: _FuncParams.kwargs, ) -> AsyncTaskiqTask[_ReturnType]: # pragma: no cover @@ -156,29 +223,60 @@ async def kiq( logger.debug( f"Kicking {self.task_name} with args={args} and kwargs={kwargs}.", ) - message = self._prepare_message(*args, **kwargs) - for middleware in self.broker.middlewares: - if middleware.__class__.pre_send != TaskiqMiddleware.pre_send: - message = await maybe_awaitable(middleware.pre_send(message)) + return await self.kiq_message(self._prepare_message(*args, **kwargs)) + + async def kiq_message( + self, + message: TaskiqMessage, + ) -> AsyncTaskiqTask[_ReturnType]: + """Send a prepared message.""" try: - await self.broker.kick(self.broker.formatter.dumps(message)) + router = getattr(self.broker, "router", None) + if isinstance(router, TaskiqRouter): + return await router.kiq( + message, + broker=self.route_broker, + flow=self.route_flow, + return_type=self.return_type, + ) + return await self._legacy_kiq(message) except Exception as exc: raise SendTaskError from exc - for middleware in reversed(self.broker.middlewares): + async def _legacy_kiq(self, message: TaskiqMessage) -> AsyncTaskiqTask[_ReturnType]: + """ + Send message through the pre-router broker path. + + This keeps middleware tests and external broker-like mocks compatible + while real AsyncBroker instances use TaskiqRouter. + """ + middlewares = getattr(self.broker, "middlewares", []) + if not isinstance(middlewares, list): + middlewares = [] + + for middleware in middlewares: + if middleware.__class__.pre_send != TaskiqMiddleware.pre_send: + message = await maybe_awaitable(middleware.pre_send(message)) + + await self.broker.kick(self.broker.formatter.dumps(message)) + + for middleware in reversed(middlewares): if middleware.__class__.post_send != TaskiqMiddleware.post_send: await maybe_awaitable(middleware.post_send(message)) return AsyncTaskiqTask( task_id=message.task_id, - result_backend=self.broker.result_backend, - return_type=self.return_type, # type: ignore # (pyright issue) + result_backend=cast( + AsyncResultBackend[_ReturnType], + self.broker.result_backend, + ), + return_type=self.return_type, ) async def schedule_by_cron( self, - source: "ScheduleSource", - cron: Union[str, "CronSpec"], + source: ScheduleSource, + cron: str | CronSpec, *args: _FuncParams.args, **kwargs: _FuncParams.kwargs, ) -> CreatedSchedule[_ReturnType]: @@ -217,7 +315,7 @@ async def schedule_by_cron( async def schedule_by_interval( self, - source: "ScheduleSource", + source: ScheduleSource, interval: int | timedelta, *args: _FuncParams.args, **kwargs: _FuncParams.kwargs, @@ -249,7 +347,7 @@ async def schedule_by_interval( async def schedule_by_time( self, - source: "ScheduleSource", + source: ScheduleSource, time: datetime, *args: _FuncParams.args, **kwargs: _FuncParams.kwargs, diff --git a/taskiq/router.py b/taskiq/router.py new file mode 100644 index 00000000..07129022 --- /dev/null +++ b/taskiq/router.py @@ -0,0 +1,268 @@ +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass +from logging import getLogger +from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast, overload + +from taskiq.abc.middleware import TaskiqMiddleware +from taskiq.abc.result_backend import AsyncResultBackend +from taskiq.flow import Flow +from taskiq.message import TaskiqMessage +from taskiq.task import AsyncTaskiqTask +from taskiq.task_builder import TaskDefinition +from taskiq.utils import maybe_awaitable + +if TYPE_CHECKING: # pragma: no cover + from taskiq.abc.broker import AsyncBroker + from taskiq.decor import AsyncTaskiqDecoratedTask + +__all__ = ("TaskiqRoute", "TaskiqRouter") + +_FuncParams = ParamSpec("_FuncParams") +_ReturnType = TypeVar("_ReturnType") + +logger = getLogger("taskiq.router") + + +@dataclass(frozen=True, slots=True) +class TaskiqRoute: + """Resolved outbound route for a task invocation.""" + + broker_name: str + flow: Flow | None = None + + +class TaskiqRouter: + """Registry and routing layer shared by one or more brokers.""" + + def __init__(self) -> None: + self.brokers: dict[str, AsyncBroker] = {} + self.default_broker_name: str | None = None + self.task_registry: dict[str, AsyncTaskiqDecoratedTask[Any, Any]] = {} + self.routes: dict[str, TaskiqRoute] = {} + + def set_broker( + self, + broker: AsyncBroker, + name: str | None = None, + default_flow: Flow | None = None, + ) -> str: + """Register broker as a transport in this router.""" + broker_name = name or broker.__class__.__name__ + registered = self.brokers.get(broker_name) + if registered is not None and registered is not broker: + raise ValueError( + f"Broker name {broker_name!r} is already registered. " + "Please provide an explicit unique broker_name.", + ) + self.brokers[broker_name] = broker + if self.default_broker_name is None: + self.default_broker_name = broker_name + return broker_name + + def find_task( + self, + task_name: str, + ) -> AsyncTaskiqDecoratedTask[Any, Any] | None: + """Find a task by name.""" + return self.task_registry.get(task_name) + + def get_all_tasks(self) -> dict[str, AsyncTaskiqDecoratedTask[Any, Any]]: + """Return all tasks registered in this router.""" + return dict(self.task_registry) + + def register_task( + self, + task: ( + AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType] + | TaskDefinition[_FuncParams, _ReturnType] + ), + broker: AsyncBroker | str | None = None, + flow: Flow | None = None, + ) -> AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]: + """Register a bound task or bind a task definition to a broker.""" + if isinstance(task, TaskDefinition): + target_broker = self._resolve_broker(broker) + registered_task = target_broker.register_task( + task.original_func, + task_name=task.task_name, + **task.labels, + ) + if flow is not None: + self.route_task(task.task_name, broker=target_broker, flow=flow) + return registered_task + + self.task_registry[task.task_name] = task + route_broker: AsyncBroker | str | None = broker + if route_broker is None: + route_broker = getattr(task, "broker", None) + if route_broker is not None or flow is not None: + self.route_task(task.task_name, broker=route_broker, flow=flow) + return task + + @overload + def task( + self, + task_name: Callable[_FuncParams, _ReturnType], + *, + broker: AsyncBroker | str | None = None, + flow: Flow | None = None, + **labels: Any, + ) -> AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]: ... + + @overload + def task( + self, + task_name: str | None = None, + *, + broker: AsyncBroker | str | None = None, + flow: Flow | None = None, + **labels: Any, + ) -> Callable[ + [Callable[_FuncParams, _ReturnType]], + AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType], + ]: ... + + def task( + self, + task_name: str | Callable[_FuncParams, _ReturnType] | None = None, + *, + broker: AsyncBroker | str | None = None, + flow: Flow | None = None, + **labels: Any, + ) -> Any: + """Decorate and register a task through this router.""" + + def register( + func: Callable[_FuncParams, _ReturnType], + ) -> AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]: + target_broker = self._resolve_broker(broker) + real_task_name = task_name if not callable(task_name) else None + task = target_broker.task(task_name=real_task_name, **labels)(func) + if flow is not None: + self.route_task(task.task_name, broker=target_broker, flow=flow) + return task + + if callable(task_name): + function = task_name + task_name = None + return register(function) + + return register + + def route_task( + self, + task_name: str, + broker: AsyncBroker | str | None = None, + flow: Flow | None = None, + ) -> TaskiqRoute: + """Set default outbound route for a task.""" + broker_name = self._resolve_broker_name(broker) + route = TaskiqRoute(broker_name=broker_name, flow=flow) + self.routes[task_name] = route + return route + + def resolve_route( + self, + task_name: str, + broker: AsyncBroker | str | None = None, + flow: Flow | None = None, + ) -> TaskiqRoute: + """Resolve outbound route for a task invocation.""" + if broker is not None: + broker_name = self._resolve_broker_name(broker) + route_flow = flow + if route_flow is None: + route_flow = self._broker_default_flow(broker_name) + return TaskiqRoute( + broker_name=broker_name, + flow=route_flow, + ) + + route = self.routes.get(task_name) + if route is not None: + if flow is None: + return route + return TaskiqRoute(broker_name=route.broker_name, flow=flow) + + broker_name = self._resolve_broker_name(None) + route_flow = flow + if route_flow is None: + route_flow = self._broker_default_flow(broker_name) + return TaskiqRoute( + broker_name=broker_name, + flow=route_flow, + ) + + async def kiq( + self, + message: TaskiqMessage, + *, + broker: AsyncBroker | str | None = None, + flow: Flow | None = None, + return_type: type[_ReturnType] | None = None, + ) -> AsyncTaskiqTask[_ReturnType]: + """Send message through the resolved broker and flow.""" + route = self.resolve_route(message.task_name, broker=broker, flow=flow) + target_broker = self.brokers[route.broker_name] + + for middleware in target_broker.middlewares: + if middleware.__class__.pre_send != TaskiqMiddleware.pre_send: + message = await maybe_awaitable(middleware.pre_send(message)) + broker_message = target_broker.formatter.dumps(message) + await target_broker.kick_to_flow(broker_message, route.flow) + + for middleware in reversed(target_broker.middlewares): + if middleware.__class__.post_send != TaskiqMiddleware.post_send: + await maybe_awaitable(middleware.post_send(message)) + + return AsyncTaskiqTask( + task_id=message.task_id, + result_backend=cast( + AsyncResultBackend[_ReturnType], + target_broker.result_backend, + ), + return_type=return_type, + ) + + async def requeue( + self, + message: TaskiqMessage, + *, + broker: AsyncBroker | str | None = None, + flow: Flow | None = None, + ) -> None: + """Send an existing message again through the resolved route.""" + route = self.resolve_route(message.task_name, broker=broker, flow=flow) + target_broker = self.brokers[route.broker_name] + await target_broker.kick_to_flow( + target_broker.formatter.dumps(message), + route.flow, + ) + + def _resolve_broker(self, broker: AsyncBroker | str | None) -> AsyncBroker: + broker_name = self._resolve_broker_name(broker) + return self.brokers[broker_name] + + def _resolve_broker_name(self, broker: AsyncBroker | str | None) -> str: + if isinstance(broker, str): + if broker not in self.brokers: + raise ValueError(f"Unknown broker {broker!r}.") + return broker + + if broker is not None: + broker_name = getattr(broker, "broker_name", None) + if broker_name is not None and broker_name in self.brokers: + return broker_name + for registered_name, registered_broker in self.brokers.items(): + if registered_broker is broker: + return registered_name + raise ValueError("Broker is not registered in this router.") + + if self.default_broker_name is None: + raise ValueError("Router doesn't have registered brokers.") + return self.default_broker_name + + def _broker_default_flow(self, broker_name: str) -> Flow | None: + return getattr(self.brokers[broker_name], "default_flow", None) diff --git a/taskiq/task_builder.py b/taskiq/task_builder.py new file mode 100644 index 00000000..d2d3b0d1 --- /dev/null +++ b/taskiq/task_builder.py @@ -0,0 +1,107 @@ +import inspect +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any, Generic, ParamSpec, TypeVar, overload + +from taskiq.message import TaskiqMessage + +__all__ = ("TaskDefinition", "task_builder") + +_FuncParams = ParamSpec("_FuncParams") +_ReturnType = TypeVar("_ReturnType") + + +@dataclass(frozen=True, slots=True) +class TaskDefinition(Generic[_FuncParams, _ReturnType]): + """Unbound task declaration that can be registered later.""" + + task_name: str + original_func: Callable[_FuncParams, _ReturnType] + labels: dict[str, Any] = field(default_factory=dict) + return_type: type[_ReturnType] | None = None + + def __call__( + self, + *args: _FuncParams.args, + **kwargs: _FuncParams.kwargs, + ) -> _ReturnType: + """Call original function directly.""" + return self.original_func(*args, **kwargs) + + async def call( + self, + *args: _FuncParams.args, + **kwargs: _FuncParams.kwargs, + ) -> _ReturnType: + """Execute original function in the current process.""" + result = self.original_func(*args, **kwargs) + if inspect.isawaitable(result): + return await result + return result + + def message( + self, + task_id: str, + *args: _FuncParams.args, + **kwargs: _FuncParams.kwargs, + ) -> TaskiqMessage: + """Build a TaskiqMessage without binding this definition to a router.""" + return TaskiqMessage( + task_id=task_id, + task_name=self.task_name, + labels=dict(self.labels), + args=list(args), + kwargs=dict(kwargs), + ) + + +@overload +def task_builder( + task_name: Callable[_FuncParams, _ReturnType], + **labels: Any, +) -> TaskDefinition[_FuncParams, _ReturnType]: ... + + +@overload +def task_builder( + task_name: str | None = None, + **labels: Any, +) -> Callable[ + [Callable[_FuncParams, _ReturnType]], + TaskDefinition[_FuncParams, _ReturnType], +]: ... + + +def task_builder( + task_name: str | Callable[_FuncParams, _ReturnType] | None = None, + **labels: Any, +) -> Any: + """Build an unbound task definition. + + This decorator is intended for library/package tasks that should be + registered by the final application. + """ + + def build( + func: Callable[_FuncParams, _ReturnType], + ) -> TaskDefinition[_FuncParams, _ReturnType]: + real_task_name = task_name + if real_task_name is None or callable(real_task_name): + real_task_name = f"{func.__module__}:{func.__name__}" + return_type = None + signature = inspect.signature(func) + if signature.return_annotation is not inspect.Signature.empty: + return_type = signature.return_annotation + return TaskDefinition( + task_name=real_task_name, + original_func=func, + labels=dict(labels), + return_type=return_type, + ) + + if callable(task_name): + function = task_name + task_name = None + return build(function) + + return build diff --git a/tests/test_router.py b/tests/test_router.py new file mode 100644 index 00000000..00b1883d --- /dev/null +++ b/tests/test_router.py @@ -0,0 +1,164 @@ +from collections.abc import AsyncGenerator + +import pytest + +from taskiq import Flow, TaskiqRouter, task_builder +from taskiq.abc.broker import AsyncBroker +from taskiq.message import BrokerMessage + + +class RecordingBroker(AsyncBroker): + """Broker that records sent messages and flows.""" + + def __init__( + self, + *, + router: TaskiqRouter | None = None, + broker_name: str | None = None, + default_flow: Flow | None = None, + ) -> None: + self.sent: list[tuple[BrokerMessage, Flow | None]] = [] + super().__init__( + router=router, + broker_name=broker_name, + default_flow=default_flow, + ) + + async def kick(self, message: BrokerMessage) -> None: + """Record old-style send.""" + self.sent.append((message, None)) + + async def kick_to_flow( + self, + message: BrokerMessage, + flow: Flow | None = None, + ) -> None: + """Record flow-aware send.""" + self.sent.append((message, flow)) + + async def listen(self) -> AsyncGenerator[bytes, None]: + """Recording broker doesn't listen in these tests.""" + if False: + yield b"" + + +def test_broker_creates_default_router() -> None: + broker = RecordingBroker() + + assert broker.router.brokers[broker.broker_name] is broker + assert broker.router.default_broker_name == broker.broker_name + + +async def test_old_broker_task_api_registers_task_in_router() -> None: + broker = RecordingBroker() + + @broker.task(task_name="demo.task") + async def demo_task() -> None: + return None + + assert broker.find_task("demo.task") is demo_task + assert broker.router.find_task("demo.task") is demo_task + + await demo_task.kiq() + + assert broker.sent[0][0].task_name == "demo.task" + assert broker.sent[0][1] is None + + +async def test_router_can_route_task_to_another_broker_flow() -> None: + router = TaskiqRouter() + source = RecordingBroker(router=router, broker_name="source") + target = RecordingBroker(router=router, broker_name="target") + flow = Flow("events") + + @source.task(task_name="demo.task") + async def demo_task() -> None: + return None + + router.route_task("demo.task", broker="target", flow=flow) + + await demo_task.kiq() + + assert source.sent == [] + assert target.sent[0][0].task_name == "demo.task" + assert target.sent[0][1] == flow + + +async def test_kicker_route_override_wins_over_registered_route() -> None: + router = TaskiqRouter() + first = RecordingBroker(router=router, broker_name="first") + second = RecordingBroker(router=router, broker_name="second") + first_flow = Flow("first") + second_flow = Flow("second") + + @first.task(task_name="demo.task") + async def demo_task() -> None: + return None + + router.route_task("demo.task", broker="first", flow=first_flow) + + await demo_task.kicker().with_route("second", second_flow).kiq() + + assert first.sent == [] + assert second.sent[0][1] == second_flow + + +async def test_kicker_can_prepare_invocation_for_later() -> None: + broker = RecordingBroker() + + @broker.task(task_name="demo.task") + async def demo_task(value: int) -> None: + return None + + prepared = demo_task.kicker().with_labels(trace_id="abc").prepare(1) + + assert prepared.message.task_name == "demo.task" + assert prepared.message.args == [1] + assert prepared.message.labels["trace_id"] == "abc" + + await prepared.kiq() + + assert broker.sent[0][0].task_id == prepared.message.task_id + + +async def test_task_builder_can_be_registered_later() -> None: + broker = RecordingBroker() + + @task_builder("shared.add", queue="shared") + def add(left: int, right: int) -> int: + return left + right + + assert await add.call(1, 2) == 3 + + registered = broker.register_task(add) + + assert registered.task_name == "shared.add" + assert registered.labels == {"queue": "shared"} + assert broker.router.find_task("shared.add") is registered + + await registered.kiq(1, 2) + + assert broker.sent[0][0].task_name == "shared.add" + + +async def test_router_task_decorator_can_choose_broker_and_flow() -> None: + router = TaskiqRouter() + target = RecordingBroker(router=router, broker_name="target") + flow = Flow("target-flow") + + @router.task("demo.task", broker="target", flow=flow) + async def demo_task() -> None: + return None + + await demo_task.kiq() + + assert target.sent[0][0].task_name == "demo.task" + assert target.sent[0][1] == flow + + +def test_router_rejects_duplicate_broker_names() -> None: + router = TaskiqRouter() + RecordingBroker(router=router, broker_name="broker") + + with pytest.raises(ValueError, match="already registered"): + RecordingBroker(router=router, broker_name="broker") From 8805f3d5d1ba1a139f5279d4dd4ba24a6f02f733 Mon Sep 17 00:00:00 2001 From: "sergei.romanchuk" Date: Sat, 23 May 2026 22:37:37 +0200 Subject: [PATCH 2/5] experiment --- docs/examples/router/multiple_brokers.py | 26 +++--- docs/examples/router/shared_task_package.py | 38 +++++++-- taskiq/__init__.py | 4 +- taskiq/abc/broker.py | 12 +-- taskiq/brokers/inmemory_broker.py | 4 +- taskiq/flow.py | 54 +++++------- taskiq/kicker.py | 10 ++- taskiq/router.py | 93 +++++++++++++-------- tests/test_router.py | 72 ++++++++++++++-- 9 files changed, 202 insertions(+), 111 deletions(-) diff --git a/docs/examples/router/multiple_brokers.py b/docs/examples/router/multiple_brokers.py index c1d5c1aa..7d6977ae 100644 --- a/docs/examples/router/multiple_brokers.py +++ b/docs/examples/router/multiple_brokers.py @@ -6,9 +6,9 @@ router = TaskiqRouter() -default_email_flow = Flow.queue("emails.default") -priority_email_flow = Flow.queue("emails.priority") -bulk_email_flow = Flow.queue("emails.bulk") +default_email_flow = Flow("emails.default") +priority_email_flow = Flow("emails.priority") +bulk_email_flow = Flow("emails.bulk") default_broker = InMemoryBroker( router=router, @@ -30,9 +30,9 @@ async def send_email(user_id: int, template: str) -> str: return f"{template} email sent to user {user_id}" -router.route_task( - send_email.task_name, - broker="priority", +priority_route = router.route_task( + send_email, + broker=priority_broker, flow=priority_email_flow, ) @@ -46,14 +46,18 @@ async def _main() -> None: routed_task = await send_email.kiq(7, "welcome") routed_result = await routed_task.wait_result(timeout=2) - bulk_task = await send_email.kicker().with_route( - "default", - bulk_email_flow, - ).kiq(8, "digest") + bulk_task = ( + await send_email.kicker() + .with_route( + default_broker, + bulk_email_flow, + ) + .kiq(8, "digest") + ) bulk_result = await bulk_task.wait_result(timeout=2) print(f"Direct call: {direct_result}") - print(f"Default route: {router.resolve_route(send_email.task_name)}") + print(f"Declared route: {priority_route.broker_name}") print(f"Routed call: {routed_result.return_value}") print(f"Route override: {bulk_result.return_value}") finally: diff --git a/docs/examples/router/shared_task_package.py b/docs/examples/router/shared_task_package.py index 0614679e..3a9e4902 100644 --- a/docs/examples/router/shared_task_package.py +++ b/docs/examples/router/shared_task_package.py @@ -1,10 +1,27 @@ """Declare shared task definitions and bind them in the final application.""" import asyncio +from collections.abc import Mapping +from dataclasses import dataclass from taskiq import Flow, InMemoryBroker, TaskiqRouter, task_builder +@dataclass(frozen=True, slots=True) +class BillingQueue: + """Broker-specific flow that follows the shared flow protocol.""" + + name: str + priority: int + + def broker_options(self, broker_name: str) -> Mapping[str, object]: + """Return options that a billing broker adapter can understand.""" + return { + "broker": broker_name, + "priority": self.priority, + } + + @task_builder("billing.calculate_total", domain="billing") async def calculate_total(price: int, quantity: int) -> int: """Package-level task definition that is not bound to any broker.""" @@ -12,8 +29,8 @@ async def calculate_total(price: int, quantity: int) -> int: router = TaskiqRouter() -billing_flow = Flow.queue("billing.tasks") -priority_billing_flow = Flow.queue("billing.priority") +billing_flow = Flow("billing.tasks") +priority_billing_flow = BillingQueue(name="billing.priority", priority=10) billing_broker = InMemoryBroker( router=router, @@ -22,7 +39,11 @@ async def calculate_total(price: int, quantity: int) -> int: await_inplace=True, ) -registered_calculate_total = billing_broker.register_task(calculate_total) +registered_calculate_total = router.register_task( + calculate_total, + broker=billing_broker, + flow=billing_flow, +) async def _main() -> None: @@ -30,9 +51,14 @@ async def _main() -> None: try: direct_result = await calculate_total.call(19, 3) - prepared_task = registered_calculate_total.kicker().with_flow( - priority_billing_flow, - ).prepare(19, 3) + prepared_task = ( + registered_calculate_total.kicker() + .with_route( + billing_broker, + priority_billing_flow, + ) + .prepare(19, 3) + ) queued_task = await prepared_task.kiq() queued_result = await queued_task.wait_result(timeout=2) diff --git a/taskiq/__init__.py b/taskiq/__init__.py index 856beb3c..04d82920 100644 --- a/taskiq/__init__.py +++ b/taskiq/__init__.py @@ -24,7 +24,7 @@ TaskiqError, TaskiqResultTimeoutError, ) -from taskiq.flow import Flow, FlowKind +from taskiq.flow import Flow, FlowProtocol from taskiq.funcs import gather from taskiq.kicker import PreparedKiq from taskiq.message import BrokerMessage, TaskiqMessage @@ -52,7 +52,7 @@ "BrokerMessage", "Context", "Flow", - "FlowKind", + "FlowProtocol", "InMemoryBroker", "NoResultError", "PreparedKiq", diff --git a/taskiq/abc/broker.py b/taskiq/abc/broker.py index 71cd1c79..42a1817b 100644 --- a/taskiq/abc/broker.py +++ b/taskiq/abc/broker.py @@ -24,7 +24,7 @@ from taskiq.decor import AsyncTaskiqDecoratedTask from taskiq.events import TaskiqEvents from taskiq.exceptions import TaskBrokerMismatchError -from taskiq.flow import Flow +from taskiq.flow import FlowProtocol from taskiq.formatters.proxy_formatter import ProxyFormatter from taskiq.message import BrokerMessage from taskiq.result_backends.dummy import DummyResultBackend @@ -84,7 +84,7 @@ def __init__( *, router: TaskiqRouter | None = None, broker_name: str | None = None, - default_flow: Flow | None = None, + default_flow: FlowProtocol | None = None, ) -> None: if result_backend is None: result_backend = DummyResultBackend() @@ -262,7 +262,7 @@ async def kick( async def kick_to_flow( self, message: BrokerMessage, - flow: Flow | None = None, + flow: FlowProtocol | None = None, ) -> None: """ Send message to a flow-aware broker. @@ -401,8 +401,10 @@ def inner( def register_task( self, - func: Callable[_FuncParams, _ReturnType] - | TaskDefinition[_FuncParams, _ReturnType], + func: ( + Callable[_FuncParams, _ReturnType] + | TaskDefinition[_FuncParams, _ReturnType] + ), task_name: str | None = None, **labels: Any, ) -> AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]: diff --git a/taskiq/brokers/inmemory_broker.py b/taskiq/brokers/inmemory_broker.py index f341d4dc..e63a97f6 100644 --- a/taskiq/brokers/inmemory_broker.py +++ b/taskiq/brokers/inmemory_broker.py @@ -9,7 +9,7 @@ from taskiq.depends.progress_tracker import TaskProgress from taskiq.events import TaskiqEvents from taskiq.exceptions import UnknownTaskError -from taskiq.flow import Flow +from taskiq.flow import FlowProtocol from taskiq.message import BrokerMessage from taskiq.receiver import Receiver from taskiq.router import TaskiqRouter @@ -135,7 +135,7 @@ def __init__( *, router: TaskiqRouter | None = None, broker_name: str | None = None, - default_flow: Flow | None = None, + default_flow: FlowProtocol | None = None, ) -> None: super().__init__( router=router, diff --git a/taskiq/flow.py b/taskiq/flow.py index 5fde2b46..8754492e 100644 --- a/taskiq/flow.py +++ b/taskiq/flow.py @@ -1,58 +1,42 @@ -import enum +from collections.abc import Mapping from dataclasses import dataclass, field, replace -from typing import Any +from typing import Protocol, runtime_checkable -__all__ = ("Flow", "FlowKind") +__all__ = ("Flow", "FlowProtocol") -@enum.unique -class FlowKind(str, enum.Enum): - """Transport-neutral flow shape.""" +@runtime_checkable +class FlowProtocol(Protocol): + """Transport-neutral flow contract accepted by routers and brokers.""" - QUEUE = "queue" - TOPIC = "topic" - STREAM = "stream" + name: str + + def broker_options(self, broker_name: str) -> Mapping[str, object]: + """Return options relevant for a concrete broker implementation.""" + ... @dataclass(frozen=True, slots=True) class Flow: - """Transport-neutral publish or subscribe address. - - Plain flows are intentionally generic. Every broker may interpret a flow - using its own defaults: queue name, topic, stream, channel, list key, or any - other transport address. + """Generic transport-neutral flow address. - Broker packages can subclass this value object to expose transport-specific - details while still accepting plain Flow instances. + Broker packages can provide their own flow objects that implement + FlowProtocol, for example RabbitQueue, KafkaTopic, NatsSubject or RedisQueue. + This generic value object is intentionally small and works as a common + fallback for brokers that only need a named address with optional metadata. """ name: str - kind: FlowKind = FlowKind.QUEUE - options: dict[str, Any] = field( + options: Mapping[str, object] = field( default_factory=dict, compare=False, hash=False, ) - @classmethod - def queue(cls, name: str, **options: Any) -> "Flow": - """Create a queue-like flow.""" - return cls(name=name, kind=FlowKind.QUEUE, options=options) - - @classmethod - def topic(cls, name: str, **options: Any) -> "Flow": - """Create a topic-like flow.""" - return cls(name=name, kind=FlowKind.TOPIC, options=options) - - @classmethod - def stream(cls, name: str, **options: Any) -> "Flow": - """Create a stream-like flow.""" - return cls(name=name, kind=FlowKind.STREAM, options=options) - - def with_options(self, **options: Any) -> "Flow": + def with_options(self, **options: object) -> "Flow": """Return the same flow with additional generic options.""" return replace(self, options={**self.options, **options}) - def broker_options(self, broker_name: str) -> dict[str, Any]: + def broker_options(self, broker_name: str) -> Mapping[str, object]: """Return transport options for broker-specific implementations.""" return dict(self.options) diff --git a/taskiq/kicker.py b/taskiq/kicker.py index 96ee9f88..d0182e4c 100644 --- a/taskiq/kicker.py +++ b/taskiq/kicker.py @@ -21,7 +21,7 @@ from taskiq.abc.result_backend import AsyncResultBackend from taskiq.compat import model_dump from taskiq.exceptions import SendTaskError -from taskiq.flow import Flow +from taskiq.flow import FlowProtocol from taskiq.labels import prepare_label from taskiq.message import TaskiqMessage from taskiq.router import TaskiqRouter @@ -74,7 +74,7 @@ def __init__( self.custom_schedule_id: str | None = None self.return_type = return_type self.route_broker: AsyncBroker | str | None = None - self.route_flow: Flow | None = None + self.route_flow: FlowProtocol | None = None def with_labels( self, @@ -137,7 +137,7 @@ def with_broker( def with_flow( self, - flow: Flow | None, + flow: FlowProtocol | None, ) -> AsyncKicker[_FuncParams, _ReturnType]: """ Replace flow for the current invocation. @@ -151,12 +151,14 @@ def with_flow( def with_route( self, broker: AsyncBroker | str, - flow: Flow | None, + flow: FlowProtocol | None, ) -> AsyncKicker[_FuncParams, _ReturnType]: """ Replace broker and flow for the current invocation. :param broker: broker instance or broker name. + Broker instances are preferred; names are kept for configuration + and backward-compatible lookup. :param flow: flow to send message to. :return: Kicker with a route override. """ diff --git a/taskiq/router.py b/taskiq/router.py index 07129022..9f770f90 100644 --- a/taskiq/router.py +++ b/taskiq/router.py @@ -7,7 +7,7 @@ from taskiq.abc.middleware import TaskiqMiddleware from taskiq.abc.result_backend import AsyncResultBackend -from taskiq.flow import Flow +from taskiq.flow import FlowProtocol from taskiq.message import TaskiqMessage from taskiq.task import AsyncTaskiqTask from taskiq.task_builder import TaskDefinition @@ -29,8 +29,13 @@ class TaskiqRoute: """Resolved outbound route for a task invocation.""" - broker_name: str - flow: Flow | None = None + broker: AsyncBroker + flow: FlowProtocol | None = None + + @property + def broker_name(self) -> str: + """Return registered broker name for compatibility and diagnostics.""" + return self.broker.broker_name class TaskiqRouter: @@ -46,7 +51,7 @@ def set_broker( self, broker: AsyncBroker, name: str | None = None, - default_flow: Flow | None = None, + default_flow: FlowProtocol | None = None, ) -> str: """Register broker as a transport in this router.""" broker_name = name or broker.__class__.__name__ @@ -79,7 +84,7 @@ def register_task( | TaskDefinition[_FuncParams, _ReturnType] ), broker: AsyncBroker | str | None = None, - flow: Flow | None = None, + flow: FlowProtocol | None = None, ) -> AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]: """Register a bound task or bind a task definition to a broker.""" if isinstance(task, TaskDefinition): @@ -107,7 +112,7 @@ def task( task_name: Callable[_FuncParams, _ReturnType], *, broker: AsyncBroker | str | None = None, - flow: Flow | None = None, + flow: FlowProtocol | None = None, **labels: Any, ) -> AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]: ... @@ -117,7 +122,7 @@ def task( task_name: str | None = None, *, broker: AsyncBroker | str | None = None, - flow: Flow | None = None, + flow: FlowProtocol | None = None, **labels: Any, ) -> Callable[ [Callable[_FuncParams, _ReturnType]], @@ -129,7 +134,7 @@ def task( task_name: str | Callable[_FuncParams, _ReturnType] | None = None, *, broker: AsyncBroker | str | None = None, - flow: Flow | None = None, + flow: FlowProtocol | None = None, **labels: Any, ) -> Any: """Decorate and register a task through this router.""" @@ -153,30 +158,32 @@ def register( def route_task( self, - task_name: str, + task: str | AsyncTaskiqDecoratedTask[Any, Any], broker: AsyncBroker | str | None = None, - flow: Flow | None = None, + flow: FlowProtocol | None = None, ) -> TaskiqRoute: """Set default outbound route for a task.""" - broker_name = self._resolve_broker_name(broker) - route = TaskiqRoute(broker_name=broker_name, flow=flow) + task_name = self._resolve_task_name(task) + target_broker = self._resolve_broker(broker) + route = TaskiqRoute(broker=target_broker, flow=flow) self.routes[task_name] = route return route def resolve_route( self, - task_name: str, + task: str | AsyncTaskiqDecoratedTask[Any, Any], broker: AsyncBroker | str | None = None, - flow: Flow | None = None, + flow: FlowProtocol | None = None, ) -> TaskiqRoute: """Resolve outbound route for a task invocation.""" + task_name = self._resolve_task_name(task) if broker is not None: - broker_name = self._resolve_broker_name(broker) + target_broker = self._resolve_broker(broker) route_flow = flow if route_flow is None: - route_flow = self._broker_default_flow(broker_name) + route_flow = self._broker_default_flow(target_broker) return TaskiqRoute( - broker_name=broker_name, + broker=target_broker, flow=route_flow, ) @@ -184,14 +191,14 @@ def resolve_route( if route is not None: if flow is None: return route - return TaskiqRoute(broker_name=route.broker_name, flow=flow) + return TaskiqRoute(broker=route.broker, flow=flow) - broker_name = self._resolve_broker_name(None) + target_broker = self._resolve_broker(None) route_flow = flow if route_flow is None: - route_flow = self._broker_default_flow(broker_name) + route_flow = self._broker_default_flow(target_broker) return TaskiqRoute( - broker_name=broker_name, + broker=target_broker, flow=route_flow, ) @@ -200,12 +207,12 @@ async def kiq( message: TaskiqMessage, *, broker: AsyncBroker | str | None = None, - flow: Flow | None = None, + flow: FlowProtocol | None = None, return_type: type[_ReturnType] | None = None, ) -> AsyncTaskiqTask[_ReturnType]: """Send message through the resolved broker and flow.""" route = self.resolve_route(message.task_name, broker=broker, flow=flow) - target_broker = self.brokers[route.broker_name] + target_broker = route.broker for middleware in target_broker.middlewares: if middleware.__class__.pre_send != TaskiqMiddleware.pre_send: @@ -231,38 +238,50 @@ async def requeue( message: TaskiqMessage, *, broker: AsyncBroker | str | None = None, - flow: Flow | None = None, + flow: FlowProtocol | None = None, ) -> None: """Send an existing message again through the resolved route.""" route = self.resolve_route(message.task_name, broker=broker, flow=flow) - target_broker = self.brokers[route.broker_name] + target_broker = route.broker await target_broker.kick_to_flow( target_broker.formatter.dumps(message), route.flow, ) def _resolve_broker(self, broker: AsyncBroker | str | None) -> AsyncBroker: - broker_name = self._resolve_broker_name(broker) - return self.brokers[broker_name] - - def _resolve_broker_name(self, broker: AsyncBroker | str | None) -> str: if isinstance(broker, str): if broker not in self.brokers: raise ValueError(f"Unknown broker {broker!r}.") - return broker + return self.brokers[broker] if broker is not None: broker_name = getattr(broker, "broker_name", None) - if broker_name is not None and broker_name in self.brokers: - return broker_name - for registered_name, registered_broker in self.brokers.items(): + if isinstance(broker_name, str): + registered_broker = self.brokers.get(broker_name) + if registered_broker is broker: + return broker + for registered_broker in self.brokers.values(): if registered_broker is broker: - return registered_name + return registered_broker raise ValueError("Broker is not registered in this router.") if self.default_broker_name is None: raise ValueError("Router doesn't have registered brokers.") - return self.default_broker_name + return self.brokers[self.default_broker_name] + + def _resolve_broker_name(self, broker: AsyncBroker | str | None) -> str: + return self._resolve_broker(broker).broker_name + + def _resolve_task_name( + self, + task: str | AsyncTaskiqDecoratedTask[Any, Any], + ) -> str: + if isinstance(task, str): + return task + task_name = getattr(task, "task_name", None) + if isinstance(task_name, str): + return task_name + raise TypeError("Route task must be a task name or decorated task.") - def _broker_default_flow(self, broker_name: str) -> Flow | None: - return getattr(self.brokers[broker_name], "default_flow", None) + def _broker_default_flow(self, broker: AsyncBroker) -> FlowProtocol | None: + return getattr(broker, "default_flow", None) diff --git a/tests/test_router.py b/tests/test_router.py index 00b1883d..f20d7223 100644 --- a/tests/test_router.py +++ b/tests/test_router.py @@ -1,12 +1,28 @@ -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Mapping +from dataclasses import dataclass import pytest -from taskiq import Flow, TaskiqRouter, task_builder +from taskiq import Flow, FlowProtocol, TaskiqRouter, task_builder from taskiq.abc.broker import AsyncBroker from taskiq.message import BrokerMessage +@dataclass(frozen=True, slots=True) +class BrokerQueue: + """Broker-specific flow used to prove protocol-based routing.""" + + name: str + durable: bool = True + + def broker_options(self, broker_name: str) -> Mapping[str, object]: + """Return options for the target broker.""" + return { + "broker": broker_name, + "durable": self.durable, + } + + class RecordingBroker(AsyncBroker): """Broker that records sent messages and flows.""" @@ -15,9 +31,9 @@ def __init__( *, router: TaskiqRouter | None = None, broker_name: str | None = None, - default_flow: Flow | None = None, + default_flow: FlowProtocol | None = None, ) -> None: - self.sent: list[tuple[BrokerMessage, Flow | None]] = [] + self.sent: list[tuple[BrokerMessage, FlowProtocol | None]] = [] super().__init__( router=router, broker_name=broker_name, @@ -31,7 +47,7 @@ async def kick(self, message: BrokerMessage) -> None: async def kick_to_flow( self, message: BrokerMessage, - flow: Flow | None = None, + flow: FlowProtocol | None = None, ) -> None: """Record flow-aware send.""" self.sent.append((message, flow)) @@ -75,10 +91,12 @@ async def test_router_can_route_task_to_another_broker_flow() -> None: async def demo_task() -> None: return None - router.route_task("demo.task", broker="target", flow=flow) + route = router.route_task(demo_task, broker=target, flow=flow) await demo_task.kiq() + assert route.broker is target + assert route.broker_name == "target" assert source.sent == [] assert target.sent[0][0].task_name == "demo.task" assert target.sent[0][1] == flow @@ -95,14 +113,50 @@ async def test_kicker_route_override_wins_over_registered_route() -> None: async def demo_task() -> None: return None - router.route_task("demo.task", broker="first", flow=first_flow) + router.route_task(demo_task, broker=first, flow=first_flow) - await demo_task.kicker().with_route("second", second_flow).kiq() + await demo_task.kicker().with_route(second, second_flow).kiq() assert first.sent == [] assert second.sent[0][1] == second_flow +async def test_router_keeps_string_broker_lookup_for_compatibility() -> None: + router = TaskiqRouter() + source = RecordingBroker(router=router, broker_name="source") + target = RecordingBroker(router=router, broker_name="target") + flow = Flow("compat") + + @source.task(task_name="demo.task") + async def demo_task() -> None: + return None + + router.route_task("demo.task", broker="target", flow=flow) + + await demo_task.kicker().with_route("target", flow).kiq() + + assert target.sent[0][0].task_name == "demo.task" + assert target.sent[0][1] == flow + + +async def test_router_accepts_broker_specific_flow_protocol() -> None: + broker = RecordingBroker(broker_name="transport") + flow = BrokerQueue(name="critical", durable=False) + + @broker.task(task_name="demo.task") + async def demo_task() -> None: + return None + + await demo_task.kicker().with_flow(flow).kiq() + + assert isinstance(flow, FlowProtocol) + assert broker.sent[0][1] is flow + assert flow.broker_options("transport") == { + "broker": "transport", + "durable": False, + } + + async def test_kicker_can_prepare_invocation_for_later() -> None: broker = RecordingBroker() @@ -146,7 +200,7 @@ async def test_router_task_decorator_can_choose_broker_and_flow() -> None: target = RecordingBroker(router=router, broker_name="target") flow = Flow("target-flow") - @router.task("demo.task", broker="target", flow=flow) + @router.task("demo.task", broker=target, flow=flow) async def demo_task() -> None: return None From bd93e5bfe605c79be88bd29667c7f7bbde7ed586 Mon Sep 17 00:00:00 2001 From: "sergei.romanchuk" Date: Sat, 6 Jun 2026 16:11:15 +0200 Subject: [PATCH 3/5] feat(routing): separate task routing from broker transport Introduce router-owned task routing, flow subscriptions, and shared task declarations while preserving the existing broker-first task API. Add transport-neutral Flow contracts, explicit route/subscription semantics, multi-broker router ownership, task_builder base_cls support, prepared invocation route snapshots, and scheduler/requeue compatibility behavior. Document migration guidance and add executable routing examples for multiple brokers and shared task packages. --- docs/examples/router/multiple_brokers.py | 47 ++- docs/examples/router/shared_task_package.py | 40 ++- docs/guide/architecture-overview.md | 22 ++ docs/guide/routing-and-flows.md | 345 ++++++++++++++++++++ docs/guide/scheduling-tasks.md | 17 +- taskiq/__init__.py | 7 +- taskiq/abc/broker.py | 141 ++++++-- taskiq/flow.py | 39 ++- taskiq/kicker.py | 160 +++++---- taskiq/message.py | 60 +++- taskiq/router.py | 299 +++++++++-------- taskiq/routing/__init__.py | 17 + taskiq/routing/dispatcher.py | 84 +++++ taskiq/routing/models.py | 38 +++ taskiq/routing/references.py | 18 + taskiq/routing/registries.py | 106 ++++++ taskiq/routing/routes.py | 106 ++++++ taskiq/routing/subscriptions.py | 136 ++++++++ taskiq/scheduler/scheduler.py | 36 +- taskiq/task_builder.py | 58 +++- tests/routing/__init__.py | 0 tests/routing/models.py | 117 +++++++ tests/routing/test_dispatch.py | 324 ++++++++++++++++++ tests/routing/test_router_core.py | 74 +++++ tests/routing/test_shared_tasks.py | 94 ++++++ tests/routing/test_subscriptions.py | 186 +++++++++++ tests/scheduler/test_scheduler.py | 114 +++++++ tests/test_requeue.py | 90 ++++- tests/test_router.py | 218 ------------- tests/test_task_builder.py | 130 ++++++++ tests/utils.py | 37 ++- 31 files changed, 2652 insertions(+), 508 deletions(-) create mode 100644 docs/guide/routing-and-flows.md create mode 100644 taskiq/routing/__init__.py create mode 100644 taskiq/routing/dispatcher.py create mode 100644 taskiq/routing/models.py create mode 100644 taskiq/routing/references.py create mode 100644 taskiq/routing/registries.py create mode 100644 taskiq/routing/routes.py create mode 100644 taskiq/routing/subscriptions.py create mode 100644 tests/routing/__init__.py create mode 100644 tests/routing/models.py create mode 100644 tests/routing/test_dispatch.py create mode 100644 tests/routing/test_router_core.py create mode 100644 tests/routing/test_shared_tasks.py create mode 100644 tests/routing/test_subscriptions.py delete mode 100644 tests/test_router.py create mode 100644 tests/test_task_builder.py diff --git a/docs/examples/router/multiple_brokers.py b/docs/examples/router/multiple_brokers.py index 7d6977ae..c845bbf0 100644 --- a/docs/examples/router/multiple_brokers.py +++ b/docs/examples/router/multiple_brokers.py @@ -2,7 +2,7 @@ import asyncio -from taskiq import Flow, InMemoryBroker, TaskiqRouter +from taskiq import Flow, InMemoryBroker, TaskiqRoute, TaskiqRouter router = TaskiqRouter() @@ -35,6 +35,23 @@ async def send_email(user_id: int, template: str) -> str: broker=priority_broker, flow=priority_email_flow, ) +priority_subscription = router.subscribe( + priority_broker, + priority_email_flow, + send_email, +) + + +def _format_route(task_name: str, route: TaskiqRoute) -> str: + """Return a readable route diagnostic for the example output.""" + flow_name = route.flow.name if route.flow is not None else "" + return f"{task_name} -> broker={route.broker_name}, flow={flow_name}" + + +def _format_listen_plan() -> str: + """Return flows that the priority broker should subscribe to.""" + flow_names = ", ".join(flow.name for flow in priority_broker.get_subscribed_flows()) + return f"priority listens to: {flow_names}" async def _main() -> None: @@ -43,22 +60,34 @@ async def _main() -> None: try: direct_result = await send_email(7, "welcome") - routed_task = await send_email.kiq(7, "welcome") - routed_result = await routed_task.wait_result(timeout=2) + declared_route = router.resolve_route(send_email) + assert declared_route == priority_route - bulk_task = ( + routed_task = ( await send_email.kicker() - .with_route( - default_broker, - bulk_email_flow, + .with_route(declared_route) + .kiq( + 7, + "welcome", ) - .kiq(8, "digest") ) + routed_result = await routed_task.wait_result(timeout=2) + + bulk_route = router.resolve_route( + send_email, + broker=default_broker, + flow=bulk_email_flow, + ) + bulk_task = await send_email.kicker().with_route(bulk_route).kiq(8, "digest") bulk_result = await bulk_task.wait_result(timeout=2) print(f"Direct call: {direct_result}") - print(f"Declared route: {priority_route.broker_name}") + print(f"Router rule: {_format_route(send_email.task_name, priority_route)}") + print(f"Subscription tasks: {sorted(priority_subscription.task_names)}") + print(_format_listen_plan()) + print(f"Resolved route: {_format_route(send_email.task_name, declared_route)}") print(f"Routed call: {routed_result.return_value}") + print(f"Override route: {_format_route(send_email.task_name, bulk_route)}") print(f"Route override: {bulk_result.return_value}") finally: await priority_broker.shutdown() diff --git a/docs/examples/router/shared_task_package.py b/docs/examples/router/shared_task_package.py index 3a9e4902..0bbd6616 100644 --- a/docs/examples/router/shared_task_package.py +++ b/docs/examples/router/shared_task_package.py @@ -3,8 +3,15 @@ import asyncio from collections.abc import Mapping from dataclasses import dataclass - -from taskiq import Flow, InMemoryBroker, TaskiqRouter, task_builder +from typing import Any + +from taskiq import ( + AsyncTaskiqDecoratedTask, + Flow, + InMemoryBroker, + TaskiqRouter, + task_builder, +) @dataclass(frozen=True, slots=True) @@ -14,15 +21,22 @@ class BillingQueue: name: str priority: int - def broker_options(self, broker_name: str) -> Mapping[str, object]: + def broker_options(self) -> Mapping[str, object]: """Return options that a billing broker adapter can understand.""" return { - "broker": broker_name, "priority": self.priority, } -@task_builder("billing.calculate_total", domain="billing") +class BillingTask(AsyncTaskiqDecoratedTask[Any, Any]): + """Custom task class shared by billing package tasks.""" + + def billing_name(self) -> str: + """Return a billing-specific task name.""" + return self.task_name + + +@task_builder("billing.calculate_total", base_cls=BillingTask, domain="billing") async def calculate_total(price: int, quantity: int) -> int: """Package-level task definition that is not bound to any broker.""" return price * quantity @@ -44,6 +58,11 @@ async def calculate_total(price: int, quantity: int) -> int: broker=billing_broker, flow=billing_flow, ) +router.subscribe( + billing_broker, + billing_flow, + registered_calculate_total, +) async def _main() -> None: @@ -51,11 +70,15 @@ async def _main() -> None: try: direct_result = await calculate_total.call(19, 3) + priority_route = router.resolve_route( + registered_calculate_total, + broker=billing_broker, + flow=priority_billing_flow, + ) prepared_task = ( registered_calculate_total.kicker() .with_route( - billing_broker, - priority_billing_flow, + priority_route, ) .prepare(19, 3) ) @@ -64,6 +87,9 @@ async def _main() -> None: queued_result = await queued_task.wait_result(timeout=2) print(f"Shared task direct call: {direct_result}") + print(f"Registered task class: {registered_calculate_total.billing_name()}") + listen_flow = billing_broker.get_subscribed_flows()[0] + print(f"Registered listen flow: {listen_flow.name}") print(f"Prepared message: {prepared_task.message.task_name}") print(f"Registered queued call: {queued_result.return_value}") finally: diff --git a/docs/guide/architecture-overview.md b/docs/guide/architecture-overview.md index d4b825f1..f84437d2 100644 --- a/docs/guide/architecture-overview.md +++ b/docs/guide/architecture-overview.md @@ -82,6 +82,28 @@ asyncio.run(main()) ``` +## Router and flows + +Taskiq can use a `TaskiqRouter` to keep routing rules outside of broker +implementations. Brokers remain transport adapters, while the router owns task +registration, route resolution and flow subscriptions. +This section describes the `experiment/separate_broker` branch contract. The +old `@broker.task(...)`, `.kiq()`, labels, scheduler and result backend behavior +remain compatible, while router/flow APIs are additive review material for the +branch. + +`Flow` is a transport-neutral delivery address. Broker packages may provide +their own flow classes for queue, topic, subject or stream options, as long as +they implement the same flow protocol. The router deduplicates subscriptions by +flow name and rejects same-name flows with incompatible broker options. + +Routing and subscribing are separate responsibilities. `route_task(...)` +chooses the outbound broker and flow for task invocations. `subscribe(...)` +adds flows to a broker listen plan for flow-aware broker adapters. Worker task +lookup still uses `task_name`; flow does not select the Python task. + +Read more in the [Routing and flows](./routing-and-flows.md) section. + ## Messages Every message has labels. You can define labels diff --git a/docs/guide/routing-and-flows.md b/docs/guide/routing-and-flows.md new file mode 100644 index 00000000..208d2502 --- /dev/null +++ b/docs/guide/routing-and-flows.md @@ -0,0 +1,345 @@ +--- +title: Routing and flows +order: 10 +--- + +# Routing and flows + +::: warning Experimental branch +This page describes the routing and flow contract on the +`experiment/separate_broker` branch. The old task declaration API remains +compatible. Broker-specific packages may keep using their existing send/listen +implementation until they opt into flow-aware dispatch and listen-plan support. +::: + +Taskiq keeps the old task declaration API: + +```python +@broker.task(task_name="billing.charge") +async def charge(user_id: int) -> None: + ... +``` + +The broker is still the transport adapter, but routing decisions can be moved to +a `TaskiqRouter`. This is useful when one application owns several brokers or +when a task package declares tasks before the final application knows which +broker must handle them. + +The important responsibility split is: + +- Task declaration describes a Python callable, task name, labels and direct + call behavior. +- Kicker prepares a task invocation: task id, labels, args, kwargs and optional + route overrides. +- Router owns outbound routing policy: task invocation to broker and flow. +- Router also owns inbound listen-plan records for flow-aware brokers. +- Broker owns transport lifecycle: startup, shutdown, send, listen, ack/nack, + serializer/formatter, middleware and result backend integration. +- Worker still executes by `task_name`; flow does not select the Python task. + +## Router + +A router owns broker registration, task registration and routing rules. +One router can own several brokers. A broker has a single router owner; pass the +router when creating the broker instead of registering the same broker in +several routers. + +```python +from taskiq import Flow, InMemoryBroker, TaskiqRouter + +router = TaskiqRouter() + +default_broker = InMemoryBroker(router=router, broker_name="default") +priority_broker = InMemoryBroker(router=router, broker_name="priority") + +priority_route = router.route_task( + "billing.charge", + broker=priority_broker, + flow=Flow("billing.priority"), +) +``` + +If a broker is created without an explicit router, Taskiq creates a default +router for that broker. Use an explicit shared router when one worker or app +needs several brokers to follow the same routing policy. + +Pass broker objects to routing APIs. If broker names come from configuration, +resolve them explicitly first: + +```python +broker = router.get_broker("priority") +route = router.route_task("billing.charge", broker=broker) +``` + +This keeps routing errors local and avoids hidden string references. + +## Flow + +`Flow` is a transport-neutral delivery address. It is intentionally small. +The routing identity of a flow is its logical name: + +```python +from taskiq import Flow, FlowIdentity, get_flow_identity + +flow = Flow("billing.priority").with_options(priority=10) + +assert flow.identity == FlowIdentity("billing.priority") +assert get_flow_identity(flow) == flow.identity +``` + +Broker packages can expose their own flow objects when they need +transport-specific options: + +```python +from collections.abc import Mapping +from dataclasses import dataclass + + +@dataclass(frozen=True, slots=True) +class RabbitQueue: + name: str + durable: bool = True + + def broker_options(self) -> Mapping[str, object]: + return {"durable": self.durable} +``` + +The core protocol requires `name` and `broker_options()`. Broker-specific +settings stay in the flow object and are returned through `broker_options()`; +they are not encoded into router task names or message labels. + +Flow identity and declaration options are separate: + +- Same broker, same flow name and same `broker_options()` means the listen plan + can deduplicate the flow and merge task names. +- Same broker and same flow name with different `broker_options()` is a + conflict, so Taskiq raises `ValueError` instead of silently choosing one + queue/topic declaration. +- Different brokers may use the same flow name independently. + +## Sending with a route + +`router.route_task()` returns a `TaskiqRoute`. You can pass that object to a +kicker when you want the route to be explicit at the call site. + +```python +route = router.resolve_route("billing.charge") +task = await charge.kicker().with_route(route).kiq(42) +``` + +For one-off overrides, resolve a route with the target broker and flow: + +```python +bulk_route = router.resolve_route( + "billing.charge", + broker=default_broker, + flow=Flow("billing.bulk"), +) + +task = await charge.kicker().with_route(bulk_route).kiq(42) +``` + +`with_broker(...)` remains available for shared-task style invocation. When a +kicker has both a route and a later broker override, the broker override clears +the explicit route. Taskiq then resolves the task for the override broker: it +uses that broker's task route flow when one exists, otherwise the broker +`default_flow`. + +`prepare(...)` creates a prepared invocation for later use. Prepared invocations +keep a snapshot of the route that was resolved at prepare time, so later changes +to the mutable kicker object or router route table do not change that prepared +send. + +## Subscriptions + +Routers also keep inbound flow subscriptions. Routing and subscribing are +separate operations: `route_task()` selects the outbound broker/flow for a task +invocation, while `subscribe()` adds a flow to the broker listen plan. + +```python +billing_flow = Flow("billing.priority") + +route = router.route_task( + "billing.charge", + broker=priority_broker, + flow=billing_flow, +) + +router.subscribe(priority_broker, billing_flow, "billing.charge") + +subscriptions = router.get_subscriptions(priority_broker) +flows = priority_broker.get_subscribed_flows() +``` + +`route_task(...)` does not subscribe by default. This is intentional: outbound +routing answers "where should this invocation be sent?", while subscribing +answers "which flows should this broker listen to?". + +The deprecated `route_task(..., subscribe=True)` shim still performs this +subscription when a flow is resolved, but new code should call `subscribe()` +directly. + +Existing brokers can keep implementing `listen()` as before. New flow-aware +brokers may use `get_subscribed_flows()` to subscribe to queues, topics, +subjects or streams while the routing rules stay in the router. + +`get_subscribed_flows()` returns the broker `default_flow` plus explicit +subscriptions. It deduplicates by flow identity and checks broker options for +conflicts. + +## Scheduler and requeue + +`ScheduledTask` remains a transport-neutral invocation payload. It stores task +name, labels, args, kwargs, optional task id and schedule timing fields; it does +not store broker objects, broker names, route objects, flows or +transport-specific flow options. + +When a scheduled task is ready, `TaskiqScheduler` resolves the task route in the +scheduler process through the scheduler broker's router. Route changes made +before `on_ready(...)` runs affect that scheduled dispatch. If no router route +exists for the scheduled task, Taskiq keeps the old behavior and sends through +the scheduler broker. + +`with_broker(...)`, `with_route(...)` and `with_flow(...)` on a kicker are not +persisted schedule route metadata. They can affect `CreatedSchedule.kiq()`, +which is an immediate queued invocation helper, but they do not change the +stored schedule payload. + +`Context.requeue()` is current-broker sticky. It sends the same task message +through the broker currently executing the task, even if the task's default +route now points to another broker. Requeue flow selection is explicit +low-level override, same-broker task route flow, current broker `default_flow`, +then no flow. Taskiq does not currently preserve the inbound source flow because +flow provenance is not stored in `TaskiqMessage` or `Context`. + +## Shared task declarations + +Libraries can declare tasks without importing an application broker: + +```python +from taskiq import task_builder + + +@task_builder("billing.calculate_total", domain="billing") +async def calculate_total(price: int, quantity: int) -> int: + return price * quantity +``` + +The final application binds the task definition to its router and broker: + +```python +billing_flow = Flow("billing.tasks") + +registered = router.register_task( + calculate_total, + broker=billing_broker, + flow=billing_flow, +) + +router.subscribe( + billing_broker, + billing_flow, + registered, +) + +result = await registered.kiq(19, 3) +``` + +Task name and labels belong to the shared declaration. Register a +`TaskDefinition` without `task_name` or label overrides; Taskiq rejects those +overrides instead of silently ignoring them. + +The unbound task can still run locally: + +```python +total = await calculate_total.call(19, 3) +``` + +Shared task declarations can also provide a custom task class for one task: + +```python +from typing import Any + +from taskiq import AsyncTaskiqDecoratedTask, task_builder + + +class TracingTask(AsyncTaskiqDecoratedTask[Any, Any]): + def tracing_name(self) -> str: + return self.task_name + + +@task_builder("billing.traced_charge", base_cls=TracingTask) +async def traced_charge(user_id: int) -> None: + ... +``` + +When the final application registers this definition, Taskiq creates the bound +task using `TracingTask`. If `base_cls` is not provided, Taskiq uses the native +decorated task class. + +For low-level integrations, `TaskDefinition.message(...)` builds a +`TaskiqMessage` without binding the task. It uses the same argument and label +preparation contract as a normal `.kicker().prepare(...)` invocation. + +## Compatibility and migration notes + +The routing/flow branch keeps the old broker-first task API working: + +```python +@broker.task(task_name="billing.charge") +async def charge(user_id: int) -> None: + ... + +await charge.kiq(42) +await charge.kicker().with_labels(source="api").kiq(42) +``` + +Migration guidance for application code: + +- Keep using `@broker.task(...)` when one broker owns task declaration and send + path. No router code is required for this case. +- Use one shared `TaskiqRouter` when several brokers must share task registry + and routing policy. +- Pass broker objects to routing APIs. If configuration stores broker names, + resolve them with `router.get_broker(name)` before calling `route_task(...)`. +- Replace `route_task(..., subscribe=True)` usage with explicit + `router.subscribe(...)`. +- Do not use labels as a transport routing schema. Labels remain task metadata. +- Register a `TaskDefinition` without `task_name` or label overrides; plain + callable registration still supports those overrides. + +Migration guidance for broker packages: + +- Existing brokers that implement only `kick(message)` and `listen()` remain + valid. The default `kick_to_flow(message, flow=None)` calls `kick(message)`. +- Flow-aware brokers can override `kick_to_flow(...)` and map `flow.name` to a + queue, topic, subject or stream. +- Flow-aware listeners can call `get_subscribed_flows()` during listen setup to + add router-owned subscriptions while keeping old default listen behavior when + there are no explicit subscriptions. +- Broker-specific declaration settings belong in broker-specific flow objects + and `broker_options()`, not in task names, labels or scheduler payloads. +- Ack/nack behavior, serializer/formatter behavior, retry, middleware hooks and + result backend ownership remain broker responsibilities. + +Scheduler and requeue compatibility: + +- `ScheduledTask` does not store routes, broker names, flow objects or + broker-specific flow options in this iteration. +- Scheduled dispatch resolves the route late in the scheduler process through + the scheduler broker's router. +- `Context.requeue()` sends through the broker currently executing the task. It + does not re-resolve a task route across brokers. +- Taskiq does not preserve inbound source-flow provenance yet because that data + is not stored in `TaskiqMessage` or `Context`. + +## Examples + +See executable examples for the two main usage shapes: + +- [Multiple brokers and explicit routes](../examples/router/multiple_brokers.py) + shows one task routed through several brokers with explicit subscription + ownership. +- [Shared task package](../examples/router/shared_task_package.py) shows + `task_builder(...)`, a custom `base_cls`, later registration and a prepared + queued invocation. diff --git a/docs/guide/scheduling-tasks.md b/docs/guide/scheduling-tasks.md index 44e646f2..ab6ce848 100644 --- a/docs/guide/scheduling-tasks.md +++ b/docs/guide/scheduling-tasks.md @@ -159,13 +159,26 @@ If you want to pass additional labels, you can call these methods on the `Kicker ) ``` -::: warning Cool warning! +::: warning Routing warning -The `with_broker` method won't do anything in this case, since we have a broker assigned to each scheduler. +`with_broker`, `with_route` and `with_flow` are immediate invocation settings. +They are not stored in the `ScheduledTask` payload. Scheduled dispatch resolves +the route later in the scheduler process through the scheduler broker's router. +Use router configuration available to the scheduler process when a scheduled +task must go to a specific broker or flow. ::: +The scheduler keeps route data out of schedule sources on purpose. A stored +`ScheduledTask` remains a transport-neutral payload with task name, labels, +args, kwargs and schedule timing fields. If you change a router rule before a +schedule becomes ready, the scheduler uses the updated route at send time. If no +route exists, the scheduler sends through its own broker as before. + Each of these methods return you an instance of the `CreatedSchedule` class. This object has unique schedule ID and some helper methods. For example, you can use the `unschedule` method to remove the schedule from the source. +Calling `kiq()` on a `CreatedSchedule` is an immediate queued invocation helper; +it uses the original kicker settings for that immediate send and does not change +the stored schedule payload. ```python schedule = await my_task.schedule_by_time( diff --git a/taskiq/__init__.py b/taskiq/__init__.py index 04d82920..6f328069 100644 --- a/taskiq/__init__.py +++ b/taskiq/__init__.py @@ -24,7 +24,7 @@ TaskiqError, TaskiqResultTimeoutError, ) -from taskiq.flow import Flow, FlowProtocol +from taskiq.flow import Flow, FlowIdentity, FlowProtocol, get_flow_identity from taskiq.funcs import gather from taskiq.kicker import PreparedKiq from taskiq.message import BrokerMessage, TaskiqMessage @@ -34,7 +34,7 @@ SmartRetryMiddleware, ) from taskiq.result import TaskiqResult -from taskiq.router import TaskiqRoute, TaskiqRouter +from taskiq.router import TaskiqRoute, TaskiqRouter, TaskiqSubscription from taskiq.scheduler.scheduled_task import ScheduledTask from taskiq.scheduler.scheduler import TaskiqScheduler from taskiq.state import TaskiqState @@ -52,6 +52,7 @@ "BrokerMessage", "Context", "Flow", + "FlowIdentity", "FlowProtocol", "InMemoryBroker", "NoResultError", @@ -78,9 +79,11 @@ "TaskiqRouter", "TaskiqScheduler", "TaskiqState", + "TaskiqSubscription", "ZeroMQBroker", "__version__", "async_shared_broker", "gather", + "get_flow_identity", "task_builder", ] diff --git a/taskiq/abc/broker.py b/taskiq/abc/broker.py index 42a1817b..a25a98a6 100644 --- a/taskiq/abc/broker.py +++ b/taskiq/abc/broker.py @@ -176,6 +176,16 @@ def get_all_tasks(self) -> dict[str, AsyncTaskiqDecoratedTask[Any, Any]]: **self.local_task_registry, } + def get_subscribed_flows(self) -> tuple[FlowProtocol, ...]: + """ + Return flows this broker should subscribe to. + + Existing brokers can keep their current `listen` implementation. New + flow-aware brokers may use this method to configure queue, topic, + stream or subject subscriptions from router-owned rules. + """ + return self.router.get_broker_flows(self) + def add_dependency_context(self, new_ctx: dict[Any, Any]) -> None: """ Add first-level dependencies. @@ -351,40 +361,12 @@ def make_decorated_task( def inner( func: Callable[_FuncParams, _ReturnType], ) -> AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]: - nonlocal inner_task_name - if inner_task_name is None: - fmodule = func.__module__ - if fmodule == "__main__": # pragma: no cover - fmodule = ".".join( - os.path.normpath(sys.argv[0]) - .removesuffix(".py") - .split(os.path.sep), - ) - fname = func.__name__ - if fname == "": - fname = f"lambda_{uuid4().hex}" - inner_task_name = f"{fmodule}:{fname}" - wrapper = wraps(func) - - sign = get_type_hints(func) - return_type = None - if "return" in sign: - return_type = sign["return"] - - decorated_task = wrapper( - self.decorator_class( - broker=self, - original_func=func, - labels=inner_labels, - task_name=inner_task_name, - return_type=return_type, # type: ignore - ), + return self._decorate_task( + func, + task_name=inner_task_name, + labels=inner_labels, ) - self._register_task(decorated_task.task_name, decorated_task) # type: ignore - - return decorated_task # type: ignore - return inner if callable(task_name): @@ -423,12 +405,87 @@ def register_task( :returns: registered task. """ if isinstance(func, TaskDefinition): + if task_name is not None or labels: + raise ValueError( + "TaskDefinition already defines task_name and labels. " + "Register it without task_name or label overrides.", + ) return self.router.register_task( func, broker=self, flow=self.default_flow, ) - return self.task(task_name=task_name, **labels)(func) + return self._decorate_task( + func, + task_name=task_name, + labels=labels or {}, + ) + + def bind_task_definition( + self, + task: TaskDefinition[_FuncParams, _ReturnType], + *, + register: bool = True, + ) -> AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]: + """Bind an unbound task definition to this broker.""" + task_cls = AsyncTaskiqDecoratedTask if task.base_cls is None else task.base_cls + return self._decorate_task( + task.original_func, + task_name=task.task_name, + labels=dict(task.labels), + base_cls=task_cls, + register=register, + ) + + def _decorate_task( + self, + func: Callable[_FuncParams, _ReturnType], + task_name: str | None, + labels: dict[str, Any], + base_cls: type[AsyncTaskiqDecoratedTask[Any, Any]] | None = None, + *, + register: bool = True, + ) -> AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]: + """Build and register a decorated task object.""" + real_task_name = task_name + if real_task_name is None: + fmodule = func.__module__ + if fmodule == "__main__": # pragma: no cover + fmodule = ".".join( + os.path.normpath(sys.argv[0]) + .removesuffix(".py") + .split(os.path.sep), + ) + fname = func.__name__ + if fname == "": + fname = f"lambda_{uuid4().hex}" + real_task_name = f"{fmodule}:{fname}" + + task_cls = self.decorator_class if base_cls is None else base_cls + if not isinstance(task_cls, type) or not issubclass( + task_cls, + AsyncTaskiqDecoratedTask, + ): + raise TypeError("base_cls must be a subclass of AsyncTaskiqDecoratedTask.") + + sign = get_type_hints(func) + return_type = None + if "return" in sign: + return_type = sign["return"] + + decorated_task = wraps(func)( + task_cls( + broker=self, + original_func=func, + labels=labels, + task_name=real_task_name, + return_type=return_type, # type: ignore + ), + ) + + if register: + self._register_task(decorated_task.task_name, decorated_task) # type: ignore + return decorated_task # type: ignore def on_event(self, *events: TaskiqEvents) -> Callable[[EventHandler], EventHandler]: """ @@ -580,12 +637,28 @@ def _register_task( """ if task.broker != self: raise TaskBrokerMismatchError(broker=task.broker) - self.local_task_registry[task_name] = task self.router.register_task( task, broker=self, flow=self.default_flow, ) + self._store_task(task_name, task) + + def _store_task( + self, + task_name: str, + task: AsyncTaskiqDecoratedTask[Any, Any], + ) -> None: + """ + Store a decorated task in this broker. + + Router-managed binding uses this method after router state has accepted + the task, so shared task registration does not re-enter router + registration while it is already in progress. + """ + if task.broker != self: + raise TaskBrokerMismatchError(broker=task.broker) + self.local_task_registry[task_name] = task async def __aenter__(self) -> None: """Starts the broker as ctx manager.""" diff --git a/taskiq/flow.py b/taskiq/flow.py index 8754492e..1e5c06b1 100644 --- a/taskiq/flow.py +++ b/taskiq/flow.py @@ -1,17 +1,28 @@ from collections.abc import Mapping from dataclasses import dataclass, field, replace +from types import MappingProxyType from typing import Protocol, runtime_checkable -__all__ = ("Flow", "FlowProtocol") +__all__ = ("Flow", "FlowIdentity", "FlowProtocol", "get_flow_identity") + + +@dataclass(frozen=True, slots=True) +class FlowIdentity: + """Stable identity for one logical flow in routing and listen plans.""" + + name: str @runtime_checkable class FlowProtocol(Protocol): """Transport-neutral flow contract accepted by routers and brokers.""" - name: str + @property + def name(self) -> str: + """Return transport-neutral flow name.""" + ... - def broker_options(self, broker_name: str) -> Mapping[str, object]: + def broker_options(self) -> Mapping[str, object]: """Return options relevant for a concrete broker implementation.""" ... @@ -33,10 +44,30 @@ class Flow: hash=False, ) + def __post_init__(self) -> None: + """Freeze options to keep Flow a stable value object.""" + object.__setattr__(self, "options", MappingProxyType(dict(self.options))) + def with_options(self, **options: object) -> "Flow": """Return the same flow with additional generic options.""" return replace(self, options={**self.options, **options}) - def broker_options(self, broker_name: str) -> Mapping[str, object]: + @property + def identity(self) -> FlowIdentity: + """Return routing and subscription identity for this flow.""" + return get_flow_identity(self) + + def broker_options(self) -> Mapping[str, object]: """Return transport options for broker-specific implementations.""" return dict(self.options) + + +def get_flow_identity(flow: FlowProtocol) -> FlowIdentity: + """ + Return stable identity for a flow. + + Flow identity is intentionally transport-neutral and based on the logical + flow name. Broker-specific declaration options are validated separately by + the subscription plan. + """ + return FlowIdentity(name=flow.name) diff --git a/taskiq/kicker.py b/taskiq/kicker.py index d0182e4c..8debca3e 100644 --- a/taskiq/kicker.py +++ b/taskiq/kicker.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Coroutine -from dataclasses import asdict, is_dataclass +from dataclasses import replace from datetime import datetime, timedelta from logging import getLogger from types import CoroutineType @@ -15,16 +15,12 @@ overload, ) -from pydantic import BaseModel - from taskiq.abc.middleware import TaskiqMiddleware from taskiq.abc.result_backend import AsyncResultBackend -from taskiq.compat import model_dump from taskiq.exceptions import SendTaskError from taskiq.flow import FlowProtocol -from taskiq.labels import prepare_label -from taskiq.message import TaskiqMessage -from taskiq.router import TaskiqRouter +from taskiq.message import TaskiqMessage, _build_taskiq_message +from taskiq.router import TaskiqRoute, TaskiqRouter from taskiq.scheduler.created_schedule import CreatedSchedule from taskiq.scheduler.scheduled_task import CronSpec, ScheduledTask from taskiq.task import AsyncTaskiqTask @@ -48,13 +44,25 @@ def __init__( self, kicker: AsyncKicker[..., _ReturnType], message: TaskiqMessage, + broker: AsyncBroker, + route: TaskiqRoute | None, + flow: FlowProtocol | None, ) -> None: self.kicker = kicker self.message = message + self.broker = broker + self.route = route + self.flow = flow async def kiq(self) -> AsyncTaskiqTask[_ReturnType]: """Send prepared invocation.""" - return await self.kicker.kiq_message(self.message) + return await self.kicker.kiq_message( + self.message, + broker=self.broker, + route=self.route, + flow=self.flow, + use_current_route=False, + ) class AsyncKicker(Generic[_FuncParams, _ReturnType]): @@ -73,8 +81,9 @@ def __init__( self.custom_task_id: str | None = None self.custom_schedule_id: str | None = None self.return_type = return_type - self.route_broker: AsyncBroker | str | None = None + self.route: TaskiqRoute | None = None self.route_flow: FlowProtocol | None = None + self._broker_overridden = False def with_labels( self, @@ -132,7 +141,9 @@ def with_broker( :return: Kicker with new broker. """ self.broker = broker - self.route_broker = broker + self.route = None + self.route_flow = None + self._broker_overridden = True return self def with_flow( @@ -145,25 +156,23 @@ def with_flow( :param flow: flow to send message to. :return: Kicker with a route flow override. """ + if self.route is not None: + self.route = replace(self.route, flow=flow) self.route_flow = flow return self def with_route( self, - broker: AsyncBroker | str, - flow: FlowProtocol | None, + route: TaskiqRoute, ) -> AsyncKicker[_FuncParams, _ReturnType]: """ - Replace broker and flow for the current invocation. + Replace route for the current invocation. - :param broker: broker instance or broker name. - Broker instances are preferred; names are kept for configuration - and backward-compatible lookup. - :param flow: flow to send message to. + :param route: route to send message through. :return: Kicker with a route override. """ - self.route_broker = broker - self.route_flow = flow + self.route = route + self.route_flow = route.flow return self def prepare( @@ -178,7 +187,14 @@ def prepare( :param kwargs: function's key word arguments. :return: prepared task invocation. """ - return PreparedKiq(self, self._prepare_message(*args, **kwargs)) + broker, route, flow = self._prepare_route_snapshot() + return PreparedKiq( + self, + self._prepare_message(*args, **kwargs), + broker=broker, + route=route, + flow=flow, + ) @overload async def kiq( @@ -230,29 +246,56 @@ async def kiq( async def kiq_message( self, message: TaskiqMessage, + *, + broker: AsyncBroker | None = None, + route: TaskiqRoute | None = None, + flow: FlowProtocol | None = None, + use_current_route: bool = True, ) -> AsyncTaskiqTask[_ReturnType]: """Send a prepared message.""" try: - router = getattr(self.broker, "router", None) + target_broker = broker or self.broker + if use_current_route: + target_route = self.route if route is None else route + target_flow = self.route_flow if flow is None else flow + else: + target_route = route + target_flow = flow + if target_route is not None: + if broker is not None and broker is not target_route.broker: + raise ValueError("Pass either route or broker override.") + target_broker = target_route.broker + target_flow = None + router = getattr(target_broker, "router", None) if isinstance(router, TaskiqRouter): + broker_override = None + if target_route is None and ( + broker is not None or self._broker_overridden + ): + broker_override = target_broker return await router.kiq( message, - broker=self.route_broker, - flow=self.route_flow, + route=target_route, + broker=broker_override, + flow=target_flow, return_type=self.return_type, ) - return await self._legacy_kiq(message) + return await self._legacy_kiq(target_broker, message) except Exception as exc: raise SendTaskError from exc - async def _legacy_kiq(self, message: TaskiqMessage) -> AsyncTaskiqTask[_ReturnType]: + async def _legacy_kiq( + self, + broker: AsyncBroker, + message: TaskiqMessage, + ) -> AsyncTaskiqTask[_ReturnType]: """ Send message through the pre-router broker path. This keeps middleware tests and external broker-like mocks compatible while real AsyncBroker instances use TaskiqRouter. """ - middlewares = getattr(self.broker, "middlewares", []) + middlewares = getattr(broker, "middlewares", []) if not isinstance(middlewares, list): middlewares = [] @@ -260,7 +303,7 @@ async def _legacy_kiq(self, message: TaskiqMessage) -> AsyncTaskiqTask[_ReturnTy if middleware.__class__.pre_send != TaskiqMiddleware.pre_send: message = await maybe_awaitable(middleware.pre_send(message)) - await self.broker.kick(self.broker.formatter.dumps(message)) + await broker.kick(broker.formatter.dumps(message)) for middleware in reversed(middlewares): if middleware.__class__.post_send != TaskiqMiddleware.post_send: @@ -270,11 +313,30 @@ async def _legacy_kiq(self, message: TaskiqMessage) -> AsyncTaskiqTask[_ReturnTy task_id=message.task_id, result_backend=cast( AsyncResultBackend[_ReturnType], - self.broker.result_backend, + broker.result_backend, ), return_type=self.return_type, ) + def _prepare_route_snapshot( + self, + ) -> tuple[AsyncBroker, TaskiqRoute | None, FlowProtocol | None]: + """Resolve the route that a prepared invocation must keep.""" + router = getattr(self.broker, "router", None) + if not isinstance(router, TaskiqRouter): + return self.broker, None, self.route_flow + + if self.route is not None: + return self.route.broker, self.route, None + + broker_override = self.broker if self._broker_overridden else None + route = router.resolve_route( + self.task_name, + broker=broker_override, + flow=self.route_flow, + ) + return route.broker, route, None + async def schedule_by_cron( self, source: ScheduleSource, @@ -378,27 +440,6 @@ async def schedule_by_time( await source.add_schedule(scheduled) return CreatedSchedule(self, source, scheduled) - @classmethod - def _prepare_arg(cls, arg: Any) -> Any: - """ - Parses argument if possible. - - This function is used to construct dicts - from pydantic models or dataclasses. - - :param arg: argument to format. - :return: Formatted argument. - """ - if isinstance(arg, BaseModel): - arg = model_dump(arg) - if is_dataclass(arg): - if isinstance(arg, type): - raise ValueError( - f"Cannot serialize types. The {arg} is not serializable.", - ) - arg = asdict(arg) - return arg - def _prepare_message( self, *args: Any, @@ -411,27 +452,14 @@ def _prepare_message( :param kwargs: function's kwargs. :return: constructed message. """ - formatted_args = [] - formatted_kwargs = {} - labels = {} - labels_types = {} - for arg in args: - formatted_args.append(self._prepare_arg(arg)) - for kwarg_name, kwarg_val in kwargs.items(): - formatted_kwargs[kwarg_name] = self._prepare_arg(kwarg_val) - - for label, label_val in self.labels.items(): - labels[label], labels_types[label] = prepare_label(label_val) - task_id = self.custom_task_id if task_id is None: task_id = self.broker.id_generator() - return TaskiqMessage( + return _build_taskiq_message( task_id=task_id, task_name=self.task_name, - labels=labels, - labels_types=labels_types, - args=formatted_args, - kwargs=formatted_kwargs, + labels=self.labels, + args=args, + kwargs=kwargs, ) diff --git a/taskiq/message.py b/taskiq/message.py index 39008eb8..00ed4791 100644 --- a/taskiq/message.py +++ b/taskiq/message.py @@ -1,8 +1,11 @@ +from collections.abc import Iterable, Mapping +from dataclasses import asdict, is_dataclass from typing import Any from pydantic import BaseModel -from taskiq.labels import parse_label +from taskiq.compat import model_dump +from taskiq.labels import parse_label, prepare_label class TaskiqMessage(BaseModel): @@ -35,6 +38,61 @@ def parse_labels(self) -> None: self.labels[label] = parse_label(self.labels[label], label_type) +def _prepare_message_arg(arg: Any) -> Any: + """ + Prepare invocation argument for message payloads. + + :param arg: argument to prepare. + :return: serializable argument representation. + """ + if isinstance(arg, BaseModel): + arg = model_dump(arg) + if is_dataclass(arg): + if isinstance(arg, type): + raise ValueError( + f"Cannot serialize types. The {arg} is not serializable.", + ) + arg = asdict(arg) + return arg + + +def _build_taskiq_message( + task_id: str, + task_name: str, + labels: Mapping[str, Any], + args: Iterable[Any], + kwargs: Mapping[str, Any], +) -> TaskiqMessage: + """ + Build a taskiq message using the common invocation contract. + + :param task_id: task id. + :param task_name: task name. + :param labels: task invocation labels. + :param args: task positional arguments. + :param kwargs: task keyword arguments. + :return: prepared taskiq message. + """ + prepared_labels: dict[str, Any] = {} + labels_types: dict[str, int] = {} + for label_name, label_value in labels.items(): + prepared_labels[label_name], labels_types[label_name] = prepare_label( + label_value, + ) + + return TaskiqMessage( + task_id=task_id, + task_name=task_name, + labels=prepared_labels, + labels_types=labels_types, + args=[_prepare_message_arg(arg) for arg in args], + kwargs={ + kwarg_name: _prepare_message_arg(kwarg_value) + for kwarg_name, kwarg_value in kwargs.items() + }, + ) + + class BrokerMessage(BaseModel): """Format of messages for brokers.""" diff --git a/taskiq/router.py b/taskiq/router.py index 9f770f90..b62d16d3 100644 --- a/taskiq/router.py +++ b/taskiq/router.py @@ -1,51 +1,81 @@ from __future__ import annotations +import warnings from collections.abc import Callable -from dataclasses import dataclass -from logging import getLogger -from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast, overload +from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, overload -from taskiq.abc.middleware import TaskiqMiddleware -from taskiq.abc.result_backend import AsyncResultBackend from taskiq.flow import FlowProtocol from taskiq.message import TaskiqMessage +from taskiq.routing import ( + BrokerRegistry, + RouterDispatcher, + RouteRegistry, + SubscriptionPlan, + TaskiqRoute, + TaskiqSubscription, + TaskRegistry, +) +from taskiq.routing.references import resolve_task_name from taskiq.task import AsyncTaskiqTask from taskiq.task_builder import TaskDefinition -from taskiq.utils import maybe_awaitable +from taskiq.warnings import TaskiqDeprecationWarning if TYPE_CHECKING: # pragma: no cover from taskiq.abc.broker import AsyncBroker from taskiq.decor import AsyncTaskiqDecoratedTask -__all__ = ("TaskiqRoute", "TaskiqRouter") +__all__ = ("TaskiqRoute", "TaskiqRouter", "TaskiqSubscription") + +TaskiqRoute.__module__ = __name__ +TaskiqSubscription.__module__ = __name__ _FuncParams = ParamSpec("_FuncParams") _ReturnType = TypeVar("_ReturnType") -logger = getLogger("taskiq.router") +class TaskiqRouter: + """Facade for task registry, routing policy, subscriptions and dispatch.""" -@dataclass(frozen=True, slots=True) -class TaskiqRoute: - """Resolved outbound route for a task invocation.""" + def __init__(self) -> None: + self._brokers = BrokerRegistry() + self._tasks = TaskRegistry() + self._routes = RouteRegistry(self._brokers) + self._subscriptions = SubscriptionPlan(self._brokers) + self._dispatcher = RouterDispatcher(self._routes) - broker: AsyncBroker - flow: FlowProtocol | None = None + @property + def brokers(self) -> dict[str, AsyncBroker]: + """Return mutable broker registry for compatibility.""" + return self._brokers.brokers @property - def broker_name(self) -> str: - """Return registered broker name for compatibility and diagnostics.""" - return self.broker.broker_name + def default_broker(self) -> AsyncBroker | None: + """Return default broker for compatibility.""" + return self._brokers.default_broker + @default_broker.setter + def default_broker(self, broker: AsyncBroker | None) -> None: + self._brokers.default_broker = broker -class TaskiqRouter: - """Registry and routing layer shared by one or more brokers.""" + @property + def task_registry(self) -> dict[str, AsyncTaskiqDecoratedTask[Any, Any]]: + """Return mutable task registry for compatibility.""" + return self._tasks.tasks - def __init__(self) -> None: - self.brokers: dict[str, AsyncBroker] = {} - self.default_broker_name: str | None = None - self.task_registry: dict[str, AsyncTaskiqDecoratedTask[Any, Any]] = {} - self.routes: dict[str, TaskiqRoute] = {} + @property + def routes(self) -> dict[str, TaskiqRoute]: + """Return mutable route registry for compatibility.""" + return self._routes.routes + + @property + def subscriptions(self) -> list[TaskiqSubscription]: + """Return mutable subscription registry for compatibility.""" + return self._subscriptions.subscriptions + + @property + def default_broker_name(self) -> str | None: + """Return default broker name for compatibility and diagnostics.""" + return self._brokers.default_broker_name def set_broker( self, @@ -54,28 +84,30 @@ def set_broker( default_flow: FlowProtocol | None = None, ) -> str: """Register broker as a transport in this router.""" - broker_name = name or broker.__class__.__name__ - registered = self.brokers.get(broker_name) - if registered is not None and registered is not broker: + if getattr(broker, "router", self) is not self: raise ValueError( - f"Broker name {broker_name!r} is already registered. " - "Please provide an explicit unique broker_name.", + "Broker is attached to another router. " + "Pass router=... when creating the broker.", ) - self.brokers[broker_name] = broker - if self.default_broker_name is None: - self.default_broker_name = broker_name + broker_name = self._brokers.register(broker, name=name) + if default_flow is not None: + broker.default_flow = default_flow return broker_name + def get_broker(self, name: str) -> AsyncBroker: + """Return a broker by registered name.""" + return self._brokers.get(name) + def find_task( self, task_name: str, ) -> AsyncTaskiqDecoratedTask[Any, Any] | None: """Find a task by name.""" - return self.task_registry.get(task_name) + return self._tasks.find(task_name) def get_all_tasks(self) -> dict[str, AsyncTaskiqDecoratedTask[Any, Any]]: """Return all tasks registered in this router.""" - return dict(self.task_registry) + return self._tasks.get_all() def register_task( self, @@ -83,35 +115,36 @@ def register_task( AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType] | TaskDefinition[_FuncParams, _ReturnType] ), - broker: AsyncBroker | str | None = None, + broker: AsyncBroker | None = None, flow: FlowProtocol | None = None, ) -> AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]: """Register a bound task or bind a task definition to a broker.""" if isinstance(task, TaskDefinition): - target_broker = self._resolve_broker(broker) - registered_task = target_broker.register_task( - task.original_func, - task_name=task.task_name, - **task.labels, + target_broker = self._brokers.resolve(broker) + registered_task = target_broker.bind_task_definition( + task, + register=False, + ) + self._register_bound_task( + registered_task, + broker=target_broker, + flow=flow, + ) + # Router and broker share this internal registration boundary. + target_broker._store_task( # noqa: SLF001 + registered_task.task_name, + registered_task, ) - if flow is not None: - self.route_task(task.task_name, broker=target_broker, flow=flow) return registered_task - self.task_registry[task.task_name] = task - route_broker: AsyncBroker | str | None = broker - if route_broker is None: - route_broker = getattr(task, "broker", None) - if route_broker is not None or flow is not None: - self.route_task(task.task_name, broker=route_broker, flow=flow) - return task + return self._register_bound_task(task, broker=broker, flow=flow) @overload def task( self, task_name: Callable[_FuncParams, _ReturnType], *, - broker: AsyncBroker | str | None = None, + broker: AsyncBroker | None = None, flow: FlowProtocol | None = None, **labels: Any, ) -> AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]: ... @@ -121,7 +154,7 @@ def task( self, task_name: str | None = None, *, - broker: AsyncBroker | str | None = None, + broker: AsyncBroker | None = None, flow: FlowProtocol | None = None, **labels: Any, ) -> Callable[ @@ -133,7 +166,7 @@ def task( self, task_name: str | Callable[_FuncParams, _ReturnType] | None = None, *, - broker: AsyncBroker | str | None = None, + broker: AsyncBroker | None = None, flow: FlowProtocol | None = None, **labels: Any, ) -> Any: @@ -142,16 +175,15 @@ def task( def register( func: Callable[_FuncParams, _ReturnType], ) -> AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]: - target_broker = self._resolve_broker(broker) + target_broker = self._brokers.resolve(broker) real_task_name = task_name if not callable(task_name) else None task = target_broker.task(task_name=real_task_name, **labels)(func) if flow is not None: - self.route_task(task.task_name, broker=target_broker, flow=flow) + self.route_task(task, broker=target_broker, flow=flow) return task if callable(task_name): function = task_name - task_name = None return register(function) return register @@ -159,77 +191,74 @@ def register( def route_task( self, task: str | AsyncTaskiqDecoratedTask[Any, Any], - broker: AsyncBroker | str | None = None, + broker: AsyncBroker | None = None, flow: FlowProtocol | None = None, + *, + subscribe: bool = False, ) -> TaskiqRoute: """Set default outbound route for a task.""" - task_name = self._resolve_task_name(task) - target_broker = self._resolve_broker(broker) - route = TaskiqRoute(broker=target_broker, flow=flow) - self.routes[task_name] = route + task_name = resolve_task_name(task) + route = self._routes.set_route(task_name, broker=broker, flow=flow) + if subscribe: + warnings.warn( + "`route_task(..., subscribe=True)` is deprecated. " + "Use `router.subscribe(...)` to add inbound flow subscriptions.", + TaskiqDeprecationWarning, + stacklevel=2, + ) + if subscribe and route.flow is not None: + self.subscribe(route.broker, route.flow, task_name) return route def resolve_route( self, task: str | AsyncTaskiqDecoratedTask[Any, Any], - broker: AsyncBroker | str | None = None, + broker: AsyncBroker | None = None, flow: FlowProtocol | None = None, ) -> TaskiqRoute: """Resolve outbound route for a task invocation.""" - task_name = self._resolve_task_name(task) - if broker is not None: - target_broker = self._resolve_broker(broker) - route_flow = flow - if route_flow is None: - route_flow = self._broker_default_flow(target_broker) - return TaskiqRoute( - broker=target_broker, - flow=route_flow, - ) + return self._routes.resolve_route(task, broker=broker, flow=flow) - route = self.routes.get(task_name) - if route is not None: - if flow is None: - return route - return TaskiqRoute(broker=route.broker, flow=flow) - - target_broker = self._resolve_broker(None) - route_flow = flow - if route_flow is None: - route_flow = self._broker_default_flow(target_broker) - return TaskiqRoute( - broker=target_broker, - flow=route_flow, + def subscribe( + self, + broker: AsyncBroker, + flow: FlowProtocol, + *tasks: str | AsyncTaskiqDecoratedTask[Any, Any], + ) -> TaskiqSubscription: + """Register an inbound flow subscription for a broker.""" + task_names = tuple(resolve_task_name(task) for task in tasks) + return self._subscriptions.subscribe( + broker, + flow, + task_names, ) + def get_subscriptions( + self, + broker: AsyncBroker | None = None, + ) -> tuple[TaskiqSubscription, ...]: + """Return registered inbound subscriptions.""" + return self._subscriptions.get(broker) + + def get_broker_flows(self, broker: AsyncBroker) -> tuple[FlowProtocol, ...]: + """Return flows a broker should subscribe to.""" + return self._subscriptions.get_broker_flows(broker) + async def kiq( self, message: TaskiqMessage, *, - broker: AsyncBroker | str | None = None, + route: TaskiqRoute | None = None, + broker: AsyncBroker | None = None, flow: FlowProtocol | None = None, return_type: type[_ReturnType] | None = None, ) -> AsyncTaskiqTask[_ReturnType]: """Send message through the resolved broker and flow.""" - route = self.resolve_route(message.task_name, broker=broker, flow=flow) - target_broker = route.broker - - for middleware in target_broker.middlewares: - if middleware.__class__.pre_send != TaskiqMiddleware.pre_send: - message = await maybe_awaitable(middleware.pre_send(message)) - broker_message = target_broker.formatter.dumps(message) - await target_broker.kick_to_flow(broker_message, route.flow) - - for middleware in reversed(target_broker.middlewares): - if middleware.__class__.post_send != TaskiqMiddleware.post_send: - await maybe_awaitable(middleware.post_send(message)) - - return AsyncTaskiqTask( - task_id=message.task_id, - result_backend=cast( - AsyncResultBackend[_ReturnType], - target_broker.result_backend, - ), + return await self._dispatcher.kiq( + message, + route=route, + broker=broker, + flow=flow, return_type=return_type, ) @@ -237,51 +266,33 @@ async def requeue( self, message: TaskiqMessage, *, - broker: AsyncBroker | str | None = None, + route: TaskiqRoute | None = None, + broker: AsyncBroker | None = None, flow: FlowProtocol | None = None, ) -> None: """Send an existing message again through the resolved route.""" - route = self.resolve_route(message.task_name, broker=broker, flow=flow) - target_broker = route.broker - await target_broker.kick_to_flow( - target_broker.formatter.dumps(message), - route.flow, + await self._dispatcher.requeue( + message, + route=route, + broker=broker, + flow=flow, ) - def _resolve_broker(self, broker: AsyncBroker | str | None) -> AsyncBroker: - if isinstance(broker, str): - if broker not in self.brokers: - raise ValueError(f"Unknown broker {broker!r}.") - return self.brokers[broker] - - if broker is not None: - broker_name = getattr(broker, "broker_name", None) - if isinstance(broker_name, str): - registered_broker = self.brokers.get(broker_name) - if registered_broker is broker: - return broker - for registered_broker in self.brokers.values(): - if registered_broker is broker: - return registered_broker - raise ValueError("Broker is not registered in this router.") - - if self.default_broker_name is None: - raise ValueError("Router doesn't have registered brokers.") - return self.brokers[self.default_broker_name] - - def _resolve_broker_name(self, broker: AsyncBroker | str | None) -> str: - return self._resolve_broker(broker).broker_name - - def _resolve_task_name( + def _register_bound_task( self, - task: str | AsyncTaskiqDecoratedTask[Any, Any], - ) -> str: - if isinstance(task, str): - return task - task_name = getattr(task, "task_name", None) - if isinstance(task_name, str): - return task_name - raise TypeError("Route task must be a task name or decorated task.") + task: AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType], + broker: AsyncBroker | None = None, + flow: FlowProtocol | None = None, + ) -> AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]: + route_broker = broker + if route_broker is None: + route_broker = getattr(task, "broker", None) + + route = None + if route_broker is not None or flow is not None: + route = self._routes.build_route(broker=route_broker, flow=flow) - def _broker_default_flow(self, broker: AsyncBroker) -> FlowProtocol | None: - return getattr(broker, "default_flow", None) + self._tasks.register(task) + if route is not None: + self._routes.set_resolved_route(task.task_name, route) + return task diff --git a/taskiq/routing/__init__.py b/taskiq/routing/__init__.py new file mode 100644 index 00000000..09dfe0dc --- /dev/null +++ b/taskiq/routing/__init__.py @@ -0,0 +1,17 @@ +"""Internal routing components used by the public router facade.""" + +from taskiq.routing.dispatcher import RouterDispatcher +from taskiq.routing.models import TaskiqRoute, TaskiqSubscription +from taskiq.routing.registries import BrokerRegistry, TaskRegistry +from taskiq.routing.routes import RouteRegistry +from taskiq.routing.subscriptions import SubscriptionPlan + +__all__ = ( + "BrokerRegistry", + "RouteRegistry", + "RouterDispatcher", + "SubscriptionPlan", + "TaskRegistry", + "TaskiqRoute", + "TaskiqSubscription", +) diff --git a/taskiq/routing/dispatcher.py b/taskiq/routing/dispatcher.py new file mode 100644 index 00000000..d5b1ad43 --- /dev/null +++ b/taskiq/routing/dispatcher.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, TypeVar, cast + +from taskiq.abc.middleware import TaskiqMiddleware +from taskiq.abc.result_backend import AsyncResultBackend +from taskiq.flow import FlowProtocol +from taskiq.message import TaskiqMessage +from taskiq.routing.models import TaskiqRoute +from taskiq.routing.routes import RouteRegistry +from taskiq.task import AsyncTaskiqTask +from taskiq.utils import maybe_awaitable + +if TYPE_CHECKING: # pragma: no cover + from taskiq.abc.broker import AsyncBroker + +__all__ = ("RouterDispatcher",) + +_ReturnType = TypeVar("_ReturnType") + + +class RouterDispatcher: + """Send task messages through resolved router routes.""" + + def __init__(self, routes: RouteRegistry) -> None: + self.routes = routes + + async def kiq( + self, + message: TaskiqMessage, + *, + route: TaskiqRoute | None = None, + broker: AsyncBroker | None = None, + flow: FlowProtocol | None = None, + return_type: type[_ReturnType] | None = None, + ) -> AsyncTaskiqTask[_ReturnType]: + """Send message through the resolved broker and flow.""" + target_route = self.routes.resolve_explicit_or_default_route( + message.task_name, + route=route, + broker=broker, + flow=flow, + ) + target_broker = target_route.broker + + for middleware in target_broker.middlewares: + if middleware.__class__.pre_send != TaskiqMiddleware.pre_send: + message = await maybe_awaitable(middleware.pre_send(message)) + broker_message = target_broker.formatter.dumps(message) + await target_broker.kick_to_flow(broker_message, target_route.flow) + + for middleware in reversed(target_broker.middlewares): + if middleware.__class__.post_send != TaskiqMiddleware.post_send: + await maybe_awaitable(middleware.post_send(message)) + + return AsyncTaskiqTask( + task_id=message.task_id, + result_backend=cast( + AsyncResultBackend[_ReturnType], + target_broker.result_backend, + ), + return_type=return_type, + ) + + async def requeue( + self, + message: TaskiqMessage, + *, + route: TaskiqRoute | None = None, + broker: AsyncBroker | None = None, + flow: FlowProtocol | None = None, + ) -> None: + """Send an existing message again through the resolved route.""" + target_route = self.routes.resolve_explicit_or_default_route( + message.task_name, + route=route, + broker=broker, + flow=flow, + ) + target_broker = target_route.broker + await target_broker.kick_to_flow( + target_broker.formatter.dumps(message), + target_route.flow, + ) diff --git a/taskiq/routing/models.py b/taskiq/routing/models.py new file mode 100644 index 00000000..439ae7da --- /dev/null +++ b/taskiq/routing/models.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +from taskiq.flow import FlowProtocol + +if TYPE_CHECKING: # pragma: no cover + from taskiq.abc.broker import AsyncBroker + +__all__ = ("TaskiqRoute", "TaskiqSubscription") + + +@dataclass(frozen=True, slots=True) +class TaskiqRoute: + """Resolved outbound route for a task invocation.""" + + broker: AsyncBroker + flow: FlowProtocol | None = None + + @property + def broker_name(self) -> str: + """Return registered broker name for diagnostics.""" + return self.broker.broker_name + + +@dataclass(frozen=True, slots=True) +class TaskiqSubscription: + """Inbound flow subscription owned by a router.""" + + broker: AsyncBroker + flow: FlowProtocol + task_names: frozenset[str] = field(default_factory=frozenset) + + @property + def broker_name(self) -> str: + """Return registered broker name for diagnostics.""" + return self.broker.broker_name diff --git a/taskiq/routing/references.py b/taskiq/routing/references.py new file mode 100644 index 00000000..f43270bc --- /dev/null +++ b/taskiq/routing/references.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: # pragma: no cover + from taskiq.decor import AsyncTaskiqDecoratedTask + +__all__ = ("resolve_task_name",) + + +def resolve_task_name(task: str | AsyncTaskiqDecoratedTask[Any, Any]) -> str: + """Resolve a task reference to a task name.""" + if isinstance(task, str): + return task + task_name = getattr(task, "task_name", None) + if isinstance(task_name, str): + return task_name + raise TypeError("Route task must be a task name or decorated task.") diff --git a/taskiq/routing/registries.py b/taskiq/routing/registries.py new file mode 100644 index 00000000..9226cdaa --- /dev/null +++ b/taskiq/routing/registries.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from taskiq.flow import FlowProtocol + +if TYPE_CHECKING: # pragma: no cover + from taskiq.abc.broker import AsyncBroker + from taskiq.decor import AsyncTaskiqDecoratedTask + +__all__ = ("BrokerRegistry", "TaskRegistry") + + +class BrokerRegistry: + """Registry of brokers attached to one router.""" + + def __init__(self) -> None: + self.brokers: dict[str, AsyncBroker] = {} + self.default_broker: AsyncBroker | None = None + + @property + def default_broker_name(self) -> str | None: + """Return default broker name for compatibility and diagnostics.""" + if self.default_broker is None: + return None + return self.default_broker.broker_name + + def register(self, broker: AsyncBroker, name: str | None = None) -> str: + """Register broker as a transport in this registry.""" + broker_name = name or broker.__class__.__name__ + registered = self.brokers.get(broker_name) + if registered is not None and registered is not broker: + raise ValueError( + f"Broker name {broker_name!r} is already registered. " + "Please provide an explicit unique broker_name.", + ) + for registered_name, registered_broker in self.brokers.items(): + if registered_broker is broker and registered_name != broker_name: + raise ValueError( + f"Broker is already registered as {registered_name!r}.", + ) + + self.brokers[broker_name] = broker + if self.default_broker is None: + self.default_broker = broker + return broker_name + + def get(self, name: str) -> AsyncBroker: + """Return a broker by registered name.""" + try: + return self.brokers[name] + except KeyError as exc: + raise ValueError(f"Unknown broker {name!r}.") from exc + + def resolve(self, broker: AsyncBroker | None) -> AsyncBroker: + """Resolve an explicit broker or return the default broker.""" + if isinstance(broker, str): + raise TypeError( + "Broker string references are not accepted here. " + "Use router.get_broker(name) and pass the broker object.", + ) + + if broker is not None: + for registered_broker in self.brokers.values(): + if registered_broker is broker: + return registered_broker + raise ValueError("Broker is not registered in this router.") + + if self.default_broker is None: + raise ValueError("Router doesn't have registered brokers.") + return self.default_broker + + def default_flow(self, broker: AsyncBroker) -> FlowProtocol | None: + """Return broker-owned default flow.""" + return getattr(broker, "default_flow", None) + + +class TaskRegistry: + """Registry of tasks attached to one router.""" + + def __init__(self) -> None: + self.tasks: dict[str, AsyncTaskiqDecoratedTask[Any, Any]] = {} + + def register( + self, + task: AsyncTaskiqDecoratedTask[Any, Any], + ) -> None: + """Register a decorated task and reject name conflicts.""" + existing_task = self.tasks.get(task.task_name) + if existing_task is not None and existing_task is not task: + raise ValueError( + f"Task name {task.task_name!r} is already registered " + "in this router.", + ) + self.tasks[task.task_name] = task + + def find( + self, + task_name: str, + ) -> AsyncTaskiqDecoratedTask[Any, Any] | None: + """Find a task by name.""" + return self.tasks.get(task_name) + + def get_all(self) -> dict[str, AsyncTaskiqDecoratedTask[Any, Any]]: + """Return all tasks registered in this registry.""" + return dict(self.tasks) diff --git a/taskiq/routing/routes.py b/taskiq/routing/routes.py new file mode 100644 index 00000000..96a866f8 --- /dev/null +++ b/taskiq/routing/routes.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +from dataclasses import replace +from typing import TYPE_CHECKING, Any + +from taskiq.flow import FlowProtocol +from taskiq.routing.models import TaskiqRoute +from taskiq.routing.references import resolve_task_name +from taskiq.routing.registries import BrokerRegistry + +if TYPE_CHECKING: # pragma: no cover + from taskiq.abc.broker import AsyncBroker + from taskiq.decor import AsyncTaskiqDecoratedTask + +__all__ = ("RouteRegistry",) + + +class RouteRegistry: + """Outbound route policy for task invocations.""" + + def __init__(self, brokers: BrokerRegistry) -> None: + self.brokers = brokers + self.routes: dict[str, TaskiqRoute] = {} + + def build_route( + self, + broker: AsyncBroker | None = None, + flow: FlowProtocol | None = None, + ) -> TaskiqRoute: + """Build a route without mutating route state.""" + target_broker = self.brokers.resolve(broker) + route_flow = flow + if route_flow is None: + route_flow = self.brokers.default_flow(target_broker) + return TaskiqRoute(broker=target_broker, flow=route_flow) + + def set_route( + self, + task: str | AsyncTaskiqDecoratedTask[Any, Any], + broker: AsyncBroker | None = None, + flow: FlowProtocol | None = None, + ) -> TaskiqRoute: + """Set default outbound route for a task.""" + task_name = resolve_task_name(task) + route = self.build_route(broker=broker, flow=flow) + self.routes[task_name] = route + return route + + def set_resolved_route(self, task_name: str, route: TaskiqRoute) -> TaskiqRoute: + """Store an already resolved route for a task.""" + self.routes[task_name] = route + return route + + def resolve_route( + self, + task: str | AsyncTaskiqDecoratedTask[Any, Any], + broker: AsyncBroker | None = None, + flow: FlowProtocol | None = None, + ) -> TaskiqRoute: + """Resolve outbound route for a task invocation.""" + task_name = resolve_task_name(task) + if broker is not None: + target_broker = self.brokers.resolve(broker) + route_flow = self._resolve_flow_for_broker_override( + task_name, + target_broker, + flow, + ) + return TaskiqRoute(broker=target_broker, flow=route_flow) + + route = self.routes.get(task_name) + if route is not None: + if flow is None: + return route + return replace(route, flow=flow) + + return self.build_route(flow=flow) + + def resolve_explicit_or_default_route( + self, + task: str | AsyncTaskiqDecoratedTask[Any, Any], + *, + route: TaskiqRoute | None, + broker: AsyncBroker | None, + flow: FlowProtocol | None, + ) -> TaskiqRoute: + """Resolve a route from explicit route or broker/flow overrides.""" + if route is not None: + if broker is not None or flow is not None: + raise ValueError("Pass either route or broker/flow overrides.") + self.brokers.resolve(route.broker) + return route + return self.resolve_route(task, broker=broker, flow=flow) + + def _resolve_flow_for_broker_override( + self, + task_name: str, + broker: AsyncBroker, + flow: FlowProtocol | None, + ) -> FlowProtocol | None: + if flow is not None: + return flow + registered_route = self.routes.get(task_name) + if registered_route is not None and registered_route.broker is broker: + return registered_route.flow + return self.brokers.default_flow(broker) diff --git a/taskiq/routing/subscriptions.py b/taskiq/routing/subscriptions.py new file mode 100644 index 00000000..2c6c6b23 --- /dev/null +++ b/taskiq/routing/subscriptions.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +from collections.abc import Iterable +from dataclasses import replace +from typing import TYPE_CHECKING + +from taskiq.flow import FlowProtocol, get_flow_identity +from taskiq.routing.models import TaskiqSubscription +from taskiq.routing.registries import BrokerRegistry + +if TYPE_CHECKING: # pragma: no cover + from taskiq.abc.broker import AsyncBroker + +__all__ = ("SubscriptionPlan",) + + +class SubscriptionPlan: + """Inbound listen-plan subscriptions for flow-aware brokers.""" + + def __init__(self, brokers: BrokerRegistry) -> None: + self.brokers = brokers + self.subscriptions: list[TaskiqSubscription] = [] + + def subscribe( + self, + broker: AsyncBroker, + flow: FlowProtocol, + task_names: Iterable[str], + ) -> TaskiqSubscription: + """Register an inbound flow subscription for a broker.""" + target_broker = self.brokers.resolve(broker) + resolved_task_names = frozenset(task_names) + default_flow = self.brokers.default_flow(target_broker) + if default_flow is not None: + self._ensure_compatible_flow( + broker=target_broker, + registered_flow=default_flow, + new_flow=flow, + ) + + for subscription_index, subscription in enumerate(self.subscriptions): + if subscription.broker is target_broker and self._is_same_flow( + subscription.flow, + flow, + ): + self._ensure_compatible_flow( + broker=target_broker, + registered_flow=subscription.flow, + new_flow=flow, + ) + updated = replace( + subscription, + task_names=subscription.task_names | resolved_task_names, + ) + self.subscriptions[subscription_index] = updated + return updated + + subscription = TaskiqSubscription( + broker=target_broker, + flow=flow, + task_names=resolved_task_names, + ) + self.subscriptions.append(subscription) + return subscription + + def get( + self, + broker: AsyncBroker | None = None, + ) -> tuple[TaskiqSubscription, ...]: + """Return registered inbound subscriptions.""" + if broker is None: + return tuple(self.subscriptions) + target_broker = self.brokers.resolve(broker) + return tuple( + subscription + for subscription in self.subscriptions + if subscription.broker is target_broker + ) + + def get_broker_flows(self, broker: AsyncBroker) -> tuple[FlowProtocol, ...]: + """Return flows a broker should subscribe to.""" + target_broker = self.brokers.resolve(broker) + flows = [subscription.flow for subscription in self.get(target_broker)] + default_flow = self.brokers.default_flow(target_broker) + if default_flow is not None and not self._contains_flow_identity( + flows, + default_flow, + ): + flows.insert(0, default_flow) + elif default_flow is not None: + for flow in flows: + if self._is_same_flow(flow, default_flow): + self._ensure_compatible_flow( + broker=target_broker, + registered_flow=flow, + new_flow=default_flow, + ) + break + return tuple(flows) + + def _contains_flow_identity( + self, + flows: Iterable[FlowProtocol], + flow: FlowProtocol, + ) -> bool: + return any( + self._is_same_flow(registered_flow, flow) for registered_flow in flows + ) + + def _is_same_flow( + self, + first_flow: FlowProtocol, + second_flow: FlowProtocol, + ) -> bool: + return get_flow_identity(first_flow) == get_flow_identity(second_flow) + + def _ensure_compatible_flow( + self, + *, + broker: AsyncBroker, + registered_flow: FlowProtocol, + new_flow: FlowProtocol, + ) -> None: + if not self._is_same_flow(registered_flow, new_flow): + return + + registered_options = dict(registered_flow.broker_options()) + new_options = dict(new_flow.broker_options()) + if registered_options == new_options: + return + + flow_identity = get_flow_identity(new_flow) + raise ValueError( + f"Flow {flow_identity.name!r} is already registered for broker " + f"{broker.broker_name!r} with different broker options.", + ) diff --git a/taskiq/scheduler/scheduler.py b/taskiq/scheduler/scheduler.py index 16677554..b0a1e00d 100644 --- a/taskiq/scheduler/scheduler.py +++ b/taskiq/scheduler/scheduler.py @@ -1,8 +1,9 @@ from logging import getLogger -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from taskiq.exceptions import ScheduledTaskCancelledError from taskiq.kicker import AsyncKicker +from taskiq.router import TaskiqRouter from taskiq.scheduler.scheduled_task import ScheduledTask from taskiq.utils import maybe_awaitable @@ -47,19 +48,34 @@ async def on_ready(self, source: "ScheduleSource", task: ScheduledTask) -> None: except ScheduledTaskCancelledError: logger.info("Scheduled task %s has been cancelled.", task.task_name) else: - await ( - AsyncKicker(task.task_name, self.broker, task.labels) - .with_labels( - schedule_id=task.schedule_id, - ) + kicker: AsyncKicker[Any, Any] = ( + AsyncKicker(task.task_name, self.broker, dict(task.labels)) + .with_labels(schedule_id=task.schedule_id) .with_task_id(task_id=task.task_id) - .kiq( - *task.args, - **task.kwargs, - ) ) + self._apply_scheduled_route(kicker, task.task_name) + await kicker.kiq(*task.args, **task.kwargs) await maybe_awaitable(source.post_send(task)) async def shutdown(self) -> None: """Shutdown the scheduler process.""" await self.broker.shutdown() + + def _apply_scheduled_route( + self, + kicker: AsyncKicker[Any, Any], + task_name: str, + ) -> None: + """ + Apply scheduler dispatch routing without changing schedule payloads. + + Registered router routes are resolved at send time. If there is no + router route for the scheduled task, keep the old scheduler behavior and + send through the scheduler broker instead of the router default broker. + """ + router = getattr(self.broker, "router", None) + if isinstance(router, TaskiqRouter) and task_name in router.routes: + kicker.with_route(router.resolve_route(task_name)) + return + + kicker.with_broker(self.broker) diff --git a/taskiq/task_builder.py b/taskiq/task_builder.py index d2d3b0d1..eeec6ee2 100644 --- a/taskiq/task_builder.py +++ b/taskiq/task_builder.py @@ -1,13 +1,20 @@ +from __future__ import annotations + import inspect -from collections.abc import Callable +from collections.abc import Callable, Coroutine, Mapping from dataclasses import dataclass, field -from typing import Any, Generic, ParamSpec, TypeVar, overload +from types import CoroutineType, MappingProxyType +from typing import TYPE_CHECKING, Any, Generic, ParamSpec, TypeVar, overload + +from taskiq.message import TaskiqMessage, _build_taskiq_message -from taskiq.message import TaskiqMessage +if TYPE_CHECKING: # pragma: no cover + from taskiq.decor import AsyncTaskiqDecoratedTask __all__ = ("TaskDefinition", "task_builder") _FuncParams = ParamSpec("_FuncParams") +_T = TypeVar("_T") _ReturnType = TypeVar("_ReturnType") @@ -17,8 +24,13 @@ class TaskDefinition(Generic[_FuncParams, _ReturnType]): task_name: str original_func: Callable[_FuncParams, _ReturnType] - labels: dict[str, Any] = field(default_factory=dict) + labels: Mapping[str, Any] = field(default_factory=dict) return_type: type[_ReturnType] | None = None + base_cls: type[AsyncTaskiqDecoratedTask[Any, Any]] | None = None + + def __post_init__(self) -> None: + """Freeze labels to make task definitions stable declarations.""" + object.__setattr__(self, "labels", MappingProxyType(dict(self.labels))) def __call__( self, @@ -28,11 +40,32 @@ def __call__( """Call original function directly.""" return self.original_func(*args, **kwargs) + @overload + async def call( + self: TaskDefinition[_FuncParams, CoroutineType[Any, Any, _T]], + *args: _FuncParams.args, + **kwargs: _FuncParams.kwargs, + ) -> _T: ... + + @overload + async def call( + self: TaskDefinition[_FuncParams, Coroutine[Any, Any, _T]], + *args: _FuncParams.args, + **kwargs: _FuncParams.kwargs, + ) -> _T: ... + + @overload async def call( self, *args: _FuncParams.args, **kwargs: _FuncParams.kwargs, - ) -> _ReturnType: + ) -> _ReturnType: ... + + async def call( + self, + *args: _FuncParams.args, + **kwargs: _FuncParams.kwargs, + ) -> Any: """Execute original function in the current process.""" result = self.original_func(*args, **kwargs) if inspect.isawaitable(result): @@ -46,18 +79,20 @@ def message( **kwargs: _FuncParams.kwargs, ) -> TaskiqMessage: """Build a TaskiqMessage without binding this definition to a router.""" - return TaskiqMessage( + return _build_taskiq_message( task_id=task_id, task_name=self.task_name, - labels=dict(self.labels), - args=list(args), - kwargs=dict(kwargs), + labels=self.labels, + args=args, + kwargs=kwargs, ) @overload def task_builder( task_name: Callable[_FuncParams, _ReturnType], + *, + base_cls: type[AsyncTaskiqDecoratedTask[Any, Any]] | None = None, **labels: Any, ) -> TaskDefinition[_FuncParams, _ReturnType]: ... @@ -65,6 +100,8 @@ def task_builder( @overload def task_builder( task_name: str | None = None, + *, + base_cls: type[AsyncTaskiqDecoratedTask[Any, Any]] | None = None, **labels: Any, ) -> Callable[ [Callable[_FuncParams, _ReturnType]], @@ -74,6 +111,8 @@ def task_builder( def task_builder( task_name: str | Callable[_FuncParams, _ReturnType] | None = None, + *, + base_cls: type[AsyncTaskiqDecoratedTask[Any, Any]] | None = None, **labels: Any, ) -> Any: """Build an unbound task definition. @@ -97,6 +136,7 @@ def build( original_func=func, labels=dict(labels), return_type=return_type, + base_cls=base_cls, ) if callable(task_name): diff --git a/tests/routing/__init__.py b/tests/routing/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/routing/models.py b/tests/routing/models.py new file mode 100644 index 00000000..e6a4c4f5 --- /dev/null +++ b/tests/routing/models.py @@ -0,0 +1,117 @@ +from collections.abc import AsyncGenerator, Mapping +from dataclasses import dataclass +from typing import Any + +from taskiq import AsyncTaskiqDecoratedTask, FlowProtocol, TaskiqRouter +from taskiq.abc.broker import AsyncBroker +from taskiq.abc.middleware import TaskiqMiddleware +from taskiq.abc.result_backend import AsyncResultBackend +from taskiq.message import BrokerMessage, TaskiqMessage +from taskiq.result import TaskiqResult + + +@dataclass(frozen=True, slots=True) +class BrokerQueue: + """Broker-specific flow used to prove protocol-based routing.""" + + name: str + durable: bool = True + + def broker_options(self) -> Mapping[str, object]: + """Return options for the target broker.""" + return { + "durable": self.durable, + } + + +class TracingTask(AsyncTaskiqDecoratedTask[Any, Any]): + """Custom task class used to prove task_builder base_cls binding.""" + + def tracing_name(self) -> str: + """Return a trace-friendly task name.""" + return self.task_name + + +class CountingRouter(TaskiqRouter): + """Router that counts task registration calls.""" + + def __init__(self) -> None: + self.register_task_calls = 0 + super().__init__() + + def register_task( + self, + task: Any, + broker: AsyncBroker | None = None, + flow: FlowProtocol | None = None, + ) -> Any: + """Count registration calls and delegate to the router.""" + self.register_task_calls += 1 + return super().register_task(task, broker=broker, flow=flow) + + +class OldStyleRecordingBroker(AsyncBroker): + """Broker that only implements the old required kick/listen API.""" + + def __init__( + self, + *, + router: TaskiqRouter | None = None, + broker_name: str | None = None, + ) -> None: + self.sent: list[BrokerMessage] = [] + super().__init__(router=router, broker_name=broker_name) + + async def kick(self, message: BrokerMessage) -> None: + """Record old-style send.""" + self.sent.append(message) + + async def listen(self) -> AsyncGenerator[bytes, None]: + """Old-style recording broker does not listen in tests.""" + if False: + yield b"" + + +class RecordingResultBackend(AsyncResultBackend[Any]): + """Result backend marker that records stored results.""" + + def __init__(self) -> None: + self.results: dict[str, TaskiqResult[Any]] = {} + + async def set_result(self, task_id: str, result: TaskiqResult[Any]) -> None: + """Store result by task id.""" + self.results[task_id] = result + + async def is_result_ready(self, task_id: str) -> bool: + """Return whether result exists.""" + return task_id in self.results + + async def get_result( + self, + task_id: str, + with_logs: bool = False, + ) -> TaskiqResult[Any]: + """Return stored result.""" + return self.results[task_id] + + +class RecordingMiddleware(TaskiqMiddleware): + """Middleware that records client-side send hooks.""" + + def __init__( + self, + name: str, + events: list[tuple[str, str, str]], + ) -> None: + super().__init__() + self.name = name + self.events = events + + def pre_send(self, message: TaskiqMessage) -> TaskiqMessage: + """Record pre-send call.""" + self.events.append((self.name, "pre_send", message.task_name)) + return message + + def post_send(self, message: TaskiqMessage) -> None: + """Record post-send call.""" + self.events.append((self.name, "post_send", message.task_name)) diff --git a/tests/routing/test_dispatch.py b/tests/routing/test_dispatch.py new file mode 100644 index 00000000..0317f4e3 --- /dev/null +++ b/tests/routing/test_dispatch.py @@ -0,0 +1,324 @@ +from taskiq import Flow, FlowProtocol, InMemoryBroker, TaskiqRouter +from tests.routing.models import ( + BrokerQueue, + OldStyleRecordingBroker, + RecordingMiddleware, + RecordingResultBackend, +) +from tests.utils import RecordingBroker + + +async def test_old_broker_task_api_registers_task_in_router() -> None: + broker = RecordingBroker() + + @broker.task(task_name="demo.task") + async def demo_task() -> None: + return None + + assert broker.find_task("demo.task") is demo_task + assert broker.router.find_task("demo.task") is demo_task + + await demo_task.kiq() + + assert broker.sent[0][0].task_name == "demo.task" + assert broker.sent[0][1] is None + + +async def test_broker_task_api_uses_default_flow() -> None: + flow = Flow("default") + broker = RecordingBroker(default_flow=flow) + + @broker.task(task_name="demo.task") + async def demo_task() -> None: + return None + + await demo_task.kiq() + + assert broker.sent[0][0].task_name == "demo.task" + assert broker.sent[0][1] == flow + assert broker.get_subscribed_flows() == (flow,) + + +def test_router_set_broker_can_configure_default_flow() -> None: + router = TaskiqRouter() + broker = RecordingBroker(router=router, broker_name="broker") + flow = Flow("manual.default") + + router.set_broker(broker, name="broker", default_flow=flow) + route = router.resolve_route("demo.task", broker=broker) + + assert broker.default_flow == flow + assert route.flow == flow + assert broker.get_subscribed_flows() == (flow,) + + +async def test_router_can_route_task_to_another_broker_flow() -> None: + router = TaskiqRouter() + source = RecordingBroker(router=router, broker_name="source") + target = RecordingBroker(router=router, broker_name="target") + flow = Flow("events") + + @source.task(task_name="demo.task") + async def demo_task() -> None: + return None + + route = router.route_task(demo_task, broker=target, flow=flow) + resolved_route = router.resolve_route(demo_task) + + await demo_task.kiq() + + assert route.broker is target + assert route.broker_name == "target" + assert resolved_route.broker is target + assert resolved_route.flow == flow + assert source.sent == [] + assert target.sent[0][0].task_name == "demo.task" + assert target.sent[0][1] == flow + + +async def test_old_style_broker_uses_kick_to_flow_fallback() -> None: + router = TaskiqRouter() + source = RecordingBroker(router=router, broker_name="source") + target = OldStyleRecordingBroker(router=router, broker_name="target") + flow = Flow("legacy") + + @source.task(task_name="demo.task") + async def demo_task() -> None: + return None + + router.route_task(demo_task, broker=target, flow=flow) + + await demo_task.kiq() + + assert source.sent == [] + assert target.sent[0].task_name == "demo.task" + + +async def test_kicker_route_override_wins_over_registered_route() -> None: + router = TaskiqRouter() + first = RecordingBroker(router=router, broker_name="first") + second = RecordingBroker(router=router, broker_name="second") + first_flow = Flow("first") + second_flow = Flow("second") + + @first.task(task_name="demo.task") + async def demo_task() -> None: + return None + + router.route_task(demo_task, broker=first, flow=first_flow) + + route = router.resolve_route(demo_task, broker=second, flow=second_flow) + + await demo_task.kicker().with_route(route).kiq() + + assert first.sent == [] + assert second.sent[0][1] == second_flow + + +async def test_kicker_broker_override_wins_over_registered_route() -> None: + router = TaskiqRouter() + first = RecordingBroker(router=router, broker_name="first") + second_flow = Flow("second") + second = RecordingBroker( + router=router, + broker_name="second", + default_flow=second_flow, + ) + first_flow = Flow("first") + + @first.task(task_name="demo.task") + async def demo_task() -> None: + return None + + first_route = router.route_task(demo_task, broker=first, flow=first_flow) + + await demo_task.kicker().with_route(first_route).with_broker(second).kiq() + + assert first.sent == [] + assert second.sent[0][1] == second_flow + + +async def test_router_uses_explicit_broker_lookup_for_config_names() -> None: + router = TaskiqRouter() + source = RecordingBroker(router=router, broker_name="source") + target = RecordingBroker(router=router, broker_name="target") + flow = Flow("compat") + + @source.task(task_name="demo.task") + async def demo_task() -> None: + return None + + broker = router.get_broker("target") + route = router.route_task("demo.task", broker=broker, flow=flow) + + await demo_task.kicker().with_route(route).kiq() + + assert target.sent[0][0].task_name == "demo.task" + assert target.sent[0][1] == flow + + +async def test_router_accepts_broker_specific_flow_protocol() -> None: + broker = RecordingBroker(broker_name="transport") + flow = BrokerQueue(name="critical", durable=False) + + @broker.task(task_name="demo.task") + async def demo_task() -> None: + return None + + await demo_task.kicker().with_flow(flow).kiq() + + assert isinstance(flow, FlowProtocol) + assert broker.sent[0][1] is flow + assert flow.broker_options() == { + "durable": False, + } + + +async def test_kicker_preserves_labels_task_id_and_result_backend() -> None: + broker = RecordingBroker() + backend = RecordingResultBackend() + broker.with_id_generator(lambda: "generated-id") + broker.with_result_backend(backend) + + @broker.task(task_name="demo.task", task_label="declared") + async def demo_task() -> None: + return None + + sent_task = await demo_task.kicker().with_labels(call_label=42).kiq() + broker_message = broker.sent[0][0] + + assert sent_task.task_id == "generated-id" + assert sent_task.result_backend is backend + assert broker_message.task_id == "generated-id" + assert broker_message.labels["task_label"] == "declared" + assert broker_message.labels["call_label"] == "42" + + +async def test_kicker_can_prepare_invocation_for_later() -> None: + broker = RecordingBroker() + + @broker.task(task_name="demo.task") + async def demo_task(value: int) -> None: + return None + + prepared = demo_task.kicker().with_labels(trace_id="abc").prepare(1) + + assert prepared.message.task_name == "demo.task" + assert prepared.message.args == [1] + assert prepared.message.labels["trace_id"] == "abc" + + await prepared.kiq() + + assert broker.sent[0][0].task_id == prepared.message.task_id + + +async def test_prepared_invocation_keeps_route_snapshot() -> None: + router = TaskiqRouter() + first = RecordingBroker(router=router, broker_name="first") + second = RecordingBroker(router=router, broker_name="second") + first_flow = Flow("first") + second_flow = Flow("second") + + @first.task(task_name="demo.task") + async def demo_task(value: int) -> None: + return None + + first_route = router.route_task(demo_task, broker=first, flow=first_flow) + second_route = router.resolve_route(demo_task, broker=second, flow=second_flow) + kicker = demo_task.kicker().with_route(first_route) + + prepared = kicker.prepare(1) + kicker.with_route(second_route) + + await prepared.kiq() + + assert first.sent[0][1] == first_flow + assert second.sent == [] + + +async def test_prepared_invocation_resolves_default_route_snapshot() -> None: + router = TaskiqRouter() + source = RecordingBroker(router=router, broker_name="source") + first = RecordingBroker(router=router, broker_name="first") + second = RecordingBroker(router=router, broker_name="second") + first_flow = Flow("first") + second_flow = Flow("second") + + @source.task(task_name="demo.task") + async def demo_task(value: int) -> None: + return None + + router.route_task(demo_task, broker=first, flow=first_flow) + prepared = demo_task.kicker().prepare(1) + + router.route_task(demo_task, broker=second, flow=second_flow) + + await prepared.kiq() + + assert source.sent == [] + assert first.sent[0][1] == first_flow + assert second.sent == [] + + +async def test_router_task_decorator_can_choose_broker_and_flow() -> None: + router = TaskiqRouter() + target = RecordingBroker(router=router, broker_name="target") + flow = Flow("target-flow") + + @router.task("demo.task", broker=target, flow=flow) + async def demo_task() -> None: + return None + + await demo_task.kiq() + + assert target.sent[0][0].task_name == "demo.task" + assert target.sent[0][1] == flow + + +async def test_routed_dispatch_uses_target_middleware_and_result_backend() -> None: + router = TaskiqRouter() + source = RecordingBroker(router=router, broker_name="source") + target = RecordingBroker(router=router, broker_name="target") + source_events: list[tuple[str, str, str]] = [] + target_events: list[tuple[str, str, str]] = [] + target_backend = RecordingResultBackend() + + source.add_middlewares(RecordingMiddleware("source", source_events)) + target.add_middlewares(RecordingMiddleware("target", target_events)) + target.with_result_backend(target_backend) + + @source.task(task_name="demo.task") + async def demo_task() -> None: + return None + + router.route_task(demo_task, broker=target, flow=Flow("target")) + + sent_task = await demo_task.kiq() + + assert sent_task.result_backend is target_backend + assert source_events == [] + assert target_events == [ + ("target", "pre_send", "demo.task"), + ("target", "post_send", "demo.task"), + ] + + +async def test_worker_lookup_uses_task_name_not_flow() -> None: + broker = InMemoryBroker(await_inplace=True) + calls: list[str] = [] + + @broker.task(task_name="demo.first") + async def first_task() -> str: + calls.append("first") + return "first" + + @broker.task(task_name="demo.second") + async def second_task() -> str: + calls.append("second") + return "second" + + task = await first_task.kicker().with_flow(Flow(second_task.task_name)).kiq() + result = await task.wait_result(timeout=2) + + assert result.return_value == "first" + assert calls == ["first"] diff --git a/tests/routing/test_router_core.py b/tests/routing/test_router_core.py new file mode 100644 index 00000000..8f9f359b --- /dev/null +++ b/tests/routing/test_router_core.py @@ -0,0 +1,74 @@ +import pytest + +from taskiq import TaskiqRoute, TaskiqRouter, TaskiqSubscription, task_builder +from tests.routing.models import CountingRouter +from tests.utils import RecordingBroker + + +def test_broker_creates_default_router() -> None: + broker = RecordingBroker() + + assert broker.router.brokers[broker.broker_name] is broker + assert broker.router.default_broker_name == broker.broker_name + + +def test_public_router_models_keep_router_module() -> None: + assert TaskiqRoute.__module__ == "taskiq.router" + assert TaskiqSubscription.__module__ == "taskiq.router" + + +def test_router_rejects_duplicate_broker_names() -> None: + router = TaskiqRouter() + RecordingBroker(router=router, broker_name="broker") + + with pytest.raises(ValueError, match="already registered"): + RecordingBroker(router=router, broker_name="broker") + + +def test_router_rejects_broker_attached_to_another_router() -> None: + first_router = TaskiqRouter() + second_router = TaskiqRouter() + broker = RecordingBroker(router=first_router, broker_name="broker") + + with pytest.raises(ValueError, match="attached to another router"): + second_router.set_broker(broker, name="broker") + + +def test_router_rejects_string_broker_references() -> None: + router = TaskiqRouter() + RecordingBroker(router=router, broker_name="broker") + + with pytest.raises(TypeError, match="Broker string references"): + router.route_task("demo.task", broker="broker") # type: ignore[arg-type] + + +def test_router_rejects_duplicate_task_names() -> None: + router = TaskiqRouter() + first = RecordingBroker(router=router, broker_name="first") + second = RecordingBroker(router=router, broker_name="second") + + @first.task(task_name="demo.task") + async def first_task() -> None: + return None + + with pytest.raises(ValueError, match="already registered"): + + @second.task(task_name="demo.task") + async def second_task() -> None: + return None + + +def test_task_definition_binding_does_not_reenter_router_registration() -> None: + router = CountingRouter() + broker = RecordingBroker(router=router) + + @task_builder("shared.once") + def shared_task() -> None: + return None + + registered = broker.register_task(shared_task) + + assert registered.task_name == "shared.once" + assert router.register_task_calls == 1 + assert broker.router.find_task("shared.once") is registered + assert broker.find_task("shared.once") is registered diff --git a/tests/routing/test_shared_tasks.py b/tests/routing/test_shared_tasks.py new file mode 100644 index 00000000..4efab51d --- /dev/null +++ b/tests/routing/test_shared_tasks.py @@ -0,0 +1,94 @@ +import pytest + +from taskiq import AsyncTaskiqDecoratedTask, Flow, task_builder +from tests.routing.models import TracingTask +from tests.utils import RecordingBroker + + +async def test_task_builder_can_be_registered_later() -> None: + broker = RecordingBroker() + + @task_builder("shared.add", queue="shared") + def add(left: int, right: int) -> int: + return left + right + + assert await add.call(1, 2) == 3 + + registered = broker.register_task(add) + + assert type(registered) is AsyncTaskiqDecoratedTask + assert registered.task_name == "shared.add" + assert registered.labels == {"queue": "shared"} + assert broker.router.find_task("shared.add") is registered + + await registered.kiq(1, 2) + + assert broker.sent[0][0].task_name == "shared.add" + + +async def test_router_register_task_definition_binds_to_selected_broker() -> None: + source = RecordingBroker(broker_name="source") + target = RecordingBroker(router=source.router, broker_name="target") + flow = Flow("target") + + @task_builder("shared.routed") + async def shared_task() -> None: + return None + + registered = source.router.register_task(shared_task, broker=target, flow=flow) + + await registered.kiq() + + assert source.sent == [] + assert target.find_task("shared.routed") is registered + assert target.sent[0][0].task_name == "shared.routed" + assert target.sent[0][1] == flow + + +async def test_task_builder_can_use_custom_base_cls() -> None: + broker = RecordingBroker() + + @task_builder("shared.traced", base_cls=TracingTask, queue="shared") + async def traced(value: int) -> int: + return value + 1 + + registered = broker.register_task(traced) + + assert isinstance(registered, TracingTask) + assert registered.tracing_name() == "shared.traced" + assert registered.labels == {"queue": "shared"} + assert await traced.call(1) == 2 + + await registered.kiq(1) + + assert broker.sent[0][0].task_name == "shared.traced" + + +def test_task_definition_default_flow_does_not_create_subscription() -> None: + flow = Flow("shared.default") + broker = RecordingBroker(default_flow=flow) + + @task_builder("shared.default") + def shared_task() -> None: + return None + + registered = broker.register_task(shared_task) + route = broker.router.resolve_route(registered) + + assert route.flow == flow + assert broker.router.get_subscriptions(broker) == () + assert broker.get_subscribed_flows() == (flow,) + + +def test_register_task_definition_rejects_overrides() -> None: + broker = RecordingBroker() + + @task_builder("shared.add", queue="shared") + def add(left: int, right: int) -> int: + return left + right + + with pytest.raises(ValueError, match="TaskDefinition already defines"): + broker.register_task(add, task_name="other.name") + + with pytest.raises(ValueError, match="TaskDefinition already defines"): + broker.register_task(add, queue="other") diff --git a/tests/routing/test_subscriptions.py b/tests/routing/test_subscriptions.py new file mode 100644 index 00000000..cc82603b --- /dev/null +++ b/tests/routing/test_subscriptions.py @@ -0,0 +1,186 @@ +import pytest + +from taskiq import ( + Flow, + FlowIdentity, + FlowProtocol, + TaskiqRouter, + get_flow_identity, +) +from taskiq.warnings import TaskiqDeprecationWarning +from tests.routing.models import BrokerQueue +from tests.utils import RecordingBroker + + +def test_flow_identity_uses_logical_name() -> None: + generic_flow = Flow("events").with_options(durable=True) + broker_flow = BrokerQueue(name="events", durable=True) + + assert generic_flow.identity == FlowIdentity("events") + assert get_flow_identity(broker_flow) == FlowIdentity("events") + assert generic_flow.broker_options() == {"durable": True} + + +def test_route_task_does_not_register_subscription_by_default() -> None: + router = TaskiqRouter() + broker = RecordingBroker(router=router, broker_name="broker") + flow = Flow("events") + + route = router.route_task("demo.task", broker=broker, flow=flow) + + assert route.flow == flow + assert router.get_subscriptions(broker) == () + assert broker.get_subscribed_flows() == () + + +def test_route_task_subscribe_true_is_deprecated_shim() -> None: + router = TaskiqRouter() + broker = RecordingBroker(router=router, broker_name="broker") + flow = Flow("events") + + with pytest.warns(TaskiqDeprecationWarning, match="router.subscribe"): + route = router.route_task( + "demo.task", + broker=broker, + flow=flow, + subscribe=True, + ) + + subscription = router.get_subscriptions(broker)[0] + + assert route.flow == flow + assert subscription.broker is broker + assert subscription.flow == flow + assert subscription.task_names == frozenset({"demo.task"}) + assert broker.get_subscribed_flows() == (flow,) + + +def test_route_update_does_not_mutate_existing_subscriptions() -> None: + router = TaskiqRouter() + broker = RecordingBroker(router=router, broker_name="broker") + first_flow = Flow("events.first") + second_flow = Flow("events.second") + + router.route_task("demo.task", broker=broker, flow=first_flow) + router.subscribe(broker, first_flow, "demo.task") + router.route_task("demo.task", broker=broker, flow=second_flow) + + subscription = router.get_subscriptions(broker)[0] + route = router.resolve_route("demo.task") + + assert route.flow == second_flow + assert subscription.flow == first_flow + assert subscription.task_names == frozenset({"demo.task"}) + assert broker.get_subscribed_flows() == (first_flow,) + + +def test_subscribe_dedupes_flow_and_merges_task_names() -> None: + router = TaskiqRouter() + broker = RecordingBroker(router=router, broker_name="broker") + flow = Flow("events") + + router.subscribe(broker, flow, "first.task") + subscription = router.subscribe(broker, flow, "second.task") + + assert router.get_subscriptions(broker) == (subscription,) + assert subscription.task_names == frozenset({"first.task", "second.task"}) + assert broker.get_subscribed_flows() == (flow,) + + +def test_subscribe_merges_same_identity_with_compatible_options() -> None: + router = TaskiqRouter() + broker = RecordingBroker(router=router, broker_name="broker") + generic_flow = Flow("events").with_options(durable=True) + broker_flow = BrokerQueue(name="events", durable=True) + + router.subscribe(broker, generic_flow, "first.task") + subscription = router.subscribe(broker, broker_flow, "second.task") + + assert router.get_subscriptions(broker) == (subscription,) + assert subscription.flow is generic_flow + assert subscription.task_names == frozenset({"first.task", "second.task"}) + assert broker.get_subscribed_flows() == (generic_flow,) + + +@pytest.mark.parametrize( + "conflicting_flow", + [ + pytest.param( + Flow("events").with_options(durable=False), + id="generic-flow", + ), + pytest.param( + BrokerQueue(name="events", durable=False), + id="broker-specific-flow", + ), + ], +) +def test_subscribe_rejects_same_identity_with_different_options( + conflicting_flow: FlowProtocol, +) -> None: + router = TaskiqRouter() + broker = RecordingBroker(router=router, broker_name="broker") + + router.subscribe(broker, Flow("events").with_options(durable=True), "first.task") + + with pytest.raises(ValueError, match="different broker options"): + router.subscribe( + broker, + conflicting_flow, + "second.task", + ) + + +def test_subscribe_rejects_default_flow_option_conflict() -> None: + router = TaskiqRouter() + default_flow = Flow("events").with_options(durable=True) + broker = RecordingBroker( + router=router, + broker_name="broker", + default_flow=default_flow, + ) + + with pytest.raises(ValueError, match="different broker options"): + router.subscribe( + broker, + BrokerQueue(name="events", durable=False), + "demo.task", + ) + + +def test_subscribe_allows_same_identity_on_different_brokers() -> None: + router = TaskiqRouter() + first_broker = RecordingBroker(router=router, broker_name="first") + second_broker = RecordingBroker(router=router, broker_name="second") + + first_subscription = router.subscribe( + first_broker, + Flow("events").with_options(durable=True), + "first.task", + ) + second_subscription = router.subscribe( + second_broker, + BrokerQueue(name="events", durable=False), + "second.task", + ) + + assert router.get_subscriptions(first_broker) == (first_subscription,) + assert router.get_subscriptions(second_broker) == (second_subscription,) + + +def test_subscribed_flows_dedupes_default_flow_by_identity() -> None: + router = TaskiqRouter() + default_flow = Flow("events").with_options(durable=True) + broker = RecordingBroker( + router=router, + broker_name="broker", + default_flow=default_flow, + ) + + subscription = router.subscribe( + broker, + BrokerQueue(name="events", durable=True), + "demo.task", + ) + + assert broker.get_subscribed_flows() == (subscription.flow,) diff --git a/tests/scheduler/test_scheduler.py b/tests/scheduler/test_scheduler.py index 05e6358a..eeb9d35c 100644 --- a/tests/scheduler/test_scheduler.py +++ b/tests/scheduler/test_scheduler.py @@ -1,8 +1,28 @@ +from datetime import datetime, timezone + +from taskiq import Flow, TaskiqRouter from taskiq.abc.schedule_source import ScheduleSource from taskiq.brokers.inmemory_broker import InMemoryBroker +from taskiq.compat import model_dump from taskiq.exceptions import ScheduledTaskCancelledError from taskiq.scheduler.scheduled_task import ScheduledTask from taskiq.scheduler.scheduler import TaskiqScheduler +from tests.utils import RecordingBroker + + +class RecordingScheduleSource(ScheduleSource): + """Schedule source that records dynamically added schedules.""" + + def __init__(self) -> None: + self.schedules: list[ScheduledTask] = [] + + async def get_schedules(self) -> list["ScheduledTask"]: + """Return stored schedules.""" + return self.schedules + + async def add_schedule(self, schedule: "ScheduledTask") -> None: + """Store a schedule.""" + self.schedules.append(schedule) class CancellingScheduleSource(ScheduleSource): @@ -31,3 +51,97 @@ async def test_scheduled_task_cancelled() -> None: ) await scheduler.on_ready(source, task) # error is caught + + +async def test_scheduler_resolves_route_at_send_time() -> None: + router = TaskiqRouter() + scheduler_broker = RecordingBroker(router=router, broker_name="scheduler") + first_broker = RecordingBroker(router=router, broker_name="first") + second_broker = RecordingBroker(router=router, broker_name="second") + first_flow = Flow("first") + second_flow = Flow("second") + source = RecordingScheduleSource() + scheduler = TaskiqScheduler(broker=scheduler_broker, sources=[source]) + + @scheduler_broker.task(task_name="demo.task") + async def demo_task() -> None: + return None + + router.route_task(demo_task, broker=first_broker, flow=first_flow) + scheduled_task = ScheduledTask( + task_name="demo.task", + labels={}, + args=[], + kwargs={}, + cron="* * * * *", + ) + router.route_task(demo_task, broker=second_broker, flow=second_flow) + + await scheduler.on_ready(source, scheduled_task) + + assert first_broker.sent == [] + assert second_broker.sent[0][0].task_name == "demo.task" + assert second_broker.sent[0][1] == second_flow + + +async def test_scheduler_without_route_uses_scheduler_broker() -> None: + router = TaskiqRouter() + default_broker = RecordingBroker(router=router, broker_name="default") + scheduler_flow = Flow("scheduler.default") + scheduler_broker = RecordingBroker( + router=router, + broker_name="scheduler", + default_flow=scheduler_flow, + ) + source = RecordingScheduleSource() + scheduler = TaskiqScheduler(broker=scheduler_broker, sources=[source]) + scheduled_task = ScheduledTask( + task_name="external.task", + labels={"source": "database"}, + args=[], + kwargs={}, + cron="* * * * *", + ) + + await scheduler.on_ready(source, scheduled_task) + + sent_message, sent_flow = scheduler_broker.sent[0] + assert default_broker.sent == [] + assert sent_message.task_name == "external.task" + assert sent_message.labels["source"] == "database" + assert sent_message.labels["schedule_id"] == scheduled_task.schedule_id + assert sent_flow == scheduler_flow + assert scheduled_task.labels == {"source": "database"} + + +async def test_created_schedule_kiq_does_not_mutate_schedule_payload() -> None: + router = TaskiqRouter() + source_broker = RecordingBroker(router=router, broker_name="source") + target_broker = RecordingBroker(router=router, broker_name="target") + target_flow = Flow("target") + source = RecordingScheduleSource() + + @source_broker.task(task_name="demo.task") + async def demo_task(value: int) -> None: + return None + + route = router.resolve_route( + demo_task, + broker=target_broker, + flow=target_flow, + ) + schedule = ( + await demo_task.kicker() + .with_route(route) + .with_labels(trace_id="trace") + .schedule_by_time(source, datetime.now(timezone.utc), 1) + ) + schedule_payload = model_dump(schedule.task) + + await schedule.kiq() + + assert source.schedules == [schedule.task] + assert {"route", "flow", "broker", "broker_name"}.isdisjoint(schedule_payload) + assert schedule.task.labels == {"trace_id": "trace"} + assert target_broker.sent[0][0].task_name == "demo.task" + assert target_broker.sent[0][1] == target_flow diff --git a/tests/test_requeue.py b/tests/test_requeue.py index c451b6df..a23f98c3 100644 --- a/tests/test_requeue.py +++ b/tests/test_requeue.py @@ -1,4 +1,20 @@ -from taskiq import Context, InMemoryBroker, TaskiqDepends +import pytest + +from taskiq import Context, Flow, InMemoryBroker, TaskiqDepends, TaskiqRouter +from taskiq.exceptions import NoResultError +from taskiq.message import TaskiqMessage +from tests.utils import RecordingBroker + + +def build_message(task_name: str = "demo.task") -> TaskiqMessage: + """Build a message for context requeue tests.""" + return TaskiqMessage( + task_id="task-id", + task_name=task_name, + labels={}, + args=[], + kwargs={}, + ) async def test_requeue() -> None: @@ -46,3 +62,75 @@ async def task(_: None = TaskiqDepends(dep_func)) -> None: ) assert runs_count == 2 + + +async def test_requeue_preserves_current_broker_when_route_points_elsewhere() -> None: + router = TaskiqRouter() + current_flow = Flow("current.default") + current_broker = RecordingBroker( + router=router, + broker_name="current", + default_flow=current_flow, + ) + routed_broker = RecordingBroker(router=router, broker_name="routed") + routed_flow = Flow("routed") + message = build_message() + context = Context(message, current_broker) + + router.route_task(message.task_name, broker=routed_broker, flow=routed_flow) + + with pytest.raises(NoResultError): + await context.requeue() + + sent_message, sent_flow = current_broker.sent[0] + assert routed_broker.sent == [] + assert sent_message.task_name == message.task_name + assert sent_message.labels["X-Taskiq-requeue"] == "1" + assert sent_flow == current_flow + + +async def test_requeue_prefers_same_broker_route_flow() -> None: + router = TaskiqRouter() + default_flow = Flow("current.default") + route_flow = Flow("current.route") + current_broker = RecordingBroker( + router=router, + broker_name="current", + default_flow=default_flow, + ) + message = build_message() + context = Context(message, current_broker) + + router.route_task(message.task_name, broker=current_broker, flow=route_flow) + + with pytest.raises(NoResultError): + await context.requeue() + + assert current_broker.sent[0][1] == route_flow + + +async def test_requeue_without_route_or_default_flow_sends_without_flow() -> None: + broker = RecordingBroker() + message = build_message() + context = Context(message, broker) + + with pytest.raises(NoResultError): + await context.requeue() + + assert broker.sent[0][1] is None + + +async def test_router_requeue_accepts_explicit_flow_override() -> None: + router = TaskiqRouter() + default_flow = Flow("default") + explicit_flow = Flow("explicit") + broker = RecordingBroker( + router=router, + broker_name="broker", + default_flow=default_flow, + ) + message = build_message() + + await router.requeue(message, broker=broker, flow=explicit_flow) + + assert broker.sent[0][1] == explicit_flow diff --git a/tests/test_router.py b/tests/test_router.py deleted file mode 100644 index f20d7223..00000000 --- a/tests/test_router.py +++ /dev/null @@ -1,218 +0,0 @@ -from collections.abc import AsyncGenerator, Mapping -from dataclasses import dataclass - -import pytest - -from taskiq import Flow, FlowProtocol, TaskiqRouter, task_builder -from taskiq.abc.broker import AsyncBroker -from taskiq.message import BrokerMessage - - -@dataclass(frozen=True, slots=True) -class BrokerQueue: - """Broker-specific flow used to prove protocol-based routing.""" - - name: str - durable: bool = True - - def broker_options(self, broker_name: str) -> Mapping[str, object]: - """Return options for the target broker.""" - return { - "broker": broker_name, - "durable": self.durable, - } - - -class RecordingBroker(AsyncBroker): - """Broker that records sent messages and flows.""" - - def __init__( - self, - *, - router: TaskiqRouter | None = None, - broker_name: str | None = None, - default_flow: FlowProtocol | None = None, - ) -> None: - self.sent: list[tuple[BrokerMessage, FlowProtocol | None]] = [] - super().__init__( - router=router, - broker_name=broker_name, - default_flow=default_flow, - ) - - async def kick(self, message: BrokerMessage) -> None: - """Record old-style send.""" - self.sent.append((message, None)) - - async def kick_to_flow( - self, - message: BrokerMessage, - flow: FlowProtocol | None = None, - ) -> None: - """Record flow-aware send.""" - self.sent.append((message, flow)) - - async def listen(self) -> AsyncGenerator[bytes, None]: - """Recording broker doesn't listen in these tests.""" - if False: - yield b"" - - -def test_broker_creates_default_router() -> None: - broker = RecordingBroker() - - assert broker.router.brokers[broker.broker_name] is broker - assert broker.router.default_broker_name == broker.broker_name - - -async def test_old_broker_task_api_registers_task_in_router() -> None: - broker = RecordingBroker() - - @broker.task(task_name="demo.task") - async def demo_task() -> None: - return None - - assert broker.find_task("demo.task") is demo_task - assert broker.router.find_task("demo.task") is demo_task - - await demo_task.kiq() - - assert broker.sent[0][0].task_name == "demo.task" - assert broker.sent[0][1] is None - - -async def test_router_can_route_task_to_another_broker_flow() -> None: - router = TaskiqRouter() - source = RecordingBroker(router=router, broker_name="source") - target = RecordingBroker(router=router, broker_name="target") - flow = Flow("events") - - @source.task(task_name="demo.task") - async def demo_task() -> None: - return None - - route = router.route_task(demo_task, broker=target, flow=flow) - - await demo_task.kiq() - - assert route.broker is target - assert route.broker_name == "target" - assert source.sent == [] - assert target.sent[0][0].task_name == "demo.task" - assert target.sent[0][1] == flow - - -async def test_kicker_route_override_wins_over_registered_route() -> None: - router = TaskiqRouter() - first = RecordingBroker(router=router, broker_name="first") - second = RecordingBroker(router=router, broker_name="second") - first_flow = Flow("first") - second_flow = Flow("second") - - @first.task(task_name="demo.task") - async def demo_task() -> None: - return None - - router.route_task(demo_task, broker=first, flow=first_flow) - - await demo_task.kicker().with_route(second, second_flow).kiq() - - assert first.sent == [] - assert second.sent[0][1] == second_flow - - -async def test_router_keeps_string_broker_lookup_for_compatibility() -> None: - router = TaskiqRouter() - source = RecordingBroker(router=router, broker_name="source") - target = RecordingBroker(router=router, broker_name="target") - flow = Flow("compat") - - @source.task(task_name="demo.task") - async def demo_task() -> None: - return None - - router.route_task("demo.task", broker="target", flow=flow) - - await demo_task.kicker().with_route("target", flow).kiq() - - assert target.sent[0][0].task_name == "demo.task" - assert target.sent[0][1] == flow - - -async def test_router_accepts_broker_specific_flow_protocol() -> None: - broker = RecordingBroker(broker_name="transport") - flow = BrokerQueue(name="critical", durable=False) - - @broker.task(task_name="demo.task") - async def demo_task() -> None: - return None - - await demo_task.kicker().with_flow(flow).kiq() - - assert isinstance(flow, FlowProtocol) - assert broker.sent[0][1] is flow - assert flow.broker_options("transport") == { - "broker": "transport", - "durable": False, - } - - -async def test_kicker_can_prepare_invocation_for_later() -> None: - broker = RecordingBroker() - - @broker.task(task_name="demo.task") - async def demo_task(value: int) -> None: - return None - - prepared = demo_task.kicker().with_labels(trace_id="abc").prepare(1) - - assert prepared.message.task_name == "demo.task" - assert prepared.message.args == [1] - assert prepared.message.labels["trace_id"] == "abc" - - await prepared.kiq() - - assert broker.sent[0][0].task_id == prepared.message.task_id - - -async def test_task_builder_can_be_registered_later() -> None: - broker = RecordingBroker() - - @task_builder("shared.add", queue="shared") - def add(left: int, right: int) -> int: - return left + right - - assert await add.call(1, 2) == 3 - - registered = broker.register_task(add) - - assert registered.task_name == "shared.add" - assert registered.labels == {"queue": "shared"} - assert broker.router.find_task("shared.add") is registered - - await registered.kiq(1, 2) - - assert broker.sent[0][0].task_name == "shared.add" - - -async def test_router_task_decorator_can_choose_broker_and_flow() -> None: - router = TaskiqRouter() - target = RecordingBroker(router=router, broker_name="target") - flow = Flow("target-flow") - - @router.task("demo.task", broker=target, flow=flow) - async def demo_task() -> None: - return None - - await demo_task.kiq() - - assert target.sent[0][0].task_name == "demo.task" - assert target.sent[0][1] == flow - - -def test_router_rejects_duplicate_broker_names() -> None: - router = TaskiqRouter() - RecordingBroker(router=router, broker_name="broker") - - with pytest.raises(ValueError, match="already registered"): - RecordingBroker(router=router, broker_name="broker") diff --git a/tests/test_task_builder.py b/tests/test_task_builder.py new file mode 100644 index 00000000..0acf3066 --- /dev/null +++ b/tests/test_task_builder.py @@ -0,0 +1,130 @@ +from dataclasses import dataclass +from typing import Any + +import pytest +from pydantic import BaseModel + +from taskiq import AsyncTaskiqDecoratedTask, InMemoryBroker, task_builder +from taskiq.labels import LabelType + + +class CustomBrokerTask(AsyncTaskiqDecoratedTask[Any, Any]): + """Custom broker-level task class used by old decorator tests.""" + + +@dataclass(frozen=True, slots=True) +class Payload: + """Payload dataclass used to check message argument preparation.""" + + value: int + + +class User(BaseModel): + """Payload model used to check pydantic argument preparation.""" + + name: str + + +async def test_task_definition_call_executes_sync_and_async_functions() -> None: + @task_builder("shared.add") + def add(left: int, right: int) -> int: + return left + right + + @task_builder("shared.multiply") + async def multiply(left: int, right: int) -> int: + return left * right + + assert add(1, 2) == 3 + assert await add.call(1, 2) == 3 + assert await multiply(2, 3) == 6 + assert await multiply.call(2, 3) == 6 + + +def test_task_definition_message_matches_registered_kicker_message() -> None: + @task_builder( + "shared.message", + priority=7, + enabled=True, + payload=b"abc", + ) + def process(payload: Payload, user: User) -> None: + return None + + broker = InMemoryBroker() + registered = broker.register_task(process) + + definition_message = process.message( + "task-id", + Payload(value=1), + User(name="Ada"), + ) + prepared_message = ( + registered.kicker() + .with_task_id("task-id") + .prepare( + Payload(value=1), + User(name="Ada"), + ) + .message + ) + + assert definition_message == prepared_message + assert definition_message.args == [{"value": 1}, {"name": "Ada"}] + assert definition_message.labels == { + "enabled": "True", + "payload": "YWJj", + "priority": "7", + } + assert definition_message.labels_types == { + "enabled": LabelType.BOOL.value, + "payload": LabelType.BYTES.value, + "priority": LabelType.INT.value, + } + + +def test_task_definition_message_rejects_dataclass_types() -> None: + @task_builder("shared.message") + def process(payload: Payload) -> None: + return None + + with pytest.raises(ValueError, match="Cannot serialize types"): + process.message("task-id", Payload) # type: ignore[arg-type] + + +def test_shared_task_without_base_cls_uses_native_task_class() -> None: + broker = InMemoryBroker() + broker.decorator_class = CustomBrokerTask + + @broker.task(task_name="old.decorator") + async def old_task() -> None: + return None + + @task_builder("shared.default") + async def shared_task() -> None: + return None + + registered = broker.register_task(shared_task) + + assert isinstance(old_task, CustomBrokerTask) + assert type(registered) is AsyncTaskiqDecoratedTask + + +@pytest.mark.parametrize( + "base_cls", + [ + pytest.param(object, id="non-task-type"), + pytest.param(0, id="falsey-runtime-value"), + ], +) +def test_invalid_task_definition_base_cls_fails_when_bound(base_cls: Any) -> None: + @task_builder("shared.invalid", base_cls=base_cls) + def invalid_task() -> None: + return None + + broker = InMemoryBroker() + + with pytest.raises( + TypeError, + match="base_cls must be a subclass of AsyncTaskiqDecoratedTask", + ): + broker.register_task(invalid_task) diff --git a/tests/utils.py b/tests/utils.py index c6c659d2..a90b59bf 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,10 +1,45 @@ import asyncio from collections.abc import AsyncGenerator -from taskiq import AsyncBroker, BrokerMessage +from taskiq import AsyncBroker, BrokerMessage, FlowProtocol, TaskiqRouter from taskiq.acks import AckableMessage +class RecordingBroker(AsyncBroker): + """Broker for tests that records messages and selected flows.""" + + def __init__( + self, + *, + router: TaskiqRouter | None = None, + broker_name: str | None = None, + default_flow: FlowProtocol | None = None, + ) -> None: + self.sent: list[tuple[BrokerMessage, FlowProtocol | None]] = [] + super().__init__( + router=router, + broker_name=broker_name, + default_flow=default_flow, + ) + + async def kick(self, message: BrokerMessage) -> None: + """Record old-style send.""" + self.sent.append((message, None)) + + async def kick_to_flow( + self, + message: BrokerMessage, + flow: FlowProtocol | None = None, + ) -> None: + """Record flow-aware send.""" + self.sent.append((message, flow)) + + async def listen(self) -> AsyncGenerator[bytes, None]: + """Recording broker does not listen in tests.""" + if False: + yield b"" + + class AsyncQueueBroker(AsyncBroker): """ Broker for testing. From 91d57dfca407f309a3575e6692d2ad0541666e1f Mon Sep 17 00:00:00 2001 From: "sergei.romanchuk" Date: Sat, 6 Jun 2026 22:12:17 +0200 Subject: [PATCH 4/5] feat(routing): separate task routing from broker transport Introduce Flow, TaskiqRouter, route and subscription contracts to move task dispatch decisions out of brokers while keeping brokers as transport adapters. Preserve backward-compatible --- taskiq/abc/broker.py | 8 +++++--- taskiq/router.py | 2 ++ tests/routing/test_router_core.py | 9 +++++++++ 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/taskiq/abc/broker.py b/taskiq/abc/broker.py index a25a98a6..9405edbb 100644 --- a/taskiq/abc/broker.py +++ b/taskiq/abc/broker.py @@ -142,7 +142,8 @@ def find_task(self, task_name: str) -> AsyncTaskiqDecoratedTask[Any, Any] | None It searches task by name in dict of tasks that were registered for this broker directly. - If it fails, it checks global dict of all available tasks. + If it fails, it checks tasks registered in the broker's router, + then global dict of all available tasks. :param task_name: name of a task. :returns: found task or None. @@ -161,8 +162,9 @@ def get_all_tasks(self) -> dict[str, AsyncTaskiqDecoratedTask[Any, Any]]: """ Method to fetch all tasks available in broker. - This method returns all tasks, globally and locally - available in broker. With local tasks having higher priority. + This method returns all tasks globally, through the broker's router + and locally available in broker. With local tasks having higher + priority. So, if you have two tasks with the same name, one registered in global registry and one registered diff --git a/taskiq/router.py b/taskiq/router.py index b62d16d3..fa054a4a 100644 --- a/taskiq/router.py +++ b/taskiq/router.py @@ -55,6 +55,8 @@ def default_broker(self) -> AsyncBroker | None: @default_broker.setter def default_broker(self, broker: AsyncBroker | None) -> None: + if broker is not None: + self._brokers.resolve(broker) self._brokers.default_broker = broker @property diff --git a/tests/routing/test_router_core.py b/tests/routing/test_router_core.py index 8f9f359b..d2b92301 100644 --- a/tests/routing/test_router_core.py +++ b/tests/routing/test_router_core.py @@ -34,6 +34,15 @@ def test_router_rejects_broker_attached_to_another_router() -> None: second_router.set_broker(broker, name="broker") +def test_router_default_broker_setter_rejects_foreign_broker() -> None: + first_router = TaskiqRouter() + second_router = TaskiqRouter() + broker = RecordingBroker(router=first_router, broker_name="broker") + + with pytest.raises(ValueError, match="not registered"): + second_router.default_broker = broker + + def test_router_rejects_string_broker_references() -> None: router = TaskiqRouter() RecordingBroker(router=router, broker_name="broker") From 426187e407376461be3e970773324806ad9e740f Mon Sep 17 00:00:00 2001 From: "sergei.romanchuk" Date: Sun, 7 Jun 2026 11:13:00 +0200 Subject: [PATCH 5/5] fix(routing): tighten router contracts after review Remove direct router access to broker private task storage and add an explicit broker-owned storage boundary for router-managed task registration. Clarify prepared route snapshots, subscription task-name semantics and setup-time router mutation contracts. Extend routing docs with flow resolution precedence and custom task base class middleware behavior. Add regression coverage for shared task local registration, custom base_cls middleware hooks and scheduler fallback when a route is removed before dispatch. --- docs/guide/routing-and-flows.md | 22 +++++++++++++++++++ taskiq/abc/broker.py | 13 ++++++++++++ taskiq/kicker.py | 9 +++++++- taskiq/router.py | 8 ++----- taskiq/routing/models.py | 8 ++++++- taskiq/routing/subscriptions.py | 7 +++++- tests/routing/test_shared_tasks.py | 24 ++++++++++++++++++++- tests/scheduler/test_scheduler.py | 34 ++++++++++++++++++++++++++++++ 8 files changed, 115 insertions(+), 10 deletions(-) diff --git a/docs/guide/routing-and-flows.md b/docs/guide/routing-and-flows.md index 208d2502..12db2fc9 100644 --- a/docs/guide/routing-and-flows.md +++ b/docs/guide/routing-and-flows.md @@ -150,6 +150,16 @@ keep a snapshot of the route that was resolved at prepare time, so later changes to the mutable kicker object or router route table do not change that prepared send. +Flow selection is resolved in this order: + +| Source | Effect | +| --- | --- | +| Explicit route | `.with_route(route)` uses that route's broker and flow. A later `.with_flow(flow)` updates the route flow, while a later `.with_broker(broker)` clears the explicit route. | +| Explicit broker | `.with_broker(broker)` sends through that broker. If `.with_flow(flow)` is also set, that flow wins; otherwise Taskiq uses the same-broker task route flow, then the broker `default_flow`. | +| Explicit flow | `.with_flow(flow)` without a broker override replaces the flow of the resolved route or broker default route. | +| Router task route | `router.route_task(task, broker=..., flow=...)` is the default outbound route for that task. | +| Broker default flow | `InMemoryBroker(..., default_flow=...)` and other brokers with `default_flow` provide the fallback flow for that broker. | + ## Subscriptions Routers also keep inbound flow subscriptions. Routing and subscribing are @@ -179,6 +189,11 @@ The deprecated `route_task(..., subscribe=True)` shim still performs this subscription when a flow is resolved, but new code should call `subscribe()` directly. +`TaskiqSubscription.task_names` is diagnostic listen-plan metadata. It records +which task names caused the broker to listen to a flow, but it does not make the +flow an inbound task router. Workers still execute messages by +`TaskiqMessage.task_name`. + Existing brokers can keep implementing `listen()` as before. New flow-aware brokers may use `get_subscribed_flows()` to subscribe to queues, topics, subjects or streams while the routing rules stay in the router. @@ -187,6 +202,10 @@ subjects or streams while the routing rules stay in the router. subscriptions. It deduplicates by flow identity and checks broker options for conflicts. +Router routes and subscriptions are mutable setup-time configuration. They are +not thread-safe runtime coordination primitives; configure them before starting +concurrent worker or client activity. + ## Scheduler and requeue `ScheduledTask` remains a transport-neutral invocation payload. It stores task @@ -277,6 +296,9 @@ When the final application registers this definition, Taskiq creates the bound task using `TracingTask`. If `base_cls` is not provided, Taskiq uses the native decorated task class. +Custom task classes do not bypass broker middleware. Send and execute lifecycle +hooks are still owned by the selected broker and worker path. + For low-level integrations, `TaskDefinition.message(...)` builds a `TaskiqMessage` without binding the task. It uses the same argument and label preparation contract as a normal `.kicker().prepare(...)` invocation. diff --git a/taskiq/abc/broker.py b/taskiq/abc/broker.py index 9405edbb..06926827 100644 --- a/taskiq/abc/broker.py +++ b/taskiq/abc/broker.py @@ -439,6 +439,19 @@ def bind_task_definition( register=register, ) + def store_registered_task( + self, + task: AsyncTaskiqDecoratedTask[Any, Any], + ) -> None: + """ + Store a task after router-owned registration accepted it. + + Application code should use `task` or `register_task`. This method is + the broker/router integration boundary for cases where the router owns + registration ordering and the broker owns local task storage. + """ + self._store_task(task.task_name, task) + def _decorate_task( self, func: Callable[_FuncParams, _ReturnType], diff --git a/taskiq/kicker.py b/taskiq/kicker.py index 8debca3e..8de79f90 100644 --- a/taskiq/kicker.py +++ b/taskiq/kicker.py @@ -321,7 +321,14 @@ async def _legacy_kiq( def _prepare_route_snapshot( self, ) -> tuple[AsyncBroker, TaskiqRoute | None, FlowProtocol | None]: - """Resolve the route that a prepared invocation must keep.""" + """ + Resolve route state that a prepared invocation must keep. + + The returned tuple is `(broker, route, explicit_flow_override)`. + When a route is present, the route already carries its flow, so the + separate flow override is `None`. A separate flow is returned only for + legacy broker paths that cannot snapshot a `TaskiqRoute`. + """ router = getattr(self.broker, "router", None) if not isinstance(router, TaskiqRouter): return self.broker, None, self.route_flow diff --git a/taskiq/router.py b/taskiq/router.py index fa054a4a..8b526066 100644 --- a/taskiq/router.py +++ b/taskiq/router.py @@ -132,11 +132,7 @@ def register_task( broker=target_broker, flow=flow, ) - # Router and broker share this internal registration boundary. - target_broker._store_task( # noqa: SLF001 - registered_task.task_name, - registered_task, - ) + target_broker.store_registered_task(registered_task) return registered_task return self._register_bound_task(task, broker=broker, flow=flow) @@ -178,7 +174,7 @@ def register( func: Callable[_FuncParams, _ReturnType], ) -> AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]: target_broker = self._brokers.resolve(broker) - real_task_name = task_name if not callable(task_name) else None + real_task_name: str | None = None if callable(task_name) else task_name task = target_broker.task(task_name=real_task_name, **labels)(func) if flow is not None: self.route_task(task, broker=target_broker, flow=flow) diff --git a/taskiq/routing/models.py b/taskiq/routing/models.py index 439ae7da..2446aadd 100644 --- a/taskiq/routing/models.py +++ b/taskiq/routing/models.py @@ -26,7 +26,13 @@ def broker_name(self) -> str: @dataclass(frozen=True, slots=True) class TaskiqSubscription: - """Inbound flow subscription owned by a router.""" + """ + Inbound flow subscription owned by a router. + + `task_names` is diagnostic listen-plan metadata. It records which task + declarations caused a broker to listen to this flow; worker execution still + resolves inbound messages by `TaskiqMessage.task_name`. + """ broker: AsyncBroker flow: FlowProtocol diff --git a/taskiq/routing/subscriptions.py b/taskiq/routing/subscriptions.py index 2c6c6b23..7f38ba2d 100644 --- a/taskiq/routing/subscriptions.py +++ b/taskiq/routing/subscriptions.py @@ -15,7 +15,12 @@ class SubscriptionPlan: - """Inbound listen-plan subscriptions for flow-aware brokers.""" + """ + Inbound listen-plan subscriptions for flow-aware brokers. + + This is mutable setup-time configuration, not a thread-safe runtime + coordination primitive. + """ def __init__(self, brokers: BrokerRegistry) -> None: self.brokers = brokers diff --git a/tests/routing/test_shared_tasks.py b/tests/routing/test_shared_tasks.py index 4efab51d..e7aa596c 100644 --- a/tests/routing/test_shared_tasks.py +++ b/tests/routing/test_shared_tasks.py @@ -1,7 +1,7 @@ import pytest from taskiq import AsyncTaskiqDecoratedTask, Flow, task_builder -from tests.routing.models import TracingTask +from tests.routing.models import RecordingMiddleware, TracingTask from tests.utils import RecordingBroker @@ -41,6 +41,7 @@ async def shared_task() -> None: assert source.sent == [] assert target.find_task("shared.routed") is registered + assert target.local_task_registry["shared.routed"] is registered assert target.sent[0][0].task_name == "shared.routed" assert target.sent[0][1] == flow @@ -64,6 +65,27 @@ async def traced(value: int) -> int: assert broker.sent[0][0].task_name == "shared.traced" +async def test_task_builder_custom_base_cls_uses_broker_middleware() -> None: + broker = RecordingBroker() + events: list[tuple[str, str, str]] = [] + broker.add_middlewares(RecordingMiddleware("broker", events)) + + @task_builder("shared.traced.middleware", base_cls=TracingTask) + async def traced(value: int) -> int: + return value + 1 + + registered = broker.register_task(traced) + + assert isinstance(registered, TracingTask) + + await registered.kiq(1) + + assert events == [ + ("broker", "pre_send", "shared.traced.middleware"), + ("broker", "post_send", "shared.traced.middleware"), + ] + + def test_task_definition_default_flow_does_not_create_subscription() -> None: flow = Flow("shared.default") broker = RecordingBroker(default_flow=flow) diff --git a/tests/scheduler/test_scheduler.py b/tests/scheduler/test_scheduler.py index eeb9d35c..bc37effc 100644 --- a/tests/scheduler/test_scheduler.py +++ b/tests/scheduler/test_scheduler.py @@ -114,6 +114,40 @@ async def test_scheduler_without_route_uses_scheduler_broker() -> None: assert scheduled_task.labels == {"source": "database"} +async def test_scheduler_removed_route_uses_scheduler_broker() -> None: + router = TaskiqRouter() + scheduler_flow = Flow("scheduler.default") + scheduler_broker = RecordingBroker( + router=router, + broker_name="scheduler", + default_flow=scheduler_flow, + ) + target_broker = RecordingBroker(router=router, broker_name="target") + source = RecordingScheduleSource() + scheduler = TaskiqScheduler(broker=scheduler_broker, sources=[source]) + + @scheduler_broker.task(task_name="demo.task") + async def demo_task() -> None: + return None + + router.route_task(demo_task, broker=target_broker, flow=Flow("target")) + scheduled_task = ScheduledTask( + task_name="demo.task", + labels={}, + args=[], + kwargs={}, + cron="* * * * *", + ) + del router.routes["demo.task"] + + await scheduler.on_ready(source, scheduled_task) + + sent_message, sent_flow = scheduler_broker.sent[0] + assert target_broker.sent == [] + assert sent_message.task_name == "demo.task" + assert sent_flow == scheduler_flow + + async def test_created_schedule_kiq_does_not_mutate_schedule_payload() -> None: router = TaskiqRouter() source_broker = RecordingBroker(router=router, broker_name="source")