diff --git a/besser/bot/core/entity/entity.py b/besser/bot/core/entity/entity.py index 5ab805d..5f06885 100644 --- a/besser/bot/core/entity/entity.py +++ b/besser/bot/core/entity/entity.py @@ -74,6 +74,7 @@ def to_json(self) -> dict: """ entity_json = { 'base_entity': self.base_entity, + 'description': self.description, 'entries': [] } if not self.base_entity: diff --git a/besser/bot/core/message.py b/besser/bot/core/message.py index be4f4c8..803f926 100644 --- a/besser/bot/core/message.py +++ b/besser/bot/core/message.py @@ -7,6 +7,8 @@ class MessageType(Enum): """Enumeration of the different message types in :class:`Message`.""" STR = 'str' + MARKDOWN = 'markdown' + HTML = 'html' FILE = 'file' IMAGE = 'image' DATAFRAME = 'dataframe' diff --git a/besser/bot/core/processors/user_adaptation_processor.py b/besser/bot/core/processors/user_adaptation_processor.py new file mode 100644 index 0000000..2f41be9 --- /dev/null +++ b/besser/bot/core/processors/user_adaptation_processor.py @@ -0,0 +1,72 @@ +from besser.bot.core.processors.processor import Processor +from besser.bot.nlp.llm.llm import LLM +from besser.bot.core.bot import Bot +from besser.bot.core.session import Session +from besser.bot.nlp.nlp_engine import NLPEngine + + +class UserAdaptationProcessor(Processor): + """The UserAdaptationProcessor takes into account the user's profile and adapts the bot's responses to fit the + profile. The goal is to increase the user experience. + + This processor leverages LLMs to adapt the messages given a user profile. For static profiles, an adaptation will be + done once. If the profile changes, then an adapation will be triggered again. + + Args: + bot (Bot): The bot the processor belongs to + llm_name (str): the name of the LLM to use. + context (str): additional context to improve the adaptation. should include information about the bot itself + and the task it should accomplish + + Attributes: + bot (Bot): The bot the processor belongs to + _llm_name (str): the name of the LLM to use. + _context (str): additional context to improve the adaptation. should include information about the bot itself + and the task it should accomplish + _user_model (dict): dictionary containing the user models + """ + def __init__(self, bot: 'Bot', llm_name: str, context: str = None): + super().__init__(bot=bot, bot_messages=True) + self._llm_name: str = llm_name + self._nlp_engine: 'NLPEngine' = bot.nlp_engine + self._user_model: dict = {} + if context: + self._context = context + else: + self._context = "You are a chatbot." + +# add capability to improve/change prompt of context + def process(self, session: 'Session', message: str) -> str: + """Method to process a message and adapt its content based on a given user model. + + The stored user model will be fetched and sent as part of the context. + + Args: + session (Session): the current session + message (str): the message to be processed + + Returns: + str: the processed message + """ + llm: LLM = self._nlp_engine._llms[self._llm_name] + user_context = f"{self._context}\n\ + You are capable of adapting your predefined answers based on a given user profile.\ + Your goal is to increase the user experience by adapting the messages based on the different attributes of the user\ + profile as best as possible and take all the attributes into account.\ + You are free to adapt the messages in any way you like.\ + The user should relate more. This is the user's profile\n \ + {str(self._user_model[session.id])}" + prompt = f"You need to adapt this message: {message}\n Only respond with the adapated message!" + llm_response: str = llm.predict(prompt, session=session, system_message=user_context) + return llm_response + + def add_user_model(self, session: 'Session', user_model: dict) -> None: + """Method to store the user model internally. + + The user model shall be stored internally. + + Args: + session (Session): the current session + user_model (dict): the user model of a given user + """ + self._user_model[session.id] = user_model diff --git a/besser/bot/core/state.py b/besser/bot/core/state.py index 9daf7e2..1982184 100644 --- a/besser/bot/core/state.py +++ b/besser/bot/core/state.py @@ -196,6 +196,8 @@ def go_to(self, dest: 'State') -> None: Args: dest (State): the destination state """ + if dest not in self._bot.states: + raise StateNotFound(self._bot, dest) if self.transitions: raise ConflictingAutoTransitionError(self._bot, self) self.transitions.append(Transition(name=self._t_name(), source=self, dest=dest, event=auto, event_params={})) @@ -265,6 +267,11 @@ def when_variable_matches_operation_go_to( target (Any): the target value to which will be used in the operation with the stored value dest (State): the destination state """ + if dest not in self._bot.states: + raise StateNotFound(self._bot, dest) + for transition in self.transitions: + if transition.is_auto(): + raise ConflictingAutoTransitionError(self._bot, self) event_params = {'var_name': var_name, 'operation': operation, 'target': target} self.transitions.append(Transition(name=self._t_name(), source=self, dest=dest, event=variable_matches_operation, diff --git a/besser/bot/nlp/llm/llm.py b/besser/bot/nlp/llm/llm.py index 2b1b88e..6de32b6 100644 --- a/besser/bot/nlp/llm/llm.py +++ b/besser/bot/nlp/llm/llm.py @@ -22,18 +22,26 @@ class LLM(ABC): nlp_engine (NLPEngine): the NLPEngine that handles the NLP processes of the bot the LLM belongs to name (str): the LLM name parameters (dict): the LLM parameters + global_context (str): the global context to be provided to the LLM for each request Attributes: _nlp_engine (NLPEngine): the NLPEngine that handles the NLP processes of the bot the LLM belongs to name (str): the LLM name parameters (dict): the LLM parameters + _global_context (str): the global context to be provided to the LLM for each request + _user_context (dict): aggregation of user specific contexts to be provided to the LLM for each request + _user_contexts (dict): dictionary containing the different context elements making up the user's context + user specific context to be provided to the LLM for each request """ - def __init__(self, nlp_engine: 'NLPEngine', name: str, parameters: dict): + def __init__(self, nlp_engine: 'NLPEngine', name: str, parameters: dict, global_context: str = None): self._nlp_engine: 'NLPEngine' = nlp_engine self.name: str = name self.parameters: dict = parameters self._nlp_engine._llms[name] = self + self._global_context: str = global_context + self._user_context: dict = dict + self._user_contexts: dict = dict def set_parameters(self, parameters: dict) -> None: """Set the LLM parameters. @@ -49,20 +57,23 @@ def initialize(self) -> None: pass @abstractmethod - def predict(self, message: str, parameters: dict = None) -> str: + def predict(self, message: str, parameters: dict = None, session: 'Session' = None, + system_message: str = None) -> str: """Make a prediction, i.e., generate an output. Args: message (Any): the LLM input text + session (Session): the ongoing session, can be None if no context needs to be applied parameters (dict): the LLM parameters to use in the prediction. If none is provided, the default LLM parameters will be used + system_message (str): system message to give high priority context to the LLM Returns: str: the LLM output """ pass - def chat(self, session: 'Session', parameters: dict = None) -> str: + def chat(self, session: 'Session', parameters: dict = None, system_message: str = None) -> str: """Make a prediction, i.e., generate an output. This function can provide the chat history to the LLM for the output generation, simulating a conversation or @@ -71,7 +82,8 @@ def chat(self, session: 'Session', parameters: dict = None) -> str: Args: session (Session): the user session parameters (dict): the LLM parameters. If none is provided, the RAG's default value will be used - + system_message (str): system message to give high priority context to the LLM + Returns: str: the LLM output """ @@ -100,3 +112,38 @@ def intent_classification( """ logging.warning(f'Intent Classification not implemented in {self.__class__.__name__}') return [] + + def add_user_context(self, session: 'Session', context: str, context_name: str) -> None: + """Add user-specific context. + + Args: + session (Session): the ongoing session + context (str): the user-specific context + context_name (str): the key given to the specific user context + """ + if session.id not in self._user_context: + self._user_contexts[session.id] = {} + self._user_contexts[session.id][context_name] = context + context_message = "" + for context_element in self._user_contexts[session.id]: + context_message = context_message + self._user_contexts[session.id][context_element] + "\n" + self._user_context[session.id] = context_message + + def remove_user_context(self, session: 'Session', context_name: str) -> None: + """Remove user-specific context. + + Args: + session (Session): the ongoing session + context_name (str): the key given to the specific user context + """ + if session.id not in self._user_context or context_name not in self._user_contexts[session.id]: + return + else: + self._user_contexts[session.id].pop(context_name) + context_message = "" + for context_element in self._user_contexts[session.id]: + context_message = context_message + self._user_contexts[session.id][context_element] + "\n" + if context_message != "": + self._user_context[session.id] = context_message + else: + self._user_context.pop(session.id) diff --git a/besser/bot/nlp/llm/llm_huggingface.py b/besser/bot/nlp/llm/llm_huggingface.py index c210092..a9d9d06 100644 --- a/besser/bot/nlp/llm/llm_huggingface.py +++ b/besser/bot/nlp/llm/llm_huggingface.py @@ -27,6 +27,8 @@ class LLMHuggingFace(LLM): num_previous_messages (int): for the chat functionality, the number of previous messages of the conversation to add to the prompt context (must be > 0). Necessary a connection to :class:`~besser.bot.db.monitoring_db.MonitoringDB`. + global_context (str): the global context to be provided to the LLM for each request + Attributes: _nlp_engine (NLPEngine): the NLPEngine that handles the NLP processes of the bot the LLM belongs to @@ -35,10 +37,13 @@ class LLMHuggingFace(LLM): num_previous_messages (int): for the chat functionality, the number of previous messages of the conversation to add to the prompt context (must be > 0). Necessary a connection to :class:`~besser.bot.db.monitoring_db.MonitoringDB`. + _global_context (str): the global context to be provided to the LLM for each request + _user_context (dict): user specific context to be provided to the LLM for each request """ - def __init__(self, bot: 'Bot', name: str, parameters: dict, num_previous_messages: int = 1): - super().__init__(bot.nlp_engine, name, parameters) + def __init__(self, bot: 'Bot', name: str, parameters: dict, num_previous_messages: int = 1, + global_context: str = None): + super().__init__(bot.nlp_engine, name, parameters, global_context) self.pipe = None self.num_previous_messages: int = num_previous_messages @@ -53,18 +58,34 @@ def set_num_previous_messages(self, num_previous_messages: int) -> None: def initialize(self) -> None: self.pipe = pipeline("text-generation", model=self.name) - def predict(self, message: str, parameters: dict = None) -> str: + def predict(self, message: str, parameters: dict = None, session: 'Session' = None, + system_message: str = None) -> str: if not parameters: parameters = self.parameters - outputs = self.pipe([{'role': 'user', 'content': message}], return_full_text=False, **parameters) + context_messages = [] + if self._global_context: + context_messages.append({'role': 'system', 'content': f"{self._global_context}\n"}) + if session and session.id in self._user_context: + context_messages.append({'role': 'system', 'content': f"{self._user_context[session.id]}\n"}) + if system_message: + context_messages = context_messages + f"{system_message}\n" + messages = merge_llm_consecutive_messages(context_messages + [{'role': 'user', 'content': message}]) + outputs = self.pipe(messages, return_full_text=False, **parameters) answer = outputs[0]['generated_text'] return answer - def chat(self, session: 'Session', parameters: dict = None) -> str: + def chat(self, session: 'Session', parameters: dict = None, system_message: str = None) -> str: if not parameters: parameters = self.parameters if self.num_previous_messages <= 0: raise ValueError('The number of previous messages to send to the LLM must be > 0') + context_messages = [] + if self._global_context: + context_messages.append({'role': 'system', 'content': f"{self._global_context}\n"}) + if session and session.id in self._user_context: + context_messages.append({'role': 'system', 'content': f"{self._user_context[session.id]}\n"}) + if system_message: + context_messages.append({'role': 'system', 'content': f"{system_message}\n"}) chat_history: list[Message] = session.get_chat_history(n=self.num_previous_messages) messages = [ {'role': 'user' if message.is_user else 'assistant', 'content': message.content} @@ -73,7 +94,7 @@ def chat(self, session: 'Session', parameters: dict = None) -> str: ] if not messages: messages.append({'role': 'user', 'content': session.message}) - messages = merge_llm_consecutive_messages(messages) + messages = merge_llm_consecutive_messages(context_messages + messages) outputs = self.pipe(messages, return_full_text=False, **parameters) answer = outputs[0]['generated_text'] return answer diff --git a/besser/bot/nlp/llm/llm_huggingface_api.py b/besser/bot/nlp/llm/llm_huggingface_api.py index 935d57e..bd57c78 100644 --- a/besser/bot/nlp/llm/llm_huggingface_api.py +++ b/besser/bot/nlp/llm/llm_huggingface_api.py @@ -9,6 +9,7 @@ if TYPE_CHECKING: from besser.bot.core.bot import Bot + from besser.bot.core.session import Session from besser.bot.nlp.intent_classifier.llm_intent_classifier import LLMIntentClassifier @@ -25,6 +26,7 @@ class LLMHuggingFaceAPI(LLM): parameters (dict): the LLM parameters num_previous_messages (int): for the chat functionality, the number of previous messages of the conversation to add to the prompt context (must be > 0) + global_context (str): the global context to be provided to the LLM for each request Attributes: _nlp_engine (NLPEngine): the NLPEngine that handles the NLP processes of the bot the LLM belongs to @@ -32,10 +34,13 @@ class LLMHuggingFaceAPI(LLM): parameters (dict): the LLM parameters num_previous_messages (int): for the chat functionality, the number of previous messages of the conversation to add to the prompt context (must be > 0) + _global_context (str): the global context to be provided to the LLM for each request + _user_context (dict): user specific context to be provided to the LLM for each request """ - def __init__(self, bot: 'Bot', name: str, parameters: dict, num_previous_messages: int = 1): - super().__init__(bot.nlp_engine, name, parameters) + def __init__(self, bot: 'Bot', name: str, parameters: dict, num_previous_messages: int = 1, + global_context: str = None): + super().__init__(bot.nlp_engine, name, parameters, global_context=global_context) self.num_previous_messages: int = num_previous_messages def set_model(self, name: str) -> None: @@ -57,7 +62,7 @@ def set_num_previous_messages(self, num_previous_messages: int) -> None: def initialize(self) -> None: pass - def predict(self, message: str, parameters: dict = None) -> str: + def predict(self, message: str, parameters: dict = None, session: 'Session' = None, system_message: str = None) -> str: """Make a prediction, i.e., generate an output. Runs the `Text Generation Inference API task @@ -67,7 +72,8 @@ def predict(self, message: str, parameters: dict = None) -> str: message (Any): the LLM input text parameters (dict): the LLM parameters to use in the prediction. If none is provided, the default LLM parameters will be used - + system_message (str): system message to give high priority context to the LLM + Returns: str: the LLM output """ @@ -78,6 +84,15 @@ def predict(self, message: str, parameters: dict = None) -> str: parameters['return_full_text'] = False headers = {"Authorization": f"Bearer {self._nlp_engine.get_property(nlp.HF_API_KEY)}"} api_url = F"https://api-inference.huggingface.co/models/{self.name}" + context_messages = "" + if self._global_context: + context_messages = f"{self._global_context}\n" + if session and session.id in self._user_context: + context_messages = context_messages + f"{self._user_context[session.id]}\n" + if system_message: + context_messages = context_messages + f"{system_message}\n" + if context_messages != "": + message = context_messages + message payload = {"inputs": message, "parameters": parameters} response = requests.post(api_url, headers=headers, json=payload) return response.json()[0]['generated_text'] diff --git a/besser/bot/nlp/llm/llm_openai_api.py b/besser/bot/nlp/llm/llm_openai_api.py index f5fbc87..90b5a5e 100644 --- a/besser/bot/nlp/llm/llm_openai_api.py +++ b/besser/bot/nlp/llm/llm_openai_api.py @@ -24,6 +24,7 @@ class LLMOpenAI(LLM): num_previous_messages (int): for the chat functionality, the number of previous messages of the conversation to add to the prompt context (must be > 0). Necessary a connection to :class:`~besser.bot.db.monitoring_db.MonitoringDB`. + global_context (str): the global context to be provided to the LLM for each request Attributes: _nlp_engine (NLPEngine): the NLPEngine that handles the NLP processes of the bot the LLM belongs to @@ -32,10 +33,13 @@ class LLMOpenAI(LLM): num_previous_messages (int): for the chat functionality, the number of previous messages of the conversation to add to the prompt context (must be > 0). Necessary a connection to :class:`~besser.bot.db.monitoring_db.MonitoringDB`. + _global_context (str): the global context to be provided to the LLM for each request + _user_context (dict): user specific context to be provided to the LLM for each request """ - def __init__(self, bot: 'Bot', name: str, parameters: dict, num_previous_messages: int = 1): - super().__init__(bot.nlp_engine, name, parameters) + def __init__(self, bot: 'Bot', name: str, parameters: dict, num_previous_messages: int = 1, + global_context: str = None): + super().__init__(bot.nlp_engine, name, parameters, global_context=global_context) self.client: OpenAI = None self.num_previous_messages: int = num_previous_messages @@ -58,19 +62,25 @@ def set_num_previous_messages(self, num_previous_messages: int) -> None: def initialize(self) -> None: self.client = OpenAI(api_key=self._nlp_engine.get_property(nlp.OPENAI_API_KEY)) - def predict(self, message: str, parameters: dict = None) -> str: + def predict(self, message: str, parameters: dict = None, session: 'Session' = None, system_message: str = None) -> str: + messages = [] + if self._global_context: + messages.append({"role": "system", "content": self._global_context}) + if session and session.id in self._user_context: + messages.append({"role": "system", "content": self._user_context[session.id]}) + if system_message: + messages.append({"role": "system", "content": system_message}) + messages.append({"role": "user", "content": message}) if not parameters: parameters = self.parameters response = self.client.chat.completions.create( model=self.name, - messages=[ - {"role": "user", "content": message} - ], + messages=messages, **parameters, ) return response.choices[0].message.content - def chat(self, session: 'Session', parameters: dict = None) -> str: + def chat(self, session: 'Session', parameters: dict = None, system_message: str = None) -> str: if not parameters: parameters = self.parameters if self.num_previous_messages <= 0: @@ -83,9 +93,16 @@ def chat(self, session: 'Session', parameters: dict = None) -> str: ] if not messages: messages.append({'role': 'user', 'content': session.message}) + context_messages = [] + if self._global_context: + context_messages.append({"role": "system", "content": self._global_context}) + if session and session.id in self._user_context: + context_messages.append({"role": "system", "content": self._user_context[session.id]}) + if system_message: + context_messages.append({"role": "system", "content": system_message}) response = self.client.chat.completions.create( model=self.name, - messages=messages, + messages=context_messages + messages, **parameters, ) return response.choices[0].message.content diff --git a/besser/bot/nlp/llm/llm_replicate_api.py b/besser/bot/nlp/llm/llm_replicate_api.py index 0cfbe28..8e8c92d 100644 --- a/besser/bot/nlp/llm/llm_replicate_api.py +++ b/besser/bot/nlp/llm/llm_replicate_api.py @@ -10,6 +10,7 @@ if TYPE_CHECKING: from besser.bot.core.bot import Bot + from besser.bot.core.session import Session from besser.bot.nlp.intent_classifier.llm_intent_classifier import LLMIntentClassifier @@ -22,6 +23,7 @@ class LLMReplicate(LLM): parameters (dict): the LLM parameters num_previous_messages (int): for the chat functionality, the number of previous messages of the conversation to add to the prompt context (must be > 0) + global_context (str): the global context to be provided to the LLM for each request Attributes: _nlp_engine (NLPEngine): the NLPEngine that handles the NLP processes of the bot the LLM belongs to @@ -29,10 +31,13 @@ class LLMReplicate(LLM): parameters (dict): the LLM parameters num_previous_messages (int): for the chat functionality, the number of previous messages of the conversation to add to the prompt context (must be > 0) + _global_context (str): the global context to be provided to the LLM for each request + _user_context (dict): user specific context to be provided to the LLM for each request """ - def __init__(self, bot: 'Bot', name: str, parameters: dict, num_previous_messages: int = 1): - super().__init__(bot.nlp_engine, name, parameters) + def __init__(self, bot: 'Bot', name: str, parameters: dict, num_previous_messages: int = 1, + global_context: str = None): + super().__init__(bot.nlp_engine, name, parameters, global_context=global_context) self.num_previous_messages: int = num_previous_messages def set_model(self, name: str) -> None: @@ -55,11 +60,20 @@ def initialize(self) -> None: if 'REPLICATE_API_TOKEN' not in os.environ: os.environ['REPLICATE_API_TOKEN'] = self._nlp_engine.get_property(nlp.REPLICATE_API_KEY) - def predict(self, message: str, parameters: dict = None) -> str: + def predict(self, message: str, parameters: dict = None, session: 'Session' = None, system_message: str = None) -> str: if not parameters: parameters = self.parameters.copy() else: parameters = parameters.copy() + context_messages = "" + if self._global_context: + context_messages = f"{self._global_context}\n" + if session and session.id in self._user_context: + context_messages = context_messages + f"{self._user_context[session.id]}\n" + if system_message: + context_messages = context_messages + f"{system_message}\n" + if context_messages != "": + message = context_messages + message parameters['prompt'] = message answer = replicate.run( self.name, diff --git a/besser/bot/platforms/payload.py b/besser/bot/platforms/payload.py index 4afa894..bc858ad 100644 --- a/besser/bot/platforms/payload.py +++ b/besser/bot/platforms/payload.py @@ -23,6 +23,14 @@ class PayloadAction(Enum): BOT_REPLY_STR = 'bot_reply_str' """PayloadAction: Indicates that the payload's purpose is to send a bot reply containing a :class:`str` object.""" + BOT_REPLY_MARKDOWN = 'bot_reply_markdown' + """PayloadAction: Indicates that the payload's purpose is to send a bot reply containing a :class:`str` object + in Markdown format.""" + + BOT_REPLY_HTML = 'bot_reply_html' + """PayloadAction: Indicates that the payload's purpose is to send a bot reply containing a :class:`str` object + in HTML format.""" + BOT_REPLY_FILE = 'bot_reply_file' """PayloadAction: Indicates that the payload's purpose is to send a bot reply containing a :class:`file.File` object.""" @@ -52,6 +60,10 @@ class PayloadAction(Enum): """ BOT_REPLY_RAG = 'bot_reply_rag' + """PayloadAction: Indicates that the payload's purpose is to send a bot reply containing a RAG (Retrieval Augmented + Generation) answer, which contains an LLM-generated answer and a set of documents the LLM used as context + (see :class:`besser.bot.nlp.rag.rag.RAGMessage`). + """ class Payload: diff --git a/besser/bot/platforms/websocket/chat_widget/css/style.css b/besser/bot/platforms/websocket/chat_widget/css/style.css new file mode 100644 index 0000000..404127b --- /dev/null +++ b/besser/bot/platforms/websocket/chat_widget/css/style.css @@ -0,0 +1,199 @@ +/* Chat widget styling */ +#chat-window { + position: fixed; + bottom: 90px; + right: 20px; + width: 400px; + height: 600px; + max-height: 90vh; + max-width: 90vw; + border: 1px solid #ccc; + border-radius: 8px; + box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); + overflow: hidden; + display: flex; + flex-direction: column; + font-family: Arial, sans-serif; + opacity: 0; + transform: translateY(50px); + transition: opacity 0.3s ease, transform 0.3s ease; + visibility: hidden; +} + +#chat-window.visible { + opacity: 1; + transform: translateY(0); + visibility: visible; +} + +#chat-header { + background-color: #0078D7; + color: #fff; + padding: 10px; + text-align: center; + cursor: pointer; +} + +#chat-messages { + flex: 1; + padding: 10px; + overflow-y: auto; + background-color: #f9f9f9; +} + +.user-message { + padding: 8px; + margin-top: 8px; + border-radius: 4px; + width: fit-content; + margin-left: auto; + background-color: #e1ffc7; +} + +.bot-message { + padding: 8px; + margin-top: 8px; + border-radius: 4px; + width: fit-content; + margin-right: auto; + background-color: #eee; +} + +.markdown-message { +} + +.html-message { +} + +.image-message { + max-width: 90%; + height: auto; + display: block; +} + +.dataframe-message { + border-collapse: collapse; +} + +.dataframe-message th, .dataframe-message td { + border: 1px solid #c4c4c4; + padding: 8px; + text-align: left; +} + +.dataframe-message th { + background-color: #e1e1e1; + font-weight: bold; +} + +.options-message { + background: none; + width: 90%; +} + +.location-message { + width: 95%; + height: 300px; +} + +.plotly-message { + background: none; + width: 95%; + height: 300px; +} + +.button { + background-color: #0078D7; + color: #fff; + font-weight: bold; + text-decoration: none; + padding: 8px; + margin-right: 10px; + margin-top: 20px; + border-radius: 16px; + cursor: pointer; + overflow-wrap: break-word; + line-height: 2.5; +} + +/* Full-screen modal style */ +.plotly-fullscreen-modal { + display: none; + position: fixed; + top: 0; + left: 0; + width: 100%; + height: 100%; + background-color: rgba(0, 0, 0, 0.8); + z-index: 9999; + justify-content: center; + align-items: center; +} + +/* Full-screen chart container */ +.plotly-fullscreen-chart { + width: 90%; + height: 90%; + background-color: #fff; +} + +#chat-input { + display: flex; + border-top: 1px solid #eeeeee; +} + +#chat-input input { + flex: 1; + padding: 10px; + border: none; + outline: none; +} + +#chat-input button { + background-color: #0078D7; + color: #fff; + padding: 10px; + border: none; + cursor: pointer; +} + +#chat-input button:hover { + background-color: #005fa3; +} + +/* Circle button styling */ +#circle-button { + position: fixed; + bottom: 20px; + right: 20px; + width: 50px; + height: 50px; + border-radius: 50%; + transition: transform 0.5s ease; +} + +#circle-button img { + width: 50px;/* Adjust size as needed */ + height: 50px; + top: 20px; + left: 20px; + box-shadow: 0 2px 5px rgba(0, 0, 0, 0.3); + cursor: pointer; + border-radius: 50%; +} + +/* Rotated class that triggers the spin */ +.spin { + transform: rotate(360deg); +} + + /* Typing indicator styling */ +#typing-indicator { + display: none; + padding: 8px; + width: fit-content; +} + +#typing-indicator img { + width: 40px; /* Adjust the size as needed */ +} \ No newline at end of file diff --git a/besser/bot/platforms/websocket/chat_widget/data/args.json b/besser/bot/platforms/websocket/chat_widget/data/args.json new file mode 100644 index 0000000..4147aa2 --- /dev/null +++ b/besser/bot/platforms/websocket/chat_widget/data/args.json @@ -0,0 +1,9 @@ +{ + "userName": "JohnDoe", + "chatbotName": "AmazingBot", + "themeColor": "#34a4bd", + "wsAddress": "ws://localhost:8765", + "messageInputPlaceHolder": "Write something...", + "icon": "img/bot_logo.jpeg", + "typingAnimation": "img/typing_dots.gif" +} diff --git a/besser/bot/platforms/websocket/chat_widget/img/bot_logo.jpeg b/besser/bot/platforms/websocket/chat_widget/img/bot_logo.jpeg new file mode 100644 index 0000000..7caed0d Binary files /dev/null and b/besser/bot/platforms/websocket/chat_widget/img/bot_logo.jpeg differ diff --git a/besser/bot/platforms/websocket/chat_widget/img/typing_dots.gif b/besser/bot/platforms/websocket/chat_widget/img/typing_dots.gif new file mode 100644 index 0000000..c52060c Binary files /dev/null and b/besser/bot/platforms/websocket/chat_widget/img/typing_dots.gif differ diff --git a/besser/bot/platforms/websocket/chat_widget/index.html b/besser/bot/platforms/websocket/chat_widget/index.html new file mode 100644 index 0000000..0e7b444 --- /dev/null +++ b/besser/bot/platforms/websocket/chat_widget/index.html @@ -0,0 +1,30 @@ + + + + + + Chatbot Widget + + + + + + + + + +
+ + + diff --git a/besser/bot/platforms/websocket/chat_widget/js/script.js b/besser/bot/platforms/websocket/chat_widget/js/script.js new file mode 100644 index 0000000..aeeeef4 --- /dev/null +++ b/besser/bot/platforms/websocket/chat_widget/js/script.js @@ -0,0 +1,411 @@ +let ws; +let config; + +function renderChatWidget(args) { + config = args + // Read values from the args dictionary + const userName = config.userName || "Guest"; + const chatbotName = config.chatbotName || "Chatbot"; + const themeColor = config.themeColor || "#2ecc71"; + const wsAddress = config.wsAddress || "ws://localhost:8765"; + const messageInputPlaceHolder = config.messageInputPlaceHolder || "Type a message..."; + const icon = config.icon || "https://www.drupal.org/files/project-images/xatkit.png"; + const typingAnimation = config.typingAnimation || "https://c.tenor.com/EZc7Xubv14AAAAAC/tenor.gif"; + + // Define the HTML structure + const container = document.getElementById('chat-widget'); + const chatWidgetHTML = ` + +
+
${chatbotName}
+
+ +
+ + +
+
+ + +
+ Chat Icon +
+ `; + // Insert the HTML structure into the container + container.innerHTML = chatWidgetHTML; + + ws = new WebSocket(wsAddress); + + const messageInput = document.getElementById('message-input'); + const circleButton = document.getElementById('circle-button'); + + ws.onopen = () => console.log('Connected to WebSocket server'); + ws.onclose = () => console.log('Disconnected from WebSocket server'); + ws.onerror = (error) => console.error('WebSocket error:', error); + + ws.onmessage = (event) => { + try { + const payload = JSON.parse(event.data); + displayMessage(payload, 'bot-message'); + } catch (error) { + console.error('Error parsing message:', error); + } + }; + + // Add a click event to toggle the spin class + circleButton.addEventListener('click', () => { + circleButton.classList.toggle('spin'); + }); + + messageInput.addEventListener('keydown', (event) => { + if (event.key === 'Enter') { + event.preventDefault(); + sendMessage(); + } + }); +} + +function sendMessage() { + const messageInput = document.getElementById('message-input'); + const message = messageInput.value; + if (message) { + const payload = { + action: 'user_message', + message: message, + }; + displayMessage(payload, 'user-message'); + ws.send(JSON.stringify(payload)); + messageInput.value = ''; + //showTypingIndicator(); + } +} + +// Function to generate a html message +function getMessageHtml(message) { + const messageElement = document.createElement('div'); + messageElement.innerHTML = message; + messageElement.classList.add('html-message'); + return messageElement +} + +// Function to generate a markdown message +function getMessageMarkdown(message) { + const messageElement = document.createElement('div'); + messageElement.innerHTML = marked.parse(message); + messageElement.classList.add('markdown-message'); + return messageElement +} + +// Function to generate a string message +function getMessageStr(message) { + const messageElement = document.createElement('p'); + messageElement.textContent = message; + messageElement.classList.add('str-message'); + return messageElement; +} + +// Function to generate an image message +function getMessageImage(message) { + // Create an img element + const imageElement = document.createElement('img'); + // Set the src attribute to the base64 string + imageElement.src = `data:image/jpeg;base64,${message}`; + imageElement.classList.add('image-message'); + return imageElement; +} + +// Function to generate a dataframe (table) message +function getMessageDataframe(message) { + // Parse the JSON string to an object + const data = JSON.parse(message); + + // Create a table element + const table = document.createElement('table'); + table.classList.add('dataframe-message'); + + // Create the header row using the DataFrame column names + const headerRow = document.createElement('tr'); + Object.keys(data).forEach(column => { + const th = document.createElement('th'); + th.textContent = column; + headerRow.appendChild(th); + }); + table.appendChild(headerRow); + + // Determine the number of rows by checking the keys of the first column's data + const numRows = Object.keys(data[Object.keys(data)[0]]).length; + + // Populate table rows with DataFrame values + for (let i = 0; i < numRows; i++) { + const row = document.createElement('tr'); + Object.keys(data).forEach(column => { + const cell = document.createElement('td'); + cell.textContent = data[column][i.toString()]; // Access the value at the current index (as a string) + row.appendChild(cell); + }); + table.appendChild(row); + } + + return table; +} + +// Function to generate a file (download button) message +function getMessageFile(message) { + const { name, type, base64 } = message; // Expected JSON structure in the message + + // Create a downloadable link element + const downloadLink = document.createElement('a'); + downloadLink.href = `data:${type};base64,${base64}`; + downloadLink.download = name; + downloadLink.textContent = `Download ${name}`; + downloadLink.classList.add('button'); + downloadLink.style.backgroundColor = config.themeColor; + return downloadLink; +} + +// Function to generate an options (buttons) message +function getMessageOptions(message) { + // Create a container to hold the buttons + const data = JSON.parse(message); + const optionsContainer = document.createElement('div'); + optionsContainer.classList.add('options-message'); + + // Iterate over each key-value pair in the options dictionary + for (const [key, value] of Object.entries(data)) { + // Create a button for each option + const button = document.createElement('a'); + button.textContent = value; + button.style.backgroundColor = config.themeColor; + button.classList.add('button'); + + // Set up click event to handle button selection + button.addEventListener('click', () => { + const messageInput = document.getElementById('message-input'); + messageInput.value = button.textContent; + sendMessage() + }); + + optionsContainer.appendChild(button); + } + return optionsContainer; +} + +// Function to generate a RAG message +function getMessageRAG(message) { + const { answer, docs, llm_name, question } = message; + + // Create the main message element + const messageElement = document.createElement('p'); + messageElement.textContent = `🔮 ${answer}`; + messageElement.classList.add('str-message'); + + // Create the "Details" clickable text + const detailsLink = document.createElement('span'); + detailsLink.textContent = " [Details]"; + detailsLink.style.color = "blue"; + detailsLink.style.cursor = "pointer"; + + // Create the expandable details section + const detailsSection = document.createElement('div'); + detailsSection.style.display = 'none'; // Hidden by default + + const introText = document.createElement('p'); + introText.innerHTML = `This answer has been generated by an LLM: ${llm_name}
+ It received the following documents as input to come up with a relevant answer:`; + detailsSection.appendChild(introText); + + docs.forEach((doc, i) => { + const docLabel = document.createElement('strong'); + docLabel.textContent = `Document ${i + 1}/${docs.length}`; + detailsSection.appendChild(docLabel); + + const docList = document.createElement('ul'); + + const sourceItem = document.createElement('li'); + sourceItem.textContent = `Source: ${doc.metadata.source}`; + docList.appendChild(sourceItem); + + const pageItem = document.createElement('li'); + pageItem.textContent = `Page: ${doc.metadata.page}`; + docList.appendChild(pageItem); + + const contentItem = document.createElement('li'); + contentItem.textContent = doc.content; + contentItem.classList.add('str-message'); + docList.appendChild(contentItem); + + detailsSection.appendChild(docList); + }); + + // Toggle visibility when "Details" is clicked + detailsLink.addEventListener('click', () => { + if (detailsSection.style.display === 'none') { + detailsSection.style.display = 'block'; + detailsLink.textContent = " [Hide Details]"; + } else { + detailsSection.style.display = 'none'; + detailsLink.textContent = " [Details]"; + } + }); + + messageElement.appendChild(detailsLink); + messageElement.appendChild(detailsSection); + + return messageElement; +} + +// Function to generate a location (map) message +function getMessageLocation(message) { + // Create a container for the map + const mapContainer = document.createElement('div'); + mapContainer.classList.add('location-message'); + + // Initialize the map after the container is added to the DOM + setTimeout(() => { + const map = L.map(mapContainer).setView([message.latitude, message.longitude], 13); + + // Set up the map tile layer (using OpenStreetMap tiles) + L.tileLayer('https://{s}.tile.openstreetmap.org/{z}/{x}/{y}.png', { + maxZoom: 19, + attribution: '© OpenStreetMap contributors' + }).addTo(map); + + // Add a marker at the specified location + L.marker([message.latitude, message.longitude]).addTo(map) + .bindPopup(`${message.latitude}, ${message.longitude}`); + }, 0); + + return mapContainer; +} + +// Function to generate a Plotly (chart) message +function getMessagePlotly(message) { + // TODO: Colors are always black + const chartWidth = 350; + const chartHeight = 275; + // Create the main chart container + const chartContainer = document.createElement('div'); + chartContainer.classList.add('plotly-message'); + + // Parse JSON data for the chart + const chartJSON = JSON.parse(message); + chartJSON.layout = chartJSON.layout || {}; + chartJSON.layout.autosize = true; + chartJSON.layout.width = chartWidth; + chartJSON.layout.height = chartHeight; + chartJSON.layout.responsive = true; + // Render the chart in the main container + Plotly.newPlot(chartContainer, chartJSON.data, chartJSON.layout); + + // Create a button to trigger full-screen view + const fullScreenButton = document.createElement('a'); + fullScreenButton.classList.add('button'); + fullScreenButton.textContent = 'View Full Screen'; + fullScreenButton.style.backgroundColor = config.themeColor; + chartContainer.appendChild(fullScreenButton); + + // Full-screen modal setup + const fullscreenModal = document.createElement('div'); + fullscreenModal.classList.add('plotly-fullscreen-modal'); + document.body.appendChild(fullscreenModal); + + // Full-screen chart container + const fullscreenChartContainer = document.createElement('div'); + fullscreenChartContainer.classList.add('plotly-fullscreen-chart'); + fullscreenModal.appendChild(fullscreenChartContainer); + + // Open full-screen view when clicking the button + fullScreenButton.addEventListener('click', () => { + // Update layout size for full screen + chartJSON.layout.width = window.innerWidth * 0.9; // 90% of window width + chartJSON.layout.height = window.innerHeight * 0.9; // 90% of window height + + // Render the chart in full-screen container + Plotly.newPlot(fullscreenChartContainer, chartJSON.data, chartJSON.layout); + fullscreenModal.style.display = 'flex'; + }); + + // Close full-screen view when clicking outside the chart + fullscreenModal.addEventListener('click', (e) => { + if (e.target === fullscreenModal) { + chartJSON.layout.width = chartWidth; + chartJSON.layout.height = chartHeight; + Plotly.purge(fullscreenChartContainer); // Clear the full-screen chart + fullscreenModal.style.display = 'none'; + } + }); + + return chartContainer; +} + +function displayMessage(payload, className) { + let messageElement; + const chatMessages = document.getElementById('chat-messages'); + // hideTypingIndicator(); + if (['bot_reply_str', 'user_message'].includes(payload.action) && payload.message) { + messageElement = getMessageStr(payload.message); + } + else if (payload.action === 'bot_reply_options' && payload.message) { + messageElement = getMessageOptions(payload.message); + } + else if (payload.action === 'bot_reply_markdown' && payload.message) { + messageElement = getMessageMarkdown(payload.message); + } + else if (payload.action === 'bot_reply_html' && payload.message) { + messageElement = getMessageHtml(payload.message); + } + else if (payload.action === 'bot_reply_image' && payload.message) { + messageElement = getMessageImage(payload.message); + } + else if (payload.action === 'bot_reply_file' && payload.message) { + messageElement = getMessageFile(payload.message); + } + else if (payload.action === 'bot_reply_dataframe' && payload.message) { + messageElement = getMessageDataframe(payload.message); + } + else if (payload.action === 'bot_reply_rag' && payload.message) { + messageElement = getMessageRAG(payload.message); + } + else if (payload.action === 'bot_reply_location' && payload.message) { + messageElement = getMessageLocation(payload.message); + } + else if (payload.action === 'bot_reply_plotly' && payload.message) { + messageElement = getMessagePlotly(payload.message); + } + else { + console.warn('Received unknown message format:', payload); + } + messageElement.classList.add(className); + chatMessages.appendChild(messageElement); + chatMessages.scrollTop = chatMessages.scrollHeight; // Scroll to latest message +} + +function toggleChatWindow() { + const chatWindow = document.getElementById('chat-window'); + if (chatWindow.classList.contains('visible')) { + chatWindow.classList.remove('visible'); + setTimeout(() => { + chatWindow.style.visibility = 'hidden'; + }, 300); // Delay for the fade-out transition + } else { + chatWindow.style.visibility = 'visible'; + chatWindow.classList.add('visible'); + } +} + +// Show typing indicator +function showTypingIndicator() { + const typingIndicator = document.getElementById('typing-indicator'); + typingIndicator.style.display = 'block'; +} + +// Hide typing indicator +function hideTypingIndicator() { + const typingIndicator = document.getElementById('typing-indicator'); + typingIndicator.style.display = 'none'; +} + diff --git a/besser/bot/platforms/websocket/streamlit_ui.py b/besser/bot/platforms/websocket/streamlit_ui.py deleted file mode 100644 index 230309d..0000000 --- a/besser/bot/platforms/websocket/streamlit_ui.py +++ /dev/null @@ -1,294 +0,0 @@ -import base64 -import json -import queue -import sys -import threading -import time -from datetime import datetime - -import pandas as pd -import plotly -import streamlit as st -import websocket -from audio_recorder_streamlit import audio_recorder -from streamlit.runtime import Runtime -from streamlit.runtime.app_session import AppSession -from streamlit.runtime.scriptrunner import add_script_run_ctx, get_script_run_ctx -from streamlit.web import cli as stcli - -from besser.bot.core.file import File -from besser.bot.core.message import Message, MessageType -from besser.bot.platforms.payload import Payload, PayloadAction, PayloadEncoder - -# Time interval to check if a streamlit session is still active, in seconds -SESSION_MONITORING_INTERVAL = 1 - - -def get_streamlit_session() -> AppSession or None: - session_id = get_script_run_ctx().session_id - runtime: Runtime = Runtime.instance() - return next(( - s.session - for s in runtime._session_mgr.list_sessions() - if s.session.id == session_id - ), None) - - -def session_monitoring(interval: int): - runtime: Runtime = Runtime.instance() - session = get_streamlit_session() - while True: - time.sleep(interval) - if not runtime.is_active_session(session.id): - runtime.close_session(session.id) - session.session_state['websocket'].close() - break - - -def main(): - try: - # We get the websocket host and port from the script arguments - bot_name = sys.argv[1] - except Exception as e: - # If they are not provided, we use default values - bot_name = 'Chatbot Demo' - st.header(bot_name) - st.markdown("[Github](https://github.com/BESSER-PEARL/BESSER-Bot-Framework)") - # User input component. Must be declared before history writing - user_input = st.chat_input("What is up?") - - def on_message(ws, payload_str): - # https://github.com/streamlit/streamlit/issues/2838 - streamlit_session = get_streamlit_session() - payload: Payload = Payload.decode(payload_str) - if payload.action == PayloadAction.BOT_REPLY_STR.value: - content = payload.message - t = MessageType.STR - elif payload.action == PayloadAction.BOT_REPLY_FILE.value: - content = payload.message - t = MessageType.FILE - elif payload.action == PayloadAction.BOT_REPLY_DF.value: - content = pd.read_json(payload.message) - t = MessageType.DATAFRAME - elif payload.action == PayloadAction.BOT_REPLY_PLOTLY.value: - content = plotly.io.from_json(payload.message) - t = MessageType.PLOTLY - elif payload.action == PayloadAction.BOT_REPLY_LOCATION.value: - content = { - 'latitude': [payload.message['latitude']], - 'longitude': [payload.message['longitude']] - } - t = MessageType.LOCATION - elif payload.action == PayloadAction.BOT_REPLY_OPTIONS.value: - t = MessageType.OPTIONS - d = json.loads(payload.message) - content = [] - for button in d.values(): - content.append(button) - elif payload.action == PayloadAction.BOT_REPLY_RAG.value: - t = MessageType.RAG_ANSWER - content = payload.message - message = Message(t=t, content=content, is_user=False, timestamp=datetime.now()) - streamlit_session._session_state['queue'].put(message) - streamlit_session._handle_rerun_script_request() - - def on_error(ws, error): - pass - - def on_open(ws): - pass - - def on_close(ws, close_status_code, close_msg): - pass - - def on_ping(ws, data): - pass - - def on_pong(ws, data): - pass - - user_type = { - 0: 'assistant', - 1: 'user' - } - - if 'history' not in st.session_state: - st.session_state['history'] = [] - - if 'queue' not in st.session_state: - st.session_state['queue'] = queue.Queue() - - if 'websocket' not in st.session_state: - try: - # We get the websocket host and port from the script arguments - host = sys.argv[2] - port = sys.argv[3] - except Exception as e: - # If they are not provided, we use default values - host = 'localhost' - port = '8765' - ws = websocket.WebSocketApp(f"ws://{host}:{port}/", - on_open=on_open, - on_message=on_message, - on_error=on_error, - on_close=on_close, - on_ping=on_ping, - on_pong=on_pong) - websocket_thread = threading.Thread(target=ws.run_forever) - add_script_run_ctx(websocket_thread) - websocket_thread.start() - st.session_state['websocket'] = ws - - if 'session_monitoring' not in st.session_state: - session_monitoring_thread = threading.Thread(target=session_monitoring, - kwargs={'interval': SESSION_MONITORING_INTERVAL}) - add_script_run_ctx(session_monitoring_thread) - session_monitoring_thread.start() - st.session_state['session_monitoring'] = session_monitoring_thread - - ws = st.session_state['websocket'] - - with st.sidebar: - - if reset_button := st.button(label="Reset bot"): - st.session_state['history'] = [] - st.session_state['queue'] = queue.Queue() - payload = Payload(action=PayloadAction.RESET) - ws.send(json.dumps(payload, cls=PayloadEncoder)) - - if voice_bytes := audio_recorder(text=None, pause_threshold=2): - if 'last_voice_message' not in st.session_state or st.session_state['last_voice_message'] != voice_bytes: - st.session_state['last_voice_message'] = voice_bytes - # Encode the audio bytes to a base64 string - voice_message = Message(t=MessageType.AUDIO, content=voice_bytes, is_user=True, timestamp=datetime.now()) - st.session_state.history.append(voice_message) - voice_base64 = base64.b64encode(voice_bytes).decode('utf-8') - payload = Payload(action=PayloadAction.USER_VOICE, message=voice_base64) - try: - ws.send(json.dumps(payload, cls=PayloadEncoder)) - except Exception as e: - st.error('Your message could not be sent. The connection is already closed') - if uploaded_file := st.file_uploader("Choose a file", accept_multiple_files=False): - if 'last_file' not in st.session_state or st.session_state['last_file'] != uploaded_file: - st.session_state['last_file'] = uploaded_file - bytes_data = uploaded_file.read() - file_object = File(file_base64=base64.b64encode(bytes_data).decode('utf-8'), file_name=uploaded_file.name, file_type=uploaded_file.type) - payload = Payload(action=PayloadAction.USER_FILE, message=file_object.get_json_string()) - file_message = Message(t=MessageType.FILE, content=file_object.to_dict(), is_user=True, timestamp=datetime.now()) - st.session_state.history.append(file_message) - try: - ws.send(json.dumps(payload, cls=PayloadEncoder)) - except Exception as e: - st.error('Your message could not be sent. The connection is already closed') - for message in st.session_state['history']: - with st.chat_message(user_type[message.is_user]): - if message.type == MessageType.AUDIO: - st.audio(message.content, format="audio/wav") - elif message.type == MessageType.FILE: - file: File = File.from_dict(message.content) - file_name = file.name - file_type = file.type - file_data = base64.b64decode(file.base64.encode('utf-8')) - st.download_button(label='Download ' + file_name, file_name=file_name, data=file_data, mime=file_type, - key=file_name + str(time.time())) - elif message.type == MessageType.LOCATION: - st.map(message.content) - elif message.type == MessageType.RAG_ANSWER: - # TODO: Avoid duplicate in history and queue - st.write(f'🔮 {message.content["answer"]}') - with st.expander('Details'): - st.write(f'This answer has been generated by an LLM: **{message.content["llm_name"]}**') - st.write(f'It received the following documents as input to come up with a relevant answer:') - if 'docs' in message.content: - for i, doc in enumerate(message.content['docs']): - st.write(f'**Document {i + 1}/{len(message.content["docs"])}**') - st.write(f'- **Source:** {doc["metadata"]["source"]}') - st.write(f'- **Page:** {doc["metadata"]["page"]}') - st.write(f'- **Content:** {doc["content"]}') - else: - st.write(message.content) - - first_message = True - while not st.session_state['queue'].empty(): - with st.chat_message("assistant"): - message = st.session_state['queue'].get() - if hasattr(message, '__len__'): - t = len(message.content) / 1000 * 3 - else: - t = 2 - if t > 3: - t = 3 - elif t < 1 and first_message: - t = 1 - first_message = False - if message.type == MessageType.OPTIONS: - st.session_state['buttons'] = message.content - elif message.type == MessageType.FILE: - st.session_state['history'].append(message) - with st.spinner(''): - time.sleep(t) - file: File = File.from_dict(message.content) - file_name = file.name - file_type = file.type - file_data = base64.b64decode(file.base64.encode('utf-8')) - st.download_button(label='Download ' + file_name, file_name=file_name, data=file_data, mime=file_type, - key=file_name + str(time.time())) - elif message.type == MessageType.LOCATION: - st.session_state['history'].append(message) - st.map(message.content) - elif message.type == MessageType.RAG_ANSWER: - st.session_state['history'].append(message) - st.write(f'🔮 {message.content["answer"]}') - with st.expander('Details'): - st.write(f'This answer has been generated by an LLM: **{message.content["llm_name"]}**') - st.write(f'It received the following documents as input to come up with a relevant answer:') - if 'docs' in message.content: - for i, doc in enumerate(message.content['docs']): - st.write(f'**Document {i + 1}/{len(message.content["docs"])}**') - st.write(f'- **Source:** {doc["metadata"]["source"]}') - st.write(f'- **Page:** {doc["metadata"]["page"]}') - st.write(f'- **Content:** {doc["content"]}') - elif message.type == MessageType.STR: - st.session_state['history'].append(message) - with st.spinner(''): - time.sleep(t) - st.write(message.content) - - if 'buttons' in st.session_state: - buttons = st.session_state['buttons'] - cols = st.columns(1) - for i, option in enumerate(buttons): - if cols[0].button(option): - with st.chat_message("user"): - st.write(option) - message = Message(t=MessageType.STR, content=option, is_user=True, timestamp=datetime.now()) - st.session_state.history.append(message) - payload = Payload(action=PayloadAction.USER_MESSAGE, - message=option) - ws.send(json.dumps(payload, cls=PayloadEncoder)) - del st.session_state['buttons'] - break - - if user_input: - if 'buttons' in st.session_state: - del st.session_state['buttons'] - with st.chat_message("user"): - st.write(user_input) - message = Message(t=MessageType.STR, content=user_input, is_user=True, timestamp=datetime.now()) - st.session_state.history.append(message) - payload = Payload(action=PayloadAction.USER_MESSAGE, - message=user_input) - try: - ws.send(json.dumps(payload, cls=PayloadEncoder)) - except Exception as e: - st.error('Your message could not be sent. The connection is already closed') - - st.stop() - - -if __name__ == "__main__": - if st.runtime.exists(): - main() - else: - sys.argv = ["streamlit", "run", sys.argv[0]] - sys.exit(stcli.main()) diff --git a/besser/bot/platforms/websocket/streamlit_ui/__init__.py b/besser/bot/platforms/websocket/streamlit_ui/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/besser/bot/platforms/websocket/streamlit_ui/chat.py b/besser/bot/platforms/websocket/streamlit_ui/chat.py new file mode 100644 index 0000000..1cf6f64 --- /dev/null +++ b/besser/bot/platforms/websocket/streamlit_ui/chat.py @@ -0,0 +1,101 @@ +import base64 +import json +import time +from datetime import datetime + +import streamlit as st + +from besser.bot.core.file import File +from besser.bot.core.message import Message, MessageType +from besser.bot.platforms.payload import Payload, PayloadAction, PayloadEncoder +from besser.bot.platforms.websocket.streamlit_ui.vars import TYPING_TIME, BUTTONS, HISTORY, QUEUE, WEBSOCKET, ASSISTANT, \ + USER + +user_type = { + 0: ASSISTANT, + 1: USER +} + + +def stream_text(text: str): + def stream_callback(): + for word in text.split(" "): + yield word + " " + time.sleep(TYPING_TIME) + return stream_callback + + +def write_or_stream(content, stream: bool): + if stream: + st.write_stream(stream_text(content)) + else: + st.write(content) + + +def write_message(message: Message, key_count: int, stream: bool = False): + key = f'message_{key_count}' + with st.chat_message(user_type[message.is_user]): + if message.type == MessageType.AUDIO: + st.audio(message.content, format="audio/wav") + + elif message.type == MessageType.FILE: + file: File = File.from_dict(message.content) + file_name = file.name + file_type = file.type + file_data = base64.b64decode(file.base64.encode('utf-8')) + st.download_button(label='Download ' + file_name, file_name=file_name, data=file_data, mime=file_type, key=key) + + elif message.type == MessageType.IMAGE: + st.image(message.content) + + elif message.type == MessageType.OPTIONS: + def send_option(): + option = st.session_state[key] + message = Message(t=MessageType.STR, content=option, is_user=True, timestamp=datetime.now()) + st.session_state.history.append(message) + payload = Payload(action=PayloadAction.USER_MESSAGE, message=option) + ws = st.session_state[WEBSOCKET] + ws.send(json.dumps(payload, cls=PayloadEncoder)) + + st.pills(label='Choose an option', options=message.content, selection_mode='single', on_change=send_option, key=key) + + elif message.type == MessageType.LOCATION: + st.map(message.content) + + elif message.type == MessageType.HTML: + st.html(message.content) + + elif message.type == MessageType.DATAFRAME: + st.dataframe(message.content, key=key) + + elif message.type == MessageType.PLOTLY: + st.plotly_chart(message.content, key=key) + + elif message.type == MessageType.RAG_ANSWER: + # TODO: Add stream text + write_or_stream(f'🔮 {message.content["answer"]}', stream) + with st.expander('Details'): + write_or_stream(f'This answer has been generated by an LLM: **{message.content["llm_name"]}**', stream) + write_or_stream(f'It received the following documents as input to come up with a relevant answer:', stream) + if 'docs' in message.content: + for i, doc in enumerate(message.content['docs']): + st.write(f'**Document {i + 1}/{len(message.content["docs"])}**') + st.write(f'- **Source:** {doc["metadata"]["source"]}') + st.write(f'- **Page:** {doc["metadata"]["page"]}') + st.write(f'- **Content:** {doc["content"]}') + + elif message.type in [MessageType.STR, MessageType.MARKDOWN]: + write_or_stream(message.content, stream) + + +def load_chat(): + key_count = 0 + for message in st.session_state[HISTORY]: + write_message(message, key_count, stream=False) + key_count += 1 + + while not st.session_state[QUEUE].empty(): + message = st.session_state[QUEUE].get() + st.session_state[HISTORY].append(message) + write_message(message, key_count, stream=True) + key_count += 1 diff --git a/besser/bot/platforms/websocket/streamlit_ui/initialization.py b/besser/bot/platforms/websocket/streamlit_ui/initialization.py new file mode 100644 index 0000000..1df47b2 --- /dev/null +++ b/besser/bot/platforms/websocket/streamlit_ui/initialization.py @@ -0,0 +1,57 @@ +import queue +import sys +import threading + +import streamlit as st +import websocket +from streamlit.runtime.scriptrunner_utils.script_run_context import add_script_run_ctx + +from besser.bot.platforms.websocket.streamlit_ui.session_management import session_monitoring +from besser.bot.platforms.websocket.streamlit_ui.vars import SESSION_MONITORING_INTERVAL, SUBMIT_TEXT, HISTORY, QUEUE, \ + WEBSOCKET, SESSION_MONITORING, SUBMIT_AUDIO, SUBMIT_FILE +from besser.bot.platforms.websocket.streamlit_ui.websocket_callbacks import on_open, on_error, on_message, on_close, on_ping, on_pong + + +def initialize(): + if SUBMIT_TEXT not in st.session_state: + st.session_state[SUBMIT_TEXT] = False + + if SUBMIT_AUDIO not in st.session_state: + st.session_state[SUBMIT_AUDIO] = False + + if SUBMIT_FILE not in st.session_state: + st.session_state[SUBMIT_FILE] = False + + if HISTORY not in st.session_state: + st.session_state[HISTORY] = [] + + if QUEUE not in st.session_state: + st.session_state[QUEUE] = queue.Queue() + + if WEBSOCKET not in st.session_state: + try: + # We get the websocket host and port from the script arguments + host = sys.argv[2] + port = sys.argv[3] + except Exception as e: + # If they are not provided, we use default values + host = 'localhost' + port = '8765' + ws = websocket.WebSocketApp(f"ws://{host}:{port}/", + on_open=on_open, + on_message=on_message, + on_error=on_error, + on_close=on_close, + on_ping=on_ping, + on_pong=on_pong) + websocket_thread = threading.Thread(target=ws.run_forever) + add_script_run_ctx(websocket_thread) + websocket_thread.start() + st.session_state[WEBSOCKET] = ws + + if SESSION_MONITORING not in st.session_state: + session_monitoring_thread = threading.Thread(target=session_monitoring, + kwargs={'interval': SESSION_MONITORING_INTERVAL}) + add_script_run_ctx(session_monitoring_thread) + session_monitoring_thread.start() + st.session_state[SESSION_MONITORING] = session_monitoring_thread diff --git a/besser/bot/platforms/websocket/streamlit_ui/message_input.py b/besser/bot/platforms/websocket/streamlit_ui/message_input.py new file mode 100644 index 0000000..2405ee9 --- /dev/null +++ b/besser/bot/platforms/websocket/streamlit_ui/message_input.py @@ -0,0 +1,32 @@ +import json +from datetime import datetime + +import streamlit as st + +from besser.bot.core.message import Message, MessageType +from besser.bot.platforms.payload import Payload, PayloadAction, PayloadEncoder +from besser.bot.platforms.websocket.streamlit_ui.vars import BUTTONS, SUBMIT_TEXT, WEBSOCKET, USER + + +def message_input(): + def submit_text(): + # Necessary callback due to buf after 1.27.0 (https://github.com/streamlit/streamlit/issues/7629) + # It was fixed for rerun but with _handle_rerun_script_request it doesn't work + st.session_state[SUBMIT_TEXT] = True + + user_input = st.chat_input("What is up?", on_submit=submit_text) + if st.session_state[SUBMIT_TEXT]: + st.session_state[SUBMIT_TEXT] = False + if BUTTONS in st.session_state: + del st.session_state[BUTTONS] + with st.chat_message(USER): + st.write(user_input) + message = Message(t=MessageType.STR, content=user_input, is_user=True, timestamp=datetime.now()) + st.session_state.history.append(message) + payload = Payload(action=PayloadAction.USER_MESSAGE, + message=user_input) + try: + ws = st.session_state[WEBSOCKET] + ws.send(json.dumps(payload, cls=PayloadEncoder)) + except Exception as e: + st.error('Your message could not be sent. The connection is already closed') diff --git a/besser/bot/platforms/websocket/streamlit_ui/session_management.py b/besser/bot/platforms/websocket/streamlit_ui/session_management.py new file mode 100644 index 0000000..4d80d31 --- /dev/null +++ b/besser/bot/platforms/websocket/streamlit_ui/session_management.py @@ -0,0 +1,28 @@ +import time + +from streamlit.runtime import Runtime +from streamlit.runtime.app_session import AppSession +from streamlit.runtime.scriptrunner_utils.script_run_context import get_script_run_ctx + +from besser.bot.platforms.websocket.streamlit_ui.vars import WEBSOCKET + + +def get_streamlit_session() -> AppSession or None: + session_id = get_script_run_ctx().session_id + runtime: Runtime = Runtime.instance() + return next(( + s.session + for s in runtime._session_mgr.list_sessions() + if s.session.id == session_id + ), None) + + +def session_monitoring(interval: int): + runtime: Runtime = Runtime.instance() + session = get_streamlit_session() + while True: + time.sleep(interval) + if not runtime.is_active_session(session.id): + runtime.close_session(session.id) + session.session_state[WEBSOCKET].close() + break diff --git a/besser/bot/platforms/websocket/streamlit_ui/sidebar.py b/besser/bot/platforms/websocket/streamlit_ui/sidebar.py new file mode 100644 index 0000000..9aca390 --- /dev/null +++ b/besser/bot/platforms/websocket/streamlit_ui/sidebar.py @@ -0,0 +1,61 @@ +import base64 +import json +import queue +from datetime import datetime + +import streamlit as st + +from besser.bot.core.file import File +from besser.bot.core.message import MessageType, Message +from besser.bot.platforms.payload import PayloadEncoder, PayloadAction, Payload +from besser.bot.platforms.websocket.streamlit_ui.vars import WEBSOCKET, HISTORY, QUEUE, SUBMIT_AUDIO, SUBMIT_FILE + + +def sidebar(): + ws = st.session_state[WEBSOCKET] + + with st.sidebar: + if reset_button := st.button(label="Reset bot"): + st.session_state[HISTORY] = [] + st.session_state[QUEUE] = queue.Queue() + payload = Payload(action=PayloadAction.RESET) + ws.send(json.dumps(payload, cls=PayloadEncoder)) + + def submit_audio(): + # Necessary callback due to buf after 1.27.0 (https://github.com/streamlit/streamlit/issues/7629) + # It was fixed for rerun but with _handle_rerun_script_request it doesn't work + st.session_state[SUBMIT_AUDIO] = True + + voice_bytes_io = st.audio_input(label='Say something', on_change=submit_audio) + if st.session_state[SUBMIT_AUDIO]: + st.session_state[SUBMIT_AUDIO] = False + voice_bytes = voice_bytes_io.read() + # Encode the audio bytes to a base64 string + voice_message = Message(t=MessageType.AUDIO, content=voice_bytes, is_user=True, timestamp=datetime.now()) + st.session_state.history.append(voice_message) + voice_base64 = base64.b64encode(voice_bytes).decode('utf-8') + payload = Payload(action=PayloadAction.USER_VOICE, message=voice_base64) + try: + ws.send(json.dumps(payload, cls=PayloadEncoder)) + except Exception as e: + st.error('Your message could not be sent. The connection is already closed') + + def submit_file(): + # Necessary callback due to buf after 1.27.0 (https://github.com/streamlit/streamlit/issues/7629) + # It was fixed for rerun but with _handle_rerun_script_request it doesn't work + st.session_state[SUBMIT_FILE] = True + + uploaded_file = st.file_uploader("Choose a file", accept_multiple_files=False, on_change=submit_file) + if st.session_state[SUBMIT_FILE]: + st.session_state[SUBMIT_FILE] = False + bytes_data = uploaded_file.read() + file_object = File(file_base64=base64.b64encode(bytes_data).decode('utf-8'), file_name=uploaded_file.name, + file_type=uploaded_file.type) + payload = Payload(action=PayloadAction.USER_FILE, message=file_object.get_json_string()) + file_message = Message(t=MessageType.FILE, content=file_object.to_dict(), is_user=True, + timestamp=datetime.now()) + st.session_state.history.append(file_message) + try: + ws.send(json.dumps(payload, cls=PayloadEncoder)) + except Exception as e: + st.error('Your message could not be sent. The connection is already closed') diff --git a/besser/bot/platforms/websocket/streamlit_ui/streamlit_ui.py b/besser/bot/platforms/websocket/streamlit_ui/streamlit_ui.py new file mode 100644 index 0000000..2281753 --- /dev/null +++ b/besser/bot/platforms/websocket/streamlit_ui/streamlit_ui.py @@ -0,0 +1,36 @@ +import sys +# sys.path.append("/Path/to/directory/bot-framework") # Replace with your directory path + +import streamlit as st +from streamlit.runtime import Runtime +from streamlit.web import cli as stcli + +from besser.bot.platforms.websocket.streamlit_ui.chat import load_chat +from besser.bot.platforms.websocket.streamlit_ui.initialization import initialize +from besser.bot.platforms.websocket.streamlit_ui.message_input import message_input +from besser.bot.platforms.websocket.streamlit_ui.sidebar import sidebar + + +def main(): + try: + # We get the websocket host and port from the script arguments + bot_name = sys.argv[1] + except Exception as e: + # If they are not provided, we use default values + bot_name = 'Chatbot Demo' + st.header(bot_name) + st.markdown("[Github](https://github.com/BESSER-PEARL/BESSER-Bot-Framework)") + + initialize() + sidebar() + load_chat() + message_input() + st.stop() + + +if __name__ == "__main__": + if st.runtime.exists(): + main() + else: + sys.argv = ["streamlit", "run", sys.argv[0]] + sys.exit(stcli.main()) diff --git a/besser/bot/platforms/websocket/streamlit_ui/vars.py b/besser/bot/platforms/websocket/streamlit_ui/vars.py new file mode 100644 index 0000000..ec42da1 --- /dev/null +++ b/besser/bot/platforms/websocket/streamlit_ui/vars.py @@ -0,0 +1,18 @@ +# Streamlit session_state keys +ASSISTANT = 'assistant' +BUTTONS = 'buttons' +HISTORY = 'history' +QUEUE = 'queue' +SESSION_MONITORING = 'session_monitoring' +SUBMIT_FILE = 'submit_file' +SUBMIT_TEXT = 'submit_text' +SUBMIT_AUDIO = 'submit_audio' +USER = 'user' +WEBSOCKET = 'websocket' + +# Time interval to check if a streamlit session is still active, in seconds +SESSION_MONITORING_INTERVAL = 1 + +# New bot messages are printed with a typing effect. This is the time between words being printed, in seconds +TYPING_TIME = 0.05 + diff --git a/besser/bot/platforms/websocket/streamlit_ui/websocket_callbacks.py b/besser/bot/platforms/websocket/streamlit_ui/websocket_callbacks.py new file mode 100644 index 0000000..40fea2c --- /dev/null +++ b/besser/bot/platforms/websocket/streamlit_ui/websocket_callbacks.py @@ -0,0 +1,85 @@ +import base64 +import json +from datetime import datetime +from io import StringIO + +import cv2 +import numpy as np +import pandas as pd +import plotly + +from besser.bot.core.message import MessageType, Message +from besser.bot.platforms.payload import PayloadAction, Payload +from besser.bot.platforms.websocket.streamlit_ui.session_management import get_streamlit_session +from besser.bot.platforms.websocket.streamlit_ui.vars import QUEUE + + +def on_message(ws, payload_str): + # https://github.com/streamlit/streamlit/issues/2838 + streamlit_session = get_streamlit_session() + payload: Payload = Payload.decode(payload_str) + content = None + if payload.action == PayloadAction.BOT_REPLY_STR.value: + content = payload.message + t = MessageType.STR + elif payload.action == PayloadAction.BOT_REPLY_MARKDOWN.value: + content = payload.message + t = MessageType.MARKDOWN + elif payload.action == PayloadAction.BOT_REPLY_HTML.value: + content = payload.message + t = MessageType.HTML + elif payload.action == PayloadAction.BOT_REPLY_FILE.value: + content = payload.message + t = MessageType.FILE + elif payload.action == PayloadAction.BOT_REPLY_IMAGE.value: + decoded_data = base64.b64decode(payload.message) # Decode base64 back to bytes + np_data = np.frombuffer(decoded_data, np.uint8) # Convert bytes to numpy array + img = cv2.imdecode(np_data, cv2.IMREAD_COLOR) # Decode numpy array back to image + content = img + t = MessageType.IMAGE + elif payload.action == PayloadAction.BOT_REPLY_DF.value: + content = pd.read_json(StringIO(payload.message)) + t = MessageType.DATAFRAME + elif payload.action == PayloadAction.BOT_REPLY_PLOTLY.value: + content = plotly.io.from_json(payload.message) + t = MessageType.PLOTLY + elif payload.action == PayloadAction.BOT_REPLY_LOCATION.value: + content = { + 'latitude': [payload.message['latitude']], + 'longitude': [payload.message['longitude']] + } + t = MessageType.LOCATION + elif payload.action == PayloadAction.BOT_REPLY_OPTIONS.value: + t = MessageType.OPTIONS + d = json.loads(payload.message) + content = [] + for button in d.values(): + content.append(button) + elif payload.action == PayloadAction.BOT_REPLY_RAG.value: + t = MessageType.RAG_ANSWER + content = payload.message + if content is not None: + message = Message(t=t, content=content, is_user=False, timestamp=datetime.now()) + streamlit_session._session_state[QUEUE].put(message) + + streamlit_session._handle_rerun_script_request() + + +def on_error(ws, error): + pass + + +def on_open(ws): + pass + + +def on_close(ws, close_status_code, close_msg): + pass + + +def on_ping(ws, data): + pass + + +def on_pong(ws, data): + pass \ No newline at end of file diff --git a/besser/bot/platforms/websocket/websocket_platform.py b/besser/bot/platforms/websocket/websocket_platform.py index e2d5456..01c8e3f 100644 --- a/besser/bot/platforms/websocket/websocket_platform.py +++ b/besser/bot/platforms/websocket/websocket_platform.py @@ -5,6 +5,8 @@ import os from datetime import datetime +import cv2 +import numpy as np import plotly import subprocess import threading @@ -21,7 +23,7 @@ from besser.bot.platforms import websocket from besser.bot.platforms.payload import Payload, PayloadAction, PayloadEncoder from besser.bot.platforms.platform import Platform -from besser.bot.platforms.websocket import streamlit_ui +from besser.bot.platforms.websocket.streamlit_ui import streamlit_ui from besser.bot.core.file import File if TYPE_CHECKING: @@ -36,7 +38,7 @@ class WebSocketPlatform(Platform): bidirectional communication between server and client (i.e. sending and receiving messages). Note: - We provide a UI (:doc:`streamlit_ui`) implementing a WebSocket client to communicate with the bot, though you + We provide different interfaces implementing a WebSocket client to communicate with the bot, though you can use or create your own UI as long as it has a WebSocket client that connects to the bot's WebSocket server. Args: @@ -154,6 +156,34 @@ def reply(self, session: Session, message: str) -> None: payload = Payload(action=PayloadAction.BOT_REPLY_STR, message=message) self._send(session.id, payload) + + def reply_markdown(self, session: Session, message: str) -> None: + """Send a bot reply to a specific user, containing text in Markdown format. + + Args: + session (Session): the user session + message (str): the message in Markdown format to send to the user + """ + if session.platform is not self: + raise PlatformMismatchError(self, session) + session.save_message(Message(t=MessageType.MARKDOWN, content=message, is_user=False, timestamp=datetime.now())) + payload = Payload(action=PayloadAction.BOT_REPLY_MARKDOWN, + message=message) + self._send(session.id, payload) + + def reply_html(self, session: Session, message: str) -> None: + """Send a bot reply to a specific user, containing text in HTML format. + + Args: + session (Session): the user session + message (str): the message in HTML format to send to the user + """ + if session.platform is not self: + raise PlatformMismatchError(self, session) + session.save_message(Message(t=MessageType.HTML, content=message, is_user=False, timestamp=datetime.now())) + payload = Payload(action=PayloadAction.BOT_REPLY_HTML, + message=message) + self._send(session.id, payload) def reply_file(self, session: Session, file: File) -> None: """Send a file reply to a specific user @@ -169,6 +199,25 @@ def reply_file(self, session: Session, file: File) -> None: message=file.to_dict()) self._send(session.id, payload) + def reply_image(self, session: Session, img: np.ndarray) -> None: + """Send an image reply to a specific user. + + Before being sent, the image is encoded as jpg and then as a base64 string. This must be known before dedocing + the image on the client side. + + Args: + session (Session): the user session + img (np.ndarray): the image to send + """ + if session.platform is not self: + raise PlatformMismatchError(self, session) + retval, buffer = cv2.imencode('.jpg', img) # Encode as JPEG + base64_img = base64.b64encode(buffer).decode('utf-8') + session.save_message(Message(t=MessageType.FILE, content=base64_img, is_user=False, timestamp=datetime.now())) + payload = Payload(action=PayloadAction.BOT_REPLY_IMAGE, + message=base64_img) + self._send(session.id, payload) + def reply_dataframe(self, session: Session, df: DataFrame) -> None: """Send a DataFrame bot reply, i.e. a table, to a specific user. diff --git a/docs/source/api/core.rst b/docs/source/api/core.rst index 8dd1de5..5587157 100644 --- a/docs/source/api/core.rst +++ b/docs/source/api/core.rst @@ -17,3 +17,4 @@ core core/intent_parameter core/language_detection_processor core/processor + core/user_adaptation_processor diff --git a/docs/source/api/core/user_adaptation_processor.rst b/docs/source/api/core/user_adaptation_processor.rst new file mode 100644 index 0000000..40897d3 --- /dev/null +++ b/docs/source/api/core/user_adaptation_processor.rst @@ -0,0 +1,8 @@ +user_adaptation_processor +========================= + +.. automodule:: besser.bot.core.processors.user_adaptation_processor + :members: + :private-members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/platforms.rst b/docs/source/api/platforms.rst index 9dc04eb..973e61e 100644 --- a/docs/source/api/platforms.rst +++ b/docs/source/api/platforms.rst @@ -7,5 +7,4 @@ platforms platforms/payload platforms/platform platforms/telegram_platform - platforms/streamlit_ui platforms/websocket_platform diff --git a/docs/source/api/platforms/streamlit_ui.rst b/docs/source/api/platforms/streamlit_ui.rst deleted file mode 100644 index be6c00e..0000000 --- a/docs/source/api/platforms/streamlit_ui.rst +++ /dev/null @@ -1,9 +0,0 @@ -streamlit_ui -============ - -A sample User Interface powered by `Streamlit `_ implementing a WebSocket -client that connects to the bot WebSocket server. - -.. literalinclude:: ../../../../besser/bot/platforms/websocket/streamlit_ui.py - :language: python - :linenos: diff --git a/docs/source/conf.py b/docs/source/conf.py index c268b04..6cf0cd2 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -89,10 +89,18 @@ def generate_api_rst_files(preffix, dir, output_dir): 'db/sidebar.py', 'db/table_overview.py', 'db/utils.py', + 'platforms/chat.py', + 'platforms/initialization.py', + 'platforms/message_input.py', + 'platforms/session_management.py', + 'platforms/sidebar.py', + 'platforms/streamlit_ui.py', + 'platforms/vars.py', + 'platforms/websocket_callbacks.py', + ] api_excluded_files = [ # Files that for which we won't automatically generate .rst files and WILL appear in the toctree - 'platforms/streamlit_ui.py', ] api_excluded_files.extend(api_excluded_files_toctree) os.makedirs(output_dir, exist_ok=True) diff --git a/docs/source/img/chat_widget_demo.gif b/docs/source/img/chat_widget_demo.gif new file mode 100644 index 0000000..2cfc3b9 Binary files /dev/null and b/docs/source/img/chat_widget_demo.gif differ diff --git a/docs/source/img/websocket_demo.gif b/docs/source/img/streamlit_ui_demo.gif similarity index 100% rename from docs/source/img/websocket_demo.gif rename to docs/source/img/streamlit_ui_demo.gif diff --git a/docs/source/release_notes/v1.5.0.rst b/docs/source/release_notes/v1.5.0.rst new file mode 100644 index 0000000..60b761d --- /dev/null +++ b/docs/source/release_notes/v1.5.0.rst @@ -0,0 +1,17 @@ +Version 1.5.0 +============= + +New Features +------------- + +- Markdown and HTML replies in WebSocketPlatform +- Chat Widget UI +- New Processor: UserAdaptationProcessor +- Reply image in WebSocektPlatform +- LLM Global and User context + +Improvements +------------- + +- Upgrade Streamlit v1.40.0 +- Typing effect on bot messages in streamlit UI diff --git a/docs/source/wiki/core/processors.rst b/docs/source/wiki/core/processors.rst index b705d1f..8aa3ce6 100644 --- a/docs/source/wiki/core/processors.rst +++ b/docs/source/wiki/core/processors.rst @@ -106,6 +106,19 @@ When processed, the recognized language will be stored as a session variable in session.get('detected_lang') +.. _user-adaptation-processor: + +UserAdaptationProcessor +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The :class:`UserAdaptationProcessor ` +attempts to adapt the bot's responses based on the user's profile. The user profile can be added using the following call: + +.. code:: python + + processor.add_user_model(user_model) + + API References -------------- diff --git a/docs/source/wiki/nlp/llm.rst b/docs/source/wiki/nlp/llm.rst index ad4d40c..88845fd 100644 --- a/docs/source/wiki/nlp/llm.rst +++ b/docs/source/wiki/nlp/llm.rst @@ -50,6 +50,43 @@ This LLM can be used within any bot state (in both the body and the fallback bod There are plenty of possibilities to take advantage of LLMs in a chatbot. The previous is a very simple use case, but we can do more advanced tasks through prompt engineering. +Adding context information to an LLM +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To improve / customize the LLM's behavior, it is also possible to add context information to an LLM. +Here, we differentiate between a global context and user-specific context information. +The global context will be applied to every LLM prediction for any user. +The user specific context is only applied for a specific user and can contain user specific information to personalize the LLM's behavior. + +Here an example where we extend the previous LLMOpenAI instance: + +.. code:: python + + # adding this global_context will cause the LLM to only answer in english. + gpt = LLMOpenAI(bot=bot, name='gpt-4o-mini', global_context='You only speak english.') + +Let's now suppose we have access to the user's name while executing the body of the current state: + +.. code:: python + + def answer_body(session: Session): + user_name = session.message + # For this specific session, the LLM will always be given the "context" string as additional information. + gpt.add_user_context(session=session, context=f"The user is called {user_name}", context_name="user_name_context") + answer = gpt.predict(session.message) # Predicts the output for the given input (the user message) + # It's also possible to remove context elements from the user specific context. + gpt.remove_user_context(session=session, context_name="user_name_context") + session.reply(answer) + +It is also possible to only add context information for a specific prompt: + +.. code:: python + + def answer_body(session: Session): + user_name = session.message + answer = gpt.predict(session=session.message, system_message=f'Start your response using the name of the user which is {user_name}') + session.reply(answer) + Available LLMs -------------- @@ -69,12 +106,19 @@ These are the currently available LLM wrappers in BBF: - :class:`~besser.bot.nlp.llm.llm_huggingface_api.LLMHuggingFaceAPI`: For HuggingFace LLMs, through its `Inference API `_ - :class:`~besser.bot.nlp.llm.llm_replicate_api.LLMReplicate`: For `Replicate `_ LLMs, through its API +.. note:: + + Models taken from Huggingface or Replicate might expect a specific prompting or context specification format to improve the results. Be sure to carefully read the guidelines for each model for an optimal experience. + + API References -------------- - Bot: :class:`besser.bot.core.bot.Bot` - LLM: :class:`besser.bot.nlp.llm.llm.LLM` - LLM.predict(): :meth:`besser.bot.nlp.llm.llm.LLM.predict` +- LLM.add_user_context(): :meth:`besser.bot.nlp.llm.llm.LLM.add_user_context` +- LLM.remove_user_context(): :meth:`besser.bot.nlp.llm.llm.LLM.remove_user_context` - LLMHuggingFace: :class:`besser.bot.nlp.llm.llm_huggingface.LLMHuggingFace`: - LLMHuggingFaceAPI: :class:`besser.bot.nlp.llm.llm_huggingface_api.LLMHuggingFaceAPI`: - LLMOpenAI: :class:`besser.bot.nlp.llm.llm_openai_api.LLMOpenAI` diff --git a/docs/source/wiki/platforms/websocket_platform.rst b/docs/source/wiki/platforms/websocket_platform.rst index 85929e8..8c84132 100644 --- a/docs/source/wiki/platforms/websocket_platform.rst +++ b/docs/source/wiki/platforms/websocket_platform.rst @@ -12,14 +12,21 @@ The next figure shows how this connection works: .. figure:: ../../img/websocket_diagram.png :alt: Intent diagram - The bot we just created, has 2 states linked by an intent. + Figure illustrating the WebSocket protocol. -We provide a UI (:doc:`../../api/platforms/streamlit_ui`) implementing a WebSocket client to communicate with the bot, -though you can use or create your own UI as long as it has a WebSocket client that connects to the bot's WebSocket -server. This is how our chatbot UI looks like: +User Interface -.. figure:: ../../img/websocket_demo.gif - :alt: WebSocket UI demo +BBF comes with some User Interfaces (WebSocket clients) to use the WebSocket platform. + +Of course, you are free to use or create your own UI as long as it has a WebSocket client that connects to the bot's WebSocket server. + +.. toctree:: + :maxdepth: 1 + + websocket_platform/streamlit_ui + websocket_platform/chat_widget + +(Their source code can be found in the besser.bot.platforms.websocket package) .. note:: @@ -48,6 +55,30 @@ After that, you can use the platform to send different kinds of messages to the websocket_platform.reply(session, 'Hello!') +- Text messages in `Markdown `_ format: + +.. code:: python + + websocket_platform.reply_markdown(session, """ + # Welcome to the chatbot experience + ## Section 1 + - one + - two + """) + +- Text messages in `HTML `_ format: + +.. code:: python + + websocket_platform.reply_html(session, """ +

