diff --git a/docs/examples/router/multiple_brokers.py b/docs/examples/router/multiple_brokers.py new file mode 100644 index 00000000..c845bbf0 --- /dev/null +++ b/docs/examples/router/multiple_brokers.py @@ -0,0 +1,98 @@ +"""Route one task through several brokers with a shared router.""" + +import asyncio + +from taskiq import Flow, InMemoryBroker, TaskiqRoute, TaskiqRouter + +router = TaskiqRouter() + +default_email_flow = Flow("emails.default") +priority_email_flow = Flow("emails.priority") +bulk_email_flow = Flow("emails.bulk") + +default_broker = InMemoryBroker( + router=router, + broker_name="default", + default_flow=default_email_flow, + await_inplace=True, +) +priority_broker = InMemoryBroker( + router=router, + broker_name="priority", + default_flow=priority_email_flow, + await_inplace=True, +) + + +@default_broker.task(task_name="examples.send_email", domain="notifications") +async def send_email(user_id: int, template: str) -> str: + """Pretend to render and send an email.""" + return f"{template} email sent to user {user_id}" + + +priority_route = router.route_task( + send_email, + broker=priority_broker, + flow=priority_email_flow, +) +priority_subscription = router.subscribe( + priority_broker, + priority_email_flow, + send_email, +) + + +def _format_route(task_name: str, route: TaskiqRoute) -> str: + """Return a readable route diagnostic for the example output.""" + flow_name = route.flow.name if route.flow is not None else "" + return f"{task_name} -> broker={route.broker_name}, flow={flow_name}" + + +def _format_listen_plan() -> str: + """Return flows that the priority broker should subscribe to.""" + flow_names = ", ".join(flow.name for flow in priority_broker.get_subscribed_flows()) + return f"priority listens to: {flow_names}" + + +async def _main() -> None: + await default_broker.startup() + await priority_broker.startup() + try: + direct_result = await send_email(7, "welcome") + + declared_route = router.resolve_route(send_email) + assert declared_route == priority_route + + routed_task = ( + await send_email.kicker() + .with_route(declared_route) + .kiq( + 7, + "welcome", + ) + ) + routed_result = await routed_task.wait_result(timeout=2) + + bulk_route = router.resolve_route( + send_email, + broker=default_broker, + flow=bulk_email_flow, + ) + bulk_task = await send_email.kicker().with_route(bulk_route).kiq(8, "digest") + bulk_result = await bulk_task.wait_result(timeout=2) + + print(f"Direct call: {direct_result}") + print(f"Router rule: {_format_route(send_email.task_name, priority_route)}") + print(f"Subscription tasks: {sorted(priority_subscription.task_names)}") + print(_format_listen_plan()) + print(f"Resolved route: {_format_route(send_email.task_name, declared_route)}") + print(f"Routed call: {routed_result.return_value}") + print(f"Override route: {_format_route(send_email.task_name, bulk_route)}") + print(f"Route override: {bulk_result.return_value}") + finally: + await priority_broker.shutdown() + await default_broker.shutdown() + + +if __name__ == "__main__": + asyncio.run(_main()) diff --git a/docs/examples/router/shared_task_package.py b/docs/examples/router/shared_task_package.py new file mode 100644 index 00000000..0bbd6616 --- /dev/null +++ b/docs/examples/router/shared_task_package.py @@ -0,0 +1,100 @@ +"""Declare shared task definitions and bind them in the final application.""" + +import asyncio +from collections.abc import Mapping +from dataclasses import dataclass +from typing import Any + +from taskiq import ( + AsyncTaskiqDecoratedTask, + Flow, + InMemoryBroker, + TaskiqRouter, + task_builder, +) + + +@dataclass(frozen=True, slots=True) +class BillingQueue: + """Broker-specific flow that follows the shared flow protocol.""" + + name: str + priority: int + + def broker_options(self) -> Mapping[str, object]: + """Return options that a billing broker adapter can understand.""" + return { + "priority": self.priority, + } + + +class BillingTask(AsyncTaskiqDecoratedTask[Any, Any]): + """Custom task class shared by billing package tasks.""" + + def billing_name(self) -> str: + """Return a billing-specific task name.""" + return self.task_name + + +@task_builder("billing.calculate_total", base_cls=BillingTask, domain="billing") +async def calculate_total(price: int, quantity: int) -> int: + """Package-level task definition that is not bound to any broker.""" + return price * quantity + + +router = TaskiqRouter() +billing_flow = Flow("billing.tasks") +priority_billing_flow = BillingQueue(name="billing.priority", priority=10) + +billing_broker = InMemoryBroker( + router=router, + broker_name="billing", + default_flow=billing_flow, + await_inplace=True, +) + +registered_calculate_total = router.register_task( + calculate_total, + broker=billing_broker, + flow=billing_flow, +) +router.subscribe( + billing_broker, + billing_flow, + registered_calculate_total, +) + + +async def _main() -> None: + await billing_broker.startup() + try: + direct_result = await calculate_total.call(19, 3) + + priority_route = router.resolve_route( + registered_calculate_total, + broker=billing_broker, + flow=priority_billing_flow, + ) + prepared_task = ( + registered_calculate_total.kicker() + .with_route( + priority_route, + ) + .prepare(19, 3) + ) + + queued_task = await prepared_task.kiq() + queued_result = await queued_task.wait_result(timeout=2) + + print(f"Shared task direct call: {direct_result}") + print(f"Registered task class: {registered_calculate_total.billing_name()}") + listen_flow = billing_broker.get_subscribed_flows()[0] + print(f"Registered listen flow: {listen_flow.name}") + print(f"Prepared message: {prepared_task.message.task_name}") + print(f"Registered queued call: {queued_result.return_value}") + finally: + await billing_broker.shutdown() + + +if __name__ == "__main__": + asyncio.run(_main()) diff --git a/docs/guide/architecture-overview.md b/docs/guide/architecture-overview.md index d4b825f1..f84437d2 100644 --- a/docs/guide/architecture-overview.md +++ b/docs/guide/architecture-overview.md @@ -82,6 +82,28 @@ asyncio.run(main()) ``` +## Router and flows + +Taskiq can use a `TaskiqRouter` to keep routing rules outside of broker +implementations. Brokers remain transport adapters, while the router owns task +registration, route resolution and flow subscriptions. +This section describes the `experiment/separate_broker` branch contract. The +old `@broker.task(...)`, `.kiq()`, labels, scheduler and result backend behavior +remain compatible, while router/flow APIs are additive review material for the +branch. + +`Flow` is a transport-neutral delivery address. Broker packages may provide +their own flow classes for queue, topic, subject or stream options, as long as +they implement the same flow protocol. The router deduplicates subscriptions by +flow name and rejects same-name flows with incompatible broker options. + +Routing and subscribing are separate responsibilities. `route_task(...)` +chooses the outbound broker and flow for task invocations. `subscribe(...)` +adds flows to a broker listen plan for flow-aware broker adapters. Worker task +lookup still uses `task_name`; flow does not select the Python task. + +Read more in the [Routing and flows](./routing-and-flows.md) section. + ## Messages Every message has labels. You can define labels diff --git a/docs/guide/routing-and-flows.md b/docs/guide/routing-and-flows.md new file mode 100644 index 00000000..12db2fc9 --- /dev/null +++ b/docs/guide/routing-and-flows.md @@ -0,0 +1,367 @@ +--- +title: Routing and flows +order: 10 +--- + +# Routing and flows + +::: warning Experimental branch +This page describes the routing and flow contract on the +`experiment/separate_broker` branch. The old task declaration API remains +compatible. Broker-specific packages may keep using their existing send/listen +implementation until they opt into flow-aware dispatch and listen-plan support. +::: + +Taskiq keeps the old task declaration API: + +```python +@broker.task(task_name="billing.charge") +async def charge(user_id: int) -> None: + ... +``` + +The broker is still the transport adapter, but routing decisions can be moved to +a `TaskiqRouter`. This is useful when one application owns several brokers or +when a task package declares tasks before the final application knows which +broker must handle them. + +The important responsibility split is: + +- Task declaration describes a Python callable, task name, labels and direct + call behavior. +- Kicker prepares a task invocation: task id, labels, args, kwargs and optional + route overrides. +- Router owns outbound routing policy: task invocation to broker and flow. +- Router also owns inbound listen-plan records for flow-aware brokers. +- Broker owns transport lifecycle: startup, shutdown, send, listen, ack/nack, + serializer/formatter, middleware and result backend integration. +- Worker still executes by `task_name`; flow does not select the Python task. + +## Router + +A router owns broker registration, task registration and routing rules. +One router can own several brokers. A broker has a single router owner; pass the +router when creating the broker instead of registering the same broker in +several routers. + +```python +from taskiq import Flow, InMemoryBroker, TaskiqRouter + +router = TaskiqRouter() + +default_broker = InMemoryBroker(router=router, broker_name="default") +priority_broker = InMemoryBroker(router=router, broker_name="priority") + +priority_route = router.route_task( + "billing.charge", + broker=priority_broker, + flow=Flow("billing.priority"), +) +``` + +If a broker is created without an explicit router, Taskiq creates a default +router for that broker. Use an explicit shared router when one worker or app +needs several brokers to follow the same routing policy. + +Pass broker objects to routing APIs. If broker names come from configuration, +resolve them explicitly first: + +```python +broker = router.get_broker("priority") +route = router.route_task("billing.charge", broker=broker) +``` + +This keeps routing errors local and avoids hidden string references. + +## Flow + +`Flow` is a transport-neutral delivery address. It is intentionally small. +The routing identity of a flow is its logical name: + +```python +from taskiq import Flow, FlowIdentity, get_flow_identity + +flow = Flow("billing.priority").with_options(priority=10) + +assert flow.identity == FlowIdentity("billing.priority") +assert get_flow_identity(flow) == flow.identity +``` + +Broker packages can expose their own flow objects when they need +transport-specific options: + +```python +from collections.abc import Mapping +from dataclasses import dataclass + + +@dataclass(frozen=True, slots=True) +class RabbitQueue: + name: str + durable: bool = True + + def broker_options(self) -> Mapping[str, object]: + return {"durable": self.durable} +``` + +The core protocol requires `name` and `broker_options()`. Broker-specific +settings stay in the flow object and are returned through `broker_options()`; +they are not encoded into router task names or message labels. + +Flow identity and declaration options are separate: + +- Same broker, same flow name and same `broker_options()` means the listen plan + can deduplicate the flow and merge task names. +- Same broker and same flow name with different `broker_options()` is a + conflict, so Taskiq raises `ValueError` instead of silently choosing one + queue/topic declaration. +- Different brokers may use the same flow name independently. + +## Sending with a route + +`router.route_task()` returns a `TaskiqRoute`. You can pass that object to a +kicker when you want the route to be explicit at the call site. + +```python +route = router.resolve_route("billing.charge") +task = await charge.kicker().with_route(route).kiq(42) +``` + +For one-off overrides, resolve a route with the target broker and flow: + +```python +bulk_route = router.resolve_route( + "billing.charge", + broker=default_broker, + flow=Flow("billing.bulk"), +) + +task = await charge.kicker().with_route(bulk_route).kiq(42) +``` + +`with_broker(...)` remains available for shared-task style invocation. When a +kicker has both a route and a later broker override, the broker override clears +the explicit route. Taskiq then resolves the task for the override broker: it +uses that broker's task route flow when one exists, otherwise the broker +`default_flow`. + +`prepare(...)` creates a prepared invocation for later use. Prepared invocations +keep a snapshot of the route that was resolved at prepare time, so later changes +to the mutable kicker object or router route table do not change that prepared +send. + +Flow selection is resolved in this order: + +| Source | Effect | +| --- | --- | +| Explicit route | `.with_route(route)` uses that route's broker and flow. A later `.with_flow(flow)` updates the route flow, while a later `.with_broker(broker)` clears the explicit route. | +| Explicit broker | `.with_broker(broker)` sends through that broker. If `.with_flow(flow)` is also set, that flow wins; otherwise Taskiq uses the same-broker task route flow, then the broker `default_flow`. | +| Explicit flow | `.with_flow(flow)` without a broker override replaces the flow of the resolved route or broker default route. | +| Router task route | `router.route_task(task, broker=..., flow=...)` is the default outbound route for that task. | +| Broker default flow | `InMemoryBroker(..., default_flow=...)` and other brokers with `default_flow` provide the fallback flow for that broker. | + +## Subscriptions + +Routers also keep inbound flow subscriptions. Routing and subscribing are +separate operations: `route_task()` selects the outbound broker/flow for a task +invocation, while `subscribe()` adds a flow to the broker listen plan. + +```python +billing_flow = Flow("billing.priority") + +route = router.route_task( + "billing.charge", + broker=priority_broker, + flow=billing_flow, +) + +router.subscribe(priority_broker, billing_flow, "billing.charge") + +subscriptions = router.get_subscriptions(priority_broker) +flows = priority_broker.get_subscribed_flows() +``` + +`route_task(...)` does not subscribe by default. This is intentional: outbound +routing answers "where should this invocation be sent?", while subscribing +answers "which flows should this broker listen to?". + +The deprecated `route_task(..., subscribe=True)` shim still performs this +subscription when a flow is resolved, but new code should call `subscribe()` +directly. + +`TaskiqSubscription.task_names` is diagnostic listen-plan metadata. It records +which task names caused the broker to listen to a flow, but it does not make the +flow an inbound task router. Workers still execute messages by +`TaskiqMessage.task_name`. + +Existing brokers can keep implementing `listen()` as before. New flow-aware +brokers may use `get_subscribed_flows()` to subscribe to queues, topics, +subjects or streams while the routing rules stay in the router. + +`get_subscribed_flows()` returns the broker `default_flow` plus explicit +subscriptions. It deduplicates by flow identity and checks broker options for +conflicts. + +Router routes and subscriptions are mutable setup-time configuration. They are +not thread-safe runtime coordination primitives; configure them before starting +concurrent worker or client activity. + +## Scheduler and requeue + +`ScheduledTask` remains a transport-neutral invocation payload. It stores task +name, labels, args, kwargs, optional task id and schedule timing fields; it does +not store broker objects, broker names, route objects, flows or +transport-specific flow options. + +When a scheduled task is ready, `TaskiqScheduler` resolves the task route in the +scheduler process through the scheduler broker's router. Route changes made +before `on_ready(...)` runs affect that scheduled dispatch. If no router route +exists for the scheduled task, Taskiq keeps the old behavior and sends through +the scheduler broker. + +`with_broker(...)`, `with_route(...)` and `with_flow(...)` on a kicker are not +persisted schedule route metadata. They can affect `CreatedSchedule.kiq()`, +which is an immediate queued invocation helper, but they do not change the +stored schedule payload. + +`Context.requeue()` is current-broker sticky. It sends the same task message +through the broker currently executing the task, even if the task's default +route now points to another broker. Requeue flow selection is explicit +low-level override, same-broker task route flow, current broker `default_flow`, +then no flow. Taskiq does not currently preserve the inbound source flow because +flow provenance is not stored in `TaskiqMessage` or `Context`. + +## Shared task declarations + +Libraries can declare tasks without importing an application broker: + +```python +from taskiq import task_builder + + +@task_builder("billing.calculate_total", domain="billing") +async def calculate_total(price: int, quantity: int) -> int: + return price * quantity +``` + +The final application binds the task definition to its router and broker: + +```python +billing_flow = Flow("billing.tasks") + +registered = router.register_task( + calculate_total, + broker=billing_broker, + flow=billing_flow, +) + +router.subscribe( + billing_broker, + billing_flow, + registered, +) + +result = await registered.kiq(19, 3) +``` + +Task name and labels belong to the shared declaration. Register a +`TaskDefinition` without `task_name` or label overrides; Taskiq rejects those +overrides instead of silently ignoring them. + +The unbound task can still run locally: + +```python +total = await calculate_total.call(19, 3) +``` + +Shared task declarations can also provide a custom task class for one task: + +```python +from typing import Any + +from taskiq import AsyncTaskiqDecoratedTask, task_builder + + +class TracingTask(AsyncTaskiqDecoratedTask[Any, Any]): + def tracing_name(self) -> str: + return self.task_name + + +@task_builder("billing.traced_charge", base_cls=TracingTask) +async def traced_charge(user_id: int) -> None: + ... +``` + +When the final application registers this definition, Taskiq creates the bound +task using `TracingTask`. If `base_cls` is not provided, Taskiq uses the native +decorated task class. + +Custom task classes do not bypass broker middleware. Send and execute lifecycle +hooks are still owned by the selected broker and worker path. + +For low-level integrations, `TaskDefinition.message(...)` builds a +`TaskiqMessage` without binding the task. It uses the same argument and label +preparation contract as a normal `.kicker().prepare(...)` invocation. + +## Compatibility and migration notes + +The routing/flow branch keeps the old broker-first task API working: + +```python +@broker.task(task_name="billing.charge") +async def charge(user_id: int) -> None: + ... + +await charge.kiq(42) +await charge.kicker().with_labels(source="api").kiq(42) +``` + +Migration guidance for application code: + +- Keep using `@broker.task(...)` when one broker owns task declaration and send + path. No router code is required for this case. +- Use one shared `TaskiqRouter` when several brokers must share task registry + and routing policy. +- Pass broker objects to routing APIs. If configuration stores broker names, + resolve them with `router.get_broker(name)` before calling `route_task(...)`. +- Replace `route_task(..., subscribe=True)` usage with explicit + `router.subscribe(...)`. +- Do not use labels as a transport routing schema. Labels remain task metadata. +- Register a `TaskDefinition` without `task_name` or label overrides; plain + callable registration still supports those overrides. + +Migration guidance for broker packages: + +- Existing brokers that implement only `kick(message)` and `listen()` remain + valid. The default `kick_to_flow(message, flow=None)` calls `kick(message)`. +- Flow-aware brokers can override `kick_to_flow(...)` and map `flow.name` to a + queue, topic, subject or stream. +- Flow-aware listeners can call `get_subscribed_flows()` during listen setup to + add router-owned subscriptions while keeping old default listen behavior when + there are no explicit subscriptions. +- Broker-specific declaration settings belong in broker-specific flow objects + and `broker_options()`, not in task names, labels or scheduler payloads. +- Ack/nack behavior, serializer/formatter behavior, retry, middleware hooks and + result backend ownership remain broker responsibilities. + +Scheduler and requeue compatibility: + +- `ScheduledTask` does not store routes, broker names, flow objects or + broker-specific flow options in this iteration. +- Scheduled dispatch resolves the route late in the scheduler process through + the scheduler broker's router. +- `Context.requeue()` sends through the broker currently executing the task. It + does not re-resolve a task route across brokers. +- Taskiq does not preserve inbound source-flow provenance yet because that data + is not stored in `TaskiqMessage` or `Context`. + +## Examples + +See executable examples for the two main usage shapes: + +- [Multiple brokers and explicit routes](../examples/router/multiple_brokers.py) + shows one task routed through several brokers with explicit subscription + ownership. +- [Shared task package](../examples/router/shared_task_package.py) shows + `task_builder(...)`, a custom `base_cls`, later registration and a prepared + queued invocation. diff --git a/docs/guide/scheduling-tasks.md b/docs/guide/scheduling-tasks.md index 44e646f2..ab6ce848 100644 --- a/docs/guide/scheduling-tasks.md +++ b/docs/guide/scheduling-tasks.md @@ -159,13 +159,26 @@ If you want to pass additional labels, you can call these methods on the `Kicker ) ``` -::: warning Cool warning! +::: warning Routing warning -The `with_broker` method won't do anything in this case, since we have a broker assigned to each scheduler. +`with_broker`, `with_route` and `with_flow` are immediate invocation settings. +They are not stored in the `ScheduledTask` payload. Scheduled dispatch resolves +the route later in the scheduler process through the scheduler broker's router. +Use router configuration available to the scheduler process when a scheduled +task must go to a specific broker or flow. ::: +The scheduler keeps route data out of schedule sources on purpose. A stored +`ScheduledTask` remains a transport-neutral payload with task name, labels, +args, kwargs and schedule timing fields. If you change a router rule before a +schedule becomes ready, the scheduler uses the updated route at send time. If no +route exists, the scheduler sends through its own broker as before. + Each of these methods return you an instance of the `CreatedSchedule` class. This object has unique schedule ID and some helper methods. For example, you can use the `unschedule` method to remove the schedule from the source. +Calling `kiq()` on a `CreatedSchedule` is an immediate queued invocation helper; +it uses the original kicker settings for that immediate send and does not change +the stored schedule payload. ```python schedule = await my_task.schedule_by_time( diff --git a/taskiq/__init__.py b/taskiq/__init__.py index 2414754f..6f328069 100644 --- a/taskiq/__init__.py +++ b/taskiq/__init__.py @@ -24,7 +24,9 @@ TaskiqError, TaskiqResultTimeoutError, ) +from taskiq.flow import Flow, FlowIdentity, FlowProtocol, get_flow_identity from taskiq.funcs import gather +from taskiq.kicker import PreparedKiq from taskiq.message import BrokerMessage, TaskiqMessage from taskiq.middlewares import ( PrometheusMiddleware, @@ -32,10 +34,12 @@ SmartRetryMiddleware, ) from taskiq.result import TaskiqResult +from taskiq.router import TaskiqRoute, TaskiqRouter, TaskiqSubscription from taskiq.scheduler.scheduled_task import ScheduledTask from taskiq.scheduler.scheduler import TaskiqScheduler from taskiq.state import TaskiqState from taskiq.task import AsyncTaskiqTask +from taskiq.task_builder import TaskDefinition, task_builder __version__ = version("taskiq") @@ -47,8 +51,12 @@ "AsyncTaskiqTask", "BrokerMessage", "Context", + "Flow", + "FlowIdentity", + "FlowProtocol", "InMemoryBroker", "NoResultError", + "PreparedKiq", "PrometheusMiddleware", "ResultGetError", "ResultIsReadyError", @@ -58,6 +66,7 @@ "SendTaskError", "SimpleRetryMiddleware", "SmartRetryMiddleware", + "TaskDefinition", "TaskiqDepends", "TaskiqError", "TaskiqEvents", @@ -66,10 +75,15 @@ "TaskiqMiddleware", "TaskiqResult", "TaskiqResultTimeoutError", + "TaskiqRoute", + "TaskiqRouter", "TaskiqScheduler", "TaskiqState", + "TaskiqSubscription", "ZeroMQBroker", "__version__", "async_shared_broker", "gather", + "get_flow_identity", + "task_builder", ] diff --git a/taskiq/abc/broker.py b/taskiq/abc/broker.py index ea2e86c0..06926827 100644 --- a/taskiq/abc/broker.py +++ b/taskiq/abc/broker.py @@ -24,11 +24,14 @@ from taskiq.decor import AsyncTaskiqDecoratedTask from taskiq.events import TaskiqEvents from taskiq.exceptions import TaskBrokerMismatchError +from taskiq.flow import FlowProtocol from taskiq.formatters.proxy_formatter import ProxyFormatter from taskiq.message import BrokerMessage from taskiq.result_backends.dummy import DummyResultBackend +from taskiq.router import TaskiqRouter from taskiq.serializers.json_serializer import JSONSerializer from taskiq.state import TaskiqState +from taskiq.task_builder import TaskDefinition from taskiq.utils import maybe_awaitable from taskiq.warnings import TaskiqDeprecationWarning @@ -78,6 +81,10 @@ def __init__( self, result_backend: "AsyncResultBackend[_T] | None" = None, task_id_generator: Callable[[], str] | None = None, + *, + router: TaskiqRouter | None = None, + broker_name: str | None = None, + default_flow: FlowProtocol | None = None, ) -> None: if result_backend is None: result_backend = DummyResultBackend() @@ -103,6 +110,13 @@ def __init__( self.serializer: TaskiqSerializer = JSONSerializer() self.formatter: TaskiqFormatter = ProxyFormatter(self) self.id_generator = task_id_generator + self.router = router or TaskiqRouter() + self.default_flow = default_flow + self.broker_name = self.router.set_broker( + self, + name=broker_name, + default_flow=default_flow, + ) self.local_task_registry: dict[str, AsyncTaskiqDecoratedTask[Any, Any]] = {} # Every event has a list of handlers. # Every handler is a function which takes state as a first argument. @@ -128,23 +142,29 @@ def find_task(self, task_name: str) -> AsyncTaskiqDecoratedTask[Any, Any] | None It searches task by name in dict of tasks that were registered for this broker directly. - If it fails, it checks global dict of all available tasks. + If it fails, it checks tasks registered in the broker's router, + then global dict of all available tasks. :param task_name: name of a task. :returns: found task or None. """ - return self.local_task_registry.get( - task_name, - ) or self.global_task_registry.get( - task_name, + return ( + self.local_task_registry.get( + task_name, + ) + or self.router.find_task(task_name) + or self.global_task_registry.get( + task_name, + ) ) def get_all_tasks(self) -> dict[str, AsyncTaskiqDecoratedTask[Any, Any]]: """ Method to fetch all tasks available in broker. - This method returns all tasks, globally and locally - available in broker. With local tasks having higher priority. + This method returns all tasks globally, through the broker's router + and locally available in broker. With local tasks having higher + priority. So, if you have two tasks with the same name, one registered in global registry and one registered @@ -152,7 +172,21 @@ def get_all_tasks(self) -> dict[str, AsyncTaskiqDecoratedTask[Any, Any]]: :return: dict of all tasks. Keys are task names, values are tasks. """ - return {**self.global_task_registry, **self.local_task_registry} + return { + **self.global_task_registry, + **self.router.get_all_tasks(), + **self.local_task_registry, + } + + def get_subscribed_flows(self) -> tuple[FlowProtocol, ...]: + """ + Return flows this broker should subscribe to. + + Existing brokers can keep their current `listen` implementation. New + flow-aware brokers may use this method to configure queue, topic, + stream or subject subscriptions from router-owned rules. + """ + return self.router.get_broker_flows(self) def add_dependency_context(self, new_ctx: dict[Any, Any]) -> None: """ @@ -237,6 +271,23 @@ async def kick( :param message: name of a task. """ + async def kick_to_flow( + self, + message: BrokerMessage, + flow: FlowProtocol | None = None, + ) -> None: + """ + Send message to a flow-aware broker. + + Existing brokers can keep implementing only `kick`. New brokers may + override this method and use `flow` to route to a concrete queue, topic, + stream or any other transport address. + + :param message: message to send. + :param flow: optional transport-neutral flow. + """ + await self.kick(message) + @abstractmethod def listen(self) -> AsyncGenerator[bytes | AckableMessage, None]: """ @@ -312,40 +363,12 @@ def make_decorated_task( def inner( func: Callable[_FuncParams, _ReturnType], ) -> AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]: - nonlocal inner_task_name - if inner_task_name is None: - fmodule = func.__module__ - if fmodule == "__main__": # pragma: no cover - fmodule = ".".join( - os.path.normpath(sys.argv[0]) - .removesuffix(".py") - .split(os.path.sep), - ) - fname = func.__name__ - if fname == "": - fname = f"lambda_{uuid4().hex}" - inner_task_name = f"{fmodule}:{fname}" - wrapper = wraps(func) - - sign = get_type_hints(func) - return_type = None - if "return" in sign: - return_type = sign["return"] - - decorated_task = wrapper( - self.decorator_class( - broker=self, - original_func=func, - labels=inner_labels, - task_name=inner_task_name, - return_type=return_type, # type: ignore - ), + return self._decorate_task( + func, + task_name=inner_task_name, + labels=inner_labels, ) - self._register_task(decorated_task.task_name, decorated_task) # type: ignore - - return decorated_task # type: ignore - return inner if callable(task_name): @@ -362,7 +385,10 @@ def inner( def register_task( self, - func: Callable[_FuncParams, _ReturnType], + func: ( + Callable[_FuncParams, _ReturnType] + | TaskDefinition[_FuncParams, _ReturnType] + ), task_name: str | None = None, **labels: Any, ) -> AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]: @@ -380,7 +406,101 @@ def register_task( :returns: registered task. """ - return self.task(task_name=task_name, **labels)(func) + if isinstance(func, TaskDefinition): + if task_name is not None or labels: + raise ValueError( + "TaskDefinition already defines task_name and labels. " + "Register it without task_name or label overrides.", + ) + return self.router.register_task( + func, + broker=self, + flow=self.default_flow, + ) + return self._decorate_task( + func, + task_name=task_name, + labels=labels or {}, + ) + + def bind_task_definition( + self, + task: TaskDefinition[_FuncParams, _ReturnType], + *, + register: bool = True, + ) -> AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]: + """Bind an unbound task definition to this broker.""" + task_cls = AsyncTaskiqDecoratedTask if task.base_cls is None else task.base_cls + return self._decorate_task( + task.original_func, + task_name=task.task_name, + labels=dict(task.labels), + base_cls=task_cls, + register=register, + ) + + def store_registered_task( + self, + task: AsyncTaskiqDecoratedTask[Any, Any], + ) -> None: + """ + Store a task after router-owned registration accepted it. + + Application code should use `task` or `register_task`. This method is + the broker/router integration boundary for cases where the router owns + registration ordering and the broker owns local task storage. + """ + self._store_task(task.task_name, task) + + def _decorate_task( + self, + func: Callable[_FuncParams, _ReturnType], + task_name: str | None, + labels: dict[str, Any], + base_cls: type[AsyncTaskiqDecoratedTask[Any, Any]] | None = None, + *, + register: bool = True, + ) -> AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]: + """Build and register a decorated task object.""" + real_task_name = task_name + if real_task_name is None: + fmodule = func.__module__ + if fmodule == "__main__": # pragma: no cover + fmodule = ".".join( + os.path.normpath(sys.argv[0]) + .removesuffix(".py") + .split(os.path.sep), + ) + fname = func.__name__ + if fname == "": + fname = f"lambda_{uuid4().hex}" + real_task_name = f"{fmodule}:{fname}" + + task_cls = self.decorator_class if base_cls is None else base_cls + if not isinstance(task_cls, type) or not issubclass( + task_cls, + AsyncTaskiqDecoratedTask, + ): + raise TypeError("base_cls must be a subclass of AsyncTaskiqDecoratedTask.") + + sign = get_type_hints(func) + return_type = None + if "return" in sign: + return_type = sign["return"] + + decorated_task = wraps(func)( + task_cls( + broker=self, + original_func=func, + labels=labels, + task_name=real_task_name, + return_type=return_type, # type: ignore + ), + ) + + if register: + self._register_task(decorated_task.task_name, decorated_task) # type: ignore + return decorated_task # type: ignore def on_event(self, *events: TaskiqEvents) -> Callable[[EventHandler], EventHandler]: """ @@ -530,6 +650,27 @@ def _register_task( :param task_name: Name of a task. :param task: Decorated task. """ + if task.broker != self: + raise TaskBrokerMismatchError(broker=task.broker) + self.router.register_task( + task, + broker=self, + flow=self.default_flow, + ) + self._store_task(task_name, task) + + def _store_task( + self, + task_name: str, + task: AsyncTaskiqDecoratedTask[Any, Any], + ) -> None: + """ + Store a decorated task in this broker. + + Router-managed binding uses this method after router state has accepted + the task, so shared task registration does not re-enter router + registration while it is already in progress. + """ if task.broker != self: raise TaskBrokerMismatchError(broker=task.broker) self.local_task_registry[task_name] = task diff --git a/taskiq/brokers/inmemory_broker.py b/taskiq/brokers/inmemory_broker.py index 0a7cc98e..e63a97f6 100644 --- a/taskiq/brokers/inmemory_broker.py +++ b/taskiq/brokers/inmemory_broker.py @@ -9,8 +9,10 @@ from taskiq.depends.progress_tracker import TaskProgress from taskiq.events import TaskiqEvents from taskiq.exceptions import UnknownTaskError +from taskiq.flow import FlowProtocol from taskiq.message import BrokerMessage from taskiq.receiver import Receiver +from taskiq.router import TaskiqRouter from taskiq.utils import maybe_awaitable _ReturnType = TypeVar("_ReturnType") @@ -130,8 +132,16 @@ def __init__( max_async_tasks_jitter: int = 0, propagate_exceptions: bool = True, await_inplace: bool = False, + *, + router: TaskiqRouter | None = None, + broker_name: str | None = None, + default_flow: FlowProtocol | None = None, ) -> None: - super().__init__() + super().__init__( + router=router, + broker_name=broker_name, + default_flow=default_flow, + ) self.result_backend: InmemoryResultBackend[Any] = InmemoryResultBackend( max_stored_results=max_stored_results, ) diff --git a/taskiq/context.py b/taskiq/context.py index f9f1d0ee..d29b8cb3 100644 --- a/taskiq/context.py +++ b/taskiq/context.py @@ -30,7 +30,7 @@ async def requeue(self) -> None: requeue_count = int(self.message.labels.get("X-Taskiq-requeue", 0)) requeue_count += 1 self.message.labels["X-Taskiq-requeue"] = str(requeue_count) - await self.broker.kick(self.broker.formatter.dumps(self.message)) + await self.broker.router.requeue(self.message, broker=self.broker) raise NoResultError def reject(self) -> None: diff --git a/taskiq/flow.py b/taskiq/flow.py new file mode 100644 index 00000000..1e5c06b1 --- /dev/null +++ b/taskiq/flow.py @@ -0,0 +1,73 @@ +from collections.abc import Mapping +from dataclasses import dataclass, field, replace +from types import MappingProxyType +from typing import Protocol, runtime_checkable + +__all__ = ("Flow", "FlowIdentity", "FlowProtocol", "get_flow_identity") + + +@dataclass(frozen=True, slots=True) +class FlowIdentity: + """Stable identity for one logical flow in routing and listen plans.""" + + name: str + + +@runtime_checkable +class FlowProtocol(Protocol): + """Transport-neutral flow contract accepted by routers and brokers.""" + + @property + def name(self) -> str: + """Return transport-neutral flow name.""" + ... + + def broker_options(self) -> Mapping[str, object]: + """Return options relevant for a concrete broker implementation.""" + ... + + +@dataclass(frozen=True, slots=True) +class Flow: + """Generic transport-neutral flow address. + + Broker packages can provide their own flow objects that implement + FlowProtocol, for example RabbitQueue, KafkaTopic, NatsSubject or RedisQueue. + This generic value object is intentionally small and works as a common + fallback for brokers that only need a named address with optional metadata. + """ + + name: str + options: Mapping[str, object] = field( + default_factory=dict, + compare=False, + hash=False, + ) + + def __post_init__(self) -> None: + """Freeze options to keep Flow a stable value object.""" + object.__setattr__(self, "options", MappingProxyType(dict(self.options))) + + def with_options(self, **options: object) -> "Flow": + """Return the same flow with additional generic options.""" + return replace(self, options={**self.options, **options}) + + @property + def identity(self) -> FlowIdentity: + """Return routing and subscription identity for this flow.""" + return get_flow_identity(self) + + def broker_options(self) -> Mapping[str, object]: + """Return transport options for broker-specific implementations.""" + return dict(self.options) + + +def get_flow_identity(flow: FlowProtocol) -> FlowIdentity: + """ + Return stable identity for a flow. + + Flow identity is intentionally transport-neutral and based on the logical + flow name. Broker-specific declaration options are validated separately by + the subscription plan. + """ + return FlowIdentity(name=flow.name) diff --git a/taskiq/kicker.py b/taskiq/kicker.py index dc113a7e..8de79f90 100644 --- a/taskiq/kicker.py +++ b/taskiq/kicker.py @@ -1,5 +1,7 @@ +from __future__ import annotations + from collections.abc import Coroutine -from dataclasses import asdict, is_dataclass +from dataclasses import replace from datetime import datetime, timedelta from logging import getLogger from types import CoroutineType @@ -9,17 +11,16 @@ Generic, ParamSpec, TypeVar, - Union, + cast, overload, ) -from pydantic import BaseModel - from taskiq.abc.middleware import TaskiqMiddleware -from taskiq.compat import model_dump +from taskiq.abc.result_backend import AsyncResultBackend from taskiq.exceptions import SendTaskError -from taskiq.labels import prepare_label -from taskiq.message import TaskiqMessage +from taskiq.flow import FlowProtocol +from taskiq.message import TaskiqMessage, _build_taskiq_message +from taskiq.router import TaskiqRoute, TaskiqRouter from taskiq.scheduler.created_schedule import CreatedSchedule from taskiq.scheduler.scheduled_task import CronSpec, ScheduledTask from taskiq.task import AsyncTaskiqTask @@ -36,13 +37,41 @@ logger = getLogger("taskiq") +class PreparedKiq(Generic[_ReturnType]): + """Prepared task invocation that can be sent later.""" + + def __init__( + self, + kicker: AsyncKicker[..., _ReturnType], + message: TaskiqMessage, + broker: AsyncBroker, + route: TaskiqRoute | None, + flow: FlowProtocol | None, + ) -> None: + self.kicker = kicker + self.message = message + self.broker = broker + self.route = route + self.flow = flow + + async def kiq(self) -> AsyncTaskiqTask[_ReturnType]: + """Send prepared invocation.""" + return await self.kicker.kiq_message( + self.message, + broker=self.broker, + route=self.route, + flow=self.flow, + use_current_route=False, + ) + + class AsyncKicker(Generic[_FuncParams, _ReturnType]): """Class that used to modify data before sending it to broker.""" def __init__( self, task_name: str, - broker: "AsyncBroker", + broker: AsyncBroker, labels: dict[str, Any], return_type: type[_ReturnType] | None = None, ) -> None: @@ -52,11 +81,14 @@ def __init__( self.custom_task_id: str | None = None self.custom_schedule_id: str | None = None self.return_type = return_type + self.route: TaskiqRoute | None = None + self.route_flow: FlowProtocol | None = None + self._broker_overridden = False def with_labels( self, **labels: str | float, - ) -> "AsyncKicker[_FuncParams, _ReturnType]": + ) -> AsyncKicker[_FuncParams, _ReturnType]: """ Update function's labels before sending. @@ -69,7 +101,7 @@ def with_labels( def with_task_id( self, task_id: str | None, - ) -> "AsyncKicker[_FuncParams, _ReturnType]": + ) -> AsyncKicker[_FuncParams, _ReturnType]: """ Set task_id for current execution. @@ -85,7 +117,7 @@ def with_task_id( def with_schedule_id( self, schedule_id: str, - ) -> "AsyncKicker[_FuncParams, _ReturnType]": + ) -> AsyncKicker[_FuncParams, _ReturnType]: """ Set schedule_id for current execution. @@ -97,8 +129,8 @@ def with_schedule_id( def with_broker( self, - broker: "AsyncBroker", - ) -> "AsyncKicker[_FuncParams, _ReturnType]": + broker: AsyncBroker, + ) -> AsyncKicker[_FuncParams, _ReturnType]: """ Replace broker for the function. @@ -109,11 +141,64 @@ def with_broker( :return: Kicker with new broker. """ self.broker = broker + self.route = None + self.route_flow = None + self._broker_overridden = True + return self + + def with_flow( + self, + flow: FlowProtocol | None, + ) -> AsyncKicker[_FuncParams, _ReturnType]: + """ + Replace flow for the current invocation. + + :param flow: flow to send message to. + :return: Kicker with a route flow override. + """ + if self.route is not None: + self.route = replace(self.route, flow=flow) + self.route_flow = flow + return self + + def with_route( + self, + route: TaskiqRoute, + ) -> AsyncKicker[_FuncParams, _ReturnType]: + """ + Replace route for the current invocation. + + :param route: route to send message through. + :return: Kicker with a route override. + """ + self.route = route + self.route_flow = route.flow return self + def prepare( + self, + *args: _FuncParams.args, + **kwargs: _FuncParams.kwargs, + ) -> PreparedKiq[_ReturnType]: + """ + Prepare a task invocation without sending it. + + :param args: function's arguments. + :param kwargs: function's key word arguments. + :return: prepared task invocation. + """ + broker, route, flow = self._prepare_route_snapshot() + return PreparedKiq( + self, + self._prepare_message(*args, **kwargs), + broker=broker, + route=route, + flow=flow, + ) + @overload async def kiq( - self: "AsyncKicker[_FuncParams, CoroutineType[Any, Any, _T]]", + self: AsyncKicker[_FuncParams, CoroutineType[Any, Any, _T]], *args: _FuncParams.args, **kwargs: _FuncParams.kwargs, ) -> AsyncTaskiqTask[_T]: # pragma: no cover @@ -121,7 +206,7 @@ async def kiq( @overload async def kiq( - self: "AsyncKicker[_FuncParams, Coroutine[Any, Any, _T]]", + self: AsyncKicker[_FuncParams, Coroutine[Any, Any, _T]], *args: _FuncParams.args, **kwargs: _FuncParams.kwargs, ) -> AsyncTaskiqTask[_T]: # pragma: no cover @@ -129,7 +214,7 @@ async def kiq( @overload async def kiq( - self: "AsyncKicker[_FuncParams, _ReturnType]", + self: AsyncKicker[_FuncParams, _ReturnType], *args: _FuncParams.args, **kwargs: _FuncParams.kwargs, ) -> AsyncTaskiqTask[_ReturnType]: # pragma: no cover @@ -156,29 +241,113 @@ async def kiq( logger.debug( f"Kicking {self.task_name} with args={args} and kwargs={kwargs}.", ) - message = self._prepare_message(*args, **kwargs) - for middleware in self.broker.middlewares: - if middleware.__class__.pre_send != TaskiqMiddleware.pre_send: - message = await maybe_awaitable(middleware.pre_send(message)) + return await self.kiq_message(self._prepare_message(*args, **kwargs)) + + async def kiq_message( + self, + message: TaskiqMessage, + *, + broker: AsyncBroker | None = None, + route: TaskiqRoute | None = None, + flow: FlowProtocol | None = None, + use_current_route: bool = True, + ) -> AsyncTaskiqTask[_ReturnType]: + """Send a prepared message.""" try: - await self.broker.kick(self.broker.formatter.dumps(message)) + target_broker = broker or self.broker + if use_current_route: + target_route = self.route if route is None else route + target_flow = self.route_flow if flow is None else flow + else: + target_route = route + target_flow = flow + if target_route is not None: + if broker is not None and broker is not target_route.broker: + raise ValueError("Pass either route or broker override.") + target_broker = target_route.broker + target_flow = None + router = getattr(target_broker, "router", None) + if isinstance(router, TaskiqRouter): + broker_override = None + if target_route is None and ( + broker is not None or self._broker_overridden + ): + broker_override = target_broker + return await router.kiq( + message, + route=target_route, + broker=broker_override, + flow=target_flow, + return_type=self.return_type, + ) + return await self._legacy_kiq(target_broker, message) except Exception as exc: raise SendTaskError from exc - for middleware in reversed(self.broker.middlewares): + async def _legacy_kiq( + self, + broker: AsyncBroker, + message: TaskiqMessage, + ) -> AsyncTaskiqTask[_ReturnType]: + """ + Send message through the pre-router broker path. + + This keeps middleware tests and external broker-like mocks compatible + while real AsyncBroker instances use TaskiqRouter. + """ + middlewares = getattr(broker, "middlewares", []) + if not isinstance(middlewares, list): + middlewares = [] + + for middleware in middlewares: + if middleware.__class__.pre_send != TaskiqMiddleware.pre_send: + message = await maybe_awaitable(middleware.pre_send(message)) + + await broker.kick(broker.formatter.dumps(message)) + + for middleware in reversed(middlewares): if middleware.__class__.post_send != TaskiqMiddleware.post_send: await maybe_awaitable(middleware.post_send(message)) return AsyncTaskiqTask( task_id=message.task_id, - result_backend=self.broker.result_backend, - return_type=self.return_type, # type: ignore # (pyright issue) + result_backend=cast( + AsyncResultBackend[_ReturnType], + broker.result_backend, + ), + return_type=self.return_type, + ) + + def _prepare_route_snapshot( + self, + ) -> tuple[AsyncBroker, TaskiqRoute | None, FlowProtocol | None]: + """ + Resolve route state that a prepared invocation must keep. + + The returned tuple is `(broker, route, explicit_flow_override)`. + When a route is present, the route already carries its flow, so the + separate flow override is `None`. A separate flow is returned only for + legacy broker paths that cannot snapshot a `TaskiqRoute`. + """ + router = getattr(self.broker, "router", None) + if not isinstance(router, TaskiqRouter): + return self.broker, None, self.route_flow + + if self.route is not None: + return self.route.broker, self.route, None + + broker_override = self.broker if self._broker_overridden else None + route = router.resolve_route( + self.task_name, + broker=broker_override, + flow=self.route_flow, ) + return route.broker, route, None async def schedule_by_cron( self, - source: "ScheduleSource", - cron: Union[str, "CronSpec"], + source: ScheduleSource, + cron: str | CronSpec, *args: _FuncParams.args, **kwargs: _FuncParams.kwargs, ) -> CreatedSchedule[_ReturnType]: @@ -217,7 +386,7 @@ async def schedule_by_cron( async def schedule_by_interval( self, - source: "ScheduleSource", + source: ScheduleSource, interval: int | timedelta, *args: _FuncParams.args, **kwargs: _FuncParams.kwargs, @@ -249,7 +418,7 @@ async def schedule_by_interval( async def schedule_by_time( self, - source: "ScheduleSource", + source: ScheduleSource, time: datetime, *args: _FuncParams.args, **kwargs: _FuncParams.kwargs, @@ -278,27 +447,6 @@ async def schedule_by_time( await source.add_schedule(scheduled) return CreatedSchedule(self, source, scheduled) - @classmethod - def _prepare_arg(cls, arg: Any) -> Any: - """ - Parses argument if possible. - - This function is used to construct dicts - from pydantic models or dataclasses. - - :param arg: argument to format. - :return: Formatted argument. - """ - if isinstance(arg, BaseModel): - arg = model_dump(arg) - if is_dataclass(arg): - if isinstance(arg, type): - raise ValueError( - f"Cannot serialize types. The {arg} is not serializable.", - ) - arg = asdict(arg) - return arg - def _prepare_message( self, *args: Any, @@ -311,27 +459,14 @@ def _prepare_message( :param kwargs: function's kwargs. :return: constructed message. """ - formatted_args = [] - formatted_kwargs = {} - labels = {} - labels_types = {} - for arg in args: - formatted_args.append(self._prepare_arg(arg)) - for kwarg_name, kwarg_val in kwargs.items(): - formatted_kwargs[kwarg_name] = self._prepare_arg(kwarg_val) - - for label, label_val in self.labels.items(): - labels[label], labels_types[label] = prepare_label(label_val) - task_id = self.custom_task_id if task_id is None: task_id = self.broker.id_generator() - return TaskiqMessage( + return _build_taskiq_message( task_id=task_id, task_name=self.task_name, - labels=labels, - labels_types=labels_types, - args=formatted_args, - kwargs=formatted_kwargs, + labels=self.labels, + args=args, + kwargs=kwargs, ) diff --git a/taskiq/message.py b/taskiq/message.py index 39008eb8..00ed4791 100644 --- a/taskiq/message.py +++ b/taskiq/message.py @@ -1,8 +1,11 @@ +from collections.abc import Iterable, Mapping +from dataclasses import asdict, is_dataclass from typing import Any from pydantic import BaseModel -from taskiq.labels import parse_label +from taskiq.compat import model_dump +from taskiq.labels import parse_label, prepare_label class TaskiqMessage(BaseModel): @@ -35,6 +38,61 @@ def parse_labels(self) -> None: self.labels[label] = parse_label(self.labels[label], label_type) +def _prepare_message_arg(arg: Any) -> Any: + """ + Prepare invocation argument for message payloads. + + :param arg: argument to prepare. + :return: serializable argument representation. + """ + if isinstance(arg, BaseModel): + arg = model_dump(arg) + if is_dataclass(arg): + if isinstance(arg, type): + raise ValueError( + f"Cannot serialize types. The {arg} is not serializable.", + ) + arg = asdict(arg) + return arg + + +def _build_taskiq_message( + task_id: str, + task_name: str, + labels: Mapping[str, Any], + args: Iterable[Any], + kwargs: Mapping[str, Any], +) -> TaskiqMessage: + """ + Build a taskiq message using the common invocation contract. + + :param task_id: task id. + :param task_name: task name. + :param labels: task invocation labels. + :param args: task positional arguments. + :param kwargs: task keyword arguments. + :return: prepared taskiq message. + """ + prepared_labels: dict[str, Any] = {} + labels_types: dict[str, int] = {} + for label_name, label_value in labels.items(): + prepared_labels[label_name], labels_types[label_name] = prepare_label( + label_value, + ) + + return TaskiqMessage( + task_id=task_id, + task_name=task_name, + labels=prepared_labels, + labels_types=labels_types, + args=[_prepare_message_arg(arg) for arg in args], + kwargs={ + kwarg_name: _prepare_message_arg(kwarg_value) + for kwarg_name, kwarg_value in kwargs.items() + }, + ) + + class BrokerMessage(BaseModel): """Format of messages for brokers.""" diff --git a/taskiq/router.py b/taskiq/router.py new file mode 100644 index 00000000..8b526066 --- /dev/null +++ b/taskiq/router.py @@ -0,0 +1,296 @@ +from __future__ import annotations + +import warnings +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, overload + +from taskiq.flow import FlowProtocol +from taskiq.message import TaskiqMessage +from taskiq.routing import ( + BrokerRegistry, + RouterDispatcher, + RouteRegistry, + SubscriptionPlan, + TaskiqRoute, + TaskiqSubscription, + TaskRegistry, +) +from taskiq.routing.references import resolve_task_name +from taskiq.task import AsyncTaskiqTask +from taskiq.task_builder import TaskDefinition +from taskiq.warnings import TaskiqDeprecationWarning + +if TYPE_CHECKING: # pragma: no cover + from taskiq.abc.broker import AsyncBroker + from taskiq.decor import AsyncTaskiqDecoratedTask + +__all__ = ("TaskiqRoute", "TaskiqRouter", "TaskiqSubscription") + +TaskiqRoute.__module__ = __name__ +TaskiqSubscription.__module__ = __name__ + +_FuncParams = ParamSpec("_FuncParams") +_ReturnType = TypeVar("_ReturnType") + + +class TaskiqRouter: + """Facade for task registry, routing policy, subscriptions and dispatch.""" + + def __init__(self) -> None: + self._brokers = BrokerRegistry() + self._tasks = TaskRegistry() + self._routes = RouteRegistry(self._brokers) + self._subscriptions = SubscriptionPlan(self._brokers) + self._dispatcher = RouterDispatcher(self._routes) + + @property + def brokers(self) -> dict[str, AsyncBroker]: + """Return mutable broker registry for compatibility.""" + return self._brokers.brokers + + @property + def default_broker(self) -> AsyncBroker | None: + """Return default broker for compatibility.""" + return self._brokers.default_broker + + @default_broker.setter + def default_broker(self, broker: AsyncBroker | None) -> None: + if broker is not None: + self._brokers.resolve(broker) + self._brokers.default_broker = broker + + @property + def task_registry(self) -> dict[str, AsyncTaskiqDecoratedTask[Any, Any]]: + """Return mutable task registry for compatibility.""" + return self._tasks.tasks + + @property + def routes(self) -> dict[str, TaskiqRoute]: + """Return mutable route registry for compatibility.""" + return self._routes.routes + + @property + def subscriptions(self) -> list[TaskiqSubscription]: + """Return mutable subscription registry for compatibility.""" + return self._subscriptions.subscriptions + + @property + def default_broker_name(self) -> str | None: + """Return default broker name for compatibility and diagnostics.""" + return self._brokers.default_broker_name + + def set_broker( + self, + broker: AsyncBroker, + name: str | None = None, + default_flow: FlowProtocol | None = None, + ) -> str: + """Register broker as a transport in this router.""" + if getattr(broker, "router", self) is not self: + raise ValueError( + "Broker is attached to another router. " + "Pass router=... when creating the broker.", + ) + broker_name = self._brokers.register(broker, name=name) + if default_flow is not None: + broker.default_flow = default_flow + return broker_name + + def get_broker(self, name: str) -> AsyncBroker: + """Return a broker by registered name.""" + return self._brokers.get(name) + + def find_task( + self, + task_name: str, + ) -> AsyncTaskiqDecoratedTask[Any, Any] | None: + """Find a task by name.""" + return self._tasks.find(task_name) + + def get_all_tasks(self) -> dict[str, AsyncTaskiqDecoratedTask[Any, Any]]: + """Return all tasks registered in this router.""" + return self._tasks.get_all() + + def register_task( + self, + task: ( + AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType] + | TaskDefinition[_FuncParams, _ReturnType] + ), + broker: AsyncBroker | None = None, + flow: FlowProtocol | None = None, + ) -> AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]: + """Register a bound task or bind a task definition to a broker.""" + if isinstance(task, TaskDefinition): + target_broker = self._brokers.resolve(broker) + registered_task = target_broker.bind_task_definition( + task, + register=False, + ) + self._register_bound_task( + registered_task, + broker=target_broker, + flow=flow, + ) + target_broker.store_registered_task(registered_task) + return registered_task + + return self._register_bound_task(task, broker=broker, flow=flow) + + @overload + def task( + self, + task_name: Callable[_FuncParams, _ReturnType], + *, + broker: AsyncBroker | None = None, + flow: FlowProtocol | None = None, + **labels: Any, + ) -> AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]: ... + + @overload + def task( + self, + task_name: str | None = None, + *, + broker: AsyncBroker | None = None, + flow: FlowProtocol | None = None, + **labels: Any, + ) -> Callable[ + [Callable[_FuncParams, _ReturnType]], + AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType], + ]: ... + + def task( + self, + task_name: str | Callable[_FuncParams, _ReturnType] | None = None, + *, + broker: AsyncBroker | None = None, + flow: FlowProtocol | None = None, + **labels: Any, + ) -> Any: + """Decorate and register a task through this router.""" + + def register( + func: Callable[_FuncParams, _ReturnType], + ) -> AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]: + target_broker = self._brokers.resolve(broker) + real_task_name: str | None = None if callable(task_name) else task_name + task = target_broker.task(task_name=real_task_name, **labels)(func) + if flow is not None: + self.route_task(task, broker=target_broker, flow=flow) + return task + + if callable(task_name): + function = task_name + return register(function) + + return register + + def route_task( + self, + task: str | AsyncTaskiqDecoratedTask[Any, Any], + broker: AsyncBroker | None = None, + flow: FlowProtocol | None = None, + *, + subscribe: bool = False, + ) -> TaskiqRoute: + """Set default outbound route for a task.""" + task_name = resolve_task_name(task) + route = self._routes.set_route(task_name, broker=broker, flow=flow) + if subscribe: + warnings.warn( + "`route_task(..., subscribe=True)` is deprecated. " + "Use `router.subscribe(...)` to add inbound flow subscriptions.", + TaskiqDeprecationWarning, + stacklevel=2, + ) + if subscribe and route.flow is not None: + self.subscribe(route.broker, route.flow, task_name) + return route + + def resolve_route( + self, + task: str | AsyncTaskiqDecoratedTask[Any, Any], + broker: AsyncBroker | None = None, + flow: FlowProtocol | None = None, + ) -> TaskiqRoute: + """Resolve outbound route for a task invocation.""" + return self._routes.resolve_route(task, broker=broker, flow=flow) + + def subscribe( + self, + broker: AsyncBroker, + flow: FlowProtocol, + *tasks: str | AsyncTaskiqDecoratedTask[Any, Any], + ) -> TaskiqSubscription: + """Register an inbound flow subscription for a broker.""" + task_names = tuple(resolve_task_name(task) for task in tasks) + return self._subscriptions.subscribe( + broker, + flow, + task_names, + ) + + def get_subscriptions( + self, + broker: AsyncBroker | None = None, + ) -> tuple[TaskiqSubscription, ...]: + """Return registered inbound subscriptions.""" + return self._subscriptions.get(broker) + + def get_broker_flows(self, broker: AsyncBroker) -> tuple[FlowProtocol, ...]: + """Return flows a broker should subscribe to.""" + return self._subscriptions.get_broker_flows(broker) + + async def kiq( + self, + message: TaskiqMessage, + *, + route: TaskiqRoute | None = None, + broker: AsyncBroker | None = None, + flow: FlowProtocol | None = None, + return_type: type[_ReturnType] | None = None, + ) -> AsyncTaskiqTask[_ReturnType]: + """Send message through the resolved broker and flow.""" + return await self._dispatcher.kiq( + message, + route=route, + broker=broker, + flow=flow, + return_type=return_type, + ) + + async def requeue( + self, + message: TaskiqMessage, + *, + route: TaskiqRoute | None = None, + broker: AsyncBroker | None = None, + flow: FlowProtocol | None = None, + ) -> None: + """Send an existing message again through the resolved route.""" + await self._dispatcher.requeue( + message, + route=route, + broker=broker, + flow=flow, + ) + + def _register_bound_task( + self, + task: AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType], + broker: AsyncBroker | None = None, + flow: FlowProtocol | None = None, + ) -> AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]: + route_broker = broker + if route_broker is None: + route_broker = getattr(task, "broker", None) + + route = None + if route_broker is not None or flow is not None: + route = self._routes.build_route(broker=route_broker, flow=flow) + + self._tasks.register(task) + if route is not None: + self._routes.set_resolved_route(task.task_name, route) + return task diff --git a/taskiq/routing/__init__.py b/taskiq/routing/__init__.py new file mode 100644 index 00000000..09dfe0dc --- /dev/null +++ b/taskiq/routing/__init__.py @@ -0,0 +1,17 @@ +"""Internal routing components used by the public router facade.""" + +from taskiq.routing.dispatcher import RouterDispatcher +from taskiq.routing.models import TaskiqRoute, TaskiqSubscription +from taskiq.routing.registries import BrokerRegistry, TaskRegistry +from taskiq.routing.routes import RouteRegistry +from taskiq.routing.subscriptions import SubscriptionPlan + +__all__ = ( + "BrokerRegistry", + "RouteRegistry", + "RouterDispatcher", + "SubscriptionPlan", + "TaskRegistry", + "TaskiqRoute", + "TaskiqSubscription", +) diff --git a/taskiq/routing/dispatcher.py b/taskiq/routing/dispatcher.py new file mode 100644 index 00000000..d5b1ad43 --- /dev/null +++ b/taskiq/routing/dispatcher.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, TypeVar, cast + +from taskiq.abc.middleware import TaskiqMiddleware +from taskiq.abc.result_backend import AsyncResultBackend +from taskiq.flow import FlowProtocol +from taskiq.message import TaskiqMessage +from taskiq.routing.models import TaskiqRoute +from taskiq.routing.routes import RouteRegistry +from taskiq.task import AsyncTaskiqTask +from taskiq.utils import maybe_awaitable + +if TYPE_CHECKING: # pragma: no cover + from taskiq.abc.broker import AsyncBroker + +__all__ = ("RouterDispatcher",) + +_ReturnType = TypeVar("_ReturnType") + + +class RouterDispatcher: + """Send task messages through resolved router routes.""" + + def __init__(self, routes: RouteRegistry) -> None: + self.routes = routes + + async def kiq( + self, + message: TaskiqMessage, + *, + route: TaskiqRoute | None = None, + broker: AsyncBroker | None = None, + flow: FlowProtocol | None = None, + return_type: type[_ReturnType] | None = None, + ) -> AsyncTaskiqTask[_ReturnType]: + """Send message through the resolved broker and flow.""" + target_route = self.routes.resolve_explicit_or_default_route( + message.task_name, + route=route, + broker=broker, + flow=flow, + ) + target_broker = target_route.broker + + for middleware in target_broker.middlewares: + if middleware.__class__.pre_send != TaskiqMiddleware.pre_send: + message = await maybe_awaitable(middleware.pre_send(message)) + broker_message = target_broker.formatter.dumps(message) + await target_broker.kick_to_flow(broker_message, target_route.flow) + + for middleware in reversed(target_broker.middlewares): + if middleware.__class__.post_send != TaskiqMiddleware.post_send: + await maybe_awaitable(middleware.post_send(message)) + + return AsyncTaskiqTask( + task_id=message.task_id, + result_backend=cast( + AsyncResultBackend[_ReturnType], + target_broker.result_backend, + ), + return_type=return_type, + ) + + async def requeue( + self, + message: TaskiqMessage, + *, + route: TaskiqRoute | None = None, + broker: AsyncBroker | None = None, + flow: FlowProtocol | None = None, + ) -> None: + """Send an existing message again through the resolved route.""" + target_route = self.routes.resolve_explicit_or_default_route( + message.task_name, + route=route, + broker=broker, + flow=flow, + ) + target_broker = target_route.broker + await target_broker.kick_to_flow( + target_broker.formatter.dumps(message), + target_route.flow, + ) diff --git a/taskiq/routing/models.py b/taskiq/routing/models.py new file mode 100644 index 00000000..2446aadd --- /dev/null +++ b/taskiq/routing/models.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +from taskiq.flow import FlowProtocol + +if TYPE_CHECKING: # pragma: no cover + from taskiq.abc.broker import AsyncBroker + +__all__ = ("TaskiqRoute", "TaskiqSubscription") + + +@dataclass(frozen=True, slots=True) +class TaskiqRoute: + """Resolved outbound route for a task invocation.""" + + broker: AsyncBroker + flow: FlowProtocol | None = None + + @property + def broker_name(self) -> str: + """Return registered broker name for diagnostics.""" + return self.broker.broker_name + + +@dataclass(frozen=True, slots=True) +class TaskiqSubscription: + """ + Inbound flow subscription owned by a router. + + `task_names` is diagnostic listen-plan metadata. It records which task + declarations caused a broker to listen to this flow; worker execution still + resolves inbound messages by `TaskiqMessage.task_name`. + """ + + broker: AsyncBroker + flow: FlowProtocol + task_names: frozenset[str] = field(default_factory=frozenset) + + @property + def broker_name(self) -> str: + """Return registered broker name for diagnostics.""" + return self.broker.broker_name diff --git a/taskiq/routing/references.py b/taskiq/routing/references.py new file mode 100644 index 00000000..f43270bc --- /dev/null +++ b/taskiq/routing/references.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: # pragma: no cover + from taskiq.decor import AsyncTaskiqDecoratedTask + +__all__ = ("resolve_task_name",) + + +def resolve_task_name(task: str | AsyncTaskiqDecoratedTask[Any, Any]) -> str: + """Resolve a task reference to a task name.""" + if isinstance(task, str): + return task + task_name = getattr(task, "task_name", None) + if isinstance(task_name, str): + return task_name + raise TypeError("Route task must be a task name or decorated task.") diff --git a/taskiq/routing/registries.py b/taskiq/routing/registries.py new file mode 100644 index 00000000..9226cdaa --- /dev/null +++ b/taskiq/routing/registries.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from taskiq.flow import FlowProtocol + +if TYPE_CHECKING: # pragma: no cover + from taskiq.abc.broker import AsyncBroker + from taskiq.decor import AsyncTaskiqDecoratedTask + +__all__ = ("BrokerRegistry", "TaskRegistry") + + +class BrokerRegistry: + """Registry of brokers attached to one router.""" + + def __init__(self) -> None: + self.brokers: dict[str, AsyncBroker] = {} + self.default_broker: AsyncBroker | None = None + + @property + def default_broker_name(self) -> str | None: + """Return default broker name for compatibility and diagnostics.""" + if self.default_broker is None: + return None + return self.default_broker.broker_name + + def register(self, broker: AsyncBroker, name: str | None = None) -> str: + """Register broker as a transport in this registry.""" + broker_name = name or broker.__class__.__name__ + registered = self.brokers.get(broker_name) + if registered is not None and registered is not broker: + raise ValueError( + f"Broker name {broker_name!r} is already registered. " + "Please provide an explicit unique broker_name.", + ) + for registered_name, registered_broker in self.brokers.items(): + if registered_broker is broker and registered_name != broker_name: + raise ValueError( + f"Broker is already registered as {registered_name!r}.", + ) + + self.brokers[broker_name] = broker + if self.default_broker is None: + self.default_broker = broker + return broker_name + + def get(self, name: str) -> AsyncBroker: + """Return a broker by registered name.""" + try: + return self.brokers[name] + except KeyError as exc: + raise ValueError(f"Unknown broker {name!r}.") from exc + + def resolve(self, broker: AsyncBroker | None) -> AsyncBroker: + """Resolve an explicit broker or return the default broker.""" + if isinstance(broker, str): + raise TypeError( + "Broker string references are not accepted here. " + "Use router.get_broker(name) and pass the broker object.", + ) + + if broker is not None: + for registered_broker in self.brokers.values(): + if registered_broker is broker: + return registered_broker + raise ValueError("Broker is not registered in this router.") + + if self.default_broker is None: + raise ValueError("Router doesn't have registered brokers.") + return self.default_broker + + def default_flow(self, broker: AsyncBroker) -> FlowProtocol | None: + """Return broker-owned default flow.""" + return getattr(broker, "default_flow", None) + + +class TaskRegistry: + """Registry of tasks attached to one router.""" + + def __init__(self) -> None: + self.tasks: dict[str, AsyncTaskiqDecoratedTask[Any, Any]] = {} + + def register( + self, + task: AsyncTaskiqDecoratedTask[Any, Any], + ) -> None: + """Register a decorated task and reject name conflicts.""" + existing_task = self.tasks.get(task.task_name) + if existing_task is not None and existing_task is not task: + raise ValueError( + f"Task name {task.task_name!r} is already registered " + "in this router.", + ) + self.tasks[task.task_name] = task + + def find( + self, + task_name: str, + ) -> AsyncTaskiqDecoratedTask[Any, Any] | None: + """Find a task by name.""" + return self.tasks.get(task_name) + + def get_all(self) -> dict[str, AsyncTaskiqDecoratedTask[Any, Any]]: + """Return all tasks registered in this registry.""" + return dict(self.tasks) diff --git a/taskiq/routing/routes.py b/taskiq/routing/routes.py new file mode 100644 index 00000000..96a866f8 --- /dev/null +++ b/taskiq/routing/routes.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +from dataclasses import replace +from typing import TYPE_CHECKING, Any + +from taskiq.flow import FlowProtocol +from taskiq.routing.models import TaskiqRoute +from taskiq.routing.references import resolve_task_name +from taskiq.routing.registries import BrokerRegistry + +if TYPE_CHECKING: # pragma: no cover + from taskiq.abc.broker import AsyncBroker + from taskiq.decor import AsyncTaskiqDecoratedTask + +__all__ = ("RouteRegistry",) + + +class RouteRegistry: + """Outbound route policy for task invocations.""" + + def __init__(self, brokers: BrokerRegistry) -> None: + self.brokers = brokers + self.routes: dict[str, TaskiqRoute] = {} + + def build_route( + self, + broker: AsyncBroker | None = None, + flow: FlowProtocol | None = None, + ) -> TaskiqRoute: + """Build a route without mutating route state.""" + target_broker = self.brokers.resolve(broker) + route_flow = flow + if route_flow is None: + route_flow = self.brokers.default_flow(target_broker) + return TaskiqRoute(broker=target_broker, flow=route_flow) + + def set_route( + self, + task: str | AsyncTaskiqDecoratedTask[Any, Any], + broker: AsyncBroker | None = None, + flow: FlowProtocol | None = None, + ) -> TaskiqRoute: + """Set default outbound route for a task.""" + task_name = resolve_task_name(task) + route = self.build_route(broker=broker, flow=flow) + self.routes[task_name] = route + return route + + def set_resolved_route(self, task_name: str, route: TaskiqRoute) -> TaskiqRoute: + """Store an already resolved route for a task.""" + self.routes[task_name] = route + return route + + def resolve_route( + self, + task: str | AsyncTaskiqDecoratedTask[Any, Any], + broker: AsyncBroker | None = None, + flow: FlowProtocol | None = None, + ) -> TaskiqRoute: + """Resolve outbound route for a task invocation.""" + task_name = resolve_task_name(task) + if broker is not None: + target_broker = self.brokers.resolve(broker) + route_flow = self._resolve_flow_for_broker_override( + task_name, + target_broker, + flow, + ) + return TaskiqRoute(broker=target_broker, flow=route_flow) + + route = self.routes.get(task_name) + if route is not None: + if flow is None: + return route + return replace(route, flow=flow) + + return self.build_route(flow=flow) + + def resolve_explicit_or_default_route( + self, + task: str | AsyncTaskiqDecoratedTask[Any, Any], + *, + route: TaskiqRoute | None, + broker: AsyncBroker | None, + flow: FlowProtocol | None, + ) -> TaskiqRoute: + """Resolve a route from explicit route or broker/flow overrides.""" + if route is not None: + if broker is not None or flow is not None: + raise ValueError("Pass either route or broker/flow overrides.") + self.brokers.resolve(route.broker) + return route + return self.resolve_route(task, broker=broker, flow=flow) + + def _resolve_flow_for_broker_override( + self, + task_name: str, + broker: AsyncBroker, + flow: FlowProtocol | None, + ) -> FlowProtocol | None: + if flow is not None: + return flow + registered_route = self.routes.get(task_name) + if registered_route is not None and registered_route.broker is broker: + return registered_route.flow + return self.brokers.default_flow(broker) diff --git a/taskiq/routing/subscriptions.py b/taskiq/routing/subscriptions.py new file mode 100644 index 00000000..7f38ba2d --- /dev/null +++ b/taskiq/routing/subscriptions.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +from collections.abc import Iterable +from dataclasses import replace +from typing import TYPE_CHECKING + +from taskiq.flow import FlowProtocol, get_flow_identity +from taskiq.routing.models import TaskiqSubscription +from taskiq.routing.registries import BrokerRegistry + +if TYPE_CHECKING: # pragma: no cover + from taskiq.abc.broker import AsyncBroker + +__all__ = ("SubscriptionPlan",) + + +class SubscriptionPlan: + """ + Inbound listen-plan subscriptions for flow-aware brokers. + + This is mutable setup-time configuration, not a thread-safe runtime + coordination primitive. + """ + + def __init__(self, brokers: BrokerRegistry) -> None: + self.brokers = brokers + self.subscriptions: list[TaskiqSubscription] = [] + + def subscribe( + self, + broker: AsyncBroker, + flow: FlowProtocol, + task_names: Iterable[str], + ) -> TaskiqSubscription: + """Register an inbound flow subscription for a broker.""" + target_broker = self.brokers.resolve(broker) + resolved_task_names = frozenset(task_names) + default_flow = self.brokers.default_flow(target_broker) + if default_flow is not None: + self._ensure_compatible_flow( + broker=target_broker, + registered_flow=default_flow, + new_flow=flow, + ) + + for subscription_index, subscription in enumerate(self.subscriptions): + if subscription.broker is target_broker and self._is_same_flow( + subscription.flow, + flow, + ): + self._ensure_compatible_flow( + broker=target_broker, + registered_flow=subscription.flow, + new_flow=flow, + ) + updated = replace( + subscription, + task_names=subscription.task_names | resolved_task_names, + ) + self.subscriptions[subscription_index] = updated + return updated + + subscription = TaskiqSubscription( + broker=target_broker, + flow=flow, + task_names=resolved_task_names, + ) + self.subscriptions.append(subscription) + return subscription + + def get( + self, + broker: AsyncBroker | None = None, + ) -> tuple[TaskiqSubscription, ...]: + """Return registered inbound subscriptions.""" + if broker is None: + return tuple(self.subscriptions) + target_broker = self.brokers.resolve(broker) + return tuple( + subscription + for subscription in self.subscriptions + if subscription.broker is target_broker + ) + + def get_broker_flows(self, broker: AsyncBroker) -> tuple[FlowProtocol, ...]: + """Return flows a broker should subscribe to.""" + target_broker = self.brokers.resolve(broker) + flows = [subscription.flow for subscription in self.get(target_broker)] + default_flow = self.brokers.default_flow(target_broker) + if default_flow is not None and not self._contains_flow_identity( + flows, + default_flow, + ): + flows.insert(0, default_flow) + elif default_flow is not None: + for flow in flows: + if self._is_same_flow(flow, default_flow): + self._ensure_compatible_flow( + broker=target_broker, + registered_flow=flow, + new_flow=default_flow, + ) + break + return tuple(flows) + + def _contains_flow_identity( + self, + flows: Iterable[FlowProtocol], + flow: FlowProtocol, + ) -> bool: + return any( + self._is_same_flow(registered_flow, flow) for registered_flow in flows + ) + + def _is_same_flow( + self, + first_flow: FlowProtocol, + second_flow: FlowProtocol, + ) -> bool: + return get_flow_identity(first_flow) == get_flow_identity(second_flow) + + def _ensure_compatible_flow( + self, + *, + broker: AsyncBroker, + registered_flow: FlowProtocol, + new_flow: FlowProtocol, + ) -> None: + if not self._is_same_flow(registered_flow, new_flow): + return + + registered_options = dict(registered_flow.broker_options()) + new_options = dict(new_flow.broker_options()) + if registered_options == new_options: + return + + flow_identity = get_flow_identity(new_flow) + raise ValueError( + f"Flow {flow_identity.name!r} is already registered for broker " + f"{broker.broker_name!r} with different broker options.", + ) diff --git a/taskiq/scheduler/scheduler.py b/taskiq/scheduler/scheduler.py index 16677554..b0a1e00d 100644 --- a/taskiq/scheduler/scheduler.py +++ b/taskiq/scheduler/scheduler.py @@ -1,8 +1,9 @@ from logging import getLogger -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from taskiq.exceptions import ScheduledTaskCancelledError from taskiq.kicker import AsyncKicker +from taskiq.router import TaskiqRouter from taskiq.scheduler.scheduled_task import ScheduledTask from taskiq.utils import maybe_awaitable @@ -47,19 +48,34 @@ async def on_ready(self, source: "ScheduleSource", task: ScheduledTask) -> None: except ScheduledTaskCancelledError: logger.info("Scheduled task %s has been cancelled.", task.task_name) else: - await ( - AsyncKicker(task.task_name, self.broker, task.labels) - .with_labels( - schedule_id=task.schedule_id, - ) + kicker: AsyncKicker[Any, Any] = ( + AsyncKicker(task.task_name, self.broker, dict(task.labels)) + .with_labels(schedule_id=task.schedule_id) .with_task_id(task_id=task.task_id) - .kiq( - *task.args, - **task.kwargs, - ) ) + self._apply_scheduled_route(kicker, task.task_name) + await kicker.kiq(*task.args, **task.kwargs) await maybe_awaitable(source.post_send(task)) async def shutdown(self) -> None: """Shutdown the scheduler process.""" await self.broker.shutdown() + + def _apply_scheduled_route( + self, + kicker: AsyncKicker[Any, Any], + task_name: str, + ) -> None: + """ + Apply scheduler dispatch routing without changing schedule payloads. + + Registered router routes are resolved at send time. If there is no + router route for the scheduled task, keep the old scheduler behavior and + send through the scheduler broker instead of the router default broker. + """ + router = getattr(self.broker, "router", None) + if isinstance(router, TaskiqRouter) and task_name in router.routes: + kicker.with_route(router.resolve_route(task_name)) + return + + kicker.with_broker(self.broker) diff --git a/taskiq/task_builder.py b/taskiq/task_builder.py new file mode 100644 index 00000000..eeec6ee2 --- /dev/null +++ b/taskiq/task_builder.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +import inspect +from collections.abc import Callable, Coroutine, Mapping +from dataclasses import dataclass, field +from types import CoroutineType, MappingProxyType +from typing import TYPE_CHECKING, Any, Generic, ParamSpec, TypeVar, overload + +from taskiq.message import TaskiqMessage, _build_taskiq_message + +if TYPE_CHECKING: # pragma: no cover + from taskiq.decor import AsyncTaskiqDecoratedTask + +__all__ = ("TaskDefinition", "task_builder") + +_FuncParams = ParamSpec("_FuncParams") +_T = TypeVar("_T") +_ReturnType = TypeVar("_ReturnType") + + +@dataclass(frozen=True, slots=True) +class TaskDefinition(Generic[_FuncParams, _ReturnType]): + """Unbound task declaration that can be registered later.""" + + task_name: str + original_func: Callable[_FuncParams, _ReturnType] + labels: Mapping[str, Any] = field(default_factory=dict) + return_type: type[_ReturnType] | None = None + base_cls: type[AsyncTaskiqDecoratedTask[Any, Any]] | None = None + + def __post_init__(self) -> None: + """Freeze labels to make task definitions stable declarations.""" + object.__setattr__(self, "labels", MappingProxyType(dict(self.labels))) + + def __call__( + self, + *args: _FuncParams.args, + **kwargs: _FuncParams.kwargs, + ) -> _ReturnType: + """Call original function directly.""" + return self.original_func(*args, **kwargs) + + @overload + async def call( + self: TaskDefinition[_FuncParams, CoroutineType[Any, Any, _T]], + *args: _FuncParams.args, + **kwargs: _FuncParams.kwargs, + ) -> _T: ... + + @overload + async def call( + self: TaskDefinition[_FuncParams, Coroutine[Any, Any, _T]], + *args: _FuncParams.args, + **kwargs: _FuncParams.kwargs, + ) -> _T: ... + + @overload + async def call( + self, + *args: _FuncParams.args, + **kwargs: _FuncParams.kwargs, + ) -> _ReturnType: ... + + async def call( + self, + *args: _FuncParams.args, + **kwargs: _FuncParams.kwargs, + ) -> Any: + """Execute original function in the current process.""" + result = self.original_func(*args, **kwargs) + if inspect.isawaitable(result): + return await result + return result + + def message( + self, + task_id: str, + *args: _FuncParams.args, + **kwargs: _FuncParams.kwargs, + ) -> TaskiqMessage: + """Build a TaskiqMessage without binding this definition to a router.""" + return _build_taskiq_message( + task_id=task_id, + task_name=self.task_name, + labels=self.labels, + args=args, + kwargs=kwargs, + ) + + +@overload +def task_builder( + task_name: Callable[_FuncParams, _ReturnType], + *, + base_cls: type[AsyncTaskiqDecoratedTask[Any, Any]] | None = None, + **labels: Any, +) -> TaskDefinition[_FuncParams, _ReturnType]: ... + + +@overload +def task_builder( + task_name: str | None = None, + *, + base_cls: type[AsyncTaskiqDecoratedTask[Any, Any]] | None = None, + **labels: Any, +) -> Callable[ + [Callable[_FuncParams, _ReturnType]], + TaskDefinition[_FuncParams, _ReturnType], +]: ... + + +def task_builder( + task_name: str | Callable[_FuncParams, _ReturnType] | None = None, + *, + base_cls: type[AsyncTaskiqDecoratedTask[Any, Any]] | None = None, + **labels: Any, +) -> Any: + """Build an unbound task definition. + + This decorator is intended for library/package tasks that should be + registered by the final application. + """ + + def build( + func: Callable[_FuncParams, _ReturnType], + ) -> TaskDefinition[_FuncParams, _ReturnType]: + real_task_name = task_name + if real_task_name is None or callable(real_task_name): + real_task_name = f"{func.__module__}:{func.__name__}" + return_type = None + signature = inspect.signature(func) + if signature.return_annotation is not inspect.Signature.empty: + return_type = signature.return_annotation + return TaskDefinition( + task_name=real_task_name, + original_func=func, + labels=dict(labels), + return_type=return_type, + base_cls=base_cls, + ) + + if callable(task_name): + function = task_name + task_name = None + return build(function) + + return build diff --git a/tests/routing/__init__.py b/tests/routing/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/routing/models.py b/tests/routing/models.py new file mode 100644 index 00000000..e6a4c4f5 --- /dev/null +++ b/tests/routing/models.py @@ -0,0 +1,117 @@ +from collections.abc import AsyncGenerator, Mapping +from dataclasses import dataclass +from typing import Any + +from taskiq import AsyncTaskiqDecoratedTask, FlowProtocol, TaskiqRouter +from taskiq.abc.broker import AsyncBroker +from taskiq.abc.middleware import TaskiqMiddleware +from taskiq.abc.result_backend import AsyncResultBackend +from taskiq.message import BrokerMessage, TaskiqMessage +from taskiq.result import TaskiqResult + + +@dataclass(frozen=True, slots=True) +class BrokerQueue: + """Broker-specific flow used to prove protocol-based routing.""" + + name: str + durable: bool = True + + def broker_options(self) -> Mapping[str, object]: + """Return options for the target broker.""" + return { + "durable": self.durable, + } + + +class TracingTask(AsyncTaskiqDecoratedTask[Any, Any]): + """Custom task class used to prove task_builder base_cls binding.""" + + def tracing_name(self) -> str: + """Return a trace-friendly task name.""" + return self.task_name + + +class CountingRouter(TaskiqRouter): + """Router that counts task registration calls.""" + + def __init__(self) -> None: + self.register_task_calls = 0 + super().__init__() + + def register_task( + self, + task: Any, + broker: AsyncBroker | None = None, + flow: FlowProtocol | None = None, + ) -> Any: + """Count registration calls and delegate to the router.""" + self.register_task_calls += 1 + return super().register_task(task, broker=broker, flow=flow) + + +class OldStyleRecordingBroker(AsyncBroker): + """Broker that only implements the old required kick/listen API.""" + + def __init__( + self, + *, + router: TaskiqRouter | None = None, + broker_name: str | None = None, + ) -> None: + self.sent: list[BrokerMessage] = [] + super().__init__(router=router, broker_name=broker_name) + + async def kick(self, message: BrokerMessage) -> None: + """Record old-style send.""" + self.sent.append(message) + + async def listen(self) -> AsyncGenerator[bytes, None]: + """Old-style recording broker does not listen in tests.""" + if False: + yield b"" + + +class RecordingResultBackend(AsyncResultBackend[Any]): + """Result backend marker that records stored results.""" + + def __init__(self) -> None: + self.results: dict[str, TaskiqResult[Any]] = {} + + async def set_result(self, task_id: str, result: TaskiqResult[Any]) -> None: + """Store result by task id.""" + self.results[task_id] = result + + async def is_result_ready(self, task_id: str) -> bool: + """Return whether result exists.""" + return task_id in self.results + + async def get_result( + self, + task_id: str, + with_logs: bool = False, + ) -> TaskiqResult[Any]: + """Return stored result.""" + return self.results[task_id] + + +class RecordingMiddleware(TaskiqMiddleware): + """Middleware that records client-side send hooks.""" + + def __init__( + self, + name: str, + events: list[tuple[str, str, str]], + ) -> None: + super().__init__() + self.name = name + self.events = events + + def pre_send(self, message: TaskiqMessage) -> TaskiqMessage: + """Record pre-send call.""" + self.events.append((self.name, "pre_send", message.task_name)) + return message + + def post_send(self, message: TaskiqMessage) -> None: + """Record post-send call.""" + self.events.append((self.name, "post_send", message.task_name)) diff --git a/tests/routing/test_dispatch.py b/tests/routing/test_dispatch.py new file mode 100644 index 00000000..0317f4e3 --- /dev/null +++ b/tests/routing/test_dispatch.py @@ -0,0 +1,324 @@ +from taskiq import Flow, FlowProtocol, InMemoryBroker, TaskiqRouter +from tests.routing.models import ( + BrokerQueue, + OldStyleRecordingBroker, + RecordingMiddleware, + RecordingResultBackend, +) +from tests.utils import RecordingBroker + + +async def test_old_broker_task_api_registers_task_in_router() -> None: + broker = RecordingBroker() + + @broker.task(task_name="demo.task") + async def demo_task() -> None: + return None + + assert broker.find_task("demo.task") is demo_task + assert broker.router.find_task("demo.task") is demo_task + + await demo_task.kiq() + + assert broker.sent[0][0].task_name == "demo.task" + assert broker.sent[0][1] is None + + +async def test_broker_task_api_uses_default_flow() -> None: + flow = Flow("default") + broker = RecordingBroker(default_flow=flow) + + @broker.task(task_name="demo.task") + async def demo_task() -> None: + return None + + await demo_task.kiq() + + assert broker.sent[0][0].task_name == "demo.task" + assert broker.sent[0][1] == flow + assert broker.get_subscribed_flows() == (flow,) + + +def test_router_set_broker_can_configure_default_flow() -> None: + router = TaskiqRouter() + broker = RecordingBroker(router=router, broker_name="broker") + flow = Flow("manual.default") + + router.set_broker(broker, name="broker", default_flow=flow) + route = router.resolve_route("demo.task", broker=broker) + + assert broker.default_flow == flow + assert route.flow == flow + assert broker.get_subscribed_flows() == (flow,) + + +async def test_router_can_route_task_to_another_broker_flow() -> None: + router = TaskiqRouter() + source = RecordingBroker(router=router, broker_name="source") + target = RecordingBroker(router=router, broker_name="target") + flow = Flow("events") + + @source.task(task_name="demo.task") + async def demo_task() -> None: + return None + + route = router.route_task(demo_task, broker=target, flow=flow) + resolved_route = router.resolve_route(demo_task) + + await demo_task.kiq() + + assert route.broker is target + assert route.broker_name == "target" + assert resolved_route.broker is target + assert resolved_route.flow == flow + assert source.sent == [] + assert target.sent[0][0].task_name == "demo.task" + assert target.sent[0][1] == flow + + +async def test_old_style_broker_uses_kick_to_flow_fallback() -> None: + router = TaskiqRouter() + source = RecordingBroker(router=router, broker_name="source") + target = OldStyleRecordingBroker(router=router, broker_name="target") + flow = Flow("legacy") + + @source.task(task_name="demo.task") + async def demo_task() -> None: + return None + + router.route_task(demo_task, broker=target, flow=flow) + + await demo_task.kiq() + + assert source.sent == [] + assert target.sent[0].task_name == "demo.task" + + +async def test_kicker_route_override_wins_over_registered_route() -> None: + router = TaskiqRouter() + first = RecordingBroker(router=router, broker_name="first") + second = RecordingBroker(router=router, broker_name="second") + first_flow = Flow("first") + second_flow = Flow("second") + + @first.task(task_name="demo.task") + async def demo_task() -> None: + return None + + router.route_task(demo_task, broker=first, flow=first_flow) + + route = router.resolve_route(demo_task, broker=second, flow=second_flow) + + await demo_task.kicker().with_route(route).kiq() + + assert first.sent == [] + assert second.sent[0][1] == second_flow + + +async def test_kicker_broker_override_wins_over_registered_route() -> None: + router = TaskiqRouter() + first = RecordingBroker(router=router, broker_name="first") + second_flow = Flow("second") + second = RecordingBroker( + router=router, + broker_name="second", + default_flow=second_flow, + ) + first_flow = Flow("first") + + @first.task(task_name="demo.task") + async def demo_task() -> None: + return None + + first_route = router.route_task(demo_task, broker=first, flow=first_flow) + + await demo_task.kicker().with_route(first_route).with_broker(second).kiq() + + assert first.sent == [] + assert second.sent[0][1] == second_flow + + +async def test_router_uses_explicit_broker_lookup_for_config_names() -> None: + router = TaskiqRouter() + source = RecordingBroker(router=router, broker_name="source") + target = RecordingBroker(router=router, broker_name="target") + flow = Flow("compat") + + @source.task(task_name="demo.task") + async def demo_task() -> None: + return None + + broker = router.get_broker("target") + route = router.route_task("demo.task", broker=broker, flow=flow) + + await demo_task.kicker().with_route(route).kiq() + + assert target.sent[0][0].task_name == "demo.task" + assert target.sent[0][1] == flow + + +async def test_router_accepts_broker_specific_flow_protocol() -> None: + broker = RecordingBroker(broker_name="transport") + flow = BrokerQueue(name="critical", durable=False) + + @broker.task(task_name="demo.task") + async def demo_task() -> None: + return None + + await demo_task.kicker().with_flow(flow).kiq() + + assert isinstance(flow, FlowProtocol) + assert broker.sent[0][1] is flow + assert flow.broker_options() == { + "durable": False, + } + + +async def test_kicker_preserves_labels_task_id_and_result_backend() -> None: + broker = RecordingBroker() + backend = RecordingResultBackend() + broker.with_id_generator(lambda: "generated-id") + broker.with_result_backend(backend) + + @broker.task(task_name="demo.task", task_label="declared") + async def demo_task() -> None: + return None + + sent_task = await demo_task.kicker().with_labels(call_label=42).kiq() + broker_message = broker.sent[0][0] + + assert sent_task.task_id == "generated-id" + assert sent_task.result_backend is backend + assert broker_message.task_id == "generated-id" + assert broker_message.labels["task_label"] == "declared" + assert broker_message.labels["call_label"] == "42" + + +async def test_kicker_can_prepare_invocation_for_later() -> None: + broker = RecordingBroker() + + @broker.task(task_name="demo.task") + async def demo_task(value: int) -> None: + return None + + prepared = demo_task.kicker().with_labels(trace_id="abc").prepare(1) + + assert prepared.message.task_name == "demo.task" + assert prepared.message.args == [1] + assert prepared.message.labels["trace_id"] == "abc" + + await prepared.kiq() + + assert broker.sent[0][0].task_id == prepared.message.task_id + + +async def test_prepared_invocation_keeps_route_snapshot() -> None: + router = TaskiqRouter() + first = RecordingBroker(router=router, broker_name="first") + second = RecordingBroker(router=router, broker_name="second") + first_flow = Flow("first") + second_flow = Flow("second") + + @first.task(task_name="demo.task") + async def demo_task(value: int) -> None: + return None + + first_route = router.route_task(demo_task, broker=first, flow=first_flow) + second_route = router.resolve_route(demo_task, broker=second, flow=second_flow) + kicker = demo_task.kicker().with_route(first_route) + + prepared = kicker.prepare(1) + kicker.with_route(second_route) + + await prepared.kiq() + + assert first.sent[0][1] == first_flow + assert second.sent == [] + + +async def test_prepared_invocation_resolves_default_route_snapshot() -> None: + router = TaskiqRouter() + source = RecordingBroker(router=router, broker_name="source") + first = RecordingBroker(router=router, broker_name="first") + second = RecordingBroker(router=router, broker_name="second") + first_flow = Flow("first") + second_flow = Flow("second") + + @source.task(task_name="demo.task") + async def demo_task(value: int) -> None: + return None + + router.route_task(demo_task, broker=first, flow=first_flow) + prepared = demo_task.kicker().prepare(1) + + router.route_task(demo_task, broker=second, flow=second_flow) + + await prepared.kiq() + + assert source.sent == [] + assert first.sent[0][1] == first_flow + assert second.sent == [] + + +async def test_router_task_decorator_can_choose_broker_and_flow() -> None: + router = TaskiqRouter() + target = RecordingBroker(router=router, broker_name="target") + flow = Flow("target-flow") + + @router.task("demo.task", broker=target, flow=flow) + async def demo_task() -> None: + return None + + await demo_task.kiq() + + assert target.sent[0][0].task_name == "demo.task" + assert target.sent[0][1] == flow + + +async def test_routed_dispatch_uses_target_middleware_and_result_backend() -> None: + router = TaskiqRouter() + source = RecordingBroker(router=router, broker_name="source") + target = RecordingBroker(router=router, broker_name="target") + source_events: list[tuple[str, str, str]] = [] + target_events: list[tuple[str, str, str]] = [] + target_backend = RecordingResultBackend() + + source.add_middlewares(RecordingMiddleware("source", source_events)) + target.add_middlewares(RecordingMiddleware("target", target_events)) + target.with_result_backend(target_backend) + + @source.task(task_name="demo.task") + async def demo_task() -> None: + return None + + router.route_task(demo_task, broker=target, flow=Flow("target")) + + sent_task = await demo_task.kiq() + + assert sent_task.result_backend is target_backend + assert source_events == [] + assert target_events == [ + ("target", "pre_send", "demo.task"), + ("target", "post_send", "demo.task"), + ] + + +async def test_worker_lookup_uses_task_name_not_flow() -> None: + broker = InMemoryBroker(await_inplace=True) + calls: list[str] = [] + + @broker.task(task_name="demo.first") + async def first_task() -> str: + calls.append("first") + return "first" + + @broker.task(task_name="demo.second") + async def second_task() -> str: + calls.append("second") + return "second" + + task = await first_task.kicker().with_flow(Flow(second_task.task_name)).kiq() + result = await task.wait_result(timeout=2) + + assert result.return_value == "first" + assert calls == ["first"] diff --git a/tests/routing/test_router_core.py b/tests/routing/test_router_core.py new file mode 100644 index 00000000..d2b92301 --- /dev/null +++ b/tests/routing/test_router_core.py @@ -0,0 +1,83 @@ +import pytest + +from taskiq import TaskiqRoute, TaskiqRouter, TaskiqSubscription, task_builder +from tests.routing.models import CountingRouter +from tests.utils import RecordingBroker + + +def test_broker_creates_default_router() -> None: + broker = RecordingBroker() + + assert broker.router.brokers[broker.broker_name] is broker + assert broker.router.default_broker_name == broker.broker_name + + +def test_public_router_models_keep_router_module() -> None: + assert TaskiqRoute.__module__ == "taskiq.router" + assert TaskiqSubscription.__module__ == "taskiq.router" + + +def test_router_rejects_duplicate_broker_names() -> None: + router = TaskiqRouter() + RecordingBroker(router=router, broker_name="broker") + + with pytest.raises(ValueError, match="already registered"): + RecordingBroker(router=router, broker_name="broker") + + +def test_router_rejects_broker_attached_to_another_router() -> None: + first_router = TaskiqRouter() + second_router = TaskiqRouter() + broker = RecordingBroker(router=first_router, broker_name="broker") + + with pytest.raises(ValueError, match="attached to another router"): + second_router.set_broker(broker, name="broker") + + +def test_router_default_broker_setter_rejects_foreign_broker() -> None: + first_router = TaskiqRouter() + second_router = TaskiqRouter() + broker = RecordingBroker(router=first_router, broker_name="broker") + + with pytest.raises(ValueError, match="not registered"): + second_router.default_broker = broker + + +def test_router_rejects_string_broker_references() -> None: + router = TaskiqRouter() + RecordingBroker(router=router, broker_name="broker") + + with pytest.raises(TypeError, match="Broker string references"): + router.route_task("demo.task", broker="broker") # type: ignore[arg-type] + + +def test_router_rejects_duplicate_task_names() -> None: + router = TaskiqRouter() + first = RecordingBroker(router=router, broker_name="first") + second = RecordingBroker(router=router, broker_name="second") + + @first.task(task_name="demo.task") + async def first_task() -> None: + return None + + with pytest.raises(ValueError, match="already registered"): + + @second.task(task_name="demo.task") + async def second_task() -> None: + return None + + +def test_task_definition_binding_does_not_reenter_router_registration() -> None: + router = CountingRouter() + broker = RecordingBroker(router=router) + + @task_builder("shared.once") + def shared_task() -> None: + return None + + registered = broker.register_task(shared_task) + + assert registered.task_name == "shared.once" + assert router.register_task_calls == 1 + assert broker.router.find_task("shared.once") is registered + assert broker.find_task("shared.once") is registered diff --git a/tests/routing/test_shared_tasks.py b/tests/routing/test_shared_tasks.py new file mode 100644 index 00000000..e7aa596c --- /dev/null +++ b/tests/routing/test_shared_tasks.py @@ -0,0 +1,116 @@ +import pytest + +from taskiq import AsyncTaskiqDecoratedTask, Flow, task_builder +from tests.routing.models import RecordingMiddleware, TracingTask +from tests.utils import RecordingBroker + + +async def test_task_builder_can_be_registered_later() -> None: + broker = RecordingBroker() + + @task_builder("shared.add", queue="shared") + def add(left: int, right: int) -> int: + return left + right + + assert await add.call(1, 2) == 3 + + registered = broker.register_task(add) + + assert type(registered) is AsyncTaskiqDecoratedTask + assert registered.task_name == "shared.add" + assert registered.labels == {"queue": "shared"} + assert broker.router.find_task("shared.add") is registered + + await registered.kiq(1, 2) + + assert broker.sent[0][0].task_name == "shared.add" + + +async def test_router_register_task_definition_binds_to_selected_broker() -> None: + source = RecordingBroker(broker_name="source") + target = RecordingBroker(router=source.router, broker_name="target") + flow = Flow("target") + + @task_builder("shared.routed") + async def shared_task() -> None: + return None + + registered = source.router.register_task(shared_task, broker=target, flow=flow) + + await registered.kiq() + + assert source.sent == [] + assert target.find_task("shared.routed") is registered + assert target.local_task_registry["shared.routed"] is registered + assert target.sent[0][0].task_name == "shared.routed" + assert target.sent[0][1] == flow + + +async def test_task_builder_can_use_custom_base_cls() -> None: + broker = RecordingBroker() + + @task_builder("shared.traced", base_cls=TracingTask, queue="shared") + async def traced(value: int) -> int: + return value + 1 + + registered = broker.register_task(traced) + + assert isinstance(registered, TracingTask) + assert registered.tracing_name() == "shared.traced" + assert registered.labels == {"queue": "shared"} + assert await traced.call(1) == 2 + + await registered.kiq(1) + + assert broker.sent[0][0].task_name == "shared.traced" + + +async def test_task_builder_custom_base_cls_uses_broker_middleware() -> None: + broker = RecordingBroker() + events: list[tuple[str, str, str]] = [] + broker.add_middlewares(RecordingMiddleware("broker", events)) + + @task_builder("shared.traced.middleware", base_cls=TracingTask) + async def traced(value: int) -> int: + return value + 1 + + registered = broker.register_task(traced) + + assert isinstance(registered, TracingTask) + + await registered.kiq(1) + + assert events == [ + ("broker", "pre_send", "shared.traced.middleware"), + ("broker", "post_send", "shared.traced.middleware"), + ] + + +def test_task_definition_default_flow_does_not_create_subscription() -> None: + flow = Flow("shared.default") + broker = RecordingBroker(default_flow=flow) + + @task_builder("shared.default") + def shared_task() -> None: + return None + + registered = broker.register_task(shared_task) + route = broker.router.resolve_route(registered) + + assert route.flow == flow + assert broker.router.get_subscriptions(broker) == () + assert broker.get_subscribed_flows() == (flow,) + + +def test_register_task_definition_rejects_overrides() -> None: + broker = RecordingBroker() + + @task_builder("shared.add", queue="shared") + def add(left: int, right: int) -> int: + return left + right + + with pytest.raises(ValueError, match="TaskDefinition already defines"): + broker.register_task(add, task_name="other.name") + + with pytest.raises(ValueError, match="TaskDefinition already defines"): + broker.register_task(add, queue="other") diff --git a/tests/routing/test_subscriptions.py b/tests/routing/test_subscriptions.py new file mode 100644 index 00000000..cc82603b --- /dev/null +++ b/tests/routing/test_subscriptions.py @@ -0,0 +1,186 @@ +import pytest + +from taskiq import ( + Flow, + FlowIdentity, + FlowProtocol, + TaskiqRouter, + get_flow_identity, +) +from taskiq.warnings import TaskiqDeprecationWarning +from tests.routing.models import BrokerQueue +from tests.utils import RecordingBroker + + +def test_flow_identity_uses_logical_name() -> None: + generic_flow = Flow("events").with_options(durable=True) + broker_flow = BrokerQueue(name="events", durable=True) + + assert generic_flow.identity == FlowIdentity("events") + assert get_flow_identity(broker_flow) == FlowIdentity("events") + assert generic_flow.broker_options() == {"durable": True} + + +def test_route_task_does_not_register_subscription_by_default() -> None: + router = TaskiqRouter() + broker = RecordingBroker(router=router, broker_name="broker") + flow = Flow("events") + + route = router.route_task("demo.task", broker=broker, flow=flow) + + assert route.flow == flow + assert router.get_subscriptions(broker) == () + assert broker.get_subscribed_flows() == () + + +def test_route_task_subscribe_true_is_deprecated_shim() -> None: + router = TaskiqRouter() + broker = RecordingBroker(router=router, broker_name="broker") + flow = Flow("events") + + with pytest.warns(TaskiqDeprecationWarning, match="router.subscribe"): + route = router.route_task( + "demo.task", + broker=broker, + flow=flow, + subscribe=True, + ) + + subscription = router.get_subscriptions(broker)[0] + + assert route.flow == flow + assert subscription.broker is broker + assert subscription.flow == flow + assert subscription.task_names == frozenset({"demo.task"}) + assert broker.get_subscribed_flows() == (flow,) + + +def test_route_update_does_not_mutate_existing_subscriptions() -> None: + router = TaskiqRouter() + broker = RecordingBroker(router=router, broker_name="broker") + first_flow = Flow("events.first") + second_flow = Flow("events.second") + + router.route_task("demo.task", broker=broker, flow=first_flow) + router.subscribe(broker, first_flow, "demo.task") + router.route_task("demo.task", broker=broker, flow=second_flow) + + subscription = router.get_subscriptions(broker)[0] + route = router.resolve_route("demo.task") + + assert route.flow == second_flow + assert subscription.flow == first_flow + assert subscription.task_names == frozenset({"demo.task"}) + assert broker.get_subscribed_flows() == (first_flow,) + + +def test_subscribe_dedupes_flow_and_merges_task_names() -> None: + router = TaskiqRouter() + broker = RecordingBroker(router=router, broker_name="broker") + flow = Flow("events") + + router.subscribe(broker, flow, "first.task") + subscription = router.subscribe(broker, flow, "second.task") + + assert router.get_subscriptions(broker) == (subscription,) + assert subscription.task_names == frozenset({"first.task", "second.task"}) + assert broker.get_subscribed_flows() == (flow,) + + +def test_subscribe_merges_same_identity_with_compatible_options() -> None: + router = TaskiqRouter() + broker = RecordingBroker(router=router, broker_name="broker") + generic_flow = Flow("events").with_options(durable=True) + broker_flow = BrokerQueue(name="events", durable=True) + + router.subscribe(broker, generic_flow, "first.task") + subscription = router.subscribe(broker, broker_flow, "second.task") + + assert router.get_subscriptions(broker) == (subscription,) + assert subscription.flow is generic_flow + assert subscription.task_names == frozenset({"first.task", "second.task"}) + assert broker.get_subscribed_flows() == (generic_flow,) + + +@pytest.mark.parametrize( + "conflicting_flow", + [ + pytest.param( + Flow("events").with_options(durable=False), + id="generic-flow", + ), + pytest.param( + BrokerQueue(name="events", durable=False), + id="broker-specific-flow", + ), + ], +) +def test_subscribe_rejects_same_identity_with_different_options( + conflicting_flow: FlowProtocol, +) -> None: + router = TaskiqRouter() + broker = RecordingBroker(router=router, broker_name="broker") + + router.subscribe(broker, Flow("events").with_options(durable=True), "first.task") + + with pytest.raises(ValueError, match="different broker options"): + router.subscribe( + broker, + conflicting_flow, + "second.task", + ) + + +def test_subscribe_rejects_default_flow_option_conflict() -> None: + router = TaskiqRouter() + default_flow = Flow("events").with_options(durable=True) + broker = RecordingBroker( + router=router, + broker_name="broker", + default_flow=default_flow, + ) + + with pytest.raises(ValueError, match="different broker options"): + router.subscribe( + broker, + BrokerQueue(name="events", durable=False), + "demo.task", + ) + + +def test_subscribe_allows_same_identity_on_different_brokers() -> None: + router = TaskiqRouter() + first_broker = RecordingBroker(router=router, broker_name="first") + second_broker = RecordingBroker(router=router, broker_name="second") + + first_subscription = router.subscribe( + first_broker, + Flow("events").with_options(durable=True), + "first.task", + ) + second_subscription = router.subscribe( + second_broker, + BrokerQueue(name="events", durable=False), + "second.task", + ) + + assert router.get_subscriptions(first_broker) == (first_subscription,) + assert router.get_subscriptions(second_broker) == (second_subscription,) + + +def test_subscribed_flows_dedupes_default_flow_by_identity() -> None: + router = TaskiqRouter() + default_flow = Flow("events").with_options(durable=True) + broker = RecordingBroker( + router=router, + broker_name="broker", + default_flow=default_flow, + ) + + subscription = router.subscribe( + broker, + BrokerQueue(name="events", durable=True), + "demo.task", + ) + + assert broker.get_subscribed_flows() == (subscription.flow,) diff --git a/tests/scheduler/test_scheduler.py b/tests/scheduler/test_scheduler.py index 05e6358a..bc37effc 100644 --- a/tests/scheduler/test_scheduler.py +++ b/tests/scheduler/test_scheduler.py @@ -1,8 +1,28 @@ +from datetime import datetime, timezone + +from taskiq import Flow, TaskiqRouter from taskiq.abc.schedule_source import ScheduleSource from taskiq.brokers.inmemory_broker import InMemoryBroker +from taskiq.compat import model_dump from taskiq.exceptions import ScheduledTaskCancelledError from taskiq.scheduler.scheduled_task import ScheduledTask from taskiq.scheduler.scheduler import TaskiqScheduler +from tests.utils import RecordingBroker + + +class RecordingScheduleSource(ScheduleSource): + """Schedule source that records dynamically added schedules.""" + + def __init__(self) -> None: + self.schedules: list[ScheduledTask] = [] + + async def get_schedules(self) -> list["ScheduledTask"]: + """Return stored schedules.""" + return self.schedules + + async def add_schedule(self, schedule: "ScheduledTask") -> None: + """Store a schedule.""" + self.schedules.append(schedule) class CancellingScheduleSource(ScheduleSource): @@ -31,3 +51,131 @@ async def test_scheduled_task_cancelled() -> None: ) await scheduler.on_ready(source, task) # error is caught + + +async def test_scheduler_resolves_route_at_send_time() -> None: + router = TaskiqRouter() + scheduler_broker = RecordingBroker(router=router, broker_name="scheduler") + first_broker = RecordingBroker(router=router, broker_name="first") + second_broker = RecordingBroker(router=router, broker_name="second") + first_flow = Flow("first") + second_flow = Flow("second") + source = RecordingScheduleSource() + scheduler = TaskiqScheduler(broker=scheduler_broker, sources=[source]) + + @scheduler_broker.task(task_name="demo.task") + async def demo_task() -> None: + return None + + router.route_task(demo_task, broker=first_broker, flow=first_flow) + scheduled_task = ScheduledTask( + task_name="demo.task", + labels={}, + args=[], + kwargs={}, + cron="* * * * *", + ) + router.route_task(demo_task, broker=second_broker, flow=second_flow) + + await scheduler.on_ready(source, scheduled_task) + + assert first_broker.sent == [] + assert second_broker.sent[0][0].task_name == "demo.task" + assert second_broker.sent[0][1] == second_flow + + +async def test_scheduler_without_route_uses_scheduler_broker() -> None: + router = TaskiqRouter() + default_broker = RecordingBroker(router=router, broker_name="default") + scheduler_flow = Flow("scheduler.default") + scheduler_broker = RecordingBroker( + router=router, + broker_name="scheduler", + default_flow=scheduler_flow, + ) + source = RecordingScheduleSource() + scheduler = TaskiqScheduler(broker=scheduler_broker, sources=[source]) + scheduled_task = ScheduledTask( + task_name="external.task", + labels={"source": "database"}, + args=[], + kwargs={}, + cron="* * * * *", + ) + + await scheduler.on_ready(source, scheduled_task) + + sent_message, sent_flow = scheduler_broker.sent[0] + assert default_broker.sent == [] + assert sent_message.task_name == "external.task" + assert sent_message.labels["source"] == "database" + assert sent_message.labels["schedule_id"] == scheduled_task.schedule_id + assert sent_flow == scheduler_flow + assert scheduled_task.labels == {"source": "database"} + + +async def test_scheduler_removed_route_uses_scheduler_broker() -> None: + router = TaskiqRouter() + scheduler_flow = Flow("scheduler.default") + scheduler_broker = RecordingBroker( + router=router, + broker_name="scheduler", + default_flow=scheduler_flow, + ) + target_broker = RecordingBroker(router=router, broker_name="target") + source = RecordingScheduleSource() + scheduler = TaskiqScheduler(broker=scheduler_broker, sources=[source]) + + @scheduler_broker.task(task_name="demo.task") + async def demo_task() -> None: + return None + + router.route_task(demo_task, broker=target_broker, flow=Flow("target")) + scheduled_task = ScheduledTask( + task_name="demo.task", + labels={}, + args=[], + kwargs={}, + cron="* * * * *", + ) + del router.routes["demo.task"] + + await scheduler.on_ready(source, scheduled_task) + + sent_message, sent_flow = scheduler_broker.sent[0] + assert target_broker.sent == [] + assert sent_message.task_name == "demo.task" + assert sent_flow == scheduler_flow + + +async def test_created_schedule_kiq_does_not_mutate_schedule_payload() -> None: + router = TaskiqRouter() + source_broker = RecordingBroker(router=router, broker_name="source") + target_broker = RecordingBroker(router=router, broker_name="target") + target_flow = Flow("target") + source = RecordingScheduleSource() + + @source_broker.task(task_name="demo.task") + async def demo_task(value: int) -> None: + return None + + route = router.resolve_route( + demo_task, + broker=target_broker, + flow=target_flow, + ) + schedule = ( + await demo_task.kicker() + .with_route(route) + .with_labels(trace_id="trace") + .schedule_by_time(source, datetime.now(timezone.utc), 1) + ) + schedule_payload = model_dump(schedule.task) + + await schedule.kiq() + + assert source.schedules == [schedule.task] + assert {"route", "flow", "broker", "broker_name"}.isdisjoint(schedule_payload) + assert schedule.task.labels == {"trace_id": "trace"} + assert target_broker.sent[0][0].task_name == "demo.task" + assert target_broker.sent[0][1] == target_flow diff --git a/tests/test_requeue.py b/tests/test_requeue.py index c451b6df..a23f98c3 100644 --- a/tests/test_requeue.py +++ b/tests/test_requeue.py @@ -1,4 +1,20 @@ -from taskiq import Context, InMemoryBroker, TaskiqDepends +import pytest + +from taskiq import Context, Flow, InMemoryBroker, TaskiqDepends, TaskiqRouter +from taskiq.exceptions import NoResultError +from taskiq.message import TaskiqMessage +from tests.utils import RecordingBroker + + +def build_message(task_name: str = "demo.task") -> TaskiqMessage: + """Build a message for context requeue tests.""" + return TaskiqMessage( + task_id="task-id", + task_name=task_name, + labels={}, + args=[], + kwargs={}, + ) async def test_requeue() -> None: @@ -46,3 +62,75 @@ async def task(_: None = TaskiqDepends(dep_func)) -> None: ) assert runs_count == 2 + + +async def test_requeue_preserves_current_broker_when_route_points_elsewhere() -> None: + router = TaskiqRouter() + current_flow = Flow("current.default") + current_broker = RecordingBroker( + router=router, + broker_name="current", + default_flow=current_flow, + ) + routed_broker = RecordingBroker(router=router, broker_name="routed") + routed_flow = Flow("routed") + message = build_message() + context = Context(message, current_broker) + + router.route_task(message.task_name, broker=routed_broker, flow=routed_flow) + + with pytest.raises(NoResultError): + await context.requeue() + + sent_message, sent_flow = current_broker.sent[0] + assert routed_broker.sent == [] + assert sent_message.task_name == message.task_name + assert sent_message.labels["X-Taskiq-requeue"] == "1" + assert sent_flow == current_flow + + +async def test_requeue_prefers_same_broker_route_flow() -> None: + router = TaskiqRouter() + default_flow = Flow("current.default") + route_flow = Flow("current.route") + current_broker = RecordingBroker( + router=router, + broker_name="current", + default_flow=default_flow, + ) + message = build_message() + context = Context(message, current_broker) + + router.route_task(message.task_name, broker=current_broker, flow=route_flow) + + with pytest.raises(NoResultError): + await context.requeue() + + assert current_broker.sent[0][1] == route_flow + + +async def test_requeue_without_route_or_default_flow_sends_without_flow() -> None: + broker = RecordingBroker() + message = build_message() + context = Context(message, broker) + + with pytest.raises(NoResultError): + await context.requeue() + + assert broker.sent[0][1] is None + + +async def test_router_requeue_accepts_explicit_flow_override() -> None: + router = TaskiqRouter() + default_flow = Flow("default") + explicit_flow = Flow("explicit") + broker = RecordingBroker( + router=router, + broker_name="broker", + default_flow=default_flow, + ) + message = build_message() + + await router.requeue(message, broker=broker, flow=explicit_flow) + + assert broker.sent[0][1] == explicit_flow diff --git a/tests/test_task_builder.py b/tests/test_task_builder.py new file mode 100644 index 00000000..0acf3066 --- /dev/null +++ b/tests/test_task_builder.py @@ -0,0 +1,130 @@ +from dataclasses import dataclass +from typing import Any + +import pytest +from pydantic import BaseModel + +from taskiq import AsyncTaskiqDecoratedTask, InMemoryBroker, task_builder +from taskiq.labels import LabelType + + +class CustomBrokerTask(AsyncTaskiqDecoratedTask[Any, Any]): + """Custom broker-level task class used by old decorator tests.""" + + +@dataclass(frozen=True, slots=True) +class Payload: + """Payload dataclass used to check message argument preparation.""" + + value: int + + +class User(BaseModel): + """Payload model used to check pydantic argument preparation.""" + + name: str + + +async def test_task_definition_call_executes_sync_and_async_functions() -> None: + @task_builder("shared.add") + def add(left: int, right: int) -> int: + return left + right + + @task_builder("shared.multiply") + async def multiply(left: int, right: int) -> int: + return left * right + + assert add(1, 2) == 3 + assert await add.call(1, 2) == 3 + assert await multiply(2, 3) == 6 + assert await multiply.call(2, 3) == 6 + + +def test_task_definition_message_matches_registered_kicker_message() -> None: + @task_builder( + "shared.message", + priority=7, + enabled=True, + payload=b"abc", + ) + def process(payload: Payload, user: User) -> None: + return None + + broker = InMemoryBroker() + registered = broker.register_task(process) + + definition_message = process.message( + "task-id", + Payload(value=1), + User(name="Ada"), + ) + prepared_message = ( + registered.kicker() + .with_task_id("task-id") + .prepare( + Payload(value=1), + User(name="Ada"), + ) + .message + ) + + assert definition_message == prepared_message + assert definition_message.args == [{"value": 1}, {"name": "Ada"}] + assert definition_message.labels == { + "enabled": "True", + "payload": "YWJj", + "priority": "7", + } + assert definition_message.labels_types == { + "enabled": LabelType.BOOL.value, + "payload": LabelType.BYTES.value, + "priority": LabelType.INT.value, + } + + +def test_task_definition_message_rejects_dataclass_types() -> None: + @task_builder("shared.message") + def process(payload: Payload) -> None: + return None + + with pytest.raises(ValueError, match="Cannot serialize types"): + process.message("task-id", Payload) # type: ignore[arg-type] + + +def test_shared_task_without_base_cls_uses_native_task_class() -> None: + broker = InMemoryBroker() + broker.decorator_class = CustomBrokerTask + + @broker.task(task_name="old.decorator") + async def old_task() -> None: + return None + + @task_builder("shared.default") + async def shared_task() -> None: + return None + + registered = broker.register_task(shared_task) + + assert isinstance(old_task, CustomBrokerTask) + assert type(registered) is AsyncTaskiqDecoratedTask + + +@pytest.mark.parametrize( + "base_cls", + [ + pytest.param(object, id="non-task-type"), + pytest.param(0, id="falsey-runtime-value"), + ], +) +def test_invalid_task_definition_base_cls_fails_when_bound(base_cls: Any) -> None: + @task_builder("shared.invalid", base_cls=base_cls) + def invalid_task() -> None: + return None + + broker = InMemoryBroker() + + with pytest.raises( + TypeError, + match="base_cls must be a subclass of AsyncTaskiqDecoratedTask", + ): + broker.register_task(invalid_task) diff --git a/tests/utils.py b/tests/utils.py index c6c659d2..a90b59bf 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,10 +1,45 @@ import asyncio from collections.abc import AsyncGenerator -from taskiq import AsyncBroker, BrokerMessage +from taskiq import AsyncBroker, BrokerMessage, FlowProtocol, TaskiqRouter from taskiq.acks import AckableMessage +class RecordingBroker(AsyncBroker): + """Broker for tests that records messages and selected flows.""" + + def __init__( + self, + *, + router: TaskiqRouter | None = None, + broker_name: str | None = None, + default_flow: FlowProtocol | None = None, + ) -> None: + self.sent: list[tuple[BrokerMessage, FlowProtocol | None]] = [] + super().__init__( + router=router, + broker_name=broker_name, + default_flow=default_flow, + ) + + async def kick(self, message: BrokerMessage) -> None: + """Record old-style send.""" + self.sent.append((message, None)) + + async def kick_to_flow( + self, + message: BrokerMessage, + flow: FlowProtocol | None = None, + ) -> None: + """Record flow-aware send.""" + self.sent.append((message, flow)) + + async def listen(self) -> AsyncGenerator[bytes, None]: + """Recording broker does not listen in tests.""" + if False: + yield b"" + + class AsyncQueueBroker(AsyncBroker): """ Broker for testing.