Skip to content

Commit

Permalink
Added PushBasedJetStreamBroker and PullBasedJetStreamBroker
Browse files Browse the repository at this point in the history
Signed-off-by: chandr-andr (Kiselev Aleksandr) <[email protected]>
  • Loading branch information
chandr-andr committed Oct 11, 2023
1 parent 359214f commit 2dff8b5
Show file tree
Hide file tree
Showing 6 changed files with 298 additions and 71 deletions.
65 changes: 58 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pip install taskiq taskiq-nats

Here's a minimal setup example with a broker and one task.

### Default NATS broker.
```python
import asyncio
from taskiq_nats import NatsBroker, JetStreamBroker
Expand All @@ -27,15 +28,47 @@ broker = NatsBroker(
queue="random_queue_name",
)

# Or alternatively you can use a JetStream broker:
broker = JetStreamBroker(
[

@broker.task
async def my_lovely_task():
print("I love taskiq")


async def main():
await broker.startup()

await my_lovely_task.kiq()

await broker.shutdown()


if __name__ == "__main__":
asyncio.run(main())

```
### NATS broker based on JetStream
```python
import asyncio
from taskiq_nats import (
PushBasedJetStreamBroker,
PullBasedJetStreamBroker
)

broker = PushBasedJetStreamBroker(
servers=[
"nats://nats1:4222",
"nats://nats2:4222",
],
queue="random_queue_name",
subject="my-subj",
stream_name="my-stream"
queue="awesome_queue_name",
)

# Or you can use pull based variant
broker = PullBasedJetStreamBroker(
servers=[
"nats://nats1:4222",
"nats://nats2:4222",
],
durable="awesome_durable_consumer_name",
)


Expand All @@ -54,7 +87,6 @@ async def main():

if __name__ == "__main__":
asyncio.run(main())

```

## NatsBroker configuration
Expand All @@ -68,3 +100,22 @@ Here's the constructor parameters:
* `result_backend` - custom result backend.
* `task_id_generator` - custom function to generate task ids.
* Every other keyword argument will be sent to `nats.connect` function.

## JetStreamBroker configuration
### Common
* `servers` - a single string or a list of strings with nats nodes addresses.
* `subject` - name of the subect that will be used to exchange tasks betwee workers and clients.
* `stream_name` - name of the stream where subjects will be located.
* `queue` - a single string or a list of strings with nats nodes addresses.
* `result_backend` - custom result backend.
* `task_id_generator` - custom function to generate task ids.
* `stream_config` - a config for stream.
* `consumer_config` - a config for consumer.

### PushBasedJetStreamBroker
* `queue` - name of the queue. It's used to share messages between different consumers.

### PullBasedJetStreamBroker
* `durable` - durable name of the consumer. It's used to share messages between different consumers.
* `pull_consume_batch` - maximum number of message that can be fetched each time.
* `pull_consume_timeout` - timeout for messages fetch. If there is no messages, we start fetching messages again.
12 changes: 10 additions & 2 deletions taskiq_nats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@
uses NATS as a message queue.
"""

from taskiq_nats.broker import JetStreamBroker, NatsBroker
from taskiq_nats.broker import (
PushBasedJetStreamBroker,
PullBasedJetStreamBroker,
NatsBroker,
)

__all__ = ["NatsBroker", "JetStreamBroker"]
__all__ = [
"NatsBroker",
"PushBasedJetStreamBroker",
"PullBasedJetStreamBroker",
]
137 changes: 118 additions & 19 deletions taskiq_nats/broker.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
from abc import ABC, abstractmethod
from logging import getLogger
from typing import Any, AsyncGenerator, Callable, List, Optional, TypeVar, Union
from typing import Any, AsyncGenerator, Callable, Final, Generic, List, Optional, Self, TypeVar, Union

from nats.aio.client import Client
from nats.aio.msg import Msg as NatsMessage
from nats.js import JetStreamContext
from nats.js.api import StreamConfig
from nats.js.api import StreamConfig, ConsumerConfig
from nats.errors import TimeoutError as NatsTimeoutError
from taskiq import AckableMessage, AsyncBroker, AsyncResultBackend, BrokerMessage

_T = TypeVar("_T") # noqa: WPS111 (Too short)


JetStreamConsumerType = TypeVar(
"JetStreamConsumerType",
)


logger = getLogger("taskiq_nats")


Expand Down Expand Up @@ -80,38 +88,54 @@ async def shutdown(self) -> None:
await super().shutdown()


class JetStreamBroker(AsyncBroker): # noqa: WPS230
"""
JetStream broker for taskiq.
class BaseJetStreamBroker(AsyncBroker, ABC, Generic[JetStreamConsumerType]):
"""Base JetStream broker for taskiq.
It has two subclasses - PullBasedJetStreamBroker
and PushBasedJetStreamBroker.
This broker creates a JetStream context
and uses it to send and receive messages.
These brokers create a JetStream context
and use it to send and receive messages.
This is useful for systems where you need to
be sure that messages are delivered to the workers.
"""

