diff --git a/README.md b/README.md index 86250d4..9532df8 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ pip install taskiq taskiq-nats Here's a minimal setup example with a broker and one task. +### Default NATS broker. ```python import asyncio from taskiq_nats import NatsBroker, JetStreamBroker @@ -27,15 +28,47 @@ broker = NatsBroker( queue="random_queue_name", ) -# Or alternatively you can use a JetStream broker: -broker = JetStreamBroker( - [ + +@broker.task +async def my_lovely_task(): + print("I love taskiq") + + +async def main(): + await broker.startup() + + await my_lovely_task.kiq() + + await broker.shutdown() + + +if __name__ == "__main__": + asyncio.run(main()) + +``` +### NATS broker based on JetStream +```python +import asyncio +from taskiq_nats import ( + PushBasedJetStreamBroker, + PullBasedJetStreamBroker +) + +broker = PushBasedJetStreamBroker( + servers=[ "nats://nats1:4222", "nats://nats2:4222", ], - queue="random_queue_name", - subject="my-subj", - stream_name="my-stream" + queue="awesome_queue_name", +) + +# Or you can use pull based variant +broker = PullBasedJetStreamBroker( + servers=[ + "nats://nats1:4222", + "nats://nats2:4222", + ], + durable="awesome_durable_consumer_name", ) @@ -54,7 +87,6 @@ async def main(): if __name__ == "__main__": asyncio.run(main()) - ``` ## NatsBroker configuration @@ -68,3 +100,22 @@ Here's the constructor parameters: * `result_backend` - custom result backend. * `task_id_generator` - custom function to generate task ids. * Every other keyword argument will be sent to `nats.connect` function. + +## JetStreamBroker configuration +### Common +* `servers` - a single string or a list of strings with nats nodes addresses. +* `subject` - name of the subect that will be used to exchange tasks betwee workers and clients. +* `stream_name` - name of the stream where subjects will be located. +* `queue` - a single string or a list of strings with nats nodes addresses. +* `result_backend` - custom result backend. +* `task_id_generator` - custom function to generate task ids. +* `stream_config` - a config for stream. +* `consumer_config` - a config for consumer. + +### PushBasedJetStreamBroker +* `queue` - name of the queue. It's used to share messages between different consumers. + +### PullBasedJetStreamBroker +* `durable` - durable name of the consumer. It's used to share messages between different consumers. +* `pull_consume_batch` - maximum number of message that can be fetched each time. +* `pull_consume_timeout` - timeout for messages fetch. If there is no messages, we start fetching messages again. diff --git a/taskiq_nats/__init__.py b/taskiq_nats/__init__.py index 8b4c509..896b55c 100644 --- a/taskiq_nats/__init__.py +++ b/taskiq_nats/__init__.py @@ -5,6 +5,14 @@ uses NATS as a message queue. """ -from taskiq_nats.broker import JetStreamBroker, NatsBroker +from taskiq_nats.broker import ( + PushBasedJetStreamBroker, + PullBasedJetStreamBroker, + NatsBroker, +) -__all__ = ["NatsBroker", "JetStreamBroker"] +__all__ = [ + "NatsBroker", + "PushBasedJetStreamBroker", + "PullBasedJetStreamBroker", +] diff --git a/taskiq_nats/broker.py b/taskiq_nats/broker.py index 8882eb8..7f606e3 100644 --- a/taskiq_nats/broker.py +++ b/taskiq_nats/broker.py @@ -1,14 +1,22 @@ +from abc import ABC, abstractmethod from logging import getLogger -from typing import Any, AsyncGenerator, Callable, List, Optional, TypeVar, Union +from typing import Any, AsyncGenerator, Callable, Final, Generic, List, Optional, Self, TypeVar, Union from nats.aio.client import Client +from nats.aio.msg import Msg as NatsMessage from nats.js import JetStreamContext -from nats.js.api import StreamConfig +from nats.js.api import StreamConfig, ConsumerConfig +from nats.errors import TimeoutError as NatsTimeoutError from taskiq import AckableMessage, AsyncBroker, AsyncResultBackend, BrokerMessage _T = TypeVar("_T") # noqa: WPS111 (Too short) +JetStreamConsumerType = TypeVar( + "JetStreamConsumerType", +) + + logger = getLogger("taskiq_nats") @@ -80,38 +88,54 @@ async def shutdown(self) -> None: await super().shutdown() -class JetStreamBroker(AsyncBroker): # noqa: WPS230 - """ - JetStream broker for taskiq. +class BaseJetStreamBroker(AsyncBroker, ABC, Generic[JetStreamConsumerType]): + """Base JetStream broker for taskiq. + + It has two subclasses - PullBasedJetStreamBroker + and PushBasedJetStreamBroker. - This broker creates a JetStream context - and uses it to send and receive messages. + These brokers create a JetStream context + and use it to send and receive messages. This is useful for systems where you need to be sure that messages are delivered to the workers. """ - def __init__( # noqa: WPS211 (too many args) - self, + def __init__( + self: Self, servers: Union[str, List[str]], - subject: str = "tasiq_tasks", - stream_name: str = "taskiq_jstream", + subject: str = "taskiq_tasks", + stream_name: str = "taskiq_jetstream", queue: Optional[str] = None, + durable: str = "taskiq_durable", result_backend: "Optional[AsyncResultBackend[_T]]" = None, task_id_generator: Optional[Callable[[], str]] = None, stream_config: Optional[StreamConfig] = None, + consumer_config: Optional[ConsumerConfig] = None, + pull_consume_batch: int = 1, + pull_consume_timeout: Optional[float] = None, **connection_kwargs: Any, ) -> None: super().__init__(result_backend, task_id_generator) self.servers = servers self.client: Client = Client() self.connection_kwargs = connection_kwargs - self.queue = queue self.subject = subject self.stream_name = stream_name self.js: JetStreamContext self.stream_config = stream_config or StreamConfig() + self.consumer_config = consumer_config + # Only for push based consumer + self.queue = queue + self.default_consumer_name: Final = "taskiq_consumer" + # Only for pull based consumer + self.durable = durable + self.pull_consume_batch = pull_consume_batch + self.pull_consume_timeout = pull_consume_timeout + + self.consumer: JetStreamConsumerType + async def startup(self) -> None: """ Startup event handler. @@ -127,7 +151,13 @@ async def startup(self) -> None: if not self.stream_config.subjects: self.stream_config.subjects = [self.subject] await self.js.add_stream(config=self.stream_config) - + await self._startup_consumer() + + async def shutdown(self) -> None: + """Close connections to NATS.""" + await self.client.close() + await super().shutdown() + async def kick(self, message: BrokerMessage) -> None: """ Send a message using NATS. @@ -139,21 +169,90 @@ async def kick(self, message: BrokerMessage) -> None: payload=message.message, headers=message.labels, ) + + @abstractmethod + async def _startup_consumer(self: Self) -> None: + """Create consumer.""" + + +class PushBasedJetStreamBroker( + BaseJetStreamBroker[JetStreamContext.PushSubscription], +): + """JetStream broker for push based message consumption. + + It's named `push` based because nats server push messages to + the consumer, not consumer requests them. + """ + + async def _startup_consumer(self: Self) -> None: + if not self.consumer_config: + self.consumer_config = ConsumerConfig( + name=self.default_consumer_name, + durable_name=self.default_consumer_name, + ) + self.consumer = await self.js.subscribe( + subject=self.subject, + queue=self.queue or "", + config=self.consumer_config, + ) + async def listen(self) -> AsyncGenerator[AckableMessage, None]: """ Start listen to new messages. :yield: incoming messages. """ - subscribe = await self.js.subscribe(self.subject, queue=self.queue or "") - async for message in subscribe.messages: + async for message in self.consumer.messages: yield AckableMessage( data=message.data, ack=message.ack, ) - async def shutdown(self) -> None: - """Close connections to NATS.""" - await self.client.close() - await super().shutdown() + +class PullBasedJetStreamBroker( + BaseJetStreamBroker[JetStreamContext.PullSubscription], +): + """JetStream broker for pull based message consumption. + + It's named `pull` based because consumer requests messages, + not NATS server sends them. + """ + + async def _startup_consumer(self: Self) -> None: + if not self.consumer_config: + self.consumer_config = ConsumerConfig( + durable_name=self.durable, + ) + + # We must use this method to create pull based consumer + # because consumer config won't change without it. + await self.js.add_consumer( + stream=self.stream_config.name or self.stream_name, + config=self.consumer_config, + ) + self.consumer = await self.js.pull_subscribe( + subject=self.subject, + durable=self.durable, + config=self.consumer_config, + ) + + async def listen(self) -> AsyncGenerator[AckableMessage, None]: + """ + Start listen to new messages. + + :yield: incoming messages. + """ + while True: + try: + nats_messages: List[NatsMessage] = await self.consumer.fetch( + batch=self.pull_consume_batch, + timeout=self.pull_consume_timeout, + ) + for nats_message in nats_messages: + yield AckableMessage( + data=nats_message.data, + ack=nats_message.ack, + ) + except NatsTimeoutError: + continue diff --git a/tests/conftest.py b/tests/conftest.py index 547758f..6d54852 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,9 @@ import os import uuid -from typing import List +from typing import AsyncGenerator, Final, List + +from nats.js import JetStreamContext +from nats import NATS import pytest @@ -38,3 +41,18 @@ def nats_urls() -> List[str]: """ urls = os.environ.get("NATS_URLS") or "nats://localhost:4222" return urls.split(",") + + +@pytest.fixture() +async def nats_jetstream(nats_urls: List[str]) -> AsyncGenerator[JetStreamContext, None]: + """Create and yield nats client and jetstream instances. + + :param nats_urls: urls to nats. + + :yields: NATS JetStream. + """ + nats: Final = NATS() + await nats.connect(servers=nats_urls) + jetstream: Final = nats.jetstream() + yield jetstream + await nats.close() diff --git a/tests/test_jetstream.py b/tests/test_jetstream.py new file mode 100644 index 0000000..ecfb175 --- /dev/null +++ b/tests/test_jetstream.py @@ -0,0 +1,93 @@ +import asyncio +import uuid +from typing import List + +import pytest +from taskiq import AckableMessage, BrokerMessage + +from taskiq_nats import PushBasedJetStreamBroker, PullBasedJetStreamBroker +from tests.utils import read_message + + +@pytest.mark.anyio +async def test_push_based_broker_success( + nats_urls: List[str], + nats_subject: str, +) -> None: + """ + Tests that PushBasedJetStreamBroker works. + + This function sends a message to JetStream + before starting to listen to it. + + It expects to receive the same message. + """ + broker = PushBasedJetStreamBroker( + servers=nats_urls, + subject=nats_subject, + queue=uuid.uuid4().hex, + stream_name=uuid.uuid4().hex, + ) + await broker.startup() + sent_message = BrokerMessage( + task_id=uuid.uuid4().hex, + task_name="meme", + message=b"some", + labels={}, + ) + await broker.kick(sent_message) + ackable_msg = await asyncio.wait_for(read_message(broker), 0.5) + assert isinstance(ackable_msg, AckableMessage) + assert ackable_msg.data == sent_message.message + ack = ackable_msg.ack() + if ack is not None: + await ack + await broker.js.delete_consumer( + stream=broker.stream_name, + consumer=broker.default_consumer_name, + ) + await broker.js.delete_stream( + broker.stream_name, + ) + await broker.shutdown() + + +@pytest.mark.anyio() +async def test_pull_based_broker_success( + nats_urls: List[str], + nats_subject: str, +) -> None: + """ + Tests that PullBasedJetStreamBroker works. + + This function sends a message to JetStream + before starting to listen to it. + + It expects to receive the same message. + """ + broker = PullBasedJetStreamBroker( + servers=nats_urls, + subject=nats_subject, + ) + await broker.startup() + sent_message = BrokerMessage( + task_id=uuid.uuid4().hex, + task_name="meme", + message=b"some", + labels={}, + ) + await broker.kick(sent_message) + ackable_msg = await asyncio.wait_for(read_message(broker), 0.5) + assert isinstance(ackable_msg, AckableMessage) + assert ackable_msg.data == sent_message.message + ack = ackable_msg.ack() + if ack is not None: + await ack + await broker.js.delete_consumer( + stream=broker.stream_name, + consumer=broker.default_consumer_name, + ) + await broker.js.delete_stream( + broker.stream_name, + ) + await broker.shutdown() \ No newline at end of file diff --git a/tests/test_jstream.py b/tests/test_jstream.py deleted file mode 100644 index ba6f930..0000000 --- a/tests/test_jstream.py +++ /dev/null @@ -1,42 +0,0 @@ -import asyncio -import uuid -from typing import List - -import pytest -from taskiq import AckableMessage, BrokerMessage - -from taskiq_nats import JetStreamBroker -from tests.utils import read_message - - -@pytest.mark.anyio -async def test_success(nats_urls: List[str], nats_subject: str) -> None: - """ - Tests that JetStream works. - - This function sends a message to JetStream - before starting to listen to it. - - It epexts to receive the same message. - """ - broker = JetStreamBroker( - servers=nats_urls, - subject=nats_subject, - queue=uuid.uuid4().hex, - stream_name=uuid.uuid4().hex, - ) - await broker.startup() - sent_message = BrokerMessage( - task_id=uuid.uuid4().hex, - task_name="meme", - message=b"some", - labels={}, - ) - await broker.kick(sent_message) - ackable_msg = await asyncio.wait_for(read_message(broker), 0.5) - assert isinstance(ackable_msg, AckableMessage) - assert ackable_msg.data == sent_message.message - ack = ackable_msg.ack() - if ack is not None: - await ack - await broker.shutdown()