Title

+
    +
  • Apples
  • +
  • Bananas
  • +
  • Cherries
  • +
+ """) + - Pandas `DataFrames `_: .. code:: python diff --git a/docs/source/wiki/platforms/websocket_platform/chat_widget.rst b/docs/source/wiki/platforms/websocket_platform/chat_widget.rst new file mode 100644 index 0000000..3a0d69f --- /dev/null +++ b/docs/source/wiki/platforms/websocket_platform/chat_widget.rst @@ -0,0 +1,38 @@ +Chat widget +=========== + +The chat widget UI allows to integrate a chatbot in any webpage. It is located in a window corner, expanded/hidden when clicking on an icon. + +This is how our chatbot UI looks like: + +.. figure:: ../../../img/chat_widget_demo.gif + :alt: Chat Widget demo + :scale: 70% + +Parameters +---------- + +The file data/args.json contains parameters you can set to customize the chat widget (websocket address, bot icon, colors, ...) + +How to use it +------------- + +Just go to the chat_widget directory and open the **index.html** file. + +.. note:: + + The parameters can only be read from the JSON file when running the interface from a server, not from the file system. + + You can create a simple server by running the following in the chat widget directory: + + .. code:: bash + + python -m http.server + + This will serve your files at http://localhost:8000 + + If you want to run it from the file system, you will have to hardcode the parameters (instead of loading them from + an external file, just write your desired values in the renderChatWidget function in the js/script.js file) + +To integrate the chat widget in a real webpage, just copy the content in index.html into the html of your webpage. +Make sure to include the other directories in the webpage dependencies, as they contain the chat widget code. diff --git a/docs/source/wiki/platforms/websocket_platform/streamlit_ui.rst b/docs/source/wiki/platforms/websocket_platform/streamlit_ui.rst new file mode 100644 index 0000000..381d26d --- /dev/null +++ b/docs/source/wiki/platforms/websocket_platform/streamlit_ui.rst @@ -0,0 +1,26 @@ +Streamlit UI +============ + +We provide a Streamlit UI implementing a WebSocket client to communicate with the bot. + +This is how our chatbot UI looks like: + +.. figure:: ../../../img/streamlit_ui_demo.gif + :alt: WebSocket UI demo + +How to use it +------------- + +You can run it directly from the bot, by setting it in the websocket_platform: + +.. code:: python + + bot = Bot('example_bot') + ... + websocket_platform = bot.use_websocket_platform(use_ui=True) + +Or you can also run it separately. Just open a terminal on the streamlit UI directory, and run: + +.. code:: bash + + streamlit run --server.address localhost --server.port 5000 streamlit_ui.py bot_name localhost 8765 diff --git a/requirements.txt b/requirements.txt index 5634f61..dd5a04b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ -audio-recorder-streamlit==0.0.8 chromadb==0.5.4 dateparser==1.1.8 keras==2.14.0 @@ -22,7 +21,7 @@ spacy==3.7.2 SpeechRecognition==3.10.0 # spellux @ git+https://github.com/Aran30/spellux # Not available in PyPi. Install manually with `pip install git+https://github.com/Aran30/spellux.git` sqlalchemy==2.0.29 -streamlit==1.27.2 +streamlit==1.40.0 streamlit-antd-components==0.3.2 tensorflow==2.14.0 text2num==2.5.0 diff --git a/setup.cfg b/setup.cfg index 341a321..9bc3c02 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = besser-bot-framework -version = 1.4.0 +version = 1.5.0 author = Luxembourg Institute of Science and Technology description = BESSER Bot Framework (BBF) long_description = file: README.md