def __init__( # noqa: WPS211 (too many args)
self,
def __init__(
self: Self,
servers: Union[str, List[str]],
subject: str = "tasiq_tasks",
stream_name: str = "taskiq_jstream",
subject: str = "taskiq_tasks",
stream_name: str = "taskiq_jetstream",
queue: Optional[str] = None,
durable: str = "taskiq_durable",
result_backend: "Optional[AsyncResultBackend[_T]]" = None,
task_id_generator: Optional[Callable[[], str]] = None,
stream_config: Optional[StreamConfig] = None,
consumer_config: Optional[ConsumerConfig] = None,
pull_consume_batch: int = 1,
pull_consume_timeout: Optional[float] = None,
**connection_kwargs: Any,
) -> None:
super().__init__(result_backend, task_id_generator)
self.servers = servers
self.client: Client = Client()
self.connection_kwargs = connection_kwargs
self.queue = queue
self.subject = subject
self.stream_name = stream_name
self.js: JetStreamContext
self.stream_config = stream_config or StreamConfig()
self.consumer_config = consumer_config

# Only for push based consumer
self.queue = queue
self.default_consumer_name: Final = "taskiq_consumer"
# Only for pull based consumer
self.durable = durable
self.pull_consume_batch = pull_consume_batch
self.pull_consume_timeout = pull_consume_timeout

self.consumer: JetStreamConsumerType

async def startup(self) -> None:
"""
Startup event handler.
Expand All @@ -127,7 +151,13 @@ async def startup(self) -> None:
if not self.stream_config.subjects:
self.stream_config.subjects = [self.subject]
await self.js.add_stream(config=self.stream_config)

await self._startup_consumer()

async def shutdown(self) -> None:
"""Close connections to NATS."""
await self.client.close()
await super().shutdown()

async def kick(self, message: BrokerMessage) -> None:
"""
Send a message using NATS.
Expand All @@ -139,21 +169,90 @@ async def kick(self, message: BrokerMessage) -> None:
payload=message.message,
headers=message.labels,
)

@abstractmethod
async def _startup_consumer(self: Self) -> None:
"""Create consumer."""


class PushBasedJetStreamBroker(
BaseJetStreamBroker[JetStreamContext.PushSubscription],
):
"""JetStream broker for push based message consumption.
It's named `push` based because nats server push messages to
the consumer, not consumer requests them.
"""

async def _startup_consumer(self: Self) -> None:
if not self.consumer_config:
self.consumer_config = ConsumerConfig(
name=self.default_consumer_name,
durable_name=self.default_consumer_name,
)

self.consumer = await self.js.subscribe(
subject=self.subject,
queue=self.queue or "",
config=self.consumer_config,
)

async def listen(self) -> AsyncGenerator[AckableMessage, None]:
"""
Start listen to new messages.
:yield: incoming messages.
"""
subscribe = await self.js.subscribe(self.subject, queue=self.queue or "")
async for message in subscribe.messages:
async for message in self.consumer.messages:
yield AckableMessage(
data=message.data,
ack=message.ack,
)

async def shutdown(self) -> None:
"""Close connections to NATS."""
await self.client.close()
await super().shutdown()

class PullBasedJetStreamBroker(
BaseJetStreamBroker[JetStreamContext.PullSubscription],
):
"""JetStream broker for pull based message consumption.
It's named `pull` based because consumer requests messages,
not NATS server sends them.
"""

async def _startup_consumer(self: Self) -> None:
if not self.consumer_config:
self.consumer_config = ConsumerConfig(
durable_name=self.durable,
)

# We must use this method to create pull based consumer
# because consumer config won't change without it.
await self.js.add_consumer(
stream=self.stream_config.name or self.stream_name,
config=self.consumer_config,
)
self.consumer = await self.js.pull_subscribe(
subject=self.subject,
durable=self.durable,
config=self.consumer_config,
)

async def listen(self) -> AsyncGenerator[AckableMessage, None]:
"""
Start listen to new messages.
:yield: incoming messages.
"""
while True:
try:
nats_messages: List[NatsMessage] = await self.consumer.fetch(
batch=self.pull_consume_batch,
timeout=self.pull_consume_timeout,
)
for nats_message in nats_messages:
yield AckableMessage(
data=nats_message.data,
ack=nats_message.ack,
)
except NatsTimeoutError:
continue
20 changes: 19 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import os
import uuid
from typing import List
from typing import AsyncGenerator, Final, List

from nats.js import JetStreamContext
from nats import NATS

import pytest

Expand Down Expand Up @@ -38,3 +41,18 @@ def nats_urls() -> List[str]:
"""
urls = os.environ.get("NATS_URLS") or "nats://localhost:4222"
return urls.split(",")


@pytest.fixture()
async def nats_jetstream(nats_urls: List[str]) -> AsyncGenerator[JetStreamContext, None]:
"""Create and yield nats client and jetstream instances.
:param nats_urls: urls to nats.
:yields: NATS JetStream.
"""
nats: Final = NATS()
await nats.connect(servers=nats_urls)
jetstream: Final = nats.jetstream()
yield jetstream
await nats.close()
Loading

0 comments on commit 2dff8b5

Please sign in to comment.