Skip to content

Commit

Permalink
Lint fixes
Browse files Browse the repository at this point in the history
Signed-off-by: chandr-andr (Kiselev Aleksandr) <[email protected]>
  • Loading branch information
chandr-andr committed Oct 11, 2023
1 parent 2dff8b5 commit d400760
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 86 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Testing taskiq-redis
name: Testing taskiq-nats

on: pull_request

Expand Down
4 changes: 2 additions & 2 deletions taskiq_nats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
"""

from taskiq_nats.broker import (
PushBasedJetStreamBroker,
PullBasedJetStreamBroker,
NatsBroker,
PullBasedJetStreamBroker,
PushBasedJetStreamBroker,
)

__all__ = [
Expand Down
152 changes: 78 additions & 74 deletions taskiq_nats/broker.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import typing
from abc import ABC, abstractmethod
from logging import getLogger
from typing import Any, AsyncGenerator, Callable, Final, Generic, List, Optional, Self, TypeVar, Union

from nats.aio.client import Client
from nats.aio.msg import Msg as NatsMessage
from nats.js import JetStreamContext
from nats.js.api import StreamConfig, ConsumerConfig
from nats.errors import TimeoutError as NatsTimeoutError
from nats.js import JetStreamContext
from nats.js.api import ConsumerConfig, StreamConfig
from taskiq import AckableMessage, AsyncBroker, AsyncResultBackend, BrokerMessage

_T = TypeVar("_T") # noqa: WPS111 (Too short)
_T = typing.TypeVar("_T") # noqa: WPS111 (Too short)


JetStreamConsumerType = TypeVar(
JetStreamConsumerType = typing.TypeVar(
"JetStreamConsumerType",
)

Expand All @@ -37,12 +37,12 @@ class NatsBroker(AsyncBroker):

def __init__( # noqa: WPS211 (too many args)
self,
servers: Union[str, List[str]],
servers: typing.Union[str, typing.List[str]],
subject: str = "taskiq_tasks",
queue: Optional[str] = None,
result_backend: "Optional[AsyncResultBackend[_T]]" = None,
task_id_generator: Optional[Callable[[], str]] = None,
**connection_kwargs: Any,
queue: typing.Optional[str] = None,
result_backend: "typing.Optional[AsyncResultBackend[_T]]" = None,
task_id_generator: typing.Optional[typing.Callable[[], str],] = None,
**connection_kwargs: typing.Any,
) -> None:
super().__init__(result_backend, task_id_generator)
self.servers = servers
Expand Down Expand Up @@ -72,7 +72,7 @@ async def kick(self, message: BrokerMessage) -> None:
headers=message.labels,
)

async def listen(self) -> AsyncGenerator[bytes, None]:
async def listen(self) -> typing.AsyncGenerator[bytes, None]:
"""
Start listen to new messages.
Expand All @@ -88,9 +88,13 @@ async def shutdown(self) -> None:
await super().shutdown()


class BaseJetStreamBroker(AsyncBroker, ABC, Generic[JetStreamConsumerType]):
class BaseJetStreamBroker( # noqa: WPS230 (too many attrs)
AsyncBroker,
ABC,
typing.Generic[JetStreamConsumerType],
):
"""Base JetStream broker for taskiq.
It has two subclasses - PullBasedJetStreamBroker
and PushBasedJetStreamBroker.
Expand All @@ -101,41 +105,41 @@ class BaseJetStreamBroker(AsyncBroker, ABC, Generic[JetStreamConsumerType]):
be sure that messages are delivered to the workers.
"""

def __init__(
self: Self,
servers: Union[str, List[str]],
def __init__( # noqa: WPS211 (too many args)
self: typing.Self,
servers: typing.Union[str, typing.List[str]],
subject: str = "taskiq_tasks",
stream_name: str = "taskiq_jetstream",
queue: Optional[str] = None,
queue: typing.Optional[str] = None,
durable: str = "taskiq_durable",
result_backend: "Optional[AsyncResultBackend[_T]]" = None,
task_id_generator: Optional[Callable[[], str]] = None,
stream_config: Optional[StreamConfig] = None,
consumer_config: Optional[ConsumerConfig] = None,
result_backend: "typing.Optional[AsyncResultBackend[_T]]" = None,
task_id_generator: typing.Optional[typing.Callable[[], str]] = None,
stream_config: typing.Optional[StreamConfig] = None,
consumer_config: typing.Optional[ConsumerConfig] = None,
pull_consume_batch: int = 1,
pull_consume_timeout: Optional[float] = None,
**connection_kwargs: Any,
pull_consume_timeout: typing.Optional[float] = None,
**connection_kwargs: typing.Any,
) -> None:
super().__init__(result_backend, task_id_generator)
self.servers = servers
self.client: Client = Client()
self.connection_kwargs = connection_kwargs
self.subject = subject
self.stream_name = stream_name
self.servers: typing.Final = servers
self.client: typing.Final = Client()
self.connection_kwargs: typing.Final = connection_kwargs
self.subject: typing.Final = subject
self.stream_name: typing.Final = stream_name
self.js: JetStreamContext
self.stream_config = stream_config or StreamConfig()
self.consumer_config = consumer_config

# Only for push based consumer
self.queue = queue
self.default_consumer_name: Final = "taskiq_consumer"
self.queue: typing.Final = queue
self.default_consumer_name: typing.Final = "taskiq_consumer"
# Only for pull based consumer
self.durable = durable
self.pull_consume_batch = pull_consume_batch
self.pull_consume_timeout = pull_consume_timeout
self.durable: typing.Final = durable
self.pull_consume_batch: typing.Final = pull_consume_batch
self.pull_consume_timeout: typing.Final = pull_consume_timeout

self.consumer: JetStreamConsumerType

async def startup(self) -> None:
"""
Startup event handler.
Expand All @@ -152,12 +156,12 @@ async def startup(self) -> None:
self.stream_config.subjects = [self.subject]
await self.js.add_stream(config=self.stream_config)
await self._startup_consumer()

async def shutdown(self) -> None:
"""Close connections to NATS."""
await self.client.close()
await super().shutdown()

async def kick(self, message: BrokerMessage) -> None:
"""
Send a message using NATS.
Expand All @@ -169,22 +173,34 @@ async def kick(self, message: BrokerMessage) -> None:
payload=message.message,
headers=message.labels,
)

