From a36ac96ff009bcd0942c1618a33d1c1abc22dff1 Mon Sep 17 00:00:00 2001 From: Mohamed Aouadhi Date: Fri, 25 Aug 2023 23:39:27 +0200 Subject: [PATCH] feat: Add /chain , gives memory to chat mode --- README.md | 2 -- bot_commands_macro/src/lib.rs | 8 ++--- homebot.service | 3 +- src/bot_commands.rs | 30 ++++++++++++++++--- src/lib/services/llm.rs | 18 ++++++++++- src/lib/telegram/bot.rs | 21 ++++--------- src/lib/types.rs | 56 +++++++++++++++++------------------ 7 files changed, 82 insertions(+), 56 deletions(-) diff --git a/README.md b/README.md index 2dca836..1a6034d 100644 --- a/README.md +++ b/README.md @@ -93,8 +93,6 @@ curl -F "url=https://11.22.33.44/" -F "certificate=@YOURPUBLIC.pem" \ ### Chat mode Chat mode is simply the LLM request command (if provided) but without typing the command prefix each time, so once in the mode, you can chat with the llm just like any normal conversation. -**Currently the llm does not have memory, so each message is a conversation on its own**. - - To inform polybot of that, we just need to tell it about the command that `starts` the chat mode, the one that `exists` it and the `llm request` command. We do that by adding the boolean attributes in the `handler` procedural macro: ```rust diff --git a/bot_commands_macro/src/lib.rs b/bot_commands_macro/src/lib.rs index 6801b9b..2f26d82 100644 --- a/bot_commands_macro/src/lib.rs +++ b/bot_commands_macro/src/lib.rs @@ -92,11 +92,11 @@ pub fn bot_commands(_args: TokenStream, input: TokenStream) -> TokenStream { let struct_name = get_cmd_struct_name(&command_name); let state = if chat_start == &Some(true) { quote! { - user_tx.set_chat_mode(true).await; + user.set_chat_mode(true).await; } } else if chat_exit == &Some(true) { quote! { - user_tx.set_chat_mode(false).await; + user.set_chat_mode(false).await; } } else { quote!() @@ -105,9 +105,9 @@ pub fn bot_commands(_args: TokenStream, input: TokenStream) -> TokenStream { #[::async_trait::async_trait] impl ::polybot::types::BotCommandHandler for #struct_name { - async fn handle(&self, user_tx: ::tokio::sync::mpsc::Sender<::polybot::types::BotUserCommand>, args: String) -> String { + async fn handle(&self, user: ::polybot::types::SharedUser, args: String) -> String { #state - #func_name(user_tx, args).await + #func_name(user, args).await } } } diff --git a/homebot.service b/homebot.service index 5129dd1..912f3b6 100644 --- a/homebot.service +++ b/homebot.service @@ -7,7 +7,8 @@ Type=simple Restart=always RestartSec=1 User=${USER} -ExecStart= RUST_LOG=debug /home/${USER}//.cargo/bin/cargo run --release /home/${USER}/personal/homebot/Cargo.toml +Environment=OPENAI_API_KEY=enter_your_token_here +ExecStart= RUST_LOG=debug /home/${USER}/.cargo/bin/cargo run --release /home/${USER}/personal/homebot/Cargo.toml WorkingDirectory=/home/${USER}/personal/homebot [Install] diff --git a/src/bot_commands.rs b/src/bot_commands.rs index 4dfcae4..1166dc1 100644 --- a/src/bot_commands.rs +++ b/src/bot_commands.rs @@ -7,9 +7,8 @@ pub mod commands { use polybot::services::llm::{Agent, OpenAiModel}; use polybot::services::openmeteo::OpenMeteo; use polybot::types::{BotUserActions, WeatherProvider}; - use rand::Rng; - use polybot::utils::{get_affirmation, get_ip}; + use rand::Rng; #[handler(cmd = "/ip")] async fn ip(_user_tx: impl BotUserActions, _: String) -> String { @@ -42,7 +41,7 @@ pub mod commands { } } - #[handler(cmd = "/ask", llm_request = true)] + #[handler(cmd = "/ask")] async fn ask(_user_tx: impl BotUserActions, request: String) -> String { if request.is_empty() { return "Ask something!".to_string(); @@ -59,10 +58,33 @@ pub mod commands { } #[handler(cmd = "/chat", chat_start = true)] - async fn chat(_user_tx: impl BotUserActions, _: String) -> String { + async fn chat(user: impl BotUserActions, system_prompt: String) -> String { + let mut prompt = system_prompt.as_str(); + if prompt.is_empty() { + prompt = "You are an intelligent cat named Nami, you will answer all questions briefly, and always + maintain your character, and will meow from time to time"; + } + if user.reset_conversation_chain(prompt).await.is_err() { + return "Error during initializing the chat!".to_string(); + } "Let's chat!".to_string() } + /// Gives memory to your conversations in the chat mode + #[handler(cmd = "/chain", llm_request = true)] + async fn converse(user: impl BotUserActions, request: String) -> String { + let conversation = user.get_conversation().await; + + if let Ok(agent) = OpenAiModel::try_new() { + if let Ok(answer) = agent.conversation(&request, conversation).await { + return answer; + } + "Problem getting the agent response".to_string() + } else { + "Could not create the llm agent, check the API key".to_string() + } + } + #[handler(cmd = "/endchat", chat_exit = true)] async fn endchat(_user_tx: impl BotUserActions, _request: String) -> String { "See ya!".to_string() diff --git a/src/lib/services/llm.rs b/src/lib/services/llm.rs index eaaa899..e2531c5 100644 --- a/src/lib/services/llm.rs +++ b/src/lib/services/llm.rs @@ -1,12 +1,16 @@ +use std::sync::Arc; + use anyhow::{bail, Result}; use async_trait::async_trait; -use llm_chain::{executor, parameters, prompt}; +use llm_chain::{chains::conversation::Chain, executor, parameters, prompt, step::Step}; use llm_chain_openai::chatgpt::Executor; +use tokio::sync::Mutex; use tracing::debug; #[async_trait] pub trait Agent: Send + Sync { async fn request(&self, req: &str) -> Result; + async fn conversation(&self, req: &str, chain: Arc>) -> Result; async fn chain_requests(&self, steps: Vec<&str>) -> Result; async fn map_reduce_chain(&self, steps: Vec<&str>) -> Result; } @@ -54,4 +58,16 @@ impl Agent for OpenAiModel { async fn map_reduce_chain(&self, _steps: Vec<&str>) -> Result { todo!() } + + async fn conversation(&self, req: &str, chain: Arc>) -> Result { + let step = Step::for_prompt_template(prompt!(user: req)); + Ok(chain + .lock() + .await + .send_message(step, ¶meters!(), &self.executor) + .await? + .to_immediate() + .await? + .to_string()) + } } diff --git a/src/lib/telegram/bot.rs b/src/lib/telegram/bot.rs index af068e7..397e158 100644 --- a/src/lib/telegram/bot.rs +++ b/src/lib/telegram/bot.rs @@ -1,12 +1,12 @@ use std::collections::HashMap; use std::marker::PhantomData; use std::path::PathBuf; -use std::sync::Arc; +use std::sync::{Arc, RwLock}; use crate::telegram::types::{Response, Update, Webhook}; use crate::types::{ - Bot, BotCommands, BotConfig, BotMessage, BotMessages, BotUser, BotUserActions, BotUserCommand, - CommandHashMap, SharedUsers, + Bot, BotCommands, BotConfig, BotMessage, BotMessages, BotUser, BotUserActions, CommandHashMap, + SharedUsers, }; use anyhow::{bail, Context, Ok, Result}; use async_trait::async_trait; @@ -108,11 +108,11 @@ impl Bot for TelegramBot { "Adding the user (id = {}), (name = {}).", user_id, user_name ); - users.insert(user_id, BotUser::new()); + users.insert(user_id, Arc::new(RwLock::new(BotUser::new()))); }; let text = msg.get_message(); - let user = users.get_mut(&user_id).unwrap(); + let mut user = Arc::clone(users.get_mut(&user_id).unwrap()); // update the user activity user.set_last_activity(chrono::Utc::now()).await; @@ -127,18 +127,9 @@ impl Bot for TelegramBot { argument = message.collect::>().join(" "); } debug!("Cmd: {:?}, Arg: {:?}", command, argument); - let (tx, mut rx) = tokio::sync::mpsc::channel::(32); answer = if let Some(bot_command) = self.command_list.get(command.unwrap()) { - let result = bot_command.handle(tx, argument).await; - if let Some(action) = rx.recv().await { - match action { - BotUserCommand::UpdateChatMode { chat_mode } => { - user.set_chat_mode(chat_mode).await; - } - } - } - result + bot_command.handle(user.clone(), argument).await } else { "Did not understand!".into() }; diff --git a/src/lib/types.rs b/src/lib/types.rs index ee85446..a32414c 100644 --- a/src/lib/types.rs +++ b/src/lib/types.rs @@ -3,7 +3,7 @@ use std::{ path::PathBuf, sync::{ atomic::{AtomicBool, Ordering}, - Arc, + Arc, RwLock, }, }; @@ -12,8 +12,9 @@ use anyhow::Result; use async_trait::async_trait; use chrono::{DateTime, Utc}; use enum_dispatch::enum_dispatch; +use llm_chain::chains::conversation::Chain; +use llm_chain::prompt; use serde::Deserialize; -use tokio::sync::mpsc::Sender; use tokio::sync::Mutex; #[derive(Deserialize, Debug, Clone)] @@ -51,7 +52,8 @@ pub enum BotMessages { Message, // Telegram messages } -pub type SharedUsers = Arc>>; +pub type SharedUser = Arc>; +pub type SharedUsers = Arc>>; pub type CommandHashMap = HashMap>; #[async_trait] @@ -75,13 +77,15 @@ pub trait BotCommands: Default + Send + Sync { #[async_trait] pub trait BotCommandHandler { - async fn handle(&self, user_tx: Sender, args: String) -> String; + async fn handle(&self, user: SharedUser, args: String) -> String; } #[derive(Default)] pub struct BotUser { chat_mode: AtomicBool, last_activity: DateTime, + // chain: Arc>, + chain: Arc>, } impl BotUser { @@ -96,50 +100,44 @@ pub trait BotUserActions { async fn get_last_activity(&self) -> DateTime; async fn set_chat_mode(&self, state: bool); async fn is_in_chat_mode(&self) -> bool; + async fn get_conversation(&self) -> Arc>; + async fn reset_conversation_chain(&self, system_prompt: &str) -> Result<()>; } #[async_trait] -impl BotUserActions for BotUser { +impl BotUserActions for SharedUser { async fn set_last_activity(&mut self, date: DateTime) { - self.last_activity = date; + self.write().expect("poisoned lock").last_activity = date; } async fn get_last_activity(&self) -> DateTime { - self.last_activity + self.read().expect("poisoned lock").last_activity } async fn set_chat_mode(&self, state: bool) { - self.chat_mode.store(state, Ordering::Relaxed); + self.write() + .expect("poisoned lock") + .chat_mode + .store(state, Ordering::Relaxed); } async fn is_in_chat_mode(&self) -> bool { - self.chat_mode.load(Ordering::Relaxed) - } -} - -#[async_trait] -impl BotUserActions for Sender { - async fn set_last_activity(&mut self, _date: DateTime) { - unimplemented!() - } - - async fn get_last_activity(&self) -> DateTime { - unimplemented!() + self.read() + .expect("poisoned lock") + .chat_mode + .load(Ordering::Relaxed) } - async fn set_chat_mode(&self, state: bool) { - self.send(BotUserCommand::UpdateChatMode { chat_mode: state }) - .await - .unwrap(); + async fn get_conversation(&self) -> Arc> { + self.read().expect("poisoned lock").chain.clone() } - async fn is_in_chat_mode(&self) -> bool { - unimplemented!() + async fn reset_conversation_chain(&self, system_prompt: &str) -> Result<()> { + self.write().expect("poisoned lock").chain = + Arc::new(Mutex::new(Chain::new(prompt!(system: system_prompt))?)); + Ok(()) } } -pub enum BotUserCommand { - UpdateChatMode { chat_mode: bool }, -} pub enum ForecastTime { Later(u32),