Skip to content

Commit

Permalink
Feat/origin_interface_id_for_message (#398)
Browse files Browse the repository at this point in the history
# Description

ID string added to messenger interface constructor.
This string is added to all the messages received via this interface.
The first message origin is the origin of the context it is related to.

---------

Co-authored-by: Roman Zlobin <[email protected]>
  • Loading branch information
pseusys and RLKRo authored Nov 13, 2024
1 parent e5e286c commit ff6b93c
Show file tree
Hide file tree
Showing 8 changed files with 197 additions and 82 deletions.
7 changes: 7 additions & 0 deletions chatsky/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,11 @@ class Context(BaseModel):
It is meant to be used by the framework only. Accessing it may result in pipeline breakage.
"""

origin_interface: Optional[str] = Field(default=None)
"""
Name of the interface that produced the first request in this context.
"""

@classmethod
def init(cls, start_label: AbsoluteNodeLabelInitTypes, id: Optional[Union[UUID, int, str]] = None):
"""Initialize new context from ``start_label`` and, optionally, context ``id``."""
Expand All @@ -165,6 +170,8 @@ def add_request(self, request: MessageInitTypes):
request_message = Message.model_validate(request)
if len(self.requests) == 0:
self.requests[1] = request_message
if request_message.origin is not None:
self.origin_interface = request_message.origin.interface
else:
last_index = get_last_index(self.requests)
self.requests[last_index + 1] = request_message
Expand Down
66 changes: 41 additions & 25 deletions chatsky/core/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import uuid
import abc

from pydantic import Field, FilePath, HttpUrl, model_validator, field_validator, field_serializer
from pydantic import BaseModel, Field, FilePath, HttpUrl, model_validator, field_validator, field_serializer
from pydantic_core import Url

from chatsky.utils.devel import (
Expand Down Expand Up @@ -257,6 +257,43 @@ class MediaGroup(Attachment):
chatsky_attachment_type: Literal["media_group"] = "media_group"


class Origin(BaseModel):
"""
Denotes the origin of the message.
"""

message: Optional[Any] = None
"""
Original data that the message is created from.
E.g. telegram update.
"""
interface: Optional[str] = None
"""
Name of the interface that produced the message.
"""

@field_serializer("message", when_used="json")
def pickle_serialize_message(self, value):
"""
Cast :py:attr:`message` to string via pickle.
Allows storing arbitrary data in this field when using context storages.
"""
if value is not None:
return pickle_serializer(value)
return value

@field_validator("message", mode="before")
@classmethod
def pickle_validate_message(cls, value):
"""
Restore :py:attr:`message` after being processed with
:py:meth:`pickle_serialize_message`.
"""
if value is not None:
return pickle_validator(value)
return value


class Message(DataModel):
"""
Class representing a message and contains several
Expand Down Expand Up @@ -292,7 +329,7 @@ class level variables to store message information.
] = None
annotations: Optional[Dict[str, Any]] = None
misc: Optional[Dict[str, Any]] = None
original_message: Optional[Any] = None
origin: Optional[Origin] = None

def __init__( # this allows initializing Message with string as positional argument
self,
Expand Down Expand Up @@ -320,15 +357,15 @@ def __init__( # this allows initializing Message with string as positional argu
] = None,
annotations: Optional[Dict[str, Any]] = None,
misc: Optional[Dict[str, Any]] = None,
original_message: Optional[Any] = None,
origin: Optional[Origin] = None,
**kwargs,
):
super().__init__(
text=text,
attachments=attachments,
annotations=annotations,
misc=misc,
original_message=original_message,
origin=origin,
**kwargs,
)

Expand All @@ -350,27 +387,6 @@ def pickle_validate_dicts(cls, value):
return json_pickle_validator(value)
return value

@field_serializer("original_message", when_used="json")
def pickle_serialize_original_message(self, value):
"""
Cast :py:attr:`original_message` to string via pickle.
Allows storing arbitrary data in this field when using context storages.
"""
if value is not None:
return pickle_serializer(value)
return value

