Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Messenger Interface rework #357

Open
wants to merge 55 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
70306f4
poetry.lock update
ZergLev May 6, 2024
af90b04
trying graceful termination
ZergLev May 6, 2024
9d7ecf0
graceful termination is done exclusively within the interface class (…
ZergLev May 8, 2024
0c377f1
mistake fixed
ZergLev May 8, 2024
8360976
sigint handling moved to pipeline + custom stop funcs added
ZergLev May 13, 2024
8476b16
feature works now
ZergLev May 13, 2024
4d83016
formatted code
ZergLev May 13, 2024
2e0dedc
fixed typo
ZergLev May 13, 2024
93b5359
interfaces enhanced with asyncio, changes to graceful termination added
ZergLev May 16, 2024
d759fba
added a check to signal handler(unfinished)
ZergLev May 16, 2024
01b0002
formatted with poetry
ZergLev May 16, 2024
52eb5bf
PollingMessengerInterface() overhaul draft
ZergLev May 22, 2024
88753a8
formatted with poetry
ZergLev May 22, 2024
5f98eda
refactor
ZergLev May 22, 2024
591bd8f
testing, but connect() doesn't work
ZergLev May 23, 2024
fd89b12
writing tests
ZergLev May 24, 2024
98a13b4
old unit-tests work with this now, a few mistakes fixed
ZergLev May 27, 2024
9889199
telegram bot works, but graceful termination apparently does not
ZergLev May 27, 2024
e956cc3
Trying the echo test
ZergLev Jun 16, 2024
e238432
Trying the echo test
ZergLev Jun 16, 2024
1f88e47
Trying the echo test
ZergLev Jun 16, 2024
97cec27
echo test draft (doesn't launch)
ZergLev Jun 16, 2024
e601ed6
first test works, several bug fixes
ZergLev Jun 19, 2024
7d8c68c
ContextLock() test added
ZergLev Jun 19, 2024
8ebf6ec
comments changed
ZergLev Jun 19, 2024
e366a61
debug output removed
ZergLev Jun 19, 2024
09fe10f
confusing comment removed
ZergLev Jun 19, 2024
0cecb80
comment changes
ZergLev Jun 19, 2024
1c24131
new tests moved to a separate file
ZergLev Jun 20, 2024
b49dac7
more tests added
ZergLev Jun 20, 2024
876ce8d
poll_timeout added + test changed
ZergLev Jun 21, 2024
f91910b
typo corrected
ZergLev Jun 21, 2024
585a37d
add siginthandler to async loop
RLKRo Jun 21, 2024
4180e17
fix test class
RLKRo Jun 21, 2024
0059bd3
adding worker timeouts and cleanup
ZergLev Jun 26, 2024
c3f18a7
Merge branch 'feat/graceful_termination' of https://github.com/ZergLe…
ZergLev Jun 26, 2024
65329f8
new _worker() seems to be working (it's awaited)
ZergLev Jun 26, 2024
30369f6
all tests but one working, can't call shutdown()
ZergLev Jun 26, 2024
ce9ac81
all tests working
ZergLev Jun 26, 2024
c765c18
ContextLock() moved to pipeline.py
ZergLev Jul 1, 2024
83ebe7f
formatted with poetry
ZergLev Jul 1, 2024
b99c4eb
Merge branch 'dev' into feat/graceful_termination
RLKRo Aug 2, 2024
8715c2f
Update tests/messengers/common/test_messenger_interface.py
ZergLev Aug 2, 2024
a829cf5
Update chatsky/messengers/common/interface.py
ZergLev Aug 5, 2024
16b049a
review changes started, bugs appeared
ZergLev Aug 7, 2024
a8607c6
Merge branch 'feat/graceful_termination' of https://github.com/ZergLe…
ZergLev Aug 7, 2024
5656437
moving graceful termination to pipeline, windows support added back i…
ZergLev Aug 16, 2024
cd16255
in the process of fixing bugs, docs partially added
ZergLev Aug 16, 2024
3282d18
new LongpollingMessengerInterface drafted + removing run_in_foregroun…
ZergLev Aug 21, 2024
b2140a5
Merge branch 'dev' into feat/graceful_termination
ZergLev Aug 21, 2024
6fde6c3
lint
ZergLev Aug 21, 2024
2529537
Merge branch 'feat/graceful_termination' of https://github.com/ZergLe…
ZergLev Aug 21, 2024
6334b8e
lint
ZergLev Aug 21, 2024
9a2381a
fully removed run_in_foreground, some changes to graceful termination
ZergLev Aug 23, 2024
b8c99d9
in the process of switching to BaseModel + other changes
ZergLev Aug 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 128 additions & 44 deletions chatsky/messengers/common/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,36 @@
import logging
from pathlib import Path
from tempfile import gettempdir
from typing import Optional, Any, List, Tuple, Hashable, TYPE_CHECKING, Type

from typing import Optional, Any, Hashable, TYPE_CHECKING, Type
from pydantic import BaseModel

if TYPE_CHECKING:
from chatsky.script import Context, Message
from chatsky.pipeline.types import PipelineRunnerFunction
from chatsky.messengers.common.types import PollingInterfaceLoopFunction
from chatsky.script.core.message import Attachment
from chatsky.pipeline.pipeline.pipeline import Pipeline

logger = logging.getLogger(__name__)


class MessengerInterface(abc.ABC):
class MessengerInterface(abc.ABC, BaseModel):
"""
Class that represents a message interface used for communication between pipeline and users.
It is responsible for connection between user and pipeline, as well as for request-response transactions.
"""

running: bool = True
"""Shows whether the interface is still accepting new requests."""
finished_working: bool = False
"""Shows whether the interface has finished processing all of the requests received."""

@abc.abstractmethod
async def connect(self, pipeline_runner: PipelineRunnerFunction):
async def connect(
self,
pipeline_runner: PipelineRunnerFunction,
):
"""
Method invoked when message interface is instantiated and connection is established.
May be used for sending an introduction message or displaying general bot information.
Expand All @@ -39,6 +50,15 @@ async def connect(self, pipeline_runner: PipelineRunnerFunction):
"""
raise NotImplementedError

async def cleanup(self):
"""
A placeholder method for any cleanup code you want to be
called before shutting down the program.
You can redefine this method in your class.
Note you also need to call cleanup() of the parent class.
"""
pass


class MessengerInterfaceWithAttachments(MessengerInterface, abc.ABC):
"""
Expand Down Expand Up @@ -93,75 +113,135 @@ class PollingMessengerInterface(MessengerInterface):
"""
Polling message interface runs in a loop, constantly asking users for a new input.
"""
number_of_workers: int = 2
_request_queue = asyncio.Queue()
_worker_tasks = []

@abc.abstractmethod
def _request(self) -> List[Tuple[Message, Hashable]]:
async def _respond(self, ctx_id, last_response):
"""
Method used for sending users request for their input.
Method used for sending users responses for their last input.

:return: A list of tuples: user inputs and context ids (any user ids) associated with the inputs.
:param ctx_id: Context id, specifies the user id. Without multiple messenger interfaces it's basically a
redundant parameter, because this function is just a more complex `print(last_response)`. (Change before merge)
:param last_response: Latest response from the pipeline which should be relayed to the specified user.
"""
raise NotImplementedError

@abc.abstractmethod
def _respond(self, responses: List[Context]):
"""
Method used for sending users responses for their last input.
async def _process_request(self, ctx_id: Any, update: Message, pipeline_runner: PipelineRunnerFunction):
"""Process a new update for ctx."""
context = await pipeline_runner(update, ctx_id)
await self._respond(ctx_id, context.last_response)

:param responses: A list of contexts, representing dialogs with the users;
`last_response`, `id` and some dialog info can be extracted from there.
async def _worker_job(self, pipeline_runner: PipelineRunnerFunction, worker_timeout: float):
"""
raise NotImplementedError
Obtain Lock over the current context,
Process the update and send it.
"""
request = await self._request_queue.get()
if request is not None:
(ctx_id, update) = request
async with self.pipeline.context_lock[ctx_id]: # get exclusive access to this context among interfaces
await asyncio.wait_for(
self._process_request(ctx_id, update, pipeline_runner),
timeout=worker_timeout,
)
return False
else:
return True

def _on_exception(self, e: BaseException):
# This worker doesn't save the request and basically deletes it from the queue in case it can't process it.
# An option to save the request may be fitting? Maybe with an amount of retries.
async def _worker(self, pipeline_runner: PipelineRunnerFunction, worker_timeout: float):
while self.running or not self._request_queue.empty():
try:
no_more_jobs = self._worker_job(pipeline_runner, worker_timeout)
if no_more_jobs:
logger.info("Worker finished working - all remaining requests have been processed.")
# Polling_loop should give the required data on whether the stop signal was sent or if
# the loop() function gave 'False'.
break
except TimeoutError:
logger.info("worker couldn't process request in time. A request *may* have been lost.")

@abc.abstractmethod
async def _get_updates(self) -> list[tuple[Any, Message]]:
"""
Method that is called on polling cycle exceptions, in some cases it should show users the exception.
By default, it logs all exit exceptions to `info` log and all non-exit exceptions to `error`.
Obtain updates from another server

:param e: The exception.
Example:
self.bot.request_updates()
"""
if isinstance(e, Exception):
logger.error(f"Exception in {type(self).__name__} loop!", exc_info=e)
else:
logger.info(f"{type(self).__name__} has stopped polling.")

async def _polling_job(self, poll_timeout: float):
try:
received_updates = await asyncio.wait_for(self._get_updates(), timeout=poll_timeout)
if received_updates is not None:
for update in received_updates:
await self._request_queue.put(update)
except TimeoutError:
logger.debug("polling_job failed - timed out")

async def _polling_loop(
self,
pipeline_runner: PipelineRunnerFunction,
loop: PollingInterfaceLoopFunction = lambda: True,
poll_timeout: float = None,
timeout: float = 0,
):
"""
Method running the request - response cycle once.
"""
user_updates = self._request()
responses = [await pipeline_runner(request, ctx_id) for request, ctx_id in user_updates]
self._respond(responses)
await asyncio.sleep(timeout)
try:
while loop() and self.running:
await asyncio.shield(self._polling_job(poll_timeout)) # shield from cancellation
await asyncio.sleep(timeout)
finally:
self.running = False
# If loop() is somehow True after being False once, this logging will be wrong.
# But no user would want to break their own logging, right?
if loop() is False:
logger.info("polling_loop stopped working - the loop() condition was False")
else:
logger.info("polling_loop stopped working - the stop signal was received.")
# If there are no more jobs/stop signal received, a special 'None' request is
# sent to the queue (one for each worker), they shut down the workers.
# In case of more workers than two, change the number of 'None' requests to the new number of workers.
for i in range(self.number_of_workers):
self._request_queue.put_nowait(None)

async def connect(
self,
pipeline_runner: PipelineRunnerFunction,
loop: PollingInterfaceLoopFunction = lambda: True,
poll_timeout: float = None,
worker_timeout: float = None,
timeout: float = 0,
):
# Saving strong references to workers, so that they can be cleaned up properly.
# shield() creates a task just like create_task() according to docs.
# But for safety we have two task wrappers, I guess.
for i in range(self.number_of_workers):
task = asyncio.create_task(asyncio.shield(self._worker(pipeline_runner, worker_timeout)))
self._worker_tasks.append(task)
print("worker tasks:", self._worker_tasks)
await self._polling_loop(loop=loop, poll_timeout=poll_timeout, timeout=timeout)

# Maybe "worker_cleanup" instead of this function name?
async def cleanup(self):
"""
Blocks until all workers are done.
"""
Method, running a request - response cycle in a loop.
The looping behavior is determined by `loop` and `timeout`,
for most cases the loop itself shouldn't be overridden.
await super().cleanup()
await asyncio.wait(self._worker_tasks)

:param pipeline_runner: A function that should process user request and return context;
usually it's a :py:meth:`~chatsky.pipeline.pipeline.pipeline.Pipeline._run_pipeline` function.
:param loop: a function that determines whether polling should be continued;
called in each cycle, should return `True` to continue polling or `False` to stop.
:param timeout: a time interval between polls (in seconds).
def _on_exception(self, e: BaseException):
"""
while loop():
try:
await self._polling_loop(pipeline_runner, timeout)
Method that is called on polling cycle exceptions, in some cases it should show users the exception.
By default, it logs all exit exceptions to `info` log and all non-exit exceptions to `error`.

except BaseException as e:
self._on_exception(e)
break
:param e: The exception.
"""
if isinstance(e, Exception):
logger.error(f"Exception in {type(self).__name__} loop!", exc_info=e)
else:
logger.info(f"{type(self).__name__} has stopped polling.")


class CallbackMessengerInterface(MessengerInterface):
Expand All @@ -171,8 +251,12 @@ class CallbackMessengerInterface(MessengerInterface):

def __init__(self) -> None:
self._pipeline_runner: Optional[PipelineRunnerFunction] = None
super().__init__()

async def connect(self, pipeline_runner: PipelineRunnerFunction):
async def connect(
self,
pipeline_runner: PipelineRunnerFunction,
):
self._pipeline_runner = pipeline_runner

async def on_request_async(
Expand Down
17 changes: 6 additions & 11 deletions chatsky/messengers/console.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from typing import Any, Hashable, List, Optional, TextIO, Tuple
from uuid import uuid4
from chatsky.messengers.common.interface import PollingMessengerInterface
from chatsky.pipeline.types import PipelineRunnerFunction
from chatsky.script.core.context import Context
from chatsky.script.core.message import Message


Expand All @@ -12,9 +10,6 @@ class CLIMessengerInterface(PollingMessengerInterface):
This message interface can maintain dialog with one user at a time only.
"""

supported_request_attachment_types = set()
supported_response_attachment_types = set()

def __init__(
self,
intro: Optional[str] = None,
Expand All @@ -29,13 +24,13 @@ def __init__(
self._prompt_response: str = prompt_response
self._descriptor: Optional[TextIO] = out_descriptor

def _request(self) -> List[Tuple[Message, Any]]:
return [(Message(input(self._prompt_request)), self._ctx_id)]
async def _get_updates(self) -> List[Tuple[Any, Message]]:
return [(self._ctx_id, Message(input(self._prompt_request)))]

def _respond(self, responses: List[Context]):
print(f"{self._prompt_response}{responses[0].last_response.text}", file=self._descriptor)
async def _respond(self, ctx_id, last_response: Message):
print(f"{self._prompt_response}{last_response.text}", file=self._descriptor)

async def connect(self, pipeline_runner: PipelineRunnerFunction, **kwargs):
async def connect(self, *args, **kwargs):
"""
The CLIProvider generates new dialog id used to user identification on each `connect` call.

Expand All @@ -46,4 +41,4 @@ async def connect(self, pipeline_runner: PipelineRunnerFunction, **kwargs):
self._ctx_id = uuid4()
if self._intro is not None:
print(self._intro)
await super().connect(pipeline_runner, **kwargs)
await super().connect(*args, **kwargs)
5 changes: 1 addition & 4 deletions chatsky/messengers/telegram/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@
Update,
Message as TelegramMessage,
)
from telegram.ext import Application, ExtBot, MessageHandler, CallbackQueryHandler
from telegram.ext.filters import ALL
from telegram.ext import Application, ExtBot

telegram_available = True
except ImportError:
Expand Down Expand Up @@ -93,8 +92,6 @@ def __init__(self, token: str, attachments_directory: Optional[Path] = None) ->
raise ImportError("`python-telegram-bot` package is missing.\nTry to run `pip install chatsky[telegram]`.")

self.application = Application.builder().token(token).build()
self.application.add_handler(MessageHandler(ALL, self.on_message))
self.application.add_handler(CallbackQueryHandler(self.on_callback))

async def get_attachment_bytes(self, source: str) -> bytes:
file = await self.application.bot.get_file(source)
Expand Down
33 changes: 29 additions & 4 deletions chatsky/messengers/telegram/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,29 @@
from pathlib import Path
from typing import Any, Optional

from chatsky.script import Message

from chatsky.messengers.common import PollingMessengerInterface

from chatsky.pipeline.types import PipelineRunnerFunction

from .abstract import _AbstractTelegramInterface

try:
from telegram.ext import MessageHandler, CallbackQueryHandler
from telegram.ext.filters import ALL

telegram_available = True
except ImportError:
telegram_available = False

try:
from telegram import Update
except ImportError:
Update = Any


class LongpollingInterface(_AbstractTelegramInterface):
class LongpollingInterface(_AbstractTelegramInterface, PollingMessengerInterface):
"""
Telegram messenger interface, that requests Telegram API in a loop.

Expand All @@ -35,11 +47,22 @@ def __init__(
self.interval = interval
self.timeout = timeout

async def connect(self, pipeline_runner: PipelineRunnerFunction, *args, **kwargs):
await super().connect(pipeline_runner, *args, **kwargs)
self.application.run_polling(
async def _get_updates(self) -> list[tuple[Any, Message]]:
updates = self.application.bot.get_updates(
poll_interval=self.interval, timeout=self.timeout, allowed_updates=Update.ALL_TYPES
)
parsed_updates = []
for update in updates:
data_available = update.message is not None or update.callback_query is not None
if update.effective_chat is not None and data_available:
message = self.extract_message_from_telegram(update)
message.original_message = update
parsed_updates.append((update.effective_chat.id, message))
return parsed_updates

async def _respond(self, ctx_id, last_response):
if last_response is not None:
await self.cast_message_to_telegram_and_send(self.application.bot, ctx_id, last_response)


class WebhookInterface(_AbstractTelegramInterface):
Expand All @@ -59,6 +82,8 @@ def __init__(
super().__init__(token, attachments_directory)
self.listen = host
self.port = port
self.application.add_handler(MessageHandler(ALL, self.on_message))
self.application.add_handler(CallbackQueryHandler(self.on_callback))

async def connect(self, pipeline_runner: PipelineRunnerFunction, *args, **kwargs):
await super().connect(pipeline_runner, *args, **kwargs)
Expand Down
Loading
Loading