@abstractmethod
async def _startup_consumer(self: Self) -> None:
async def _startup_consumer(self: typing.Self) -> None:
"""Create consumer."""


class PushBasedJetStreamBroker(
BaseJetStreamBroker[JetStreamContext.PushSubscription],
):
"""JetStream broker for push based message consumption.
It's named `push` based because nats server push messages to
the consumer, not consumer requests them.
"""

async def _startup_consumer(self: Self) -> None:
async def listen(self) -> typing.AsyncGenerator[AckableMessage, None]:
"""
Start listen to new messages.
:yield: incoming messages.
"""
async for message in self.consumer.messages:
yield AckableMessage(
data=message.data,
ack=message.ack,
)

async def _startup_consumer(self: typing.Self) -> None:
if not self.consumer_config:
self.consumer_config = ConsumerConfig(
name=self.default_consumer_name,
Expand All @@ -196,56 +212,26 @@ async def _startup_consumer(self: Self) -> None:
queue=self.queue or "",
config=self.consumer_config,
)

async def listen(self) -> AsyncGenerator[AckableMessage, None]:
"""
Start listen to new messages.
:yield: incoming messages.
"""
async for message in self.consumer.messages:
yield AckableMessage(
data=message.data,
ack=message.ack,
)


class PullBasedJetStreamBroker(
BaseJetStreamBroker[JetStreamContext.PullSubscription],
):
"""JetStream broker for pull based message consumption.
It's named `pull` based because consumer requests messages,
not NATS server sends them.
"""

async def _startup_consumer(self: Self) -> None:
if not self.consumer_config:
self.consumer_config = ConsumerConfig(
durable_name=self.durable,
)

# We must use this method to create pull based consumer
# because consumer config won't change without it.
await self.js.add_consumer(
stream=self.stream_config.name or self.stream_name,
config=self.consumer_config,
)
self.consumer = await self.js.pull_subscribe(
subject=self.subject,
durable=self.durable,
config=self.consumer_config,
)

async def listen(self) -> AsyncGenerator[AckableMessage, None]:
async def listen(self) -> typing.AsyncGenerator[AckableMessage, None]:
"""
Start listen to new messages.
:yield: incoming messages.
"""
while True:
while True: # noqa: WPS327
try:
nats_messages: List[NatsMessage] = await self.consumer.fetch(
nats_messages: typing.List[NatsMessage] = await self.consumer.fetch(
batch=self.pull_consume_batch,
timeout=self.pull_consume_timeout,
)
Expand All @@ -256,3 +242,21 @@ async def listen(self) -> AsyncGenerator[AckableMessage, None]:
)
except NatsTimeoutError:
continue

async def _startup_consumer(self: typing.Self) -> None:
if not self.consumer_config:
self.consumer_config = ConsumerConfig(
durable_name=self.durable,
)

# We must use this method to create pull based consumer
# because consumer config won't change without it.
await self.js.add_consumer(
stream=self.stream_config.name or self.stream_name,
config=self.consumer_config,
)
self.consumer = await self.js.pull_subscribe(
subject=self.subject,
durable=self.durable,
config=self.consumer_config,
)
11 changes: 6 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
import uuid
from typing import AsyncGenerator, Final, List

from nats.js import JetStreamContext
from nats import NATS

import pytest
from nats import NATS
from nats.js import JetStreamContext


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -44,9 +43,11 @@ def nats_urls() -> List[str]:


@pytest.fixture()
async def nats_jetstream(nats_urls: List[str]) -> AsyncGenerator[JetStreamContext, None]:
async def nats_jetstream(
nats_urls: List[str], # noqa: WPS442
) -> AsyncGenerator[JetStreamContext, None]:
"""Create and yield nats client and jetstream instances.
:param nats_urls: urls to nats.
:yields: NATS JetStream.
Expand Down
8 changes: 4 additions & 4 deletions tests/test_jetstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import pytest
from taskiq import AckableMessage, BrokerMessage

from taskiq_nats import PushBasedJetStreamBroker, PullBasedJetStreamBroker
from taskiq_nats import PullBasedJetStreamBroker, PushBasedJetStreamBroker
from tests.utils import read_message


@pytest.mark.anyio
async def test_push_based_broker_success(
async def test_push_based_broker_success( # noqa: WPS217 (too many await)
nats_urls: List[str],
nats_subject: str,
) -> None:
Expand Down Expand Up @@ -53,7 +53,7 @@ async def test_push_based_broker_success(


@pytest.mark.anyio()
async def test_pull_based_broker_success(
async def test_pull_based_broker_success( # noqa: WPS217 (too many await)
nats_urls: List[str],
nats_subject: str,
) -> None:
Expand Down Expand Up @@ -90,4 +90,4 @@ async def test_pull_based_broker_success(
await broker.js.delete_stream(
broker.stream_name,
)
await broker.shutdown()
await broker.shutdown()

0 comments on commit d400760

Please sign in to comment.