@field_validator("original_message", mode="before")
@classmethod
def pickle_validate_original_message(cls, value):
"""
Restore :py:attr:`original_message` after being processed with
:py:meth:`pickle_serialize_original_message`.
"""
if value is not None:
return pickle_validator(value)
return value

def __str__(self) -> str:
return " ".join([f"{key}='{value}'" for key, value in self.model_dump(exclude_none=True).items()])

Expand Down
5 changes: 5 additions & 0 deletions chatsky/messengers/common/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class MessengerInterface(abc.ABC):
It is responsible for connection between user and pipeline, as well as for request-response transactions.
"""

def __init__(self) -> None:
self.id = type(self).__name__

@abc.abstractmethod
async def connect(self, pipeline_runner: PipelineRunnerFunction):
"""
Expand Down Expand Up @@ -61,6 +64,7 @@ class MessengerInterfaceWithAttachments(MessengerInterface, abc.ABC):
"""

def __init__(self, attachments_directory: Optional[Path] = None) -> None:
super().__init__()
tempdir = gettempdir()
if attachments_directory is not None and not str(attachments_directory.absolute()).startswith(tempdir):
self.attachments_directory = attachments_directory
Expand Down Expand Up @@ -170,6 +174,7 @@ class CallbackMessengerInterface(MessengerInterface):
"""

def __init__(self) -> None:
super().__init__()
self._pipeline_runner: Optional[PipelineRunnerFunction] = None

async def connect(self, pipeline_runner: PipelineRunnerFunction):
Expand Down
3 changes: 2 additions & 1 deletion chatsky/messengers/telegram/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Invoice,
Location,
Message,
Origin,
Poll,
PollOption,
Sticker,
Expand Down Expand Up @@ -627,7 +628,7 @@ async def _on_event(self, update: Update, _: Any, create_message: Callable[[Upda
data_available = update.message is not None or update.callback_query is not None
if update.effective_chat is not None and data_available:
message = create_message(update)
message.original_message = update
message.origin = Origin.model_construct(message=update, interface=self.id)
resp = await self._pipeline_runner(message, update.effective_chat.id)
if resp.last_response is not None:
await self.cast_message_to_telegram_and_send(
Expand Down
12 changes: 10 additions & 2 deletions chatsky/messengers/telegram/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ class LongpollingInterface(_AbstractTelegramInterface):
"""

def __init__(
self, token: str, attachments_directory: Optional[Path] = None, interval: int = 2, timeout: int = 20
self,
token: str,
attachments_directory: Optional[Path] = None,
interval: int = 2,
timeout: int = 20,
) -> None:
super().__init__(token, attachments_directory)
self.interval = interval
Expand All @@ -54,7 +58,11 @@ class WebhookInterface(_AbstractTelegramInterface):
"""

def __init__(
self, token: str, attachments_directory: Optional[Path] = None, host: str = "localhost", port: int = 844
self,
token: str,
attachments_directory: Optional[Path] = None,
host: str = "localhost",
port: int = 844,
):
super().__init__(token, attachments_directory)
self.listen = host
Expand Down
3 changes: 2 additions & 1 deletion tests/core/test_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Location,
DataAttachment,
Message,
Origin,
Poll,
PollOption,
Sticker,
Expand Down Expand Up @@ -93,7 +94,7 @@ def test_attachment_serialize(self, attachment: DataAttachment):
def test_field_serializable(self, random_original_message: UnserializableObject):
message = Message(text="sample message")
message.misc = {"answer": 42, "unserializable": random_original_message}
message.original_message = random_original_message
message.origin = Origin.model_construct(message=random_original_message)
message.some_extra_field = random_original_message
message.other_extra_field = {"unserializable": random_original_message}
serialized = message.model_dump_json()
Expand Down
Loading

0 comments on commit ff6b93c

Please sign in to comment.