From dd4ae79345575cba97bc6592505388f193ab4f18 Mon Sep 17 00:00:00 2001 From: Robert Kornacki <11645932+robkorn@users.noreply.github.com> Date: Tue, 12 Sep 2023 15:50:19 +0200 Subject: [PATCH 01/26] Basic renaming --- .../shinkai_message_schemas.rs | 20 +++++----- .../shinkai_utils/shinkai_message_builder.rs | 32 +++++++++------- .../shinkai_job_wrapper.rs | 24 ++++++++---- .../shinkai_message_builder_wrapper.rs | 35 +++++++++++++---- src/managers/job_manager.rs | 38 ++++++++++++------- tests/toolkit_tests.rs | 3 +- 6 files changed, 98 insertions(+), 54 deletions(-) diff --git a/shinkai-libs/shinkai-message-primitives/src/shinkai_message/shinkai_message_schemas.rs b/shinkai-libs/shinkai-message-primitives/src/shinkai_message/shinkai_message_schemas.rs index d24cb61d3..1edf9ccfb 100644 --- a/shinkai-libs/shinkai-message-primitives/src/shinkai_message/shinkai_message_schemas.rs +++ b/shinkai-libs/shinkai-message-primitives/src/shinkai_message/shinkai_message_schemas.rs @@ -1,10 +1,10 @@ use std::fmt; -use serde::{Deserialize, Serialize, Serializer, Deserializer}; -use serde_json::Result; use regex::Regex; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use serde_json::Result; -use crate::schemas::{inbox_name::InboxName, shinkai_name::ShinkaiName, agents::serialized_agent::SerializedAgent}; +use crate::schemas::{agents::serialized_agent::SerializedAgent, inbox_name::InboxName, shinkai_name::ShinkaiName}; #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] pub enum MessageSchemaType { @@ -17,7 +17,7 @@ pub enum MessageSchemaType { APIReadUpToTimeRequest, APIAddAgentRequest, TextContent, - Empty + Empty, } impl MessageSchemaType { @@ -95,7 +95,7 @@ impl JobScope { } #[derive(Serialize, Deserialize, Debug, Clone)] -pub struct JobCreation { +pub struct JobCreationInfo { pub scope: JobScope, } @@ -214,9 +214,9 @@ pub struct RegistrationCodeRequest { #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] #[serde(rename_all = "lowercase")] pub enum IdentityPermissions { - Admin, // can create and delete other profiles + Admin, // can create and delete other profiles Standard, // can add / remove devices - None, // none of the above + None, // none of the above } impl IdentityPermissions { @@ -272,7 +272,7 @@ impl Serialize for RegistrationCodeType { RegistrationCodeType::Device(device_name) => { let s = format!("device:{}", device_name); serializer.serialize_str(&s) - }, + } RegistrationCodeType::Profile => serializer.serialize_str("profile"), } } @@ -289,7 +289,7 @@ impl<'de> Deserialize<'de> for RegistrationCodeType { Some(&"device") => { let device_name = parts.get(1).unwrap_or(&"main"); Ok(RegistrationCodeType::Device(device_name.to_string())) - }, + } Some(&"profile") => Ok(RegistrationCodeType::Profile), _ => Err(serde::de::Error::custom("Unexpected variant")), } @@ -303,4 +303,4 @@ impl fmt::Display for RegistrationCodeType { RegistrationCodeType::Profile => write!(f, "profile"), } } -} \ No newline at end of file +} diff --git a/shinkai-libs/shinkai-message-primitives/src/shinkai_utils/shinkai_message_builder.rs b/shinkai-libs/shinkai-message-primitives/src/shinkai_utils/shinkai_message_builder.rs index 6cf0d4cf3..f8ff8a7b7 100644 --- a/shinkai-libs/shinkai-message-primitives/src/shinkai_utils/shinkai_message_builder.rs +++ b/shinkai-libs/shinkai-message-primitives/src/shinkai_utils/shinkai_message_builder.rs @@ -15,7 +15,7 @@ use crate::{ }, shinkai_message_schemas::{ APIAddAgentRequest, APIGetMessagesFromInboxRequest, APIReadUpToTimeRequest, IdentityPermissions, - JobCreation, JobMessage, JobScope, MessageSchemaType, RegistrationCodeRequest, RegistrationCodeType, + JobCreationInfo, JobMessage, JobScope, MessageSchemaType, RegistrationCodeRequest, RegistrationCodeType, }, }, shinkai_utils::{ @@ -394,7 +394,7 @@ impl ShinkaiMessageBuilder { node_receiver: ProfileName, node_receiver_subidentity: ProfileName, ) -> Result { - let job_creation = JobCreation { scope }; + let job_creation = JobCreationInfo { scope }; let body = serde_json::to_string(&job_creation).map_err(|_| "Failed to serialize job creation to JSON")?; ShinkaiMessageBuilder::new(my_encryption_secret_key, my_signature_secret_key, receiver_public_key) @@ -461,18 +461,22 @@ impl ShinkaiMessageBuilder { // Use for placeholder. These messages *are not* encrypted so it's not required let (placeholder_encryption_sk, placeholder_encryption_pk) = unsafe_deterministic_encryption_keypair(0); - ShinkaiMessageBuilder::new(placeholder_encryption_sk, my_signature_secret_key, placeholder_encryption_pk) - .message_raw_content(body) - .internal_metadata_with_schema( - "".to_string(), - "".to_string(), - inbox, - MessageSchemaType::JobMessageSchema, - EncryptionMethod::None, - ) - .body_encryption(EncryptionMethod::None) - .external_metadata(node_receiver, node_sender) - .build() + ShinkaiMessageBuilder::new( + placeholder_encryption_sk, + my_signature_secret_key, + placeholder_encryption_pk, + ) + .message_raw_content(body) + .internal_metadata_with_schema( + "".to_string(), + "".to_string(), + inbox, + MessageSchemaType::JobMessageSchema, + EncryptionMethod::None, + ) + .body_encryption(EncryptionMethod::None) + .external_metadata(node_receiver, node_sender) + .build() } pub fn terminate_message( diff --git a/shinkai-libs/shinkai-message-wasm/src/shinkai_wasm_wrappers/shinkai_job_wrapper.rs b/shinkai-libs/shinkai-message-wasm/src/shinkai_wasm_wrappers/shinkai_job_wrapper.rs index ee1cd1e94..a4b932609 100644 --- a/shinkai-libs/shinkai-message-wasm/src/shinkai_wasm_wrappers/shinkai_job_wrapper.rs +++ b/shinkai-libs/shinkai-message-wasm/src/shinkai_wasm_wrappers/shinkai_job_wrapper.rs @@ -1,6 +1,9 @@ use serde::{Deserialize, Serialize}; use serde_wasm_bindgen; -use shinkai_message_primitives::{shinkai_message::shinkai_message_schemas::{JobScope, JobCreation, JobMessage}, schemas::inbox_name::InboxName}; +use shinkai_message_primitives::{ + schemas::inbox_name::InboxName, + shinkai_message::shinkai_message_schemas::{JobCreationInfo, JobMessage, JobScope}, +}; use wasm_bindgen::prelude::*; use crate::shinkai_wasm_wrappers::shinkai_wasm_error::ShinkaiWasmError; @@ -36,7 +39,7 @@ impl JobScopeWrapper { #[wasm_bindgen] #[derive(Serialize, Deserialize, Clone, Debug)] pub struct JobCreationWrapper { - inner: JobCreation, + inner: JobCreationInfo, } #[wasm_bindgen] @@ -44,7 +47,7 @@ impl JobCreationWrapper { #[wasm_bindgen(constructor)] pub fn new(scope_js: &JsValue) -> Result { let scope: JobScope = serde_wasm_bindgen::from_value(scope_js.clone())?; - let job_creation = JobCreation { scope }; + let job_creation = JobCreationInfo { scope }; Ok(JobCreationWrapper { inner: job_creation }) } @@ -66,13 +69,13 @@ impl JobCreationWrapper { #[wasm_bindgen] pub fn from_json_str(s: &str) -> Result { - let deserialized: JobCreation = serde_json::from_str(s).map_err(|e| JsValue::from_str(&e.to_string()))?; + let deserialized: JobCreationInfo = serde_json::from_str(s).map_err(|e| JsValue::from_str(&e.to_string()))?; Ok(JobCreationWrapper { inner: deserialized }) } #[wasm_bindgen] pub fn from_jsvalue(js_value: &JsValue) -> Result { - let deserialized: JobCreation = serde_wasm_bindgen::from_value(js_value.clone())?; + let deserialized: JobCreationInfo = serde_wasm_bindgen::from_value(js_value.clone())?; Ok(JobCreationWrapper { inner: deserialized }) } @@ -81,7 +84,9 @@ impl JobCreationWrapper { let buckets: Vec = Vec::new(); let documents: Vec = Vec::new(); let job_scope = JobScope::new(Some(buckets), Some(documents)); - Ok(JobCreationWrapper { inner: JobCreation { scope: job_scope } }) + Ok(JobCreationWrapper { + inner: JobCreationInfo { scope: job_scope }, + }) } } @@ -126,7 +131,10 @@ impl JobMessageWrapper { #[wasm_bindgen(js_name = fromStrings)] pub fn from_strings(job_id: &str, content: &str) -> JobMessageWrapper { - let job_message = JobMessage { job_id: job_id.to_string(), content: content.to_string() }; + let job_message = JobMessage { + job_id: job_id.to_string(), + content: content.to_string(), + }; JobMessageWrapper { inner: job_message } } -} \ No newline at end of file +} diff --git a/shinkai-libs/shinkai-message-wasm/src/shinkai_wasm_wrappers/shinkai_message_builder_wrapper.rs b/shinkai-libs/shinkai-message-wasm/src/shinkai_wasm_wrappers/shinkai_message_builder_wrapper.rs index 9bf659ae5..eb6e48106 100644 --- a/shinkai-libs/shinkai-message-wasm/src/shinkai_wasm_wrappers/shinkai_message_builder_wrapper.rs +++ b/shinkai-libs/shinkai-message-wasm/src/shinkai_wasm_wrappers/shinkai_message_builder_wrapper.rs @@ -1,11 +1,29 @@ -use crate::shinkai_wasm_wrappers::{shinkai_message_wrapper::ShinkaiMessageWrapper, wasm_shinkai_message::SerdeWasmMethods, shinkai_wasm_error::{WasmErrorWrapper, ShinkaiWasmError}}; +use crate::shinkai_wasm_wrappers::{ + shinkai_message_wrapper::ShinkaiMessageWrapper, + shinkai_wasm_error::{ShinkaiWasmError, WasmErrorWrapper}, + wasm_shinkai_message::SerdeWasmMethods, +}; use ed25519_dalek::{PublicKey as SignaturePublicKey, SecretKey as SignatureStaticKey}; use js_sys::Uint8Array; use serde::{Deserialize, Serialize}; -use shinkai_message_primitives::{shinkai_utils::{encryption::{string_to_encryption_static_key, string_to_encryption_public_key, encryption_public_key_to_string, EncryptionMethod}, signatures::{string_to_signature_secret_key, signature_public_key_to_string}, shinkai_message_builder::{ShinkaiMessageBuilder, ProfileName}}, shinkai_message::shinkai_message_schemas::{IdentityPermissions, RegistrationCodeType, RegistrationCodeRequest, MessageSchemaType, APIGetMessagesFromInboxRequest, APIAddAgentRequest, APIReadUpToTimeRequest, JobScope, JobCreation, JobMessage}, schemas::{registration_code::RegistrationCode, inbox_name::InboxName, agents::serialized_agent::SerializedAgent}}; +use serde_wasm_bindgen::{from_value, to_value}; +use shinkai_message_primitives::{ + schemas::{agents::serialized_agent::SerializedAgent, inbox_name::InboxName, registration_code::RegistrationCode}, + shinkai_message::shinkai_message_schemas::{ + APIAddAgentRequest, APIGetMessagesFromInboxRequest, APIReadUpToTimeRequest, IdentityPermissions, + JobCreationInfo, JobMessage, JobScope, MessageSchemaType, RegistrationCodeRequest, RegistrationCodeType, + }, + shinkai_utils::{ + encryption::{ + encryption_public_key_to_string, string_to_encryption_public_key, string_to_encryption_static_key, + EncryptionMethod, + }, + shinkai_message_builder::{ProfileName, ShinkaiMessageBuilder}, + signatures::{signature_public_key_to_string, string_to_signature_secret_key}, + }, +}; use wasm_bindgen::prelude::*; use x25519_dalek::{PublicKey as EncryptionPublicKey, StaticSecret as EncryptionStaticKey}; -use serde_wasm_bindgen::{from_value, to_value}; #[wasm_bindgen] pub struct ShinkaiMessageBuilderWrapper { @@ -240,7 +258,8 @@ impl ShinkaiMessageBuilderWrapper { match builder.build() { Ok(shinkai_message) => { let js_value = shinkai_message.to_jsvalue().map_err(WasmErrorWrapper)?; - Ok(ShinkaiMessageWrapper::from_jsvalue(&js_value).map_err(|e| WasmErrorWrapper::new(ShinkaiWasmError::from(e)))?) + Ok(ShinkaiMessageWrapper::from_jsvalue(&js_value) + .map_err(|e| WasmErrorWrapper::new(ShinkaiWasmError::from(e)))?) } Err(e) => Err(JsValue::from_str(&e.to_string())), } @@ -255,9 +274,9 @@ impl ShinkaiMessageBuilderWrapper { pub fn build_to_jsvalue(&mut self) -> Result { if let Some(ref builder) = self.inner { match builder.build() { - Ok(shinkai_message) => { - shinkai_message.to_jsvalue().map_err(|e| JsValue::from_str(&e.to_string())) - } + Ok(shinkai_message) => shinkai_message + .to_jsvalue() + .map_err(|e| JsValue::from_str(&e.to_string())), Err(e) => Err(JsValue::from_str(e)), } } else { @@ -649,7 +668,7 @@ impl ShinkaiMessageBuilderWrapper { ) -> Result { let scope: JobScope = serde_wasm_bindgen::from_value(scope).map_err(|e| JsValue::from_str(&e.to_string()))?; - let job_creation = JobCreation { scope }; + let job_creation = JobCreationInfo { scope }; let body = serde_json::to_string(&job_creation).map_err(|e| JsValue::from_str(&e.to_string()))?; let mut builder = diff --git a/src/managers/job_manager.rs b/src/managers/job_manager.rs index d4994d62a..bc4479b28 100644 --- a/src/managers/job_manager.rs +++ b/src/managers/job_manager.rs @@ -9,7 +9,9 @@ use shinkai_message_primitives::{ }, shinkai_message::{ shinkai_message::{MessageBody, MessageData, ShinkaiMessage}, - shinkai_message_schemas::{JobCreation, JobMessage, JobPreMessage, JobRecipient, JobScope, MessageSchemaType}, + shinkai_message_schemas::{ + JobCreationInfo, JobMessage, JobPreMessage, JobRecipient, JobScope, MessageSchemaType, + }, }, shinkai_utils::{shinkai_message_builder::ShinkaiMessageBuilder, signatures::clone_signature_secret_key}, }; @@ -122,7 +124,7 @@ impl JobManager { let agent_manager = Arc::clone(&self.agent_manager); let receiver = Arc::clone(&self.job_manager_receiver); let node_profile_name_clone = self.node_profile_name.clone(); - let identity_secret_key_clone = clone_signature_secret_key(&self.identity_secret_key); + let identity_secret_key_clone = clone_signature_secret_key(&self.identity_secret_key); tokio::spawn(async move { while let Some((messages, job_id)) = receiver.lock().await.recv().await { for message in messages { @@ -130,14 +132,17 @@ impl JobManager { let shinkai_message_result = ShinkaiMessageBuilder::job_message_from_agent( job_id.clone(), - message.clone().content, + message.content.clone(), clone_signature_secret_key(&identity_secret_key_clone), - node_profile_name_clone.clone().to_string(), - node_profile_name_clone.clone().to_string(), + node_profile_name_clone.to_string(), + node_profile_name_clone.to_string(), ); if let Ok(shinkai_message) = shinkai_message_result { - if let Err(err) = agent_manager.handle_pre_message_schema(message, job_id.clone(), shinkai_message).await { + if let Err(err) = agent_manager + .handle_pre_message_schema(message, job_id.clone(), shinkai_message) + .await + { eprintln!("Error while handling pre message schema: {:?}", err); } } else if let Err(err) = shinkai_message_result { @@ -206,6 +211,7 @@ impl AgentManager { job_manager } + /// Checks that the provided ShinkaiMessage is an unencrypted job message pub fn is_job_message(&mut self, message: ShinkaiMessage) -> bool { match &message.body { MessageBody::Unencrypted(body) => match &body.message_data { @@ -221,7 +227,7 @@ impl AgentManager { pub async fn handle_job_creation_schema( &mut self, - job_creation: JobCreation, + job_creation: JobCreationInfo, agent_id: &String, ) -> Result { let job_id = format!("jobid_{}", uuid::Uuid::new_v4()); @@ -292,11 +298,14 @@ impl AgentManager { &mut self, pre_message: JobPreMessage, job_id: String, - shinkai_message: ShinkaiMessage + shinkai_message: ShinkaiMessage, ) -> Result { println!("handle_pre_message_schema> pre_message: {:?}", pre_message); - - self.db.lock().await.add_message_to_job_inbox(job_id.as_str(), &shinkai_message)?; + + self.db + .lock() + .await + .add_message_to_job_inbox(job_id.as_str(), &shinkai_message)?; Ok(String::new()) } @@ -311,7 +320,7 @@ impl AgentManager { let agent_name = ShinkaiName::from_shinkai_message_using_recipient_subidentity(&message)?; let agent_id = agent_name.get_agent_name().ok_or(JobManagerError::AgentNotFound)?; - let job_creation: JobCreation = serde_json::from_str(&data.message_raw_content) + let job_creation: JobCreationInfo = serde_json::from_str(&data.message_raw_content) .map_err(|_| JobManagerError::ContentParseFailed)?; self.handle_job_creation_schema(job_creation, &agent_id).await } @@ -324,7 +333,8 @@ impl AgentManager { let pre_message: JobPreMessage = serde_json::from_str(&data.message_raw_content) .map_err(|_| JobManagerError::ContentParseFailed)?; // TODO: we should be able to extract the job_id from the inbox - self.handle_pre_message_schema(pre_message, "".to_string(), message).await + self.handle_pre_message_schema(pre_message, "".to_string(), message) + .await } _ => { // Handle Empty message type if needed, or return an error if it's not a valid job message @@ -436,7 +446,9 @@ impl fmt::Display for JobManagerError { match self { JobManagerError::NotAJobMessage => write!(f, "Message is not a job message"), JobManagerError::JobNotFound => write!(f, "Job not found"), - JobManagerError::JobCreationDeserializationFailed => write!(f, "Failed to deserialize JobCreation message"), + JobManagerError::JobCreationDeserializationFailed => { + write!(f, "Failed to deserialize JobCreationInfo message") + } JobManagerError::JobMessageDeserializationFailed => write!(f, "Failed to deserialize JobMessage"), JobManagerError::JobPreMessageDeserializationFailed => write!(f, "Failed to deserialize JobPreMessage"), JobManagerError::MessageTypeParseFailed => write!(f, "Could not parse message type"), diff --git a/tests/toolkit_tests.rs b/tests/toolkit_tests.rs index 5753d849a..2987a92ae 100644 --- a/tests/toolkit_tests.rs +++ b/tests/toolkit_tests.rs @@ -200,7 +200,8 @@ fn test_tool_router_and_toolkit_flow() { // A fake test which purposefully fails so that we can generate embeddings // for all existing rust tools and print them into console (so we can copy-paste) -// and hard-code them in rust_tools.rs +// and hard-code them in rust_tools.rs. +// Temporary solution // #[test] // fn generate_rust_tool_embeddings() { // setup(); From e1bb094645016968b1cee87cd616bfd249f0187e Mon Sep 17 00:00:00 2001 From: Robert Kornacki <11645932+robkorn@users.noreply.github.com> Date: Tue, 12 Sep 2023 19:09:39 +0200 Subject: [PATCH 02/26] Fixed context doubling last message --- src/managers/job_manager.rs | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/managers/job_manager.rs b/src/managers/job_manager.rs index bc4479b28..ebb93d28c 100644 --- a/src/managers/job_manager.rs +++ b/src/managers/job_manager.rs @@ -294,6 +294,7 @@ impl AgentManager { } } + /// Adds pre-message to job inbox pub async fn handle_pre_message_schema( &mut self, pre_message: JobPreMessage, @@ -356,8 +357,8 @@ impl AgentManager { // Append current time as ISO8601 to step history let time_with_comment = format!("{}: {}", "Current datetime in RFC3339", Utc::now().to_rfc3339()); - let full_job = { self.db.lock().await.get_job(job.job_id()).unwrap() }; let job_id = job.job_id().to_string(); + let full_job = { self.db.lock().await.get_job(&job_id).unwrap() }; let mut context = full_job.step_history.clone(); context.push(time_with_comment); println!("decision_phase> context: {:?}", context); @@ -376,11 +377,7 @@ impl AgentManager { Some(agent) => { // Create a new async task where the agent's execute method will run // Note: agent execute run in a separate thread - let last_message = full_job - .step_history - .last() - .ok_or(JobManagerError::ContentParseFailed)? - .clone(); + let last_message = context.pop().ok_or(JobManagerError::ContentParseFailed)?.clone(); tokio::spawn(async move { let mut agent = agent.lock().await; agent.execute(last_message.to_string(), context, job_id).await; From fbf056b447082855e755c62561052dae565b04b5 Mon Sep 17 00:00:00 2001 From: Robert Kornacki <11645932+robkorn@users.noreply.github.com> Date: Tue, 12 Sep 2023 20:04:20 +0200 Subject: [PATCH 03/26] Proper last message impl --- src/managers/job_manager.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/managers/job_manager.rs b/src/managers/job_manager.rs index ebb93d28c..f785fb3f0 100644 --- a/src/managers/job_manager.rs +++ b/src/managers/job_manager.rs @@ -225,6 +225,7 @@ impl AgentManager { } } + /// Processes a job creation message pub async fn handle_job_creation_schema( &mut self, job_creation: JobCreationInfo, @@ -274,6 +275,7 @@ impl AgentManager { } } + /// Processes a job message and starts the decision phase pub async fn handle_job_message_schema( &mut self, message: ShinkaiMessage, @@ -357,12 +359,16 @@ impl AgentManager { // Append current time as ISO8601 to step history let time_with_comment = format!("{}: {}", "Current datetime in RFC3339", Utc::now().to_rfc3339()); + // Prepare context/latest message let job_id = job.job_id().to_string(); let full_job = { self.db.lock().await.get_job(&job_id).unwrap() }; let mut context = full_job.step_history.clone(); + let last_message = context.pop().ok_or(JobManagerError::ContentParseFailed)?.clone(); context.push(time_with_comment); println!("decision_phase> context: {:?}", context); + println!("decision_phase> last message: {:?}", last_message); + // Acquire Agent let agent_id = full_job.parent_agent_id; let mut agent_found = None; for agent in &self.agents { @@ -373,11 +379,11 @@ impl AgentManager { } } + // Execute LLM inferencing let response = match agent_found { Some(agent) => { // Create a new async task where the agent's execute method will run // Note: agent execute run in a separate thread - let last_message = context.pop().ok_or(JobManagerError::ContentParseFailed)?.clone(); tokio::spawn(async move { let mut agent = agent.lock().await; agent.execute(last_message.to_string(), context, job_id).await; From 08c79cba2684b5334df8572fdcb4dc990affad1d Mon Sep 17 00:00:00 2001 From: Robert Kornacki <11645932+robkorn@users.noreply.github.com> Date: Tue, 12 Sep 2023 22:03:37 +0200 Subject: [PATCH 04/26] Implemented initial prompt appending logic --- src/db/db_jobs.rs | 6 ++++-- src/managers/job_manager.rs | 30 +++++++++++++++++--------- src/utils/job_prompts.rs | 42 +++++++++++++++++++++++++++++++++++++ src/utils/mod.rs | 5 +++-- 4 files changed, 69 insertions(+), 14 deletions(-) create mode 100644 src/utils/job_prompts.rs diff --git a/src/db/db_jobs.rs b/src/db/db_jobs.rs index 5c10e19cc..c6aef1b08 100644 --- a/src/db/db_jobs.rs +++ b/src/db/db_jobs.rs @@ -280,14 +280,16 @@ impl ShinkaiDB { Ok(()) } - pub fn add_step_history(&self, job_id: String, step: String) -> Result<(), ShinkaiDBError> { + /// Adds a step to a job's step history + pub fn add_step_history(&self, job_id: String, step_output: String) -> Result<(), ShinkaiDBError> { let cf_name = format!("{}_step_history", &job_id); let cf_handle = self .db .cf_handle(&cf_name) .ok_or(ShinkaiDBError::ProfileNameNonExistent(cf_name))?; let current_time = ShinkaiTime::generate_time_now(); - self.db.put_cf(cf_handle, current_time.as_bytes(), step.as_bytes())?; + self.db + .put_cf(cf_handle, current_time.as_bytes(), step_output.as_bytes())?; Ok(()) } diff --git a/src/managers/job_manager.rs b/src/managers/job_manager.rs index f785fb3f0..9bccb8540 100644 --- a/src/managers/job_manager.rs +++ b/src/managers/job_manager.rs @@ -1,4 +1,5 @@ use crate::db::{db_errors::ShinkaiDBError, ShinkaiDB}; +use crate::utils::job_prompts::JOB_INIT_PROMPT; use chrono::Utc; use ed25519_dalek::{PublicKey as SignaturePublicKey, SecretKey as SignatureStaticKey}; use reqwest::Identity; @@ -352,21 +353,24 @@ impl AgentManager { } } + // When a new message is supplied to the job, the decision phase of the new step begins running + // (with its existing step history as context) which triggers calling the Agent's LLM. async fn decision_phase(&self, job: &dyn JobLike) -> Result<(), Box> { - // When a new message is supplied to the job, the decision phase of the new step begins running - // (with its existing step history as context) which triggers calling the Agent's LLM. - - // Append current time as ISO8601 to step history - let time_with_comment = format!("{}: {}", "Current datetime in RFC3339", Utc::now().to_rfc3339()); - - // Prepare context/latest message + // Fetch the job let job_id = job.job_id().to_string(); let full_job = { self.db.lock().await.get_job(&job_id).unwrap() }; + + // Fetch context, if this is the first message of the job (new job just created), prefill step history with the default initial prompt let mut context = full_job.step_history.clone(); + if context.len() == 1 { + context.insert(0, JOB_INIT_PROMPT.clone()); + self.db + .lock() + .await + .add_step_history(job_id.clone(), JOB_INIT_PROMPT.clone())?; + } + let last_message = context.pop().ok_or(JobManagerError::ContentParseFailed)?.clone(); - context.push(time_with_comment); - println!("decision_phase> context: {:?}", context); - println!("decision_phase> last message: {:?}", last_message); // Acquire Agent let agent_id = full_job.parent_agent_id; @@ -379,6 +383,12 @@ impl AgentManager { } } + // Append current time as ISO8601 to step history + let time_with_comment = format!("{}: {}", "Current datetime ", Utc::now().to_rfc3339()); + context.push(time_with_comment); + println!("decision_phase> context: {:?}", context); + println!("decision_phase> last message: {:?}", last_message); + // Execute LLM inferencing let response = match agent_found { Some(agent) => { diff --git a/src/utils/job_prompts.rs b/src/utils/job_prompts.rs new file mode 100644 index 000000000..5b6b9d3ac --- /dev/null +++ b/src/utils/job_prompts.rs @@ -0,0 +1,42 @@ +use lazy_static::lazy_static; + +lazy_static! { + pub static ref JOB_INIT_PROMPT: String = String::from( + r#"You are an agent who is currently running a job which receives task requests as messages, and outputs new messages that have tool calls which will be executed. When you respond, you must specify a list of one or more messages, with only the messages returned, no explanation or any other text. But make sure to explain in the content of the message the included results of the tools, and include the sub variable like `$1`. + +Note, you can send a message back to yourself recursively by inputting the job id in the userid field. Use this when you need to use two tools consecutively. + +Here is a grammar you must respond with and nothing else: + + +root ::= message [] root | message +message ::= '{' '"' 'to' '"' ':' '"' userid '"' ',' '"' 'content' '"' ':' '"' content_var '"' ',' '"' 'tool-calls' '"' ':' '[' tool_call_list ']' '}' +userid ::= '[@][@][a-zA-Z0-9._]+[.shinkai]' +content_var ::= content_text tool_output content_text +content_text ::= string +char ::= string +tool_output ::= '$' tool_id +tool_call_list ::= tool_call ',' tool_call_list | tool_call +tool_call ::= '"' tool_id '"' ':' '{' tool_specific '}' +tool_id ::= number +tool_specific ::= tool_1 | tool_2 | tool_3 +tool_1 ::= '{' '"' 'Weather' '"' ':' '"' city '"' '}' +tool_2 ::= '{' '"' 'DateTime' '"' ':' '"' string '"' '}' +tool_3 ::= '{' '"' 'Vector Search' '"' ':' '"' query '"' '}' +city ::= string +query ::= string +string ::= '[a-zA-Z0-9 ,.?!_]*' +number ::= '[1-9][0-9]*' + +My user id: @@bob.shinkai +Job id: 2as23gas3y68aje + +Task: + +- fetch the time and weather in Vancouver, sending it back to me +- fetch the time and weather in Vancouver, sending it back to my friend @@alice.shinkai +- search my vector database for "My Home Town" to find where my home town was, and fetch the weather for it after + +```json"# + ); +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index e4414f950..d8a430c81 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,7 +1,8 @@ pub mod args; +pub mod cli; pub mod environment; +pub mod job_prompts; pub mod keys; +pub mod logging_helpers; pub mod printer; -pub mod cli; pub mod qr_code_setup; -pub mod logging_helpers; \ No newline at end of file From eb96263332dc840096ce3c2b4727ec1c26e96592 Mon Sep 17 00:00:00 2001 From: Robert Kornacki <11645932+robkorn@users.noreply.github.com> Date: Wed, 13 Sep 2023 20:55:11 +0200 Subject: [PATCH 05/26] Initial work on tool prompts --- src/utils/job_prompts.rs | 78 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/src/utils/job_prompts.rs b/src/utils/job_prompts.rs index 5b6b9d3ac..41bf3c0a0 100644 --- a/src/utils/job_prompts.rs +++ b/src/utils/job_prompts.rs @@ -1,6 +1,84 @@ use lazy_static::lazy_static; lazy_static! { + static ref task_bootstrap_prompt: String = String::from( + r#" + You are an assistant running in a system who only has access to a series of tools and your own knowledge. The user has asked you: + + `What is the weather like today in New York?` + + + If it is a task not pertaining to recent/current knowledge and you can respond respond directly without any external help, respond using the following EBNF and absolutely nothing else: + + `"{" "answer" ":" string "}"` + + If you do not have the ability to respond correctly yourself, it is your goal is to find the final tool that will provide you with the capabilities you need. + Search to find tools which you can use, respond using the following EBNF and absolutely nothing else: + + "{" ("tool-search" ":" string) "}" + + Only respond with an answer if you are not using any tools. Make sure the response matches the EBNF and includes absolutely nothing else. + + ```json + "# + ); + static ref tool_selection_prompt: String = String::from( + r#" + + You are an assistant running in a system who only has access to a series of tools and your own knowledge. The user has asked the system: + + `What is the weather like today in New York?` + + Here are up to 10 of the most relevant tools available: + 1. Name: Weather Fetch - Description: Requests weather via an API given a city name. + 2. Name: Country Population - Description: Provides population numbers given a country name. + 3. Name: HTTP GET - Description: Issues an http get request to a specified URL. Note: Only fetch URLs from user's input or from output of other tools. + + It is your goal to select the final tool that will enable the system to accomplish the user's task. The system may end up needing to chain multiple tools to acquire all needed info/data, but the goal right now is to find the final tool. + Select the name of the tool from the list above that fulfill this, respond using the following EBNF and absolutely nothing else: + + "{" ("tool" ":" string) "}" + + If none of the tools match explain what the issue is by responding using the following EBNF and absolutely nothing else: + + "{" ("error" ":" string) "}" + + + ```json + + + + "# + ); + static ref tool_ebnf_prompt: String = String::from( + r#" + + You are an assistant running in a system who only has access to a series of tools and your own knowledge. The user has asked the system: + + `What is the weather like today in New York?` + + The system has selected the following tool to be used: + + Name: Weather Fetch + Description: Requests weather via an API given a city name. + Tool Input EBNF: ... + Tool Output EBNF: ... + + Your goal is to decide whether you have all of the information you need to fill out the Tool Input EBNF. + + If all of the data/information to use the tool is available, respond using the following EBNF and absolutely nothing else: + + "{" ("prepared" ":" true) "}" + + If you need to acquire more information in order to use this tool (ex. user's personal data, related facts, info from external APIs, etc.) then you will need to search for other tools that provide you with this data by responding using the following EBNF and absolutely nothing else: + + "{" ("tool-search" ":" string) "}" + + ```json + + + "# + ); pub static ref JOB_INIT_PROMPT: String = String::from( r#"You are an agent who is currently running a job which receives task requests as messages, and outputs new messages that have tool calls which will be executed. When you respond, you must specify a list of one or more messages, with only the messages returned, no explanation or any other text. But make sure to explain in the content of the message the included results of the tools, and include the sub variable like `$1`. From b56d570e73ce527e1208842e0b02ce1b12d5ceca Mon Sep 17 00:00:00 2001 From: Robert Kornacki <11645932+robkorn@users.noreply.github.com> Date: Wed, 13 Sep 2023 22:51:57 +0200 Subject: [PATCH 06/26] Specced out job step plan/execution flow --- src/utils/job_prompts.rs | 145 ++++++++++++++++++++++++++------------- 1 file changed, 96 insertions(+), 49 deletions(-) diff --git a/src/utils/job_prompts.rs b/src/utils/job_prompts.rs index 41bf3c0a0..a41769df0 100644 --- a/src/utils/job_prompts.rs +++ b/src/utils/job_prompts.rs @@ -1,16 +1,89 @@ use lazy_static::lazy_static; +// +// Core Job Step Flow +// +// Note this will all happen within a single Job step. We will probably end up summarizing the context/results from previous steps into the step history to be included as the base initial context for new steps. +// +// 0. User submits an initial message/request to their AI Agent. +// 1. An initial bootstrap plan is created based on the initial request from the user. +// +// 2. We enter into "analysis phase". +// 3a. Iterating starting from the first point in the plan, we ask the LLM true/false if it can provide an answer given it's personal knowledge + current context. +// 3b. If it can then we mark this analysis step as "prepared" and go back to 3a for the next bootstrap plan task. +// 3c. If not we tell the LLM to search for tools that would work for this task. +// 4a. We return a list of tools to it, and ask it to either select one, or return an error message. +// 4b. If it returns an error message, it means the plan can not be completed/Agent has failed, and we exit/ send message back to user with the error message. +// 4c. If it chooses one, we fetch the tool info including the input EBNF. +// 5a. We now show the input EBNF to the LLM, and ask it whether or not it has all the needed knowledge + potential data in the current context to be able to use the tool. (In either case after the LLM chooses) +// 5b. The LLM says it has all the needed info, then we add the tool's name/input EBNF to the current context, and either go back to 3a for the next bootstrap plan task if the task is now finished/prepared, or go to 6 if this tool was searched for to find an input for another tool. +// 5c. The LLM doesn't have all the info it needs, so it performs another tool search and we go back to 4a. +// 6. After resolving 4-5 for the new tool search, the new tool's input EBNF has been added into the context window, which will allow us to go back to 5a for the original tool, which enables the LLM to now state it has all the info it needs (marking the analysis step as prepared), thus going back to 3a for the next top level task. +// 7. After iterating through all the bootstrap plan tasks and analyzing them, we have created an "execution plan" that specifies all tool calls which will need to be made. +// +// 8. We now move to the "execution phase". +// 9. Using the execution plan, we move forward alternating between inferencing the LLM and executing a tool, as dictated by the plan. +// 10. To start we inference the LLM with the user's initial request text + the input EBNF of the first tool, and tell the LLM to fill out the input EBNF with real data. +// 11. The input JSON is taken and the tool is called/executed, with the results being added into the context. +// 12. With the tool executed, we now inference the LLM with just the context + the input EBNF of the next tool that it needs to fill out (we can skip user's request text). +// 13. We iterate through the entire execution plan (looping back/forth between 11/12) and arrive at the end with a context filled with all relevant data needed to answer the user's initial request. +// 14. We inference the LLM one last time, providing it just the context + list of executed tools, and telling it to respond to the user's request by using/summarizing the results. +// +// +// +// + lazy_static! { - static ref task_bootstrap_prompt: String = String::from( + static ref top_level_plan_bootstrap_prompt: String = String::from( r#" - You are an assistant running in a system who only has access to a series of tools and your own knowledge. The user has asked you: + + You are an assistant running in a system who only has access to a series of tools and your own knowledge to accomplish any task. + + The user has asked the system: `What is the weather like today in New York?` +Create a plan that the system will need to take in order to fulfill the user's task. Make sure to make separate steps for any sub-task where data, computation, or API access may need to happen from different sources. + +Keep each step in the plan extremely concise/high level comprising of a single sentence each. Do not mention anything optional, nothing about error checking or logging or displaying data. Anything related to parsing/formatting can be merged together into a single step. Any calls to APIs, including parsing the resulting data from the API, should be considered as a single step. + +Respond using the following EBNF and absolutely nothing else: +"{" "plan" ":" "[" string ("," string)* "]" "}" + +"# + ); + + +// Output: +// { +// "plan": [ +// "Retrieve the current date and time for New York.", +// "Query a weather API for New York's current weather using the obtained date and time.", +// "Parse the weather data to extract the current weather conditions." +// ] +// } + + + +// Example ebnf of weather fetch output for testing +// weather-fetch-output ::= "{" "city" ":" text "," "weather-description" ":" text "," "tool" ": \"Weather Fetch\", "}" text ::= [a-zA-Z0-9_]+ + + static ref task_bootstrap_prompt: String = String::from( + r#" + You are an assistant running in a system who only has access to a series of tools, your own knowledge, and the current context of acquired info includes: + + ``` + Datetime ::= 4DIGIT "-" 2DIGIT "-" 2DIGIT "T" 2DIGIT ":" 2DIGIT ":" 2DIGIT + ``` + + The current task at hand is: + + `Query a weather API for New York's current weather using the obtained date and time.` + If it is a task not pertaining to recent/current knowledge and you can respond respond directly without any external help, respond using the following EBNF and absolutely nothing else: - `"{" "answer" ":" string "}"` + `"{" "prepared" ":" true "}"` If you do not have the ability to respond correctly yourself, it is your goal is to find the final tool that will provide you with the capabilities you need. Search to find tools which you can use, respond using the following EBNF and absolutely nothing else: @@ -25,9 +98,15 @@ lazy_static! { static ref tool_selection_prompt: String = String::from( r#" - You are an assistant running in a system who only has access to a series of tools and your own knowledge. The user has asked the system: + You are an assistant running in a system who only has access to a series of tools, your own knowledge, and the current context of acquired info includes: - `What is the weather like today in New York?` + ``` + Datetime: 2023-09-13T14:30:00 + ``` + + The current task at hand is: + + `Query a weather API for New York's current weather using the obtained date and time.` Here are up to 10 of the most relevant tools available: 1. Name: Weather Fetch - Description: Requests weather via an API given a city name. @@ -53,18 +132,24 @@ lazy_static! { static ref tool_ebnf_prompt: String = String::from( r#" - You are an assistant running in a system who only has access to a series of tools and your own knowledge. The user has asked the system: + You are an assistant running in a system who only has access to a series of tools, your own knowledge, and the current context of acquired info includes: - `What is the weather like today in New York?` + ``` + Datetime: 2023-09-13T14:30:00 + ``` + + The current task at hand is: + + `Query a weather API for New York's current weather using the obtained date and time.` The system has selected the following tool to be used: - Name: Weather Fetch + Tool Name: Weather Fetch + Toolkit Name: weather-toolkit Description: Requests weather via an API given a city name. - Tool Input EBNF: ... - Tool Output EBNF: ... + Tool Input EBNF: "{" "city" ":" text "," "datetime" ":" text "," "tool" ": \"Weather Fetch\"," "toolkit" ": \"weather-toolkit\" }" text ::= [a-zA-Z0-9_]+ - Your goal is to decide whether you have all of the information you need to fill out the Tool Input EBNF. + Your goal is to decide whether for each field in the Tool Input EBNF, you have been provided all the needed data to fill it out fully. If all of the data/information to use the tool is available, respond using the following EBNF and absolutely nothing else: @@ -79,42 +164,4 @@ lazy_static! { "# ); - pub static ref JOB_INIT_PROMPT: String = String::from( - r#"You are an agent who is currently running a job which receives task requests as messages, and outputs new messages that have tool calls which will be executed. When you respond, you must specify a list of one or more messages, with only the messages returned, no explanation or any other text. But make sure to explain in the content of the message the included results of the tools, and include the sub variable like `$1`. - -Note, you can send a message back to yourself recursively by inputting the job id in the userid field. Use this when you need to use two tools consecutively. - -Here is a grammar you must respond with and nothing else: - - -root ::= message [] root | message -message ::= '{' '"' 'to' '"' ':' '"' userid '"' ',' '"' 'content' '"' ':' '"' content_var '"' ',' '"' 'tool-calls' '"' ':' '[' tool_call_list ']' '}' -userid ::= '[@][@][a-zA-Z0-9._]+[.shinkai]' -content_var ::= content_text tool_output content_text -content_text ::= string -char ::= string -tool_output ::= '$' tool_id -tool_call_list ::= tool_call ',' tool_call_list | tool_call -tool_call ::= '"' tool_id '"' ':' '{' tool_specific '}' -tool_id ::= number -tool_specific ::= tool_1 | tool_2 | tool_3 -tool_1 ::= '{' '"' 'Weather' '"' ':' '"' city '"' '}' -tool_2 ::= '{' '"' 'DateTime' '"' ':' '"' string '"' '}' -tool_3 ::= '{' '"' 'Vector Search' '"' ':' '"' query '"' '}' -city ::= string -query ::= string -string ::= '[a-zA-Z0-9 ,.?!_]*' -number ::= '[1-9][0-9]*' - -My user id: @@bob.shinkai -Job id: 2as23gas3y68aje - -Task: - -- fetch the time and weather in Vancouver, sending it back to me -- fetch the time and weather in Vancouver, sending it back to my friend @@alice.shinkai -- search my vector database for "My Home Town" to find where my home town was, and fetch the weather for it after - -```json"# - ); } From ba3c96e39b0bfdaa2b1f7c1611d7017c3f0c6f1a Mon Sep 17 00:00:00 2001 From: Robert Kornacki <11645932+robkorn@users.noreply.github.com> Date: Wed, 13 Sep 2023 23:27:01 +0200 Subject: [PATCH 07/26] Added small fix --- src/utils/job_prompts.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/utils/job_prompts.rs b/src/utils/job_prompts.rs index a41769df0..cb8893a52 100644 --- a/src/utils/job_prompts.rs +++ b/src/utils/job_prompts.rs @@ -13,7 +13,7 @@ use lazy_static::lazy_static; // 3b. If it can then we mark this analysis step as "prepared" and go back to 3a for the next bootstrap plan task. // 3c. If not we tell the LLM to search for tools that would work for this task. // 4a. We return a list of tools to it, and ask it to either select one, or return an error message. -// 4b. If it returns an error message, it means the plan can not be completed/Agent has failed, and we exit/ send message back to user with the error message. +// 4b. If it returns an error message, it means the plan can not be completed/Agent has failed, and we exit/send message back to user with the error message (15). // 4c. If it chooses one, we fetch the tool info including the input EBNF. // 5a. We now show the input EBNF to the LLM, and ask it whether or not it has all the needed knowledge + potential data in the current context to be able to use the tool. (In either case after the LLM chooses) // 5b. The LLM says it has all the needed info, then we add the tool's name/input EBNF to the current context, and either go back to 3a for the next bootstrap plan task if the task is now finished/prepared, or go to 6 if this tool was searched for to find an input for another tool. @@ -28,6 +28,7 @@ use lazy_static::lazy_static; // 12. With the tool executed, we now inference the LLM with just the context + the input EBNF of the next tool that it needs to fill out (we can skip user's request text). // 13. We iterate through the entire execution plan (looping back/forth between 11/12) and arrive at the end with a context filled with all relevant data needed to answer the user's initial request. // 14. We inference the LLM one last time, providing it just the context + list of executed tools, and telling it to respond to the user's request by using/summarizing the results. +// 15. We add a Shinkai message into the job's inbox with the LLM's response, allowing the user to see the result. // // // From 272b67be78836601f2d437b33ed635936b6a7381 Mon Sep 17 00:00:00 2001 From: Robert Kornacki <11645932+robkorn@users.noreply.github.com> Date: Thu, 14 Sep 2023 14:18:59 +0200 Subject: [PATCH 08/26] Improved prompt gen setup --- src/tools/argument.rs | 5 +++-- src/utils/job_prompts.rs | 28 ++++++++++++++++++++++++++-- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/src/tools/argument.rs b/src/tools/argument.rs index b94359a4e..9b3d436c0 100644 --- a/src/tools/argument.rs +++ b/src/tools/argument.rs @@ -52,7 +52,8 @@ impl ToolArgument { let name = &input_arg.name; let ebnf = input_arg.labled_ebnf(); - ebnf_result.push_str(&format!(r#""{}": {}, "#, name, name)); + // ebnf_result.push_str(&format!(r#""{}": {}, "#, name, name)); + ebnf_result.push_str(&format!(r#""tool": {}, "#, name)); // Add descriptions to argument definitions if set to true if add_arg_descriptions { @@ -66,7 +67,7 @@ impl ToolArgument { } // Add the toolkit name to the required inputs for the tool - ebnf_result.push_str(&format!(r#""{}": {}, "#, "toolkit_name", toolkit_name)); + ebnf_result.push_str(&format!(r#""{}": {}, "#, "toolkit", toolkit_name)); ebnf_result.push_str("}\n"); ebnf_result.push_str(&ebnf_arg_definitions); diff --git a/src/utils/job_prompts.rs b/src/utils/job_prompts.rs index cb8893a52..3ab6d9b83 100644 --- a/src/utils/job_prompts.rs +++ b/src/utils/job_prompts.rs @@ -35,14 +35,14 @@ use lazy_static::lazy_static; // lazy_static! { - static ref top_level_plan_bootstrap_prompt: String = String::from( + static ref bootstrap_plan__prompt: String = String::from( r#" You are an assistant running in a system who only has access to a series of tools and your own knowledge to accomplish any task. The user has asked the system: - `What is the weather like today in New York?` + `{}` Create a plan that the system will need to take in order to fulfill the user's task. Make sure to make separate steps for any sub-task where data, computation, or API access may need to happen from different sources. @@ -166,3 +166,27 @@ Respond using the following EBNF and absolutely nothing else: "# ); } + +pub struct PromptGenerator {} + +impl PromptGenerator { + pub fn bootstrap_plan_prompt(job_task: String) -> String { + format!( + r#" + You are an assistant running in a system who only has access to a series of tools and your own knowledge to accomplish any task. + + The user has asked the system: + + `{}` + + Create a plan that the system will need to take in order to fulfill the user's task. Make sure to make separate steps for any sub-task where data, computation, or API access may need to happen from different sources. + + Keep each step in the plan extremely concise/high level comprising of a single sentence each. Do not mention anything optional, nothing about error checking or logging or displaying data. Anything related to parsing/formatting can be merged together into a single step. Any calls to APIs, including parsing the resulting data from the API, should be considered as a single step. + + Respond using the following EBNF and absolutely nothing else: + "{{" "plan" ":" "[" string ("," string)* "]" "}}" + "#, + job_task + ) + } +} From 7ac48c0122f1672c0f67ab531db78f461527aec6 Mon Sep 17 00:00:00 2001 From: Robert Kornacki <11645932+robkorn@users.noreply.github.com> Date: Thu, 14 Sep 2023 15:24:07 +0200 Subject: [PATCH 09/26] Laid out job message processing logic --- src/db/db_jobs.rs | 80 ++++++++++++++++++++++++++++++------- src/managers/job_manager.rs | 79 +++++++++++++++++++++++------------- 2 files changed, 116 insertions(+), 43 deletions(-) diff --git a/src/db/db_jobs.rs b/src/db/db_jobs.rs index c6aef1b08..b242c7b7b 100644 --- a/src/db/db_jobs.rs +++ b/src/db/db_jobs.rs @@ -123,8 +123,15 @@ impl ShinkaiDB { } pub fn get_job(&self, job_id: &str) -> Result { - let (scope, is_finished, datetime_created, parent_agent_id, conversation_inbox, step_history) = - self.get_job_data(job_id, true)?; + let ( + scope, + is_finished, + datetime_created, + parent_agent_id, + conversation_inbox, + step_history, + unprocessed_messges, + ) = self.get_job_data(job_id, true)?; // Construct the job let job = Job { @@ -141,7 +148,7 @@ impl ShinkaiDB { } pub fn get_job_like(&self, job_id: &str) -> Result, ShinkaiDBError> { - let (scope, is_finished, datetime_created, parent_agent_id, conversation_inbox, _) = + let (scope, is_finished, datetime_created, parent_agent_id, conversation_inbox, _, unprocessed_messages) = self.get_job_data(job_id, false)?; // Construct the job @@ -158,15 +165,30 @@ impl ShinkaiDB { Ok(Box::new(job)) } + /// Fetches data for a specific Job from the DB fn get_job_data( &self, job_id: &str, fetch_step_history: bool, - ) -> Result<(JobScope, bool, String, String, InboxName, Option>), ShinkaiDBError> { + ) -> Result< + ( + JobScope, + bool, + String, + String, + InboxName, + Option>, + Vec, + ), + ShinkaiDBError, + > { + // Define cf names for all data we need to fetch let cf_job_id_name = format!("jobtopic_{}", job_id); let cf_job_id_scope_name = format!("{}_scope", job_id); let cf_job_id_step_history_name = format!("{}_step_history", job_id); + let cf_job_id_unprocessed_messages_name = format!("{}_unprocessed_messages", &job_id); + // Get the needed cf handles let cf_job_id_scope = self .db .cf_handle(&cf_job_id_scope_name) @@ -175,7 +197,18 @@ impl ShinkaiDB { .db .cf_handle(&cf_job_id_name) .ok_or(ShinkaiDBError::ColumnFamilyNotFound(cf_job_id_name))?; - + let cf_job_id_step_history = self + .db + .cf_handle(&cf_job_id_step_history_name) + .ok_or(ShinkaiDBError::ColumnFamilyNotFound(cf_job_id_step_history_name))?; + let cf_job_id_unprocessed_messages = + self.db + .cf_handle(&cf_job_id_unprocessed_messages_name) + .ok_or(ShinkaiDBError::ColumnFamilyNotFound( + cf_job_id_unprocessed_messages_name, + ))?; + + // Begin fetching the data from the DB let scope_value = self .db .get_cf(cf_job_id_scope, job_id)? @@ -188,11 +221,6 @@ impl ShinkaiDB { .ok_or(ShinkaiDBError::DataNotFound)?; let is_finished = std::str::from_utf8(&is_finished_value)?.to_string() == "true"; - let cf_job_id_step_history = self - .db - .cf_handle(&cf_job_id_step_history_name) - .ok_or(ShinkaiDBError::ColumnFamilyNotFound(cf_job_id_step_history_name))?; - let datetime_created_value = self .db .get_cf(cf_job_id, JobInfo::DatetimeCreated.to_str().as_bytes())? @@ -207,7 +235,6 @@ impl ShinkaiDB { let mut conversation_inbox: Option = None; let mut step_history: Option> = if fetch_step_history { Some(Vec::new()) } else { None }; - let conversation_inbox_value = self .db .get_cf(cf_job_id, JobInfo::ConversationInboxName.to_str().as_bytes())? @@ -215,6 +242,8 @@ impl ShinkaiDB { let inbox_name = std::str::from_utf8(&conversation_inbox_value)?.to_string(); conversation_inbox = Some(InboxName::new(inbox_name)?); + // Reads all of the step history by iterating + let mut step_history: Option> = if fetch_step_history { Some(Vec::new()) } else { None }; if let Some(ref mut step_history) = step_history { let iter = self.db.iterator_cf(cf_job_id_step_history, IteratorMode::Start); for item in iter { @@ -224,6 +253,15 @@ impl ShinkaiDB { } } + // Reads all of the unprocessed messages by iterating + let mut unprocessed_messages: Vec = Vec::new(); + let iter = self.db.iterator_cf(cf_job_id_unprocessed_messages, IteratorMode::Start); + for item in iter { + let (_key, value) = item.map_err(|e| ShinkaiDBError::RocksDBError(e))?; + let message = std::str::from_utf8(&value)?.to_string(); + unprocessed_messages.push(message); + } + Ok(( scope, is_finished, @@ -231,6 +269,7 @@ impl ShinkaiDB { parent_agent_id, conversation_inbox.unwrap(), step_history, + unprocessed_messages, )) } @@ -280,16 +319,27 @@ impl ShinkaiDB { Ok(()) } - /// Adds a step to a job's step history - pub fn add_step_history(&self, job_id: String, step_output: String) -> Result<(), ShinkaiDBError> { + /// Adds a message to a job's unprocessed messages list + pub fn add_to_unprocessed_messages_list(&self, job_id: String, message: String) -> Result<(), ShinkaiDBError> { + let cf_name = format!("{}_unprocessed_messages", &job_id); + let cf_handle = self + .db + .cf_handle(&cf_name) + .ok_or(ShinkaiDBError::ProfileNameNonExistent(cf_name))?; + let current_time = ShinkaiTime::generate_time_now(); + self.db.put_cf(cf_handle, current_time.as_bytes(), message.as_bytes())?; + Ok(()) + } + + /// Adds a String to a job's step history + pub fn add_step_history(&self, job_id: String, content: String) -> Result<(), ShinkaiDBError> { let cf_name = format!("{}_step_history", &job_id); let cf_handle = self .db .cf_handle(&cf_name) .ok_or(ShinkaiDBError::ProfileNameNonExistent(cf_name))?; let current_time = ShinkaiTime::generate_time_now(); - self.db - .put_cf(cf_handle, current_time.as_bytes(), step_output.as_bytes())?; + self.db.put_cf(cf_handle, current_time.as_bytes(), content.as_bytes())?; Ok(()) } diff --git a/src/managers/job_manager.rs b/src/managers/job_manager.rs index 9bccb8540..3cbcf935e 100644 --- a/src/managers/job_manager.rs +++ b/src/managers/job_manager.rs @@ -1,5 +1,5 @@ use crate::db::{db_errors::ShinkaiDBError, ShinkaiDB}; -use crate::utils::job_prompts::JOB_INIT_PROMPT; +use crate::utils::job_prompts::PromptGenerator; use chrono::Utc; use ed25519_dalek::{PublicKey as SignaturePublicKey, SecretKey as SignatureStaticKey}; use reqwest::Identity; @@ -287,10 +287,28 @@ impl AgentManager { let mut shinkai_db = self.db.lock().await; println!("handle_job_message_schema> job_message: {:?}", job_message); shinkai_db.add_message_to_job_inbox(&job_message.job_id.clone(), &message)?; - shinkai_db.add_step_history(job.job_id().to_string(), job_message.content.clone())?; + // shinkai_db.add_step_history(job.job_id().to_string(), job_message.content.clone())?; + + // + // Todo: Implement unprocessed messages logic + // If current unprocessed message count >= 1, then simply add unprocessed message and return success. + // However if unprocessed message count == 0, then: + // 0. You add the unprocessed message to the list in the DB + // 1. Start a while loop where every time you fetch the unprocessed messages for the job from the DB and check if there's >= 1 + // 2. You read the first/front unprocessed message (not pop from the back) + // 3. You start analysis phase to generate the analysis plan. + // 4. You then take the analysis plan and process the execution phase. + // 5. Once execution phase succeeds, you then delete the message from the unprocessed list in the DB + // and take the result and append it both to the Job inbox and step history + // 6. As we're in a while loop, go back to 1, meaning any new unprocessed messages added while the step was happening are now processed sequentially + + // + // let current_unprocessed_message_count = ... + shinkai_db.add_to_unprocessed_messages_list(job.job_id().to_string(), job_message.content.clone())?; + std::mem::drop(shinkai_db); // require to avoid deadlock - let _ = self.decision_phase(&**job).await?; + // let _ = self.decision_phase(&**job).await?; return Ok(job_message.job_id.clone()); } else { return Err(JobManagerError::JobNotFound); @@ -362,18 +380,18 @@ impl AgentManager { // Fetch context, if this is the first message of the job (new job just created), prefill step history with the default initial prompt let mut context = full_job.step_history.clone(); - if context.len() == 1 { - context.insert(0, JOB_INIT_PROMPT.clone()); - self.db - .lock() - .await - .add_step_history(job_id.clone(), JOB_INIT_PROMPT.clone())?; - } + // if context.len() == 1 { + // context.insert(0, JOB_INIT_PROMPT.clone()); + // self.db + // .lock() + // .await + // .add_step_history(job_id.clone(), JOB_INIT_PROMPT.clone())?; + // } let last_message = context.pop().ok_or(JobManagerError::ContentParseFailed)?.clone(); // Acquire Agent - let agent_id = full_job.parent_agent_id; + let agent_id = full_job.parent_agent_id.clone(); let mut agent_found = None; for agent in &self.agents { let locked_agent = agent.lock().await; @@ -383,32 +401,37 @@ impl AgentManager { } } + match agent_found { + Some(agent) => self.decision_iteration(full_job, context, last_message, agent).await, + None => Err(Box::new(JobManagerError::AgentNotFound)), + } + } + + async fn decision_iteration( + &self, + job: Job, + mut context: Vec, + last_message: String, + agent: Arc>, + ) -> Result<(), Box> { // Append current time as ISO8601 to step history let time_with_comment = format!("{}: {}", "Current datetime ", Utc::now().to_rfc3339()); context.push(time_with_comment); - println!("decision_phase> context: {:?}", context); - println!("decision_phase> last message: {:?}", last_message); + println!("decision_iteration> context: {:?}", context); + println!("decision_iteration> last message: {:?}", last_message); // Execute LLM inferencing - let response = match agent_found { - Some(agent) => { - // Create a new async task where the agent's execute method will run - // Note: agent execute run in a separate thread - tokio::spawn(async move { - let mut agent = agent.lock().await; - agent.execute(last_message.to_string(), context, job_id).await; - }) - .await?; - Ok(()) - } - None => Err(Box::new(JobManagerError::AgentNotFound)), - }; - println!("decision_phase> response: {:?}", response); + let response = tokio::spawn(async move { + let mut agent = agent.lock().await; + agent.execute(last_message, context, job.job_id().to_string()).await; + }) + .await?; + println!("decision_iteration> response: {:?}", response); // TODO: update this fn so it allows for recursion // let is_valid = self.is_decision_phase_output_valid().await; // if is_valid == false { - // self.decision_phase(job).await?; + // self.decision_iteration(job, context, last_message, agent).await?; // } // The expected output from the LLM is one or more `Premessage`s (a message that potentially From 64c8d9493276e357d01a0cbddb0c74811e538322 Mon Sep 17 00:00:00 2001 From: Robert Kornacki <11645932+robkorn@users.noreply.github.com> Date: Thu, 14 Sep 2023 16:50:13 +0200 Subject: [PATCH 10/26] Added unprocessed messages to Job --- src/db/db_errors.rs | 2 ++ src/db/db_jobs.rs | 17 ++++++++++++----- src/managers/job_manager.rs | 18 +++++++++--------- 3 files changed, 23 insertions(+), 14 deletions(-) diff --git a/src/db/db_errors.rs b/src/db/db_errors.rs index 6558aeb86..0ce1ea09d 100644 --- a/src/db/db_errors.rs +++ b/src/db/db_errors.rs @@ -48,6 +48,7 @@ pub enum ShinkaiDBError { ToolError(ToolError), MessageEncodingError(String), ShinkaiMessageError(String), + JobAlreadyExists(String), } impl fmt::Display for ShinkaiDBError { @@ -104,6 +105,7 @@ impl fmt::Display for ShinkaiDBError { ShinkaiDBError::DeviceNameNonExistent(e) => write!(f, "Device name does not exist: {}", e), ShinkaiDBError::MessageEncodingError(e) => write!(f, "Message encoding error: {}", e), ShinkaiDBError::ShinkaiMessageError(e) => write!(f, "ShinkaiMessage error: {}", e), + ShinkaiDBError::JobAlreadyExists(e) => write!(f, "Job attempted to be created, but already exists: {}", e), } } } diff --git a/src/db/db_jobs.rs b/src/db/db_jobs.rs index b242c7b7b..da606f95f 100644 --- a/src/db/db_jobs.rs +++ b/src/db/db_jobs.rs @@ -51,33 +51,35 @@ impl ShinkaiDB { let cf_conversation_inbox_name = format!("job_inbox::{}::false", &job_id); let cf_job_id_perms_name = format!("job_inbox::{}::false_perms", &job_id); let cf_job_id_unread_list_name = format!("job_inbox::{}::false_unread_list", &job_id); + let cf_job_id_unprocessed_messages_name = format!("{}_unprocessed_messages", &job_id); - // Check that the profile name exists in ProfilesIdentityKey, ProfilesEncryptionKey and ProfilesIdentityType + // Check that the cf handles exist, and create them if self.db.cf_handle(&cf_job_id_scope_name).is_some() || self.db.cf_handle(&cf_job_id_step_history_name).is_some() || self.db.cf_handle(&cf_job_id_name).is_some() || self.db.cf_handle(&cf_conversation_inbox_name).is_some() || self.db.cf_handle(&cf_job_id_perms_name).is_some() || self.db.cf_handle(&cf_job_id_unread_list_name).is_some() + || self.db.cf_handle(&cf_job_id_unprocessed_messages_name).is_some() { - return Err(ShinkaiDBError::ProfileNameAlreadyExists); + return Err(ShinkaiDBError::JobAlreadyExists(cf_job_id_name.to_string())); } if self.db.cf_handle(&cf_agent_id_name).is_none() { self.db.create_cf(&cf_agent_id_name, &cf_opts)?; } - self.db.create_cf(&cf_job_id_name, &cf_opts)?; self.db.create_cf(&cf_job_id_scope_name, &cf_opts)?; self.db.create_cf(&cf_job_id_step_history_name, &cf_opts)?; self.db.create_cf(&cf_conversation_inbox_name, &cf_opts)?; self.db.create_cf(&cf_job_id_perms_name, &cf_opts)?; self.db.create_cf(&cf_job_id_unread_list_name, &cf_opts)?; + self.db.create_cf(&cf_job_id_unprocessed_messages_name, &cf_opts)?; // Start a write batch let mut batch = WriteBatch::default(); - // Generate time now used as a key. it should be safe because it's generated here so it shouldn't be duplicated (presumably) + // Generate time currently, used as a key. It should be safe because it's generated here so it shouldn't be duplicated (presumably) let current_time = ShinkaiTime::generate_time_now(); let scope_bytes = scope.to_bytes()?; @@ -122,6 +124,7 @@ impl ShinkaiDB { Ok(()) } + /// Fetches a job from the DB pub fn get_job(&self, job_id: &str) -> Result { let ( scope, @@ -130,7 +133,7 @@ impl ShinkaiDB { parent_agent_id, conversation_inbox, step_history, - unprocessed_messges, + unprocessed_messages, ) = self.get_job_data(job_id, true)?; // Construct the job @@ -142,11 +145,13 @@ impl ShinkaiDB { scope, conversation_inbox_name: conversation_inbox, step_history: step_history.unwrap_or_else(Vec::new), + unprocessed_messages, }; Ok(job) } + /// Fetches a job from the DB as a Box pub fn get_job_like(&self, job_id: &str) -> Result, ShinkaiDBError> { let (scope, is_finished, datetime_created, parent_agent_id, conversation_inbox, _, unprocessed_messages) = self.get_job_data(job_id, false)?; @@ -160,6 +165,7 @@ impl ShinkaiDB { scope, conversation_inbox_name: conversation_inbox, step_history: Vec::new(), // Empty step history for JobLike + unprocessed_messages, }; Ok(Box::new(job)) @@ -273,6 +279,7 @@ impl ShinkaiDB { )) } + /// Fetches all jobs pub fn get_all_jobs(&self) -> Result>, ShinkaiDBError> { let cf_handle = self .db diff --git a/src/managers/job_manager.rs b/src/managers/job_manager.rs index 3cbcf935e..f22bb8581 100644 --- a/src/managers/job_manager.rs +++ b/src/managers/job_manager.rs @@ -36,23 +36,23 @@ pub trait JobLike: Send + Sync { #[derive(Clone, Debug)] pub struct Job { - // based on uuid + // Based on uuid pub job_id: String, // Format: "20230702T20533481346" or Utc::now().format("%Y%m%dT%H%M%S%f").to_string(); pub datetime_created: String, - // determines if the job is finished or not + // Marks if the job is finished or not pub is_finished: bool, - // identity of the parent agent. We just use a full identity name for simplicity + // Identity of the parent agent. We just use a full identity name for simplicity pub parent_agent_id: String, - // what storage buckets and/or documents are accessible to the LLM via vector search - // and/or direct querying based off bucket name/key + // What VectorResources the Job has access to when performing vector searches pub scope: JobScope, - // an inbox where messages to the agent from the user and messages from the agent are stored, + // An inbox where messages to the agent from the user and messages from the agent are stored, // enabling each job to have a classical chat/conversation UI pub conversation_inbox_name: InboxName, - // A step history (an ordered list of all messages submitted to the LLM which triggered a step to execute, - // including everything in the conversation inbox + any messages from the agent recursively calling itself or otherwise) + // The job's step history (an ordered list of all prompts/outputs from LLM inferencing when processing steps) pub step_history: Vec, + // An ordered list of the latest messages sent to the job which are yet to be processed + pub unprocessed_messages: Vec, } impl JobLike for Job { @@ -287,7 +287,7 @@ impl AgentManager { let mut shinkai_db = self.db.lock().await; println!("handle_job_message_schema> job_message: {:?}", job_message); shinkai_db.add_message_to_job_inbox(&job_message.job_id.clone(), &message)?; - // shinkai_db.add_step_history(job.job_id().to_string(), job_message.content.clone())?; + shinkai_db.add_step_history(job.job_id().to_string(), job_message.content.clone())?; // // Todo: Implement unprocessed messages logic From 929d071be100df9ed0f0b493cc99550bb0600e25 Mon Sep 17 00:00:00 2001 From: Robert Kornacki <11645932+robkorn@users.noreply.github.com> Date: Thu, 14 Sep 2023 17:07:13 +0200 Subject: [PATCH 11/26] Commented out soon to be deprecated parts of tests --- tests/agent_integration_tests.rs | 57 ++++++++++++++++++++++---------- 1 file changed, 39 insertions(+), 18 deletions(-) diff --git a/tests/agent_integration_tests.rs b/tests/agent_integration_tests.rs index 5b0e0c38e..7e3884cd4 100644 --- a/tests/agent_integration_tests.rs +++ b/tests/agent_integration_tests.rs @@ -226,15 +226,19 @@ fn node_agent_registration() { .await .unwrap(); let node2_last_messages = res2_receiver.recv().await.unwrap().expect("Failed to receive messages"); - println!("### node2_last_messages: {:?}", node2_last_messages); - let shinkai_message_content_agent = node2_last_messages[1].get_message_content().unwrap(); - let message_content_agent: JobMessage = serde_json::from_str(&shinkai_message_content_agent).unwrap(); - assert_eq!( - message_content_agent.content, - "\n\nHello there, how may I assist you today?".to_string() - ); - assert!(node2_last_messages.len() == 2); + // + // TODO: Write new tests for checking the job results + // + // println!("### node2_last_messages: {:?}", node2_last_messages); + // let shinkai_message_content_agent = node2_last_messages[1].get_message_content().unwrap(); + // let message_content_agent: JobMessage = serde_json::from_str(&shinkai_message_content_agent).unwrap(); + + // assert_eq!( + // message_content_agent.content, + // "\n\nHello there, how may I assist you today?".to_string() + // ); + // assert!(node2_last_messages.len() == 2); } { // Check Profile inboxes (to confirm job's there) @@ -302,11 +306,15 @@ fn node_agent_registration() { .unwrap(); let node2_last_messages = res2_receiver.recv().await.unwrap().expect("Failed to receive messages"); println!("### node2_last_messages: {:?}", node2_last_messages); - let shinkai_message_content_agent = node2_last_messages[2].get_message_content().unwrap(); - let message_content_agent: JobMessage = serde_json::from_str(&shinkai_message_content_agent).unwrap(); - assert_eq!(message_content_agent.content, message.to_string()); - assert!(node2_last_messages.len() == 3); + // + // TODO: Write new tests for checking the job results + // + // let shinkai_message_content_agent = node2_last_messages[2].get_message_content().unwrap(); + // let message_content_agent: JobMessage = serde_json::from_str(&shinkai_message_content_agent).unwrap(); + + // assert_eq!(message_content_agent.content, message.to_string()); + // assert!(node2_last_messages.len() == 3); let offset = format!( "{}:::{}", @@ -336,11 +344,15 @@ fn node_agent_registration() { .unwrap(); let node2_last_messages = res2_receiver.recv().await.unwrap().expect("Failed to receive messages"); println!("### node2_last_messages unread pagination: {:?}", node2_last_messages); - let shinkai_message_content_agent = node2_last_messages[0].get_message_content().unwrap(); - let message_content_agent: JobMessage = serde_json::from_str(&shinkai_message_content_agent).unwrap(); - assert!(node2_last_messages.len() == 2); - assert_eq!(message_content_agent.content, message.to_string()); + // + // TODO: Write new tests for checking the job results + // + // let shinkai_message_content_agent = node2_last_messages[0].get_message_content().unwrap(); + // let message_content_agent: JobMessage = serde_json::from_str(&shinkai_message_content_agent).unwrap(); + + // assert!(node2_last_messages.len() == 2); + // assert_eq!(message_content_agent.content, message.to_string()); // we mark read until the offset let read_msg = ShinkaiMessageBuilder::read_up_to_time( @@ -394,7 +406,11 @@ fn node_agent_registration() { "### unread after cleaning node2_last_messages len: {:?}", node2_last_messages.len() ); - assert!(node2_last_messages.len() == 2); + + // + // TODO: Write new tests for checking the job results + // + // assert!(node2_last_messages.len() == 2); } { // Send a scheduled message @@ -411,7 +427,12 @@ fn node_agent_registration() { .body_encryption(EncryptionMethod::DiffieHellmanChaChaPoly1305) .external_metadata_with_schedule(node1_identity_name.clone().to_string(), sender, future_time_2_secs) .message_raw_content(message.clone()) - .internal_metadata_with_inbox("".to_string(), "".to_string(), inbox_name.to_string(), EncryptionMethod::None) + .internal_metadata_with_inbox( + "".to_string(), + "".to_string(), + inbox_name.to_string(), + EncryptionMethod::None, + ) .build(); } }); From 0100d9dda909e415c068f387aff08c5166b5d285 Mon Sep 17 00:00:00 2001 From: Robert Kornacki <11645932+robkorn@users.noreply.github.com> Date: Thu, 14 Sep 2023 17:11:33 +0200 Subject: [PATCH 12/26] Fixed toolkit ebnf test --- src/tools/argument.rs | 2 +- tests/toolkit_tests.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tools/argument.rs b/src/tools/argument.rs index 9b3d436c0..3ac2d62f4 100644 --- a/src/tools/argument.rs +++ b/src/tools/argument.rs @@ -53,7 +53,7 @@ impl ToolArgument { let ebnf = input_arg.labled_ebnf(); // ebnf_result.push_str(&format!(r#""{}": {}, "#, name, name)); - ebnf_result.push_str(&format!(r#""tool": {}, "#, name)); + ebnf_result.push_str(&format!(r#""{}": {}, "#, name, name)); // Add descriptions to argument definitions if set to true if add_arg_descriptions { diff --git a/tests/toolkit_tests.rs b/tests/toolkit_tests.rs index 2987a92ae..f3a5161fa 100644 --- a/tests/toolkit_tests.rs +++ b/tests/toolkit_tests.rs @@ -49,7 +49,7 @@ fn test_default_js_toolkit_json_parsing() { assert_eq!(toolkit.name, "Google Calendar Toolkit"); assert_eq!( toolkit.tools[0].ebnf_inputs(false).replace("\n", ""), - r#"{"calendar_id": calendar_id, "text": text, "send_updates": send_updates, "toolkit_name": Google Calendar Toolkit, }calendar_id :== ([a-zA-Z0-9_]+)?text :== ([a-zA-Z0-9_]+)send_updates :== ("all" | "externalOnly" | "none")?"# + r#"{"calendar_id": calendar_id, "text": text, "send_updates": send_updates, "toolkit": Google Calendar Toolkit, }calendar_id :== ([a-zA-Z0-9_]+)?text :== ([a-zA-Z0-9_]+)send_updates :== ("all" | "externalOnly" | "none")?"# ); assert_eq!(toolkit.header_definitions.len(), 4); From 827af8e89a008a18aaa81c1baeb8790c2b7b8900 Mon Sep 17 00:00:00 2001 From: Robert Kornacki <11645932+robkorn@users.noreply.github.com> Date: Thu, 14 Sep 2023 19:43:45 +0200 Subject: [PATCH 13/26] Implemented iterator_cf_pb --- src/db/db.rs | 39 +++++++++++++++++++++++++++++++++++++-- src/db/db_jobs.rs | 2 ++ 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/src/db/db.rs b/src/db/db.rs index 895992d73..4727d53ec 100644 --- a/src/db/db.rs +++ b/src/db/db.rs @@ -1,5 +1,9 @@ use chrono::{DateTime, Utc}; -use rocksdb::{AsColumnFamilyRef, ColumnFamily, ColumnFamilyDescriptor, Error, IteratorMode, Options, WriteBatch, DB}; +use rocksdb::{ + AsColumnFamilyRef, ColumnFamily, ColumnFamilyDescriptor, DBCommon, DBIteratorWithThreadMode, Error, IteratorMode, + Options, SingleThreaded, WriteBatch, DB, +}; + use shinkai_message_primitives::{ schemas::{shinkai_name::ShinkaiName, shinkai_time::ShinkaiTime}, shinkai_message::shinkai_message::ShinkaiMessage, @@ -54,7 +58,7 @@ impl Topic { Self::VectorResources => "resources", Self::Agents => "agents", Self::Toolkits => "toolkits", - Self::MessagesToRetry => "mesages_to_retry" + Self::MessagesToRetry => "mesages_to_retry", } } } @@ -186,6 +190,31 @@ impl ShinkaiDB { self.get_cf(topic, new_key) } + /// Iterates over the provided column family + pub fn iterator_cf<'a>( + &'a self, + cf: &impl AsColumnFamilyRef, + ) -> Result, ShinkaiDBError> { + Ok(self.db.iterator_cf(cf, IteratorMode::Start)) + } + + /// Iterates over the provided column family profile-bounded, meaning that + /// we filter out all keys in the iterator which are not profile-bounded to the + /// correct profile, before returning the iterator. + pub fn iterator_cf_pb<'a>( + &'a self, + cf: &impl AsColumnFamilyRef, + profile: &ShinkaiName, + ) -> Result, Box<[u8]>), rocksdb::Error>> + 'a, ShinkaiDBError> { + let profile_prefix = ShinkaiDB::get_profile_name(profile)?.into_bytes(); + let iter = self.db.iterator_cf(cf, IteratorMode::Start); + let filtered_iter = iter.filter(move |result| match result { + Ok((key, _)) => key.starts_with(&profile_prefix), + Err(_) => false, + }); + Ok(filtered_iter) + } + /// Saves the value inside of the key at the provided column family pub fn put_cf(&self, cf: &impl AsColumnFamilyRef, key: K, value: V) -> Result<(), ShinkaiDBError> where @@ -239,6 +268,12 @@ impl ShinkaiDB { self.write(pb_batch.write_batch) } + /// Validates if the key has the provided profile name properly prepended to it + pub fn validate_profile_bound_key(key: &str, profile: &ShinkaiName) -> Result { + let profile_name = ShinkaiDB::get_profile_name(profile)?; + Ok(key.starts_with(&profile_name)) + } + /// Prepends the profile name to the provided key to make it "profile bound" pub fn generate_profile_bound_key(key: &str, profile: &ShinkaiName) -> Result { let mut prof_name = ShinkaiDB::get_profile_name(profile)?; diff --git a/src/db/db_jobs.rs b/src/db/db_jobs.rs index da606f95f..863d12073 100644 --- a/src/db/db_jobs.rs +++ b/src/db/db_jobs.rs @@ -36,6 +36,7 @@ impl JobInfo { } } +// TODO: Replace all db writes with the profile-bound interface impl ShinkaiDB { pub fn create_new_job(&mut self, job_id: String, agent_id: String, scope: JobScope) -> Result<(), ShinkaiDBError> { // Create Options for ColumnFamily @@ -297,6 +298,7 @@ impl ShinkaiDB { Ok(jobs) } + /// Fetches all jobs under a specific Agent pub fn get_agent_jobs(&self, agent_id: String) -> Result>, ShinkaiDBError> { let cf_name = format!("agentid_{}", &agent_id); let cf_handle = self From 8219c673c9231a826385dc1a47906561bd0282e6 Mon Sep 17 00:00:00 2001 From: Robert Kornacki <11645932+robkorn@users.noreply.github.com> Date: Thu, 14 Sep 2023 21:16:42 +0200 Subject: [PATCH 14/26] Implemented DB unprocessed message methods --- src/db/db_jobs.rs | 78 ++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 64 insertions(+), 14 deletions(-) diff --git a/src/db/db_jobs.rs b/src/db/db_jobs.rs index 863d12073..48c0d6487 100644 --- a/src/db/db_jobs.rs +++ b/src/db/db_jobs.rs @@ -193,7 +193,6 @@ impl ShinkaiDB { let cf_job_id_name = format!("jobtopic_{}", job_id); let cf_job_id_scope_name = format!("{}_scope", job_id); let cf_job_id_step_history_name = format!("{}_step_history", job_id); - let cf_job_id_unprocessed_messages_name = format!("{}_unprocessed_messages", &job_id); // Get the needed cf handles let cf_job_id_scope = self @@ -208,12 +207,6 @@ impl ShinkaiDB { .db .cf_handle(&cf_job_id_step_history_name) .ok_or(ShinkaiDBError::ColumnFamilyNotFound(cf_job_id_step_history_name))?; - let cf_job_id_unprocessed_messages = - self.db - .cf_handle(&cf_job_id_unprocessed_messages_name) - .ok_or(ShinkaiDBError::ColumnFamilyNotFound( - cf_job_id_unprocessed_messages_name, - ))?; // Begin fetching the data from the DB let scope_value = self @@ -261,13 +254,7 @@ impl ShinkaiDB { } // Reads all of the unprocessed messages by iterating - let mut unprocessed_messages: Vec = Vec::new(); - let iter = self.db.iterator_cf(cf_job_id_unprocessed_messages, IteratorMode::Start); - for item in iter { - let (_key, value) = item.map_err(|e| ShinkaiDBError::RocksDBError(e))?; - let message = std::str::from_utf8(&value)?.to_string(); - unprocessed_messages.push(message); - } + let unprocessed_messages = self.get_unprocessed_messages(job_id)?; Ok(( scope, @@ -316,6 +303,69 @@ impl ShinkaiDB { Ok(jobs) } + /// Fetches all unprocessed messages for a specific Job from the DB + fn get_unprocessed_messages(&self, job_id: &str) -> Result, ShinkaiDBError> { + // Get the iterator + let iter = self.get_unprocessed_messages_iterator(job_id)?; + + // Reads all of the unprocessed messages by iterating + let mut unprocessed_messages: Vec = Vec::new(); + for item in iter { + let (_key, value) = item.map_err(|e| ShinkaiDBError::RocksDBError(e))?; + let message = std::str::from_utf8(&value)?.to_string(); + unprocessed_messages.push(message); + } + + Ok(unprocessed_messages) + } + + /// Fetches an iterator over all unprocessed messages for a specific Job from the DB + fn get_unprocessed_messages_iterator<'a>( + &'a self, + job_id: &str, + ) -> Result, Box<[u8]>), rocksdb::Error>> + 'a, ShinkaiDBError> { + // Get the needed cf handle + let cf_job_id_unprocessed_messages = self._get_unprocessed_messages_handle(job_id)?; + + // Get the iterator + let iter = self.db.iterator_cf(cf_job_id_unprocessed_messages, IteratorMode::Start); + + Ok(iter) + } + + /// Removes the oldest unprocessed message for a specific Job from the DB + pub fn remove_oldest_unprocessed_message(&self, job_id: &str) -> Result<(), ShinkaiDBError> { + // Get the needed cf handle + let cf_job_id_unprocessed_messages = self._get_unprocessed_messages_handle(job_id)?; + + // Get the iterator + let mut iter = self.get_unprocessed_messages_iterator(job_id)?; + + // Get the oldest message (first item in the iterator) + if let Some(Ok((key, _))) = iter.next() { + // Remove the oldest message from the DB + self.db.delete_cf(cf_job_id_unprocessed_messages, key)?; + } + + Ok(()) + } + + /// Fetches the column family handle for unprocessed messages of a specific Job + fn _get_unprocessed_messages_handle(&self, job_id: &str) -> Result<&rocksdb::ColumnFamily, ShinkaiDBError> { + let cf_job_id_unprocessed_messages_name = format!("{}_unprocessed_messages", job_id); + + // Get the needed cf handle + let cf_job_id_unprocessed_messages = + self.db + .cf_handle(&cf_job_id_unprocessed_messages_name) + .ok_or(ShinkaiDBError::ColumnFamilyNotFound( + cf_job_id_unprocessed_messages_name, + ))?; + + Ok(cf_job_id_unprocessed_messages) + } + + /// Updates the Job to being finished pub fn update_job_to_finished(&self, job_id: String) -> Result<(), ShinkaiDBError> { let cf_name = format!("jobtopic_{}", &job_id); let cf_handle = self From 2b85a5dd11032bad64dba7ae06670f21acc8a136 Mon Sep 17 00:00:00 2001 From: Robert Kornacki <11645932+robkorn@users.noreply.github.com> Date: Thu, 14 Sep 2023 22:25:53 +0200 Subject: [PATCH 15/26] Specced out execution steps --- src/job/execution_steps.rs | 35 +++++++++++++++++++++++++++++++++++ src/job/mod.rs | 1 + src/lib.rs | 1 + src/managers/job_manager.rs | 4 ++-- src/utils/job_prompts.rs | 4 ++-- 5 files changed, 41 insertions(+), 4 deletions(-) create mode 100644 src/job/execution_steps.rs create mode 100644 src/job/mod.rs diff --git a/src/job/execution_steps.rs b/src/job/execution_steps.rs new file mode 100644 index 000000000..fa7e2bcf2 --- /dev/null +++ b/src/job/execution_steps.rs @@ -0,0 +1,35 @@ +use crate::tools::router::ShinkaiTool; +use serde_json::Value as JsonValue; + +// 1. We start with all execution plans filling the context and saving the user's message with an InitialExecutionStep +// 2. We then iterate through the rest of the steps. +// 3. If they're Inference steps, we just record the task message from the original bootstrap plan, and get a name for the output which we will assign before adding the result into the context. +// 4. If they're Tool steps, we execute the tool and use the output_name before adding the result to the context +// 5. Once all execution steps have been processed, we inference the LLM one last time, providing it the whole context + the user's initial message, and tell it to respond to the user using the context. +// 6. We then save the final execution context (eventually adding summarization/pruning) as the Job's persistent context, save all of the prompts/responses from the LLM in the step history, and add a ShinkaiMessage into the Job inbox with the final response. + +/// Initial data to be used for execution, including filling up the context +pub struct InitialExecutionStep { + initial_context: Option, + user_message: String, +} + +/// An execution step that the LLM decided it could perform without any tools. +pub struct InferenceExecutionStep { + plan_task_message: String, + output_name: String, +} + +/// An execution step that requires executing a ShinkaiTool. +/// Of note `output_name` is used to label the output of the tool with an alternate name +/// before adding the results into the execution context +pub struct ToolExecutionStep { + tool: ShinkaiTool, + output_name: String, +} + +pub enum ExecutionStep { + Initial(InitialExecutionStep), + Inference(InferenceExecutionStep), + Tool(ToolExecutionStep), +} diff --git a/src/job/mod.rs b/src/job/mod.rs new file mode 100644 index 000000000..6c7b6e521 --- /dev/null +++ b/src/job/mod.rs @@ -0,0 +1 @@ +pub mod execution_steps; diff --git a/src/lib.rs b/src/lib.rs index f87a5a9a0..805980f4f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ pub mod db; +pub mod job; pub mod managers; pub mod network; pub mod resources; diff --git a/src/managers/job_manager.rs b/src/managers/job_manager.rs index f22bb8581..a5e378046 100644 --- a/src/managers/job_manager.rs +++ b/src/managers/job_manager.rs @@ -296,8 +296,8 @@ impl AgentManager { // 0. You add the unprocessed message to the list in the DB // 1. Start a while loop where every time you fetch the unprocessed messages for the job from the DB and check if there's >= 1 // 2. You read the first/front unprocessed message (not pop from the back) - // 3. You start analysis phase to generate the analysis plan. - // 4. You then take the analysis plan and process the execution phase. + // 3. You start analysis phase to generate the execution plan. + // 4. You then take the execution plan and process the execution phase. // 5. Once execution phase succeeds, you then delete the message from the unprocessed list in the DB // and take the result and append it both to the Job inbox and step history // 6. As we're in a while loop, go back to 1, meaning any new unprocessed messages added while the step was happening are now processed sequentially diff --git a/src/utils/job_prompts.rs b/src/utils/job_prompts.rs index 3ab6d9b83..2da836338 100644 --- a/src/utils/job_prompts.rs +++ b/src/utils/job_prompts.rs @@ -23,7 +23,7 @@ use lazy_static::lazy_static; // // 8. We now move to the "execution phase". // 9. Using the execution plan, we move forward alternating between inferencing the LLM and executing a tool, as dictated by the plan. -// 10. To start we inference the LLM with the user's initial request text + the input EBNF of the first tool, and tell the LLM to fill out the input EBNF with real data. +// 10. To start we inference the LLM with the first step in the plan + the input EBNF of the first tool, and tell the LLM to fill out the input EBNF with real data. // 11. The input JSON is taken and the tool is called/executed, with the results being added into the context. // 12. With the tool executed, we now inference the LLM with just the context + the input EBNF of the next tool that it needs to fill out (we can skip user's request text). // 13. We iterate through the entire execution plan (looping back/forth between 11/12) and arrive at the end with a context filled with all relevant data needed to answer the user's initial request. @@ -35,7 +35,7 @@ use lazy_static::lazy_static; // lazy_static! { - static ref bootstrap_plan__prompt: String = String::from( + static ref bootstrap_plan_prompt: String = String::from( r#" You are an assistant running in a system who only has access to a series of tools and your own knowledge to accomplish any task. From 63f5165e2e78ab9bdff9de5e0b66cac7c25d5005 Mon Sep 17 00:00:00 2001 From: Robert Kornacki <11645932+robkorn@users.noreply.github.com> Date: Thu, 14 Sep 2023 22:40:55 +0200 Subject: [PATCH 16/26] Abstracted Job into separate files --- src/db/db.rs | 1 - src/db/db_jobs.rs | 7 +--- src/job/job.rs | 60 ++++++++++++++++++++++++++++ src/job/mod.rs | 1 + src/main.rs | 1 + src/managers/job_manager.rs | 79 +++---------------------------------- tests/job_manager_tests.rs | 11 +++++- 7 files changed, 79 insertions(+), 81 deletions(-) create mode 100644 src/job/job.rs diff --git a/src/db/db.rs b/src/db/db.rs index 4727d53ec..a9bb6d017 100644 --- a/src/db/db.rs +++ b/src/db/db.rs @@ -3,7 +3,6 @@ use rocksdb::{ AsColumnFamilyRef, ColumnFamily, ColumnFamilyDescriptor, DBCommon, DBIteratorWithThreadMode, Error, IteratorMode, Options, SingleThreaded, WriteBatch, DB, }; - use shinkai_message_primitives::{ schemas::{shinkai_name::ShinkaiName, shinkai_time::ShinkaiTime}, shinkai_message::shinkai_message::ShinkaiMessage, diff --git a/src/db/db_jobs.rs b/src/db/db_jobs.rs index 48c0d6487..161164b99 100644 --- a/src/db/db_jobs.rs +++ b/src/db/db_jobs.rs @@ -1,12 +1,9 @@ use super::{db::Topic, db_errors::ShinkaiDBError, ShinkaiDB}; -use crate::managers::job_manager::{Job, JobLike}; -use ed25519_dalek::{PublicKey as SignaturePublicKey, SecretKey as SignatureStaticKey}; -use rand::RngCore; -use rocksdb::{Error, IteratorMode, Options, WriteBatch}; +use crate::job::job::{Job, JobLike}; +use rocksdb::{IteratorMode, Options, WriteBatch}; use shinkai_message_primitives::schemas::{inbox_name::InboxName, shinkai_time::ShinkaiTime}; use shinkai_message_primitives::shinkai_message::shinkai_message::ShinkaiMessage; use shinkai_message_primitives::shinkai_message::shinkai_message_schemas::JobScope; -use x25519_dalek::{PublicKey as EncryptionPublicKey, StaticSecret as EncryptionStaticKey}; enum JobInfo { IsFinished, diff --git a/src/job/job.rs b/src/job/job.rs new file mode 100644 index 000000000..1c80e7d22 --- /dev/null +++ b/src/job/job.rs @@ -0,0 +1,60 @@ +use shinkai_message_primitives::{schemas::inbox_name::InboxName, shinkai_message::shinkai_message_schemas::JobScope}; + +pub type JobId = String; + +pub trait JobLike: Send + Sync { + fn job_id(&self) -> &str; + fn datetime_created(&self) -> &str; + fn is_finished(&self) -> bool; + fn parent_agent_id(&self) -> &str; + fn scope(&self) -> &JobScope; + fn conversation_inbox_name(&self) -> &InboxName; +} + +// Todo: Add a persistent_context: String +#[derive(Clone, Debug)] +pub struct Job { + // Based on uuid + pub job_id: String, + // Format: "20230702T20533481346" or Utc::now().format("%Y%m%dT%H%M%S%f").to_string(); + pub datetime_created: String, + // Marks if the job is finished or not + pub is_finished: bool, + // Identity of the parent agent. We just use a full identity name for simplicity + pub parent_agent_id: String, + // What VectorResources the Job has access to when performing vector searches + pub scope: JobScope, + // An inbox where messages to the agent from the user and messages from the agent are stored, + // enabling each job to have a classical chat/conversation UI + pub conversation_inbox_name: InboxName, + // The job's step history (an ordered list of all prompts/outputs from LLM inferencing when processing steps) + pub step_history: Vec, + // An ordered list of the latest messages sent to the job which are yet to be processed + pub unprocessed_messages: Vec, +} + +impl JobLike for Job { + fn job_id(&self) -> &str { + &self.job_id + } + + fn datetime_created(&self) -> &str { + &self.datetime_created + } + + fn is_finished(&self) -> bool { + self.is_finished + } + + fn parent_agent_id(&self) -> &str { + &self.parent_agent_id + } + + fn scope(&self) -> &JobScope { + &self.scope + } + + fn conversation_inbox_name(&self) -> &InboxName { + &self.conversation_inbox_name + } +} diff --git a/src/job/mod.rs b/src/job/mod.rs index 6c7b6e521..ee0371cba 100644 --- a/src/job/mod.rs +++ b/src/job/mod.rs @@ -1 +1,2 @@ pub mod execution_steps; +pub mod job; diff --git a/src/main.rs b/src/main.rs index 41c02c2f5..a2e7ad9ef 100644 --- a/src/main.rs +++ b/src/main.rs @@ -23,6 +23,7 @@ use tokio::runtime::Runtime; use x25519_dalek::{PublicKey as EncryptionPublicKey, StaticSecret as EncryptionStaticKey}; mod db; +mod job; mod managers; mod network; mod resources; diff --git a/src/managers/job_manager.rs b/src/managers/job_manager.rs index a5e378046..767fed831 100644 --- a/src/managers/job_manager.rs +++ b/src/managers/job_manager.rs @@ -1,87 +1,20 @@ +use super::{agent::Agent, IdentityManager}; use crate::db::{db_errors::ShinkaiDBError, ShinkaiDB}; -use crate::utils::job_prompts::PromptGenerator; +use crate::job::job::{Job, JobId, JobLike}; use chrono::Utc; -use ed25519_dalek::{PublicKey as SignaturePublicKey, SecretKey as SignatureStaticKey}; -use reqwest::Identity; +use ed25519_dalek::SecretKey as SignatureStaticKey; use shinkai_message_primitives::{ - schemas::{ - inbox_name::InboxName, - shinkai_name::{ShinkaiName, ShinkaiNameError}, - }, + schemas::shinkai_name::{ShinkaiName, ShinkaiNameError}, shinkai_message::{ shinkai_message::{MessageBody, MessageData, ShinkaiMessage}, - shinkai_message_schemas::{ - JobCreationInfo, JobMessage, JobPreMessage, JobRecipient, JobScope, MessageSchemaType, - }, + shinkai_message_schemas::{JobCreationInfo, JobMessage, JobPreMessage, MessageSchemaType}, }, shinkai_utils::{shinkai_message_builder::ShinkaiMessageBuilder, signatures::clone_signature_secret_key}, }; +use std::fmt; use std::result::Result::Ok; use std::{collections::HashMap, error::Error, sync::Arc}; -use std::{fmt, thread}; use tokio::sync::{mpsc, Mutex}; -use warp::path::full; -use x25519_dalek::{PublicKey as EncryptionPublicKey, StaticSecret as EncryptionStaticKey}; - -use super::{agent::Agent, IdentityManager}; - -pub trait JobLike: Send + Sync { - fn job_id(&self) -> &str; - fn datetime_created(&self) -> &str; - fn is_finished(&self) -> bool; - fn parent_agent_id(&self) -> &str; - fn scope(&self) -> &JobScope; - fn conversation_inbox_name(&self) -> &InboxName; -} - -#[derive(Clone, Debug)] -pub struct Job { - // Based on uuid - pub job_id: String, - // Format: "20230702T20533481346" or Utc::now().format("%Y%m%dT%H%M%S%f").to_string(); - pub datetime_created: String, - // Marks if the job is finished or not - pub is_finished: bool, - // Identity of the parent agent. We just use a full identity name for simplicity - pub parent_agent_id: String, - // What VectorResources the Job has access to when performing vector searches - pub scope: JobScope, - // An inbox where messages to the agent from the user and messages from the agent are stored, - // enabling each job to have a classical chat/conversation UI - pub conversation_inbox_name: InboxName, - // The job's step history (an ordered list of all prompts/outputs from LLM inferencing when processing steps) - pub step_history: Vec, - // An ordered list of the latest messages sent to the job which are yet to be processed - pub unprocessed_messages: Vec, -} - -impl JobLike for Job { - fn job_id(&self) -> &str { - &self.job_id - } - - fn datetime_created(&self) -> &str { - &self.datetime_created - } - - fn is_finished(&self) -> bool { - self.is_finished - } - - fn parent_agent_id(&self) -> &str { - &self.parent_agent_id - } - - fn scope(&self) -> &JobScope { - &self.scope - } - - fn conversation_inbox_name(&self) -> &InboxName { - &self.conversation_inbox_name - } -} - -type JobId = String; pub struct JobManager { pub agent_manager: Arc>, diff --git a/tests/job_manager_tests.rs b/tests/job_manager_tests.rs index f3b198fad..4af0d6d95 100644 --- a/tests/job_manager_tests.rs +++ b/tests/job_manager_tests.rs @@ -32,11 +32,12 @@ mod tests { utils::hash_string, }, }; + use shinkai_node::job::job::{Job, JobId, JobLike}; use shinkai_node::{ db::ShinkaiDB, managers::{ identity_manager, - job_manager::{AgentManager, JobLike, JobManager}, + job_manager::{AgentManager, JobManager}, }, }; use std::collections::HashMap; @@ -130,7 +131,13 @@ mod tests { } // Create JobManager - let mut job_manager = JobManager::new(db_arc.clone(), identity_manager, clone_signature_secret_key(&node1_identity_sk), node_profile_name.clone()).await; + let mut job_manager = JobManager::new( + db_arc.clone(), + identity_manager, + clone_signature_secret_key(&node1_identity_sk), + node_profile_name.clone(), + ) + .await; // Create a JobCreationMessage ShinkaiMessage let scope = JobScope { From 1e7e5218035c8de2f33ffde3f9a1b2b8974b43af Mon Sep 17 00:00:00 2001 From: Robert Kornacki <11645932+robkorn@users.noreply.github.com> Date: Thu, 14 Sep 2023 22:52:20 +0200 Subject: [PATCH 17/26] Large rearrange to start cleanup of agents/jobs --- src/{managers => agent}/agent.rs | 11 ++++-- .../agent_to_serialization.rs | 0 src/{job => agent}/execution_steps.rs | 0 src/{job => agent}/job.rs | 0 src/agent/mod.rs | 5 +++ src/{managers => agent}/providers/mod.rs | 0 src/{managers => agent}/providers/openai.rs | 33 +++++++++------- src/agent/providers/sleep_api.rs | 39 +++++++++++++++++++ src/db/db_jobs.rs | 2 +- src/job/mod.rs | 2 - src/lib.rs | 2 +- src/main.rs | 2 +- src/managers/job_manager.rs | 5 ++- src/managers/mod.rs | 3 -- src/managers/providers/sleep_api.rs | 32 --------------- tests/agent_integration_tests.rs | 2 +- tests/db_job_tests.rs | 21 ++++++---- tests/job_manager_tests.rs | 2 +- tests/node_retrying_tests.rs | 15 +++---- 19 files changed, 99 insertions(+), 77 deletions(-) rename src/{managers => agent}/agent.rs (97%) rename src/{managers => agent}/agent_to_serialization.rs (100%) rename src/{job => agent}/execution_steps.rs (100%) rename src/{job => agent}/job.rs (100%) create mode 100644 src/agent/mod.rs rename src/{managers => agent}/providers/mod.rs (100%) rename src/{managers => agent}/providers/openai.rs (79%) create mode 100644 src/agent/providers/sleep_api.rs delete mode 100644 src/job/mod.rs delete mode 100644 src/managers/providers/sleep_api.rs diff --git a/src/managers/agent.rs b/src/agent/agent.rs similarity index 97% rename from src/managers/agent.rs rename to src/agent/agent.rs index bac8bac3f..f1ca24c86 100644 --- a/src/managers/agent.rs +++ b/src/agent/agent.rs @@ -1,4 +1,4 @@ -use crate::managers::providers::Provider; +use super::providers::Provider; use reqwest::Client; use serde::{Deserialize, Serialize}; use shinkai_message_primitives::{ @@ -145,7 +145,10 @@ impl Agent { } impl Agent { - pub fn from_serialized_agent(serialized_agent: SerializedAgent, sender: mpsc::Sender<(Vec, String)>) -> Self { + pub fn from_serialized_agent( + serialized_agent: SerializedAgent, + sender: mpsc::Sender<(Vec, String)>, + ) -> Self { Self::new( serialized_agent.id, serialized_agent.full_identity_name, @@ -236,7 +239,9 @@ mod tests { ); tokio::spawn(async move { - agent.execute("Test".to_string(), context, "some_job_1".to_string()).await; + agent + .execute("Test".to_string(), context, "some_job_1".to_string()) + .await; }); let val = tokio::time::timeout(std::time::Duration::from_millis(600), rx.recv()).await; diff --git a/src/managers/agent_to_serialization.rs b/src/agent/agent_to_serialization.rs similarity index 100% rename from src/managers/agent_to_serialization.rs rename to src/agent/agent_to_serialization.rs diff --git a/src/job/execution_steps.rs b/src/agent/execution_steps.rs similarity index 100% rename from src/job/execution_steps.rs rename to src/agent/execution_steps.rs diff --git a/src/job/job.rs b/src/agent/job.rs similarity index 100% rename from src/job/job.rs rename to src/agent/job.rs diff --git a/src/agent/mod.rs b/src/agent/mod.rs new file mode 100644 index 000000000..b90782ad1 --- /dev/null +++ b/src/agent/mod.rs @@ -0,0 +1,5 @@ +pub mod agent; +pub mod agent_to_serialization; +pub mod execution_steps; +pub mod job; +pub mod providers; diff --git a/src/managers/providers/mod.rs b/src/agent/providers/mod.rs similarity index 100% rename from src/managers/providers/mod.rs rename to src/agent/providers/mod.rs diff --git a/src/managers/providers/openai.rs b/src/agent/providers/openai.rs similarity index 79% rename from src/managers/providers/openai.rs rename to src/agent/providers/openai.rs index c9aa2eb86..f1914caf0 100644 --- a/src/managers/providers/openai.rs +++ b/src/agent/providers/openai.rs @@ -1,13 +1,14 @@ +use super::AgentError; +use super::Provider; +use async_trait::async_trait; use reqwest::Client; use serde::{Deserialize, Serialize}; use serde_json; -use shinkai_message_primitives::{shinkai_message::shinkai_message_schemas::{JobPreMessage, JobRecipient}, schemas::agents::serialized_agent::OpenAI}; +use shinkai_message_primitives::{ + schemas::agents::serialized_agent::OpenAI, + shinkai_message::shinkai_message_schemas::{JobPreMessage, JobRecipient}, +}; use std::error::Error; -use async_trait::async_trait; - -use crate::managers::agent::AgentError; - -use super::Provider; #[derive(Debug, Deserialize)] pub struct Response { @@ -51,15 +52,19 @@ impl Provider for OpenAI { } fn extract_content(response: &Self::Response) -> Vec { - response.choices.iter().map(|choice| { - JobPreMessage { - tool_calls: Vec::new(), // TODO: You might want to replace this with actual values - content: choice.message.content.clone(), - recipient: JobRecipient::SelfNode, // TODO: This is a placeholder. You should replace this with the actual recipient. - } - }).collect() + response + .choices + .iter() + .map(|choice| { + JobPreMessage { + tool_calls: Vec::new(), // TODO: You might want to replace this with actual values + content: choice.message.content.clone(), + recipient: JobRecipient::SelfNode, // TODO: This is a placeholder. You should replace this with the actual recipient. + } + }) + .collect() } - + async fn call_api( &self, client: &Client, diff --git a/src/agent/providers/sleep_api.rs b/src/agent/providers/sleep_api.rs new file mode 100644 index 000000000..a64bd3187 --- /dev/null +++ b/src/agent/providers/sleep_api.rs @@ -0,0 +1,39 @@ +use super::AgentError; +use super::Provider; +use async_trait::async_trait; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use shinkai_message_primitives::{ + schemas::agents::serialized_agent::SleepAPI, + shinkai_message::shinkai_message_schemas::{JobPreMessage, JobRecipient}, +}; +use tokio::time::Duration; + +#[async_trait] +impl Provider for SleepAPI { + type Response = (); // Empty tuple as a stand-in for no data + + fn parse_response(_: &str) -> Result> { + Ok(()) + } + + fn extract_content(_: &Self::Response) -> Vec { + vec![JobPreMessage { + tool_calls: Vec::new(), + content: "OK".to_string(), + recipient: JobRecipient::SelfNode, + }] + } + + async fn call_api( + &self, + _: &Client, + _: Option<&String>, + _: Option<&String>, + _: &str, + _: Vec, + ) -> Result, AgentError> { + tokio::time::sleep(Duration::from_millis(500)).await; + Ok(Self::extract_content(&())) + } +} diff --git a/src/db/db_jobs.rs b/src/db/db_jobs.rs index 161164b99..9101188a8 100644 --- a/src/db/db_jobs.rs +++ b/src/db/db_jobs.rs @@ -1,5 +1,5 @@ use super::{db::Topic, db_errors::ShinkaiDBError, ShinkaiDB}; -use crate::job::job::{Job, JobLike}; +use crate::agent::job::{Job, JobLike}; use rocksdb::{IteratorMode, Options, WriteBatch}; use shinkai_message_primitives::schemas::{inbox_name::InboxName, shinkai_time::ShinkaiTime}; use shinkai_message_primitives::shinkai_message::shinkai_message::ShinkaiMessage; diff --git a/src/job/mod.rs b/src/job/mod.rs deleted file mode 100644 index ee0371cba..000000000 --- a/src/job/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod execution_steps; -pub mod job; diff --git a/src/lib.rs b/src/lib.rs index 805980f4f..acde140b1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,5 @@ +pub mod agent; pub mod db; -pub mod job; pub mod managers; pub mod network; pub mod resources; diff --git a/src/main.rs b/src/main.rs index a2e7ad9ef..0dc55edba 100644 --- a/src/main.rs +++ b/src/main.rs @@ -23,7 +23,7 @@ use tokio::runtime::Runtime; use x25519_dalek::{PublicKey as EncryptionPublicKey, StaticSecret as EncryptionStaticKey}; mod db; -mod job; +mod agent; mod managers; mod network; mod resources; diff --git a/src/managers/job_manager.rs b/src/managers/job_manager.rs index 767fed831..4744fc236 100644 --- a/src/managers/job_manager.rs +++ b/src/managers/job_manager.rs @@ -1,6 +1,7 @@ -use super::{agent::Agent, IdentityManager}; +use super::IdentityManager; +use crate::agent::agent::Agent; +use crate::agent::job::{Job, JobId, JobLike}; use crate::db::{db_errors::ShinkaiDBError, ShinkaiDB}; -use crate::job::job::{Job, JobId, JobLike}; use chrono::Utc; use ed25519_dalek::SecretKey as SignatureStaticKey; use shinkai_message_primitives::{ diff --git a/src/managers/mod.rs b/src/managers/mod.rs index 3ec2a48b2..a1d12ce81 100644 --- a/src/managers/mod.rs +++ b/src/managers/mod.rs @@ -2,6 +2,3 @@ pub mod identity_manager; pub use identity_manager::IdentityManager; pub mod identity_network_manager; pub mod job_manager; -pub mod agent; -pub mod agent_to_serialization; -pub mod providers; \ No newline at end of file diff --git a/src/managers/providers/sleep_api.rs b/src/managers/providers/sleep_api.rs deleted file mode 100644 index a93c67daf..000000000 --- a/src/managers/providers/sleep_api.rs +++ /dev/null @@ -1,32 +0,0 @@ -use crate::{managers::agent::AgentError}; - -use super::Provider; -use async_trait::async_trait; -use reqwest::Client; -use serde::{Deserialize, Serialize}; -use shinkai_message_primitives::{shinkai_message::shinkai_message_schemas::{JobPreMessage, JobRecipient}, schemas::agents::serialized_agent::SleepAPI}; -use tokio::time::Duration; - -#[async_trait] -impl Provider for SleepAPI { - type Response = (); // Empty tuple as a stand-in for no data - - fn parse_response(_: &str) -> Result> { - Ok(()) - } - - fn extract_content(_: &Self::Response) -> Vec { - vec![ - JobPreMessage { - tool_calls: Vec::new(), - content: "OK".to_string(), - recipient: JobRecipient::SelfNode, - } - ] - } - - async fn call_api(&self, _: &Client, _: Option<&String>, _: Option<&String>, _: &str, _: Vec) -> Result, AgentError> { - tokio::time::sleep(Duration::from_millis(500)).await; - Ok(Self::extract_content(&())) - } -} diff --git a/tests/agent_integration_tests.rs b/tests/agent_integration_tests.rs index 7e3884cd4..106c05359 100644 --- a/tests/agent_integration_tests.rs +++ b/tests/agent_integration_tests.rs @@ -12,7 +12,7 @@ use shinkai_message_primitives::shinkai_utils::signatures::{ clone_signature_secret_key, unsafe_deterministic_signature_keypair, }; use shinkai_message_primitives::shinkai_utils::utils::hash_string; -use shinkai_node::managers::agent; +use shinkai_node::agent::agent; use shinkai_node::network::node::NodeCommand; use shinkai_node::network::node_api::APIError; use shinkai_node::network::Node; diff --git a/tests/db_job_tests.rs b/tests/db_job_tests.rs index a1de05421..3d911b62d 100644 --- a/tests/db_job_tests.rs +++ b/tests/db_job_tests.rs @@ -3,9 +3,7 @@ use std::{fs, path::Path}; use async_std::task; use rocksdb::{Error, Options, WriteBatch}; use shinkai_message_primitives::shinkai_message::shinkai_message_schemas::JobScope; -use shinkai_node::{ - db::ShinkaiDB, -}; +use shinkai_node::db::ShinkaiDB; fn create_new_job(db: &mut ShinkaiDB, job_id: String, agent_id: String, scope: JobScope) { match db.create_new_job(job_id, agent_id, scope) { @@ -23,8 +21,11 @@ fn setup() { mod tests { use std::collections::HashSet; - use shinkai_message_primitives::{shinkai_message::shinkai_message_schemas::JobScope, shinkai_utils::utils::hash_string, schemas::inbox_name::InboxName}; - use shinkai_node::{db::db_errors::ShinkaiDBError, managers::agent}; + use shinkai_message_primitives::{ + schemas::inbox_name::InboxName, shinkai_message::shinkai_message_schemas::JobScope, + shinkai_utils::utils::hash_string, + }; + use shinkai_node::{agent::agent, db::db_errors::ShinkaiDBError}; use super::*; @@ -148,7 +149,10 @@ mod tests { match shinkai_db.get_job(&job_id) { Ok(_) => panic!("Expected an error when getting a non-existent job"), - Err(e) => assert_eq!(e, ShinkaiDBError::ColumnFamilyNotFound("non_existent_job_scope".to_string())), + Err(e) => assert_eq!( + e, + ShinkaiDBError::ColumnFamilyNotFound("non_existent_job_scope".to_string()) + ), } } @@ -184,7 +188,10 @@ mod tests { match shinkai_db.update_job_to_finished(job_id.clone()) { Ok(_) => panic!("Expected an error when updating a non-existent job"), - Err(e) => assert_eq!(e, ShinkaiDBError::ProfileNameNonExistent(format!("jobtopic_{}", job_id))), + Err(e) => assert_eq!( + e, + ShinkaiDBError::ProfileNameNonExistent(format!("jobtopic_{}", job_id)) + ), } } diff --git a/tests/job_manager_tests.rs b/tests/job_manager_tests.rs index 4af0d6d95..f9fe68279 100644 --- a/tests/job_manager_tests.rs +++ b/tests/job_manager_tests.rs @@ -32,7 +32,7 @@ mod tests { utils::hash_string, }, }; - use shinkai_node::job::job::{Job, JobId, JobLike}; + use shinkai_node::agent::job::{Job, JobId, JobLike}; use shinkai_node::{ db::ShinkaiDB, managers::{ diff --git a/tests/node_retrying_tests.rs b/tests/node_retrying_tests.rs index 889e99758..373e6194c 100644 --- a/tests/node_retrying_tests.rs +++ b/tests/node_retrying_tests.rs @@ -4,14 +4,14 @@ use shinkai_message_primitives::schemas::inbox_name::InboxName; use shinkai_message_primitives::schemas::shinkai_name::ShinkaiName; use shinkai_message_primitives::shinkai_message::shinkai_message_schemas::{JobMessage, MessageSchemaType}; use shinkai_message_primitives::shinkai_utils::encryption::{ - clone_static_secret_key, unsafe_deterministic_encryption_keypair, EncryptionMethod, encryption_public_key_to_string, + clone_static_secret_key, encryption_public_key_to_string, unsafe_deterministic_encryption_keypair, EncryptionMethod, }; use shinkai_message_primitives::shinkai_utils::shinkai_message_builder::ShinkaiMessageBuilder; use shinkai_message_primitives::shinkai_utils::signatures::{ clone_signature_secret_key, unsafe_deterministic_signature_keypair, }; use shinkai_message_primitives::shinkai_utils::utils::hash_string; -use shinkai_node::managers::agent; +use shinkai_node::agent::agent; use shinkai_node::network::node::NodeCommand; use shinkai_node::network::node_api::APIError; use shinkai_node::network::Node; @@ -154,8 +154,8 @@ fn node_retrying_test() { .await; } - // Send message from Node 2 subidentity to Node 1 - { + // Send message from Node 2 subidentity to Node 1 + { eprintln!("\n\n### Sending message from a node 2 profile to node 1 profile\n\n"); let message_content = "test body content".to_string(); @@ -183,10 +183,7 @@ fn node_retrying_test() { eprintln!("\n\n unchanged message: {:?}", unchanged_message); // Shutdown Node 1 - node1_commands_sender - .send(NodeCommand::Shutdown) - .await - .unwrap(); + node1_commands_sender.send(NodeCommand::Shutdown).await.unwrap(); let (res_send_msg_sender, res_send_msg_receiver): ( async_channel::Sender>, @@ -216,7 +213,7 @@ fn node_retrying_test() { .await .unwrap(); let node2_last_messages = res2_receiver.recv().await.unwrap(); - } + } }); let _ = tokio::try_join!(node1_handler, node2_handler, interactions_handler).unwrap(); From 38357535f9743017974b5ef096172f3a458d92e5 Mon Sep 17 00:00:00 2001 From: Robert Kornacki <11645932+robkorn@users.noreply.github.com> Date: Fri, 15 Sep 2023 17:03:08 +0200 Subject: [PATCH 18/26] Rearranged agent/manager errors --- src/agent/agent.rs | 35 +-------------- src/agent/error.rs | 33 +++++++++++++++ src/{utils => agent}/job_prompts.rs | 0 src/agent/mod.rs | 2 + src/agent/providers/mod.rs | 12 ++++-- src/managers/error.rs | 66 +++++++++++++++++++++++++++++ src/managers/identity_manager.rs | 15 ++----- src/managers/job_manager.rs | 66 +---------------------------- src/managers/mod.rs | 1 + src/network/node.rs | 3 +- src/network/node_error.rs | 10 +++-- src/utils/mod.rs | 1 - 12 files changed, 125 insertions(+), 119 deletions(-) create mode 100644 src/agent/error.rs rename src/{utils => agent}/job_prompts.rs (100%) create mode 100644 src/managers/error.rs diff --git a/src/agent/agent.rs b/src/agent/agent.rs index f1ca24c86..1cf7ee963 100644 --- a/src/agent/agent.rs +++ b/src/agent/agent.rs @@ -1,6 +1,6 @@ +use super::error::AgentError; use super::providers::Provider; use reqwest::Client; -use serde::{Deserialize, Serialize}; use shinkai_message_primitives::{ schemas::{ agents::serialized_agent::{AgentAPIModel, SerializedAgent}, @@ -8,7 +8,6 @@ use shinkai_message_primitives::{ }, shinkai_message::shinkai_message_schemas::{JobPreMessage, JobRecipient}, }; -use std::fmt; use std::sync::Arc; use tokio::sync::{mpsc, Mutex}; @@ -164,38 +163,6 @@ impl Agent { } } -pub enum AgentError { - UrlNotSet, - ApiKeyNotSet, - ReqwestError(reqwest::Error), -} - -impl fmt::Display for AgentError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - AgentError::UrlNotSet => write!(f, "URL is not set"), - AgentError::ApiKeyNotSet => write!(f, "API Key not set"), - AgentError::ReqwestError(err) => write!(f, "Reqwest error: {}", err), - } - } -} - -impl fmt::Debug for AgentError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - AgentError::UrlNotSet => f.debug_tuple("UrlNotSet").finish(), - AgentError::ApiKeyNotSet => f.debug_tuple("ApiKeyNotSet").finish(), - AgentError::ReqwestError(err) => f.debug_tuple("ReqwestError").field(err).finish(), - } - } -} - -impl From for AgentError { - fn from(err: reqwest::Error) -> AgentError { - AgentError::ReqwestError(err) - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/src/agent/error.rs b/src/agent/error.rs new file mode 100644 index 000000000..abbe46527 --- /dev/null +++ b/src/agent/error.rs @@ -0,0 +1,33 @@ +use std::fmt; + +pub enum AgentError { + UrlNotSet, + ApiKeyNotSet, + ReqwestError(reqwest::Error), +} + +impl fmt::Display for AgentError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + AgentError::UrlNotSet => write!(f, "URL is not set"), + AgentError::ApiKeyNotSet => write!(f, "API Key not set"), + AgentError::ReqwestError(err) => write!(f, "Reqwest error: {}", err), + } + } +} + +impl fmt::Debug for AgentError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + AgentError::UrlNotSet => f.debug_tuple("UrlNotSet").finish(), + AgentError::ApiKeyNotSet => f.debug_tuple("ApiKeyNotSet").finish(), + AgentError::ReqwestError(err) => f.debug_tuple("ReqwestError").field(err).finish(), + } + } +} + +impl From for AgentError { + fn from(err: reqwest::Error) -> AgentError { + AgentError::ReqwestError(err) + } +} diff --git a/src/utils/job_prompts.rs b/src/agent/job_prompts.rs similarity index 100% rename from src/utils/job_prompts.rs rename to src/agent/job_prompts.rs diff --git a/src/agent/mod.rs b/src/agent/mod.rs index b90782ad1..396fb8ba2 100644 --- a/src/agent/mod.rs +++ b/src/agent/mod.rs @@ -1,5 +1,7 @@ pub mod agent; pub mod agent_to_serialization; +pub mod error; pub mod execution_steps; pub mod job; +pub mod job_prompts; pub mod providers; diff --git a/src/agent/providers/mod.rs b/src/agent/providers/mod.rs index 51c71f9d3..b1a008479 100644 --- a/src/agent/providers/mod.rs +++ b/src/agent/providers/mod.rs @@ -1,7 +1,7 @@ use reqwest::Client; use shinkai_message_primitives::shinkai_message::shinkai_message_schemas::JobPreMessage; -use super::agent::AgentError; +use super::error::AgentError; use async_trait::async_trait; #[async_trait] @@ -9,9 +9,15 @@ pub trait Provider { type Response; fn parse_response(response_body: &str) -> Result>; fn extract_content(response: &Self::Response) -> Vec; - async fn call_api(&self, client: &Client, url: Option<&String>, api_key: Option<&String>, content: &str, step_history: Vec) -> Result, AgentError>; + async fn call_api( + &self, + client: &Client, + url: Option<&String>, + api_key: Option<&String>, + content: &str, + step_history: Vec, + ) -> Result, AgentError>; } pub mod openai; pub mod sleep_api; - diff --git a/src/managers/error.rs b/src/managers/error.rs new file mode 100644 index 000000000..51396a428 --- /dev/null +++ b/src/managers/error.rs @@ -0,0 +1,66 @@ +use crate::db::db_errors::ShinkaiDBError; +use shinkai_message_primitives::schemas::shinkai_name::ShinkaiNameError; +use std::fmt; + +#[derive(Debug)] +pub enum JobManagerError { + NotAJobMessage, + JobNotFound, + JobCreationDeserializationFailed, + JobMessageDeserializationFailed, + JobPreMessageDeserializationFailed, + MessageTypeParseFailed, + IO(String), + ShinkaiDB(ShinkaiDBError), + ShinkaiNameError(ShinkaiNameError), + AgentNotFound, + ContentParseFailed, +} + +impl fmt::Display for JobManagerError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + JobManagerError::NotAJobMessage => write!(f, "Message is not a job message"), + JobManagerError::JobNotFound => write!(f, "Job not found"), + JobManagerError::JobCreationDeserializationFailed => { + write!(f, "Failed to deserialize JobCreationInfo message") + } + JobManagerError::JobMessageDeserializationFailed => write!(f, "Failed to deserialize JobMessage"), + JobManagerError::JobPreMessageDeserializationFailed => write!(f, "Failed to deserialize JobPreMessage"), + JobManagerError::MessageTypeParseFailed => write!(f, "Could not parse message type"), + JobManagerError::IO(err) => write!(f, "IO error: {}", err), + JobManagerError::ShinkaiDB(err) => write!(f, "Shinkai DB error: {}", err), + JobManagerError::AgentNotFound => write!(f, "Agent not found"), + JobManagerError::ContentParseFailed => write!(f, "Failed to parse content"), + JobManagerError::ShinkaiNameError(err) => write!(f, "ShinkaiName error: {}", err), + } + } +} + +impl std::error::Error for JobManagerError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + JobManagerError::ShinkaiDB(err) => Some(err), + JobManagerError::ShinkaiNameError(err) => Some(err), + _ => None, + } + } +} + +impl From> for JobManagerError { + fn from(err: Box) -> JobManagerError { + JobManagerError::IO(err.to_string()) + } +} + +impl From for JobManagerError { + fn from(err: ShinkaiDBError) -> JobManagerError { + JobManagerError::ShinkaiDB(err) + } +} + +impl From for JobManagerError { + fn from(err: ShinkaiNameError) -> JobManagerError { + JobManagerError::ShinkaiNameError(err) + } +} diff --git a/src/managers/identity_manager.rs b/src/managers/identity_manager.rs index 865420a4e..087116ce2 100644 --- a/src/managers/identity_manager.rs +++ b/src/managers/identity_manager.rs @@ -1,23 +1,14 @@ +use super::identity_network_manager::IdentityNetworkManager; use crate::db::ShinkaiDB; use crate::network::node_error::NodeError; use crate::network::node_message_handlers::verify_message_signature; -use crate::schemas::identity::{DeviceIdentity, Identity, IdentityType, StandardIdentity, StandardIdentityType}; -use ed25519_dalek::{PublicKey as SignaturePublicKey, SecretKey as SignatureStaticKey}; +use crate::schemas::identity::{DeviceIdentity, Identity, StandardIdentity, StandardIdentityType}; use shinkai_message_primitives::schemas::agents::serialized_agent::SerializedAgent; use shinkai_message_primitives::schemas::shinkai_name::ShinkaiName; use shinkai_message_primitives::shinkai_message::shinkai_message::ShinkaiMessage; use shinkai_message_primitives::shinkai_message::shinkai_message_schemas::IdentityPermissions; -use shinkai_message_primitives::shinkai_utils::encryption::{ - encryption_public_key_to_string, encryption_public_key_to_string_ref, -}; -use shinkai_message_primitives::shinkai_utils::signatures::{ - signature_public_key_to_string, signature_public_key_to_string_ref, -}; -use std::sync::{Arc, Weak}; +use std::sync::Arc; use tokio::sync::Mutex; -use x25519_dalek::{PublicKey as EncryptionPublicKey, StaticSecret as EncryptionStaticKey}; - -use super::identity_network_manager::IdentityNetworkManager; #[derive(Clone)] pub struct IdentityManager { diff --git a/src/managers/job_manager.rs b/src/managers/job_manager.rs index 4744fc236..59bcb2edb 100644 --- a/src/managers/job_manager.rs +++ b/src/managers/job_manager.rs @@ -1,3 +1,4 @@ +use super::error::JobManagerError; use super::IdentityManager; use crate::agent::agent::Agent; use crate::agent::job::{Job, JobId, JobLike}; @@ -242,7 +243,7 @@ impl AgentManager { std::mem::drop(shinkai_db); // require to avoid deadlock - // let _ = self.decision_phase(&**job).await?; + let _ = self.decision_phase(&**job).await?; return Ok(job_message.job_id.clone()); } else { return Err(JobManagerError::JobNotFound); @@ -395,66 +396,3 @@ impl AgentManager { unimplemented!() } } - -#[derive(Debug)] -pub enum JobManagerError { - NotAJobMessage, - JobNotFound, - JobCreationDeserializationFailed, - JobMessageDeserializationFailed, - JobPreMessageDeserializationFailed, - MessageTypeParseFailed, - IO(String), - ShinkaiDB(ShinkaiDBError), - ShinkaiNameError(ShinkaiNameError), - AgentNotFound, - ContentParseFailed, -} - -impl fmt::Display for JobManagerError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - JobManagerError::NotAJobMessage => write!(f, "Message is not a job message"), - JobManagerError::JobNotFound => write!(f, "Job not found"), - JobManagerError::JobCreationDeserializationFailed => { - write!(f, "Failed to deserialize JobCreationInfo message") - } - JobManagerError::JobMessageDeserializationFailed => write!(f, "Failed to deserialize JobMessage"), - JobManagerError::JobPreMessageDeserializationFailed => write!(f, "Failed to deserialize JobPreMessage"), - JobManagerError::MessageTypeParseFailed => write!(f, "Could not parse message type"), - JobManagerError::IO(err) => write!(f, "IO error: {}", err), - JobManagerError::ShinkaiDB(err) => write!(f, "Shinkai DB error: {}", err), - JobManagerError::AgentNotFound => write!(f, "Agent not found"), - JobManagerError::ContentParseFailed => write!(f, "Failed to parse content"), - JobManagerError::ShinkaiNameError(err) => write!(f, "ShinkaiName error: {}", err), - } - } -} - -impl std::error::Error for JobManagerError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - match self { - JobManagerError::ShinkaiDB(err) => Some(err), - JobManagerError::ShinkaiNameError(err) => Some(err), - _ => None, - } - } -} - -impl From> for JobManagerError { - fn from(err: Box) -> JobManagerError { - JobManagerError::IO(err.to_string()) - } -} - -impl From for JobManagerError { - fn from(err: ShinkaiDBError) -> JobManagerError { - JobManagerError::ShinkaiDB(err) - } -} - -impl From for JobManagerError { - fn from(err: ShinkaiNameError) -> JobManagerError { - JobManagerError::ShinkaiNameError(err) - } -} diff --git a/src/managers/mod.rs b/src/managers/mod.rs index a1d12ce81..41a0c96fd 100644 --- a/src/managers/mod.rs +++ b/src/managers/mod.rs @@ -1,4 +1,5 @@ pub mod identity_manager; pub use identity_manager::IdentityManager; +pub mod error; pub mod identity_network_manager; pub mod job_manager; diff --git a/src/network/node.rs b/src/network/node.rs index 52ea25077..2e8f1a162 100644 --- a/src/network/node.rs +++ b/src/network/node.rs @@ -26,8 +26,9 @@ use x25519_dalek::{PublicKey as EncryptionPublicKey, StaticSecret as EncryptionS use crate::db::db_errors::ShinkaiDBError; use crate::db::ShinkaiDB; +use crate::managers::error::JobManagerError; use crate::managers::identity_manager::{self}; -use crate::managers::job_manager::{JobManager, JobManagerError}; +use crate::managers::job_manager::JobManager; use crate::managers::{job_manager, IdentityManager}; use crate::network::node_message_handlers::{ extract_message, handle_based_on_message_content_and_encryption, ping_pong, verify_message_signature, PingPong, diff --git a/src/network/node_error.rs b/src/network/node_error.rs index 3ef5d7e59..d074da891 100644 --- a/src/network/node_error.rs +++ b/src/network/node_error.rs @@ -1,6 +1,8 @@ -use shinkai_message_primitives::{shinkai_message::shinkai_message_error::ShinkaiMessageError, schemas::{inbox_name::InboxNameError, shinkai_name::ShinkaiNameError}}; -use crate::{managers::job_manager::JobManagerError, db::db_errors::ShinkaiDBError}; - +use crate::{db::db_errors::ShinkaiDBError, managers::error::JobManagerError}; +use shinkai_message_primitives::{ + schemas::{inbox_name::InboxNameError, shinkai_name::ShinkaiNameError}, + shinkai_message::shinkai_message_error::ShinkaiMessageError, +}; #[derive(Debug)] pub struct NodeError { @@ -69,4 +71,4 @@ impl From for NodeError { message: format!("ShinkaiNameError: {}", error.to_string()), } } -} \ No newline at end of file +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index d8a430c81..8648ac4f5 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,7 +1,6 @@ pub mod args; pub mod cli; pub mod environment; -pub mod job_prompts; pub mod keys; pub mod logging_helpers; pub mod printer; From 85eb86d36ff5707f379097373992953563388e0c Mon Sep 17 00:00:00 2001 From: Robert Kornacki <11645932+robkorn@users.noreply.github.com> Date: Fri, 15 Sep 2023 18:59:09 +0200 Subject: [PATCH 19/26] Started plan executor impl --- src/agent/error.rs | 8 ++++++++ src/agent/execution_steps.rs | 28 +++++++++++++++++++++++++--- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/src/agent/error.rs b/src/agent/error.rs index abbe46527..093b2f863 100644 --- a/src/agent/error.rs +++ b/src/agent/error.rs @@ -4,6 +4,7 @@ pub enum AgentError { UrlNotSet, ApiKeyNotSet, ReqwestError(reqwest::Error), + MissingInitialStepInExecutionPlan, } impl fmt::Display for AgentError { @@ -11,6 +12,10 @@ impl fmt::Display for AgentError { match self { AgentError::UrlNotSet => write!(f, "URL is not set"), AgentError::ApiKeyNotSet => write!(f, "API Key not set"), + AgentError::MissingInitialStepInExecutionPlan => write!( + f, + "The provided execution plan does not have an InitialExecutionStep as its first element." + ), AgentError::ReqwestError(err) => write!(f, "Reqwest error: {}", err), } } @@ -22,6 +27,9 @@ impl fmt::Debug for AgentError { AgentError::UrlNotSet => f.debug_tuple("UrlNotSet").finish(), AgentError::ApiKeyNotSet => f.debug_tuple("ApiKeyNotSet").finish(), AgentError::ReqwestError(err) => f.debug_tuple("ReqwestError").field(err).finish(), + AgentError::MissingInitialStepInExecutionPlan => { + f.debug_tuple("MissingInitialStepInExecutionPlan").finish() + } } } } diff --git a/src/agent/execution_steps.rs b/src/agent/execution_steps.rs index fa7e2bcf2..26ca1e41d 100644 --- a/src/agent/execution_steps.rs +++ b/src/agent/execution_steps.rs @@ -1,5 +1,7 @@ use crate::tools::router::ShinkaiTool; -use serde_json::Value as JsonValue; +use std::collections::HashMap; + +use super::error::AgentError; // 1. We start with all execution plans filling the context and saving the user's message with an InitialExecutionStep // 2. We then iterate through the rest of the steps. @@ -8,9 +10,29 @@ use serde_json::Value as JsonValue; // 5. Once all execution steps have been processed, we inference the LLM one last time, providing it the whole context + the user's initial message, and tell it to respond to the user using the context. // 6. We then save the final execution context (eventually adding summarization/pruning) as the Job's persistent context, save all of the prompts/responses from the LLM in the step history, and add a ShinkaiMessage into the Job inbox with the final response. -/// Initial data to be used for execution, including filling up the context +/// Struct that executes a plan (Vec) generated from the analysis phase +pub struct PlanExecutor { + context: HashMap, + user_message: String, + execution_plan: Vec, +} + +impl PlanExecutor { + pub fn new(execution_plan: Vec) -> Result { + match execution_plan.get(0) { + Some(ExecutionStep::Initial(initial_step)) => Ok(Self { + context: initial_step.initial_context.clone(), + user_message: initial_step.user_message.clone(), + execution_plan, + }), + _ => Err(AgentError::MissingInitialStepInExecutionPlan), + } + } +} + +/// Initial data to be used by the PlanExecutor, primarily to fill up the context pub struct InitialExecutionStep { - initial_context: Option, + initial_context: HashMap, user_message: String, } From 30d580da486b19550dfc767d00ad73e68319f443 Mon Sep 17 00:00:00 2001 From: Robert Kornacki <11645932+robkorn@users.noreply.github.com> Date: Fri, 15 Sep 2023 20:14:42 +0200 Subject: [PATCH 20/26] Added basic boilerplate for plan execution --- src/agent/agent.rs | 6 +-- src/agent/execution_steps.rs | 76 ++++++++++++++++++++++++++------ src/agent/providers/mod.rs | 2 +- src/agent/providers/openai.rs | 4 +- src/agent/providers/sleep_api.rs | 4 +- src/managers/job_manager.rs | 2 +- 6 files changed, 72 insertions(+), 22 deletions(-) diff --git a/src/agent/agent.rs b/src/agent/agent.rs index 1cf7ee963..b9180cc5d 100644 --- a/src/agent/agent.rs +++ b/src/agent/agent.rs @@ -1,5 +1,5 @@ use super::error::AgentError; -use super::providers::Provider; +use super::providers::LLMProvider; use reqwest::Client; use shinkai_message_primitives::{ schemas::{ @@ -111,7 +111,7 @@ impl Agent { } } - pub async fn execute(&self, content: String, context: Vec, job_id: String) { + pub async fn inference(&self, content: String, context: Vec, job_id: String) { if self.perform_locally { // No need to spawn a new task here self.process_locally(content.clone(), context.clone(), job_id).await; @@ -207,7 +207,7 @@ mod tests { tokio::spawn(async move { agent - .execute("Test".to_string(), context, "some_job_1".to_string()) + .inference("Test".to_string(), context, "some_job_1".to_string()) .await; }); diff --git a/src/agent/execution_steps.rs b/src/agent/execution_steps.rs index 26ca1e41d..493beb8e9 100644 --- a/src/agent/execution_steps.rs +++ b/src/agent/execution_steps.rs @@ -1,8 +1,8 @@ +use super::{agent::Agent, error::AgentError}; +use crate::agent::job_prompts::PromptGenerator; use crate::tools::router::ShinkaiTool; use std::collections::HashMap; -use super::error::AgentError; - // 1. We start with all execution plans filling the context and saving the user's message with an InitialExecutionStep // 2. We then iterate through the rest of the steps. // 3. If they're Inference steps, we just record the task message from the original bootstrap plan, and get a name for the output which we will assign before adding the result into the context. @@ -11,32 +11,80 @@ use super::error::AgentError; // 6. We then save the final execution context (eventually adding summarization/pruning) as the Job's persistent context, save all of the prompts/responses from the LLM in the step history, and add a ShinkaiMessage into the Job inbox with the final response. /// Struct that executes a plan (Vec) generated from the analysis phase -pub struct PlanExecutor { - context: HashMap, +#[derive(Clone, Debug)] +pub struct PlanExecutor<'a> { + agent: &'a Agent, + execution_context: HashMap, user_message: String, execution_plan: Vec, + inference_trace: Vec, } -impl PlanExecutor { - pub fn new(execution_plan: Vec) -> Result { +impl<'a> PlanExecutor<'a> { + pub fn new(agent: &'a Agent, execution_plan: &Vec) -> Result { match execution_plan.get(0) { - Some(ExecutionStep::Initial(initial_step)) => Ok(Self { - context: initial_step.initial_context.clone(), - user_message: initial_step.user_message.clone(), - execution_plan, - }), + Some(ExecutionStep::Initial(initial_step)) => { + let mut execution_plan = execution_plan.to_vec(); + let execution_context = initial_step.initial_execution_context.clone(); + let user_message = initial_step.user_message.clone(); + execution_plan.remove(0); // Remove the initial step + Ok(Self { + agent, + execution_context, + user_message, + execution_plan, + inference_trace: vec![], + }) + } _ => Err(AgentError::MissingInitialStepInExecutionPlan), } } + + // TODO: Properly implement this once we have jobs update for context + agent infernece/use tool + /// Executes the plan step-by-step, performing all inferencing & tool calls. + /// All content sent for inferencing and all responses from the LLM are saved in self.inference_trace + pub async fn execute(&mut self) -> Result<(), AgentError> { + for step in &self.execution_plan { + match step { + ExecutionStep::Inference(inference_step) => { + + // 1. Generate the content to be sent using prompt generator/self/step + // PromptGenerator::... + + // 2. Save the content to be sent to the LLM + // self.inference_trace.push(content) + + // 3. Inference + // self.agent + // .inference( + // inference_step.content.clone(), + // ) + // .await; + + // 4. Save full response to trace + // self.inference_trace.push(response) + + // 5. Find & parse the JSON in the response + } + ExecutionStep::Tool(tool_step) => { + // self.agent.use_tool(tool_step.tool.clone()).await?; + } + _ => (), + } + } + Ok(()) + } } -/// Initial data to be used by the PlanExecutor, primarily to fill up the context +/// Initial data to be consumed while creating the PlanExecutor, primarily to fill up the initial_execution_context +#[derive(Clone, Debug)] pub struct InitialExecutionStep { - initial_context: HashMap, + initial_execution_context: HashMap, user_message: String, } /// An execution step that the LLM decided it could perform without any tools. +#[derive(Clone, Debug)] pub struct InferenceExecutionStep { plan_task_message: String, output_name: String, @@ -45,11 +93,13 @@ pub struct InferenceExecutionStep { /// An execution step that requires executing a ShinkaiTool. /// Of note `output_name` is used to label the output of the tool with an alternate name /// before adding the results into the execution context +#[derive(Clone, Debug)] pub struct ToolExecutionStep { tool: ShinkaiTool, output_name: String, } +#[derive(Clone, Debug)] pub enum ExecutionStep { Initial(InitialExecutionStep), Inference(InferenceExecutionStep), diff --git a/src/agent/providers/mod.rs b/src/agent/providers/mod.rs index b1a008479..34e84de03 100644 --- a/src/agent/providers/mod.rs +++ b/src/agent/providers/mod.rs @@ -5,7 +5,7 @@ use super::error::AgentError; use async_trait::async_trait; #[async_trait] -pub trait Provider { +pub trait LLMProvider { type Response; fn parse_response(response_body: &str) -> Result>; fn extract_content(response: &Self::Response) -> Vec; diff --git a/src/agent/providers/openai.rs b/src/agent/providers/openai.rs index f1914caf0..3b17983a3 100644 --- a/src/agent/providers/openai.rs +++ b/src/agent/providers/openai.rs @@ -1,5 +1,5 @@ use super::AgentError; -use super::Provider; +use super::LLMProvider; use async_trait::async_trait; use reqwest::Client; use serde::{Deserialize, Serialize}; @@ -40,7 +40,7 @@ struct Usage { } #[async_trait] -impl Provider for OpenAI { +impl LLMProvider for OpenAI { type Response = Response; fn parse_response(response_body: &str) -> Result> { diff --git a/src/agent/providers/sleep_api.rs b/src/agent/providers/sleep_api.rs index a64bd3187..c83344e05 100644 --- a/src/agent/providers/sleep_api.rs +++ b/src/agent/providers/sleep_api.rs @@ -1,5 +1,5 @@ use super::AgentError; -use super::Provider; +use super::LLMProvider; use async_trait::async_trait; use reqwest::Client; use serde::{Deserialize, Serialize}; @@ -10,7 +10,7 @@ use shinkai_message_primitives::{ use tokio::time::Duration; #[async_trait] -impl Provider for SleepAPI { +impl LLMProvider for SleepAPI { type Response = (); // Empty tuple as a stand-in for no data fn parse_response(_: &str) -> Result> { diff --git a/src/managers/job_manager.rs b/src/managers/job_manager.rs index 59bcb2edb..6c8850d65 100644 --- a/src/managers/job_manager.rs +++ b/src/managers/job_manager.rs @@ -358,7 +358,7 @@ impl AgentManager { // Execute LLM inferencing let response = tokio::spawn(async move { let mut agent = agent.lock().await; - agent.execute(last_message, context, job.job_id().to_string()).await; + agent.inference(last_message, context, job.job_id().to_string()).await; }) .await?; println!("decision_iteration> response: {:?}", response); From 54f7713c97fc56637d8be14ca3ced6d4593d9832 Mon Sep 17 00:00:00 2001 From: Robert Kornacki <11645932+robkorn@users.noreply.github.com> Date: Fri, 15 Sep 2023 20:29:27 +0200 Subject: [PATCH 21/26] Added cf handle to db struct --- src/agent/job.rs | 22 +++++++++++++--------- src/db/db.rs | 4 ++++ src/db/db_jobs.rs | 6 ++++++ 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/src/agent/job.rs b/src/agent/job.rs index 1c80e7d22..6271c104c 100644 --- a/src/agent/job.rs +++ b/src/agent/job.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use shinkai_message_primitives::{schemas::inbox_name::InboxName, shinkai_message::shinkai_message_schemas::JobScope}; pub type JobId = String; @@ -14,23 +16,25 @@ pub trait JobLike: Send + Sync { // Todo: Add a persistent_context: String #[derive(Clone, Debug)] pub struct Job { - // Based on uuid + /// Based on uuid pub job_id: String, - // Format: "20230702T20533481346" or Utc::now().format("%Y%m%dT%H%M%S%f").to_string(); + /// Format: "20230702T20533481346" or Utc::now().format("%Y%m%dT%H%M%S%f").to_string(); pub datetime_created: String, - // Marks if the job is finished or not + /// Marks if the job is finished or not pub is_finished: bool, - // Identity of the parent agent. We just use a full identity name for simplicity + /// Identity of the parent agent. We just use a full identity name for simplicity pub parent_agent_id: String, - // What VectorResources the Job has access to when performing vector searches + /// What VectorResources the Job has access to when performing vector searches pub scope: JobScope, - // An inbox where messages to the agent from the user and messages from the agent are stored, - // enabling each job to have a classical chat/conversation UI + /// An inbox where messages to the agent from the user and messages from the agent are stored, + /// enabling each job to have a classical chat/conversation UI pub conversation_inbox_name: InboxName, - // The job's step history (an ordered list of all prompts/outputs from LLM inferencing when processing steps) + /// The job's step history (an ordered list of all prompts/outputs from LLM inferencing when processing steps) pub step_history: Vec, - // An ordered list of the latest messages sent to the job which are yet to be processed + /// An ordered list of the latest messages sent to the job which are yet to be processed pub unprocessed_messages: Vec, + // /// A hashmap which holds a bunch of labeled values which were generated as output from the latest Job step + // pub execution_context: HashMap, } impl JobLike for Job { diff --git a/src/db/db.rs b/src/db/db.rs index a9bb6d017..4db147e8b 100644 --- a/src/db/db.rs +++ b/src/db/db.rs @@ -257,6 +257,10 @@ impl ShinkaiDB { self.delete_cf(cf, new_key) } + /// Fetches the ColumnFamily handle. + pub fn cf_handle(&self, name: &str) -> Result<&ColumnFamily, ShinkaiDBError> { + self.db.cf_handle(name).ok_or(ShinkaiDBError::FailedFetchingCF) + } /// Saves the WriteBatch to the database pub fn write(&self, batch: WriteBatch) -> Result<(), ShinkaiDBError> { Ok(self.db.write(batch)?) diff --git a/src/db/db_jobs.rs b/src/db/db_jobs.rs index 9101188a8..b69b172e3 100644 --- a/src/db/db_jobs.rs +++ b/src/db/db_jobs.rs @@ -50,6 +50,7 @@ impl ShinkaiDB { let cf_job_id_perms_name = format!("job_inbox::{}::false_perms", &job_id); let cf_job_id_unread_list_name = format!("job_inbox::{}::false_unread_list", &job_id); let cf_job_id_unprocessed_messages_name = format!("{}_unprocessed_messages", &job_id); + let cf_job_id_execution_context_name = format!("{}_execution_context", &job_id); // Check that the cf handles exist, and create them if self.db.cf_handle(&cf_job_id_scope_name).is_some() @@ -59,6 +60,7 @@ impl ShinkaiDB { || self.db.cf_handle(&cf_job_id_perms_name).is_some() || self.db.cf_handle(&cf_job_id_unread_list_name).is_some() || self.db.cf_handle(&cf_job_id_unprocessed_messages_name).is_some() + || self.db.cf_handle(&cf_job_id_execution_context_name).is_some() { return Err(ShinkaiDBError::JobAlreadyExists(cf_job_id_name.to_string())); } @@ -73,6 +75,7 @@ impl ShinkaiDB { self.db.create_cf(&cf_job_id_perms_name, &cf_opts)?; self.db.create_cf(&cf_job_id_unread_list_name, &cf_opts)?; self.db.create_cf(&cf_job_id_unprocessed_messages_name, &cf_opts)?; + self.db.create_cf(&cf_job_id_execution_context_name, &cf_opts)?; // Start a write batch let mut batch = WriteBatch::default(); @@ -117,6 +120,9 @@ impl ShinkaiDB { .expect("to be able to access Topic::Inbox"); batch.put_cf(cf_inbox, &cf_conversation_inbox_name, &cf_conversation_inbox_name); + // Save an empty hashmap for the initial execution context + let cf_job_id_execution_context = self.cf_handle(&cf_job_id_execution_context_name)?; + self.db.write(batch)?; Ok(()) From 42f95cc5e58332ca417e5447f725c5d7ff75541c Mon Sep 17 00:00:00 2001 From: Robert Kornacki <11645932+robkorn@users.noreply.github.com> Date: Fri, 15 Sep 2023 20:36:37 +0200 Subject: [PATCH 22/26] Cleanup of types to ShinkaiDBError + cf_handle use --- src/db/db_agents.rs | 24 ++++++------ src/db/db_identity.rs | 62 +++++++++++++++--------------- src/db/db_identity_registration.rs | 18 ++++----- src/db/db_inbox.rs | 14 +++---- src/db/db_jobs.rs | 6 +-- src/managers/identity_manager.rs | 3 +- 6 files changed, 64 insertions(+), 63 deletions(-) diff --git a/src/db/db_agents.rs b/src/db/db_agents.rs index 7069128db..32de7e57c 100644 --- a/src/db/db_agents.rs +++ b/src/db/db_agents.rs @@ -5,8 +5,8 @@ use shinkai_message_primitives::schemas::{agents::serialized_agent::SerializedAg impl ShinkaiDB { // Fetches all agents from the Agents topic - pub fn get_all_agents(&self) -> Result, Error> { - let cf = self.db.cf_handle(Topic::Agents.as_str()).unwrap(); + pub fn get_all_agents(&self) -> Result, ShinkaiDBError> { + let cf = self.cf_handle(Topic::Agents.as_str())?; let mut result = Vec::new(); let iter = self.db.iterator_cf(cf, rocksdb::IteratorMode::Start); @@ -16,7 +16,7 @@ impl ShinkaiDB { let agent: SerializedAgent = from_slice(value.as_ref()).unwrap(); result.push(agent); } - Err(e) => return Err(e), + Err(e) => return Err(ShinkaiDBError::RocksDBError(e)), } } @@ -41,8 +41,8 @@ impl ShinkaiDB { let mut batch = rocksdb::WriteBatch::default(); // Get handles to the newly created column families - let cf_profiles_access = self.db.cf_handle(&cf_name_profiles_access).unwrap(); - let cf_toolkits_accessible = self.db.cf_handle(&cf_name_toolkits_accessible).unwrap(); + let cf_profiles_access = self.cf_handle(&cf_name_profiles_access)?; + let cf_toolkits_accessible = self.cf_handle(&cf_name_toolkits_accessible)?; // Write profiles_with_access and toolkits_accessible to respective columns for profile in &agent.allowed_message_senders { @@ -54,7 +54,7 @@ impl ShinkaiDB { // Serialize the agent to bytes and write it to the Agents topic let bytes = to_vec(&agent).unwrap(); - let cf_agents = self.db.cf_handle(Topic::Agents.as_str()).unwrap(); + let cf_agents = self.cf_handle(Topic::Agents.as_str())?; batch.put_cf(cf_agents, agent.id.as_bytes(), &bytes); // Write the batch @@ -69,7 +69,7 @@ impl ShinkaiDB { let cf_name_toolkits_accessible = format!("agent_{}_toolkits_accessible", agent_id); // Get cf handle for Agents topic - let cf_agents = self.db.cf_handle(Topic::Agents.as_str()).unwrap(); + let cf_agents = self.cf_handle(Topic::Agents.as_str())?; // Start write batch for atomic operation let mut batch = rocksdb::WriteBatch::default(); @@ -139,7 +139,7 @@ impl ShinkaiDB { pub fn get_agent(&self, agent_id: &str) -> Result, ShinkaiDBError> { // Get cf handle for Agents topic - let cf_agents = self.db.cf_handle(Topic::Agents.as_str()).unwrap(); + let cf_agents = self.cf_handle(Topic::Agents.as_str())?; // Fetch the agent's bytes by their id from the Agents topic let agent_bytes = self.db.get_cf(cf_agents, agent_id.as_bytes())?; @@ -236,12 +236,12 @@ impl ShinkaiDB { ))?; let all_agents = self.get_all_agents()?; let mut result = Vec::new(); - + for agent in all_agents { let cf_name = format!("agent_{}_profiles_with_access", agent.id); - let cf = self.db.cf_handle(&cf_name).unwrap(); + let cf = self.cf_handle(&cf_name)?; let profiles = self.get_column_family_data(cf)?; - + if profiles.contains(&profile) { result.push(agent); } else { @@ -251,7 +251,7 @@ impl ShinkaiDB { } } } - + Ok(result) } } diff --git a/src/db/db_identity.rs b/src/db/db_identity.rs index b84e0f399..087640049 100644 --- a/src/db/db_identity.rs +++ b/src/db/db_identity.rs @@ -17,8 +17,8 @@ use x25519_dalek::{PublicKey as EncryptionPublicKey, StaticSecret as EncryptionS impl ShinkaiDB { pub fn get_encryption_public_key(&self, identity_public_key: &str) -> Result { - let cf_identity = self.db.cf_handle(Topic::ProfilesIdentityKey.as_str()).unwrap(); - let cf_encryption = self.db.cf_handle(Topic::ProfilesEncryptionKey.as_str()).unwrap(); + let cf_identity = self.cf_handle(Topic::ProfilesIdentityKey.as_str())?; + let cf_encryption = self.cf_handle(Topic::ProfilesEncryptionKey.as_str())?; // Get the associated profile name for the identity public key let profile_name = match self.db.get_cf(cf_identity, identity_public_key)? { @@ -37,14 +37,14 @@ impl ShinkaiDB { let my_node_identity_name = my_node_identity.get_node_name(); println!("my_node_identity_name: {}", my_node_identity_name); - let cf_identity = self.db.cf_handle(Topic::ProfilesIdentityKey.as_str()).unwrap(); - let cf_encryption = self.db.cf_handle(Topic::ProfilesEncryptionKey.as_str()).unwrap(); - let cf_type = self.db.cf_handle(Topic::ProfilesIdentityType.as_str()).unwrap(); - let cf_permission = self.db.cf_handle(Topic::ProfilesPermission.as_str()).unwrap(); // Added this line + let cf_identity = self.cf_handle(Topic::ProfilesIdentityKey.as_str())?; + let cf_encryption = self.cf_handle(Topic::ProfilesEncryptionKey.as_str())?; + let cf_type = self.cf_handle(Topic::ProfilesIdentityType.as_str())?; + let cf_permission = self.cf_handle(Topic::ProfilesPermission.as_str())?; // Added this line // Handle node related information - let cf_node_encryption = self.db.cf_handle(Topic::ExternalNodeEncryptionKey.as_str()).unwrap(); - let cf_node_identity = self.db.cf_handle(Topic::ExternalNodeIdentityKey.as_str()).unwrap(); + let cf_node_encryption = self.cf_handle(Topic::ExternalNodeEncryptionKey.as_str())?; + let cf_node_identity = self.cf_handle(Topic::ExternalNodeIdentityKey.as_str())?; let node_encryption_public_key = match self.db.get_cf(cf_node_encryption, &my_node_identity_name)? { Some(value) => { @@ -143,11 +143,11 @@ impl ShinkaiDB { pub fn get_all_profiles_and_devices(&self, my_node_identity: ShinkaiName) -> Result, ShinkaiDBError> { let my_node_identity_name = my_node_identity.get_node_name(); - let cf_identity = self.db.cf_handle(Topic::ProfilesIdentityKey.as_str()).unwrap(); - let cf_encryption = self.db.cf_handle(Topic::ProfilesEncryptionKey.as_str()).unwrap(); - let cf_type = self.db.cf_handle(Topic::ProfilesIdentityType.as_str()).unwrap(); - let cf_permission = self.db.cf_handle(Topic::ProfilesPermission.as_str()).unwrap(); - let cf_device = self.db.cf_handle(Topic::DevicesIdentityKey.as_str()).unwrap(); + let cf_identity = self.cf_handle(Topic::ProfilesIdentityKey.as_str())?; + let cf_encryption = self.cf_handle(Topic::ProfilesEncryptionKey.as_str())?; + let cf_type = self.cf_handle(Topic::ProfilesIdentityType.as_str())?; + let cf_permission = self.cf_handle(Topic::ProfilesPermission.as_str())?; + let cf_device = self.cf_handle(Topic::DevicesIdentityKey.as_str())?; let (node_encryption_public_key, node_signature_public_key) = self.get_local_node_keys(my_node_identity.clone())?; @@ -213,10 +213,10 @@ impl ShinkaiDB { identity.full_identity_name.to_string(), ))?; - let cf_identity = self.db.cf_handle(Topic::ProfilesIdentityKey.as_str()).unwrap(); - let cf_encryption = self.db.cf_handle(Topic::ProfilesEncryptionKey.as_str()).unwrap(); - let cf_identity_type = self.db.cf_handle(Topic::ProfilesIdentityType.as_str()).unwrap(); - let cf_permission_type = self.db.cf_handle(Topic::ProfilesPermission.as_str()).unwrap(); + let cf_identity = self.cf_handle(Topic::ProfilesIdentityKey.as_str())?; + let cf_encryption = self.cf_handle(Topic::ProfilesEncryptionKey.as_str())?; + let cf_identity_type = self.cf_handle(Topic::ProfilesIdentityType.as_str())?; + let cf_permission_type = self.cf_handle(Topic::ProfilesPermission.as_str())?; // Check that the full identity name doesn't exist in the columns if self.db.get_cf(cf_identity, &profile_name)?.is_some() @@ -273,7 +273,7 @@ impl ShinkaiDB { .clone() .get_profile_name() .ok_or(ShinkaiDBError::InvalidIdentityName(profile_name.to_string()))?; - let cf_permission = self.db.cf_handle(Topic::ProfilesPermission.as_str()).unwrap(); + let cf_permission = self.cf_handle(Topic::ProfilesPermission.as_str())?; match self.db.get_cf(cf_permission, profile_name.clone())? { Some(value) => { let permission_str = std::str::from_utf8(&value).map_err(|_| { @@ -296,7 +296,7 @@ impl ShinkaiDB { let device_name = device_name.to_string(); // Get a handle to the devices' permissions column family - let cf_permission = self.db.cf_handle(Topic::DevicesPermissions.as_str()).unwrap(); + let cf_permission = self.cf_handle(Topic::DevicesPermissions.as_str())?; // Attempt to get the permission value for the device name match self.db.get_cf(cf_permission, device_name.clone())? { @@ -395,14 +395,14 @@ impl ShinkaiDB { }; // First, make sure that the profile the device is to be linked with exists - let cf_identity = self.db.cf_handle(Topic::ProfilesIdentityKey.as_str()).unwrap(); + let cf_identity = self.cf_handle(Topic::ProfilesIdentityKey.as_str())?; if self.db.get_cf(cf_identity, profile_name.clone())?.is_none() { return Err(ShinkaiDBError::ProfileNotFound(profile_name.to_string())); } // Get a handle to the device column family - let cf_device_identity = self.db.cf_handle(Topic::DevicesIdentityKey.as_str()).unwrap(); - let cf_device_encryption = self.db.cf_handle(Topic::DevicesEncryptionKey.as_str()).unwrap(); + let cf_device_identity = self.cf_handle(Topic::DevicesIdentityKey.as_str())?; + let cf_device_encryption = self.cf_handle(Topic::DevicesEncryptionKey.as_str())?; // Check that the full device identity name doesn't already exist in the column if self @@ -435,7 +435,7 @@ impl ShinkaiDB { ); // Handle for DevicePermissions column family - let cf_device_permissions = self.db.cf_handle(Topic::DevicesPermissions.as_str()).unwrap(); + let cf_device_permissions = self.cf_handle(Topic::DevicesPermissions.as_str())?; // Convert device.permission_type to a suitable format (e.g., string) for storage let permission_str = device.permission_type.to_string(); @@ -454,9 +454,9 @@ impl ShinkaiDB { } pub fn remove_profile(&self, name: &str) -> Result<(), ShinkaiDBError> { - let cf_identity = self.db.cf_handle(Topic::ProfilesIdentityKey.as_str()).unwrap(); - let cf_encryption = self.db.cf_handle(Topic::ProfilesEncryptionKey.as_str()).unwrap(); - let cf_permission = self.db.cf_handle(Topic::ProfilesIdentityType.as_str()).unwrap(); + let cf_identity = self.cf_handle(Topic::ProfilesIdentityKey.as_str())?; + let cf_encryption = self.cf_handle(Topic::ProfilesEncryptionKey.as_str())?; + let cf_permission = self.cf_handle(Topic::ProfilesIdentityType.as_str())?; // Check that the profile name exists in ProfilesIdentityKey, ProfilesEncryptionKey and ProfilesIdentityType if self.db.get_cf(cf_identity, name)?.is_none() @@ -616,8 +616,8 @@ impl ShinkaiDB { .get_profile_name() .ok_or(ShinkaiDBError::InvalidIdentityName(full_identity_name.to_string()))?; - let cf_encryption = self.db.cf_handle(Topic::ProfilesEncryptionKey.as_str()).unwrap(); - let cf_identity = self.db.cf_handle(Topic::ProfilesIdentityKey.as_str()).unwrap(); + let cf_encryption = self.cf_handle(Topic::ProfilesEncryptionKey.as_str())?; + let cf_identity = self.cf_handle(Topic::ProfilesIdentityKey.as_str())?; let profile_encryption_public_key_bytes = self .db @@ -656,7 +656,7 @@ impl ShinkaiDB { &self, full_identity_name: &str, ) -> Result { - let cf_encryption = self.db.cf_handle(Topic::ProfilesEncryptionKey.as_str()).unwrap(); + let cf_encryption = self.cf_handle(Topic::ProfilesEncryptionKey.as_str())?; match self.db.get_cf(cf_encryption, full_identity_name)? { Some(value) => { let key_string = String::from_utf8(value.to_vec()).map_err(|_| ShinkaiDBError::Utf8ConversionError)?; @@ -667,7 +667,7 @@ impl ShinkaiDB { } pub fn get_identity_type(&self, full_identity_name: &str) -> Result { - let cf_type = self.db.cf_handle(Topic::ProfilesIdentityType.as_str()).unwrap(); + let cf_type = self.cf_handle(Topic::ProfilesIdentityType.as_str())?; match self.db.get_cf(cf_type, full_identity_name)? { Some(value) => { let identity_type_str = String::from_utf8(value.to_vec()).unwrap(); @@ -681,7 +681,7 @@ impl ShinkaiDB { } pub fn get_permissions(&self, full_identity_name: &str) -> Result { - let cf_permission = self.db.cf_handle(Topic::ProfilesPermission.as_str()).unwrap(); + let cf_permission = self.cf_handle(Topic::ProfilesPermission.as_str())?; match self.db.get_cf(cf_permission, full_identity_name)? { Some(value) => { let permissions_str = String::from_utf8(value.to_vec()).unwrap(); diff --git a/src/db/db_identity_registration.rs b/src/db/db_identity_registration.rs index 85704a8b1..5b5830a70 100644 --- a/src/db/db_identity_registration.rs +++ b/src/db/db_identity_registration.rs @@ -109,13 +109,13 @@ impl ShinkaiDB { &self, permissions: IdentityPermissions, code_type: RegistrationCodeType, - ) -> Result { + ) -> Result { let mut rng = rand::thread_rng(); let mut random_bytes = [0u8; 64]; rng.fill_bytes(&mut random_bytes); let new_code = hex::encode(random_bytes); - let cf = self.db.cf_handle(Topic::OneTimeRegistrationCodes.as_str()).unwrap(); + let cf = self.cf_handle(Topic::OneTimeRegistrationCodes.as_str())?; let code_info = RegistrationCodeInfo { status: RegistrationCodeStatus::Unused, @@ -139,7 +139,7 @@ impl ShinkaiDB { device_encryption_public_key: Option<&str>, ) -> Result<(), ShinkaiDBError> { // Check if the code exists in Topic::OneTimeRegistrationCodes and its value is unused - let cf_codes = self.db.cf_handle(Topic::OneTimeRegistrationCodes.as_str()).unwrap(); + let cf_codes = self.cf_handle(Topic::OneTimeRegistrationCodes.as_str())?; let code_info: RegistrationCodeInfo = match self.db.get_cf(cf_codes, registration_code)? { Some(value) => RegistrationCodeInfo::from_slice(&value), None => return Err(ShinkaiDBError::CodeNonExistent), @@ -343,7 +343,7 @@ impl ShinkaiDB { } pub fn get_registration_code_info(&self, registration_code: &str) -> Result { - let cf_codes = self.db.cf_handle(Topic::OneTimeRegistrationCodes.as_str()).unwrap(); + let cf_codes = self.cf_handle(Topic::OneTimeRegistrationCodes.as_str())?; match self.db.get_cf(cf_codes, registration_code)? { Some(value) => Ok(RegistrationCodeInfo::from_slice(&value)), None => Err(ShinkaiDBError::CodeNonExistent), @@ -351,7 +351,7 @@ impl ShinkaiDB { } pub fn check_profile_existence(&self, profile_name: &str) -> Result<(), ShinkaiDBError> { - let cf_identity = self.db.cf_handle(Topic::ProfilesIdentityKey.as_str()).unwrap(); + let cf_identity = self.cf_handle(Topic::ProfilesIdentityKey.as_str())?; if self.db.get_cf(cf_identity, profile_name)?.is_none() { return Err(ShinkaiDBError::ProfileNotFound(profile_name.to_string())); @@ -368,8 +368,8 @@ impl ShinkaiDB { ) -> Result<(), ShinkaiDBError> { let node_name = my_node_identity_name.get_node_name().to_string(); - let cf_node_encryption = self.db.cf_handle(Topic::ExternalNodeEncryptionKey.as_str()).unwrap(); - let cf_node_identity = self.db.cf_handle(Topic::ExternalNodeIdentityKey.as_str()).unwrap(); + let cf_node_encryption = self.cf_handle(Topic::ExternalNodeEncryptionKey.as_str())?; + let cf_node_identity = self.cf_handle(Topic::ExternalNodeIdentityKey.as_str())?; let mut batch = rocksdb::WriteBatch::default(); @@ -391,8 +391,8 @@ impl ShinkaiDB { ) -> Result<(EncryptionPublicKey, SignaturePublicKey), ShinkaiDBError> { let node_name = my_node_identity_name.get_node_name().to_string(); - let cf_node_encryption = self.db.cf_handle(Topic::ExternalNodeEncryptionKey.as_str()).unwrap(); - let cf_node_identity = self.db.cf_handle(Topic::ExternalNodeIdentityKey.as_str()).unwrap(); + let cf_node_encryption = self.cf_handle(Topic::ExternalNodeEncryptionKey.as_str())?; + let cf_node_identity = self.cf_handle(Topic::ExternalNodeIdentityKey.as_str())?; // Get the encryption key let encryption_pk_string = self diff --git a/src/db/db_inbox.rs b/src/db/db_inbox.rs index 9959324e1..37a89cda0 100644 --- a/src/db/db_inbox.rs +++ b/src/db/db_inbox.rs @@ -146,7 +146,7 @@ impl ShinkaiDB { }; // Fetch the column family for all messages - let messages_cf = self.db.cf_handle(Topic::AllMessages.as_str()).unwrap(); + let messages_cf = self.cf_handle(Topic::AllMessages.as_str())?; // Create an iterator for the specified inbox let mut iter = match &until_offset_key { @@ -201,10 +201,10 @@ impl ShinkaiDB { ))) } }; - + // Create an iterator for the specified unread_list, starting from the beginning let iter = self.db.iterator_cf(unread_list_cf, rocksdb::IteratorMode::Start); - + // Iterate through the unread_list and delete all messages up to the specified offset for item in iter { // Handle the Result returned by the iterator @@ -214,7 +214,7 @@ impl ShinkaiDB { Ok(s) => s, Err(_) => return Err(ShinkaiDBError::SomeError("UTF-8 conversion error".to_string())), }; - + if key_str <= up_to_offset { // Delete the message from the unread_list self.db.delete_cf(unread_list_cf, key)?; @@ -226,7 +226,7 @@ impl ShinkaiDB { Err(e) => return Err(e.into()), } } - + Ok(()) } @@ -249,7 +249,7 @@ impl ShinkaiDB { }; // Fetch the column family for all messages - let messages_cf = self.db.cf_handle(Topic::AllMessages.as_str()).unwrap(); + let messages_cf = self.cf_handle(Topic::AllMessages.as_str())?; // Create an iterator for the specified unread_list let mut iter = match &from_offset_key { @@ -270,7 +270,7 @@ impl ShinkaiDB { } None => None, }; - + let mut messages = Vec::new(); let mut first_message = true; for item in iter.take(n) { diff --git a/src/db/db_jobs.rs b/src/db/db_jobs.rs index b69b172e3..9a011791a 100644 --- a/src/db/db_jobs.rs +++ b/src/db/db_jobs.rs @@ -84,9 +84,9 @@ impl ShinkaiDB { let current_time = ShinkaiTime::generate_time_now(); let scope_bytes = scope.to_bytes()?; - let cf_job_id = self.db.cf_handle(&cf_job_id_name).unwrap(); - let cf_agent_id = self.db.cf_handle(&cf_agent_id_name).unwrap(); - let cf_job_id_scope = self.db.cf_handle(&cf_job_id_scope_name).unwrap(); + let cf_job_id = self.cf_handle(&cf_job_id_name)?; + let cf_agent_id = self.cf_handle(&cf_agent_id_name)?; + let cf_job_id_scope = self.cf_handle(&cf_job_id_scope_name)?; batch.put_cf(cf_agent_id, current_time.as_bytes(), job_id.as_bytes()); batch.put_cf(cf_job_id_scope, job_id.as_bytes(), &scope_bytes); diff --git a/src/managers/identity_manager.rs b/src/managers/identity_manager.rs index 087116ce2..7e05fadd3 100644 --- a/src/managers/identity_manager.rs +++ b/src/managers/identity_manager.rs @@ -1,4 +1,5 @@ use super::identity_network_manager::IdentityNetworkManager; +use crate::db::db_errors::ShinkaiDBError; use crate::db::ShinkaiDB; use crate::network::node_error::NodeError; use crate::network::node_message_handlers::verify_message_signature; @@ -200,7 +201,7 @@ impl IdentityManager { self.local_identities.clone() } - pub async fn get_all_agents(&self) -> Result, rocksdb::Error> { + pub async fn get_all_agents(&self) -> Result, ShinkaiDBError> { let db = self.db.lock().await; db.get_all_agents() } From be3b0cfb13ffbfa8773e565f5d2ca6eacb6500af Mon Sep 17 00:00:00 2001 From: Robert Kornacki <11645932+robkorn@users.noreply.github.com> Date: Fri, 15 Sep 2023 20:55:31 +0200 Subject: [PATCH 23/26] Added execution context functionality to db struct --- src/agent/job.rs | 2 +- src/db/db_jobs.rs | 65 +++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 64 insertions(+), 3 deletions(-) diff --git a/src/agent/job.rs b/src/agent/job.rs index 6271c104c..37e7c376c 100644 --- a/src/agent/job.rs +++ b/src/agent/job.rs @@ -34,7 +34,7 @@ pub struct Job { /// An ordered list of the latest messages sent to the job which are yet to be processed pub unprocessed_messages: Vec, // /// A hashmap which holds a bunch of labeled values which were generated as output from the latest Job step - // pub execution_context: HashMap, + pub execution_context: HashMap, } impl JobLike for Job { diff --git a/src/db/db_jobs.rs b/src/db/db_jobs.rs index 9a011791a..745fdcf03 100644 --- a/src/db/db_jobs.rs +++ b/src/db/db_jobs.rs @@ -1,3 +1,6 @@ +use std::collections::HashMap; +use std::hash::Hash; + use super::{db::Topic, db_errors::ShinkaiDBError, ShinkaiDB}; use crate::agent::job::{Job, JobLike}; use rocksdb::{IteratorMode, Options, WriteBatch}; @@ -122,6 +125,13 @@ impl ShinkaiDB { // Save an empty hashmap for the initial execution context let cf_job_id_execution_context = self.cf_handle(&cf_job_id_execution_context_name)?; + let empty_hashmap: HashMap = HashMap::new(); + let hashmap_bytes = bincode::serialize(&empty_hashmap).unwrap(); + batch.put_cf( + cf_job_id_execution_context, + &cf_job_id_execution_context_name, + &hashmap_bytes, + ); self.db.write(batch)?; @@ -138,6 +148,7 @@ impl ShinkaiDB { conversation_inbox, step_history, unprocessed_messages, + execution_context, ) = self.get_job_data(job_id, true)?; // Construct the job @@ -150,6 +161,7 @@ impl ShinkaiDB { conversation_inbox_name: conversation_inbox, step_history: step_history.unwrap_or_else(Vec::new), unprocessed_messages, + execution_context, }; Ok(job) @@ -157,8 +169,16 @@ impl ShinkaiDB { /// Fetches a job from the DB as a Box pub fn get_job_like(&self, job_id: &str) -> Result, ShinkaiDBError> { - let (scope, is_finished, datetime_created, parent_agent_id, conversation_inbox, _, unprocessed_messages) = - self.get_job_data(job_id, false)?; + let ( + scope, + is_finished, + datetime_created, + parent_agent_id, + conversation_inbox, + _, + unprocessed_messages, + execution_context, + ) = self.get_job_data(job_id, false)?; // Construct the job let job = Job { @@ -170,6 +190,7 @@ impl ShinkaiDB { conversation_inbox_name: conversation_inbox, step_history: Vec::new(), // Empty step history for JobLike unprocessed_messages, + execution_context, }; Ok(Box::new(job)) @@ -189,6 +210,7 @@ impl ShinkaiDB { InboxName, Option>, Vec, + HashMap, ), ShinkaiDBError, > { @@ -267,6 +289,7 @@ impl ShinkaiDB { conversation_inbox.unwrap(), step_history, unprocessed_messages, + self.get_job_execution_context(job_id)?, )) } @@ -306,6 +329,44 @@ impl ShinkaiDB { Ok(jobs) } + /// Sets/updates the execution context for a Job in the DB + pub fn set_job_execution_context( + &self, + job_id: &str, + context: HashMap, + ) -> Result<(), ShinkaiDBError> { + let cf_job_id_execution_context_name = format!("{}_execution_context", job_id); + let cf_job_id_execution_context = self.cf_handle(&cf_job_id_execution_context_name)?; + + let context_bytes = bincode::serialize(&context).map_err(|_| { + ShinkaiDBError::SomeError("Failed converting execution context hashmap to bytes".to_string()) + })?; + self.put_cf( + cf_job_id_execution_context, + &cf_job_id_execution_context_name, + &context_bytes, + )?; + + Ok(()) + } + + /// Gets the execution context for a job + pub fn get_job_execution_context(&self, job_id: &str) -> Result, ShinkaiDBError> { + let cf_job_id_execution_context_name = format!("{}_execution_context", job_id); + let cf_job_id_execution_context = self.cf_handle(&cf_job_id_execution_context_name)?; + + let context_bytes = self + .db + .get_cf(cf_job_id_execution_context, &cf_job_id_execution_context_name)? + .ok_or(ShinkaiDBError::DataNotFound)?; + + let context = bincode::deserialize(&context_bytes).map_err(|_| { + ShinkaiDBError::SomeError("Failed converting execution context bytes to hashmap".to_string()) + })?; + + Ok(context) + } + /// Fetches all unprocessed messages for a specific Job from the DB fn get_unprocessed_messages(&self, job_id: &str) -> Result, ShinkaiDBError> { // Get the iterator From 2b05709c4e7ff68921c6146a86c68a7c188d96c7 Mon Sep 17 00:00:00 2001 From: Robert Kornacki <11645932+robkorn@users.noreply.github.com> Date: Fri, 15 Sep 2023 22:30:52 +0200 Subject: [PATCH 24/26] Reworked inferencing to be JsonValue based --- src/agent/agent.rs | 202 ++---------------- src/agent/error.rs | 14 ++ src/agent/mod.rs | 2 +- .../{execution_steps.rs => plan_executor.rs} | 2 +- src/agent/providers/mod.rs | 49 ++++- src/agent/providers/openai.rs | 42 +--- src/agent/providers/sleep_api.rs | 26 +-- src/managers/job_manager.rs | 17 +- tests/db_agents_tests.rs | 122 ++++++++++- 9 files changed, 212 insertions(+), 264 deletions(-) rename src/agent/{execution_steps.rs => plan_executor.rs} (98%) diff --git a/src/agent/agent.rs b/src/agent/agent.rs index b9180cc5d..46ed82af8 100644 --- a/src/agent/agent.rs +++ b/src/agent/agent.rs @@ -1,6 +1,7 @@ use super::error::AgentError; use super::providers::LLMProvider; use reqwest::Client; +use serde_json::{Map, Value as JsonValue}; use shinkai_message_primitives::{ schemas::{ agents::serialized_agent::{AgentAPIModel, SerializedAgent}, @@ -59,87 +60,51 @@ impl Agent { } } - pub async fn call_external_api( - &self, - content: &str, - context: Vec, - ) -> Result, AgentError> { + pub async fn call_external_api(&self, content: &str) -> Result { match &self.model { AgentAPIModel::OpenAI(openai) => { openai - .call_api( - &self.client, - self.external_url.as_ref(), - self.api_key.as_ref(), - content, - context, - ) + .call_api(&self.client, self.external_url.as_ref(), self.api_key.as_ref(), content) .await } AgentAPIModel::Sleep(sleep_api) => { sleep_api - .call_api( - &self.client, - self.external_url.as_ref(), - self.api_key.as_ref(), - content, - context, - ) + .call_api(&self.client, self.external_url.as_ref(), self.api_key.as_ref(), content) .await } } } - pub async fn process_locally(&self, content: String, context: Vec, job_id: String) { + /// TODO: Probably just throw this away, and move this logic into a LocalLLM struct that implements the Provider trait + pub async fn inference_locally(&self, content: String) -> Result { // Here we run our GPU-intensive task on a separate thread let handle = tokio::task::spawn_blocking(move || { - // perform GPU-intensive work - vec![JobPreMessage { - tool_calls: Vec::new(), // You might want to replace this with actual values - content: "Updated response!".to_string(), - recipient: JobRecipient::SelfNode, // This is a placeholder. You should replace this with the actual recipient. - }] + let mut map = Map::new(); + map.insert( + "answer".to_string(), + JsonValue::String("\n\nHello there, how may I assist you today?".to_string()), + ); + JsonValue::Object(map) }); - let result = handle.await; - match result { - Ok(response) => { - // create ShinkaiMessage based on result and send to AgentManager - let _ = self.job_manager_sender.send((response, job_id)).await; - } - Err(e) => eprintln!("Error in processing message: {:?}", e), + match handle.await { + Ok(response) => Ok(response), + Err(e) => Err(AgentError::FailedInferencingLocalLLM), } } - pub async fn inference(&self, content: String, context: Vec, job_id: String) { + /// Inferences the LLM model tied to the agent to get a response back. + /// Note, all `content` is expected to use prompts from the PromptGenerator, + /// meaning that they tell/force the LLM to always respond in JSON. We automatically + /// parse the JSON object out of the response into a JsonValue, or error if no object is found. + pub async fn inference(&self, content: String) -> Result { if self.perform_locally { // No need to spawn a new task here - self.process_locally(content.clone(), context.clone(), job_id).await; + return self.inference_locally(content.clone()).await; } else { // Call external API - let response = self.call_external_api(&content.clone(), context.clone()).await; - match response { - Ok(message) => { - // Send the message to AgentManager - println!( - "Sending message to AgentManager {:?} with context: {:?}", - message, context - ); - match self.job_manager_sender.send((message, job_id.clone())).await { - Ok(_) => println!("Message sent successfully"), - Err(e) => eprintln!("Error when sending message: {}", e), - } - } - Err(e) => eprintln!("Error when calling API: {}", e), - } + return self.call_external_api(&content.clone()).await; } - // TODO: For debugging - // // Check if the sender is still connected to the channel - // if self.job_manager_sender.is_closed() { - // eprintln!("Sender is closed"); - // } else { - // println!("Sender is still connected"); - // } } } @@ -162,126 +127,3 @@ impl Agent { ) } } - -#[cfg(test)] -mod tests { - use super::*; - use mockito::Server; - use shinkai_message_primitives::schemas::agents::serialized_agent::{OpenAI, SleepAPI}; - use tokio::sync::mpsc; - - #[tokio::test] - async fn test_agent_creation() { - let (tx, mut rx) = mpsc::channel(1); - let sleep_api = SleepAPI {}; - let agent = Agent::new( - "1".to_string(), - ShinkaiName::new("@@alice.shinkai/profileName/agent/myChatGPTAgent".to_string()).unwrap(), - tx, - false, - Some("http://localhost:8000".to_string()), - Some("paramparam".to_string()), - AgentAPIModel::Sleep(sleep_api), - vec!["tk1".to_string(), "tk2".to_string()], - vec!["sb1".to_string(), "sb2".to_string()], - vec!["allowed1".to_string(), "allowed2".to_string()], - ); - let context = vec![String::from("context1"), String::from("context2")]; - - assert_eq!(agent.id, "1"); - assert_eq!( - agent.full_identity_name, - ShinkaiName::new("@@alice.shinkai/profileName/agent/myChatGPTAgent".to_string()).unwrap() - ); - assert_eq!(agent.perform_locally, false); - assert_eq!(agent.external_url, Some("http://localhost:8000".to_string())); - assert_eq!(agent.toolkit_permissions, vec!["tk1".to_string(), "tk2".to_string()]); - assert_eq!( - agent.storage_bucket_permissions, - vec!["sb1".to_string(), "sb2".to_string()] - ); - assert_eq!( - agent.allowed_message_senders, - vec!["allowed1".to_string(), "allowed2".to_string()] - ); - - tokio::spawn(async move { - agent - .inference("Test".to_string(), context, "some_job_1".to_string()) - .await; - }); - - let val = tokio::time::timeout(std::time::Duration::from_millis(600), rx.recv()).await; - let expected_resp = JobPreMessage { - tool_calls: Vec::new(), - content: "OK".to_string(), - recipient: JobRecipient::SelfNode, - }; - - match val { - Ok(Some(response)) => assert_eq!(response.0.first().unwrap(), &expected_resp), - Ok(None) => panic!("Channel is empty"), - Err(_) => panic!("Timeout exceeded"), - } - } - - #[tokio::test] - async fn test_agent_call_external_api_openai() { - let mut server = Server::new(); - let _m = server - .mock("POST", "/v1/chat/completions") - .match_header("authorization", "Bearer mockapikey") - .with_status(200) - .with_header("content-type", "application/json") - .with_body( - r#"{ - "id": "chatcmpl-123", - "object": "chat.completion", - "created": 1677652288, - "choices": [{ - "index": 0, - "message": { - "role": "assistant", - "content": "\n\nHello there, how may I assist you today?" - }, - "finish_reason": "stop" - }], - "usage": { - "prompt_tokens": 9, - "completion_tokens": 12, - "total_tokens": 21 - } - }"#, - ) - .create(); - - let context = vec![String::from("context1"), String::from("context2")]; - let (tx, _rx) = mpsc::channel(1); - let openai = OpenAI { - model_type: "gpt-3.5-turbo".to_string(), - }; - let agent = Agent::new( - "1".to_string(), - ShinkaiName::new("@@alice.shinkai/profileName/agent/myChatGPTAgent".to_string()).unwrap(), - tx, - false, - Some(server.url()), // use the url of the mock server - Some("mockapikey".to_string()), - AgentAPIModel::OpenAI(openai), - vec!["tk1".to_string(), "tk2".to_string()], - vec!["sb1".to_string(), "sb2".to_string()], - vec!["allowed1".to_string(), "allowed2".to_string()], - ); - - let response = agent.call_external_api("Hello!", context).await; - let expected_resp = JobPreMessage { - tool_calls: Vec::new(), - content: "\n\nHello there, how may I assist you today?".to_string(), - recipient: JobRecipient::SelfNode, - }; - match response { - Ok(res) => assert_eq!(res.first().unwrap(), &expected_resp), - Err(e) => panic!("Error when calling API: {}", e), - } - } -} diff --git a/src/agent/error.rs b/src/agent/error.rs index 093b2f863..b8f570588 100644 --- a/src/agent/error.rs +++ b/src/agent/error.rs @@ -5,6 +5,8 @@ pub enum AgentError { ApiKeyNotSet, ReqwestError(reqwest::Error), MissingInitialStepInExecutionPlan, + FailedExtractingJSONObjectFromResponse(String), + FailedInferencingLocalLLM, } impl fmt::Display for AgentError { @@ -16,7 +18,13 @@ impl fmt::Display for AgentError { f, "The provided execution plan does not have an InitialExecutionStep as its first element." ), + AgentError::FailedExtractingJSONObjectFromResponse(s) => { + write!(f, "Could not find JSON Object in the LLM's response: {}", s) + } AgentError::ReqwestError(err) => write!(f, "Reqwest error: {}", err), + AgentError::FailedInferencingLocalLLM => { + write!(f, "Failed inferencing and getting a valid response from the local LLM") + } } } } @@ -30,6 +38,12 @@ impl fmt::Debug for AgentError { AgentError::MissingInitialStepInExecutionPlan => { f.debug_tuple("MissingInitialStepInExecutionPlan").finish() } + AgentError::FailedExtractingJSONObjectFromResponse(err) => f + .debug_tuple("FailedExtractingJSONObjectFromResponse") + .field(err) + .finish(), + + AgentError::FailedInferencingLocalLLM => f.debug_tuple("FailedInferencingLocalLLM").finish(), } } } diff --git a/src/agent/mod.rs b/src/agent/mod.rs index 396fb8ba2..8108c4496 100644 --- a/src/agent/mod.rs +++ b/src/agent/mod.rs @@ -1,7 +1,7 @@ pub mod agent; pub mod agent_to_serialization; pub mod error; -pub mod execution_steps; pub mod job; pub mod job_prompts; +pub mod plan_executor; pub mod providers; diff --git a/src/agent/execution_steps.rs b/src/agent/plan_executor.rs similarity index 98% rename from src/agent/execution_steps.rs rename to src/agent/plan_executor.rs index 493beb8e9..94f163d24 100644 --- a/src/agent/execution_steps.rs +++ b/src/agent/plan_executor.rs @@ -43,7 +43,7 @@ impl<'a> PlanExecutor<'a> { // TODO: Properly implement this once we have jobs update for context + agent infernece/use tool /// Executes the plan step-by-step, performing all inferencing & tool calls. /// All content sent for inferencing and all responses from the LLM are saved in self.inference_trace - pub async fn execute(&mut self) -> Result<(), AgentError> { + pub async fn execute_plan(&mut self) -> Result<(), AgentError> { for step in &self.execution_plan { match step { ExecutionStep::Inference(inference_step) => { diff --git a/src/agent/providers/mod.rs b/src/agent/providers/mod.rs index 34e84de03..d47f3c9e0 100644 --- a/src/agent/providers/mod.rs +++ b/src/agent/providers/mod.rs @@ -1,23 +1,50 @@ -use reqwest::Client; -use shinkai_message_primitives::shinkai_message::shinkai_message_schemas::JobPreMessage; - use super::error::AgentError; use async_trait::async_trait; +use reqwest::Client; +use serde_json::Value as JsonValue; + +pub mod openai; +pub mod sleep_api; #[async_trait] pub trait LLMProvider { - type Response; - fn parse_response(response_body: &str) -> Result>; - fn extract_content(response: &Self::Response) -> Vec; + // type Response; + // fn parse_response(response_body: &str) -> Result; + // fn extract_content(response: &Self::Response) -> Result; async fn call_api( &self, client: &Client, url: Option<&String>, api_key: Option<&String>, content: &str, - step_history: Vec, - ) -> Result, AgentError>; -} + ) -> Result; -pub mod openai; -pub mod sleep_api; + /// Given an input string, parses the first JSON object that it finds + fn extract_first_json_object(s: &str) -> Result { + let mut depth = 0; + let mut start = None; + + for (i, c) in s.char_indices() { + match c { + '{' => { + if depth == 0 { + start = Some(i); + } + depth += 1; + } + '}' => { + depth -= 1; + if depth == 0 { + let json_str = &s[start.unwrap()..=i]; + let json_val: JsonValue = serde_json::from_str(json_str) + .map_err(|_| AgentError::FailedExtractingJSONObjectFromResponse(s.to_string()))?; + return Ok(json_val); + } + } + _ => {} + } + } + + Err(AgentError::FailedExtractingJSONObjectFromResponse(s.to_string())) + } +} diff --git a/src/agent/providers/openai.rs b/src/agent/providers/openai.rs index 3b17983a3..83e016537 100644 --- a/src/agent/providers/openai.rs +++ b/src/agent/providers/openai.rs @@ -4,11 +4,8 @@ use async_trait::async_trait; use reqwest::Client; use serde::{Deserialize, Serialize}; use serde_json; -use shinkai_message_primitives::{ - schemas::agents::serialized_agent::OpenAI, - shinkai_message::shinkai_message_schemas::{JobPreMessage, JobRecipient}, -}; -use std::error::Error; +use serde_json::Value as JsonValue; +use shinkai_message_primitives::schemas::agents::serialized_agent::OpenAI; #[derive(Debug, Deserialize)] pub struct Response { @@ -41,38 +38,13 @@ struct Usage { #[async_trait] impl LLMProvider for OpenAI { - type Response = Response; - - fn parse_response(response_body: &str) -> Result> { - let res: Result = serde_json::from_str(response_body); - match res { - Ok(response) => Ok(response), - Err(e) => Err(Box::new(e)), - } - } - - fn extract_content(response: &Self::Response) -> Vec { - response - .choices - .iter() - .map(|choice| { - JobPreMessage { - tool_calls: Vec::new(), // TODO: You might want to replace this with actual values - content: choice.message.content.clone(), - recipient: JobRecipient::SelfNode, // TODO: This is a placeholder. You should replace this with the actual recipient. - } - }) - .collect() - } - async fn call_api( &self, client: &Client, url: Option<&String>, api_key: Option<&String>, content: &str, - step_history: Vec, - ) -> Result, AgentError> { + ) -> Result { if let Some(base_url) = url { if let Some(key) = api_key { let url = format!("{}{}", base_url, "/v1/chat/completions"); @@ -99,7 +71,13 @@ impl LLMProvider for OpenAI { eprintln!("Status: {}", res.status()); let data: Response = res.json().await.map_err(AgentError::ReqwestError)?; - Ok(Self::extract_content(&data)) + let response_string: String = data + .choices + .iter() + .map(|choice| choice.message.content.clone()) + .collect::>() + .join(" "); + Self::extract_first_json_object(&response_string) } else { Err(AgentError::ApiKeyNotSet) } diff --git a/src/agent/providers/sleep_api.rs b/src/agent/providers/sleep_api.rs index c83344e05..ff9aa0f72 100644 --- a/src/agent/providers/sleep_api.rs +++ b/src/agent/providers/sleep_api.rs @@ -2,38 +2,20 @@ use super::AgentError; use super::LLMProvider; use async_trait::async_trait; use reqwest::Client; -use serde::{Deserialize, Serialize}; -use shinkai_message_primitives::{ - schemas::agents::serialized_agent::SleepAPI, - shinkai_message::shinkai_message_schemas::{JobPreMessage, JobRecipient}, -}; +use serde_json::Value as JsonValue; +use shinkai_message_primitives::schemas::agents::serialized_agent::SleepAPI; use tokio::time::Duration; #[async_trait] impl LLMProvider for SleepAPI { - type Response = (); // Empty tuple as a stand-in for no data - - fn parse_response(_: &str) -> Result> { - Ok(()) - } - - fn extract_content(_: &Self::Response) -> Vec { - vec![JobPreMessage { - tool_calls: Vec::new(), - content: "OK".to_string(), - recipient: JobRecipient::SelfNode, - }] - } - async fn call_api( &self, _: &Client, _: Option<&String>, _: Option<&String>, _: &str, - _: Vec, - ) -> Result, AgentError> { + ) -> Result { tokio::time::sleep(Duration::from_millis(500)).await; - Ok(Self::extract_content(&())) + Ok(JsonValue::Bool(true)) } } diff --git a/src/managers/job_manager.rs b/src/managers/job_manager.rs index 6c8850d65..a205fbb16 100644 --- a/src/managers/job_manager.rs +++ b/src/managers/job_manager.rs @@ -2,6 +2,7 @@ use super::error::JobManagerError; use super::IdentityManager; use crate::agent::agent::Agent; use crate::agent::job::{Job, JobId, JobLike}; +use crate::agent::plan_executor::PlanExecutor; use crate::db::{db_errors::ShinkaiDBError, ShinkaiDB}; use chrono::Utc; use ed25519_dalek::SecretKey as SignatureStaticKey; @@ -244,6 +245,11 @@ impl AgentManager { std::mem::drop(shinkai_db); // require to avoid deadlock let _ = self.decision_phase(&**job).await?; + + // After analysis phase, we execute the resulting execution plan + // let executor = PlanExecutor::new(agent, execution_plan)?; + // executor.execute_plan(); + return Ok(job_message.job_id.clone()); } else { return Err(JobManagerError::JobNotFound); @@ -358,7 +364,7 @@ impl AgentManager { // Execute LLM inferencing let response = tokio::spawn(async move { let mut agent = agent.lock().await; - agent.inference(last_message, context, job.job_id().to_string()).await; + agent.inference(last_message).await; }) .await?; println!("decision_iteration> response: {:?}", response); @@ -369,15 +375,6 @@ impl AgentManager { // self.decision_iteration(job, context, last_message, agent).await?; // } - // The expected output from the LLM is one or more `Premessage`s (a message that potentially - // still has computation that needs to be performed via tools to fill out its contents). - // If the output from the LLM does not fit the expected structure, then the LLM is queried again - // with the exact same inputs until a valid output is provided (potentially supplying extra text - // each time to the LLM clarifying the previous result was invalid with an example/error message). - - // Make sure the output is valid - // If not valid, keep calling the LLM until a valid output is produced - // Return the output Ok(()) } diff --git a/tests/db_agents_tests.rs b/tests/db_agents_tests.rs index d00ef6740..8b885c6be 100644 --- a/tests/db_agents_tests.rs +++ b/tests/db_agents_tests.rs @@ -1,8 +1,10 @@ -use shinkai_node::{ - db::{db_errors::ShinkaiDBError, ShinkaiDB}, -}; +use mockito::Server; +use serde_json::Value as JsonValue; +use shinkai_message_primitives::schemas::agents::serialized_agent::{OpenAI, SleepAPI}; +use shinkai_node::db::{db_errors::ShinkaiDBError, ShinkaiDB}; use std::fs; use std::path::Path; +use tokio::sync::mpsc; fn setup() { let path = Path::new("db_tests/"); @@ -11,7 +13,14 @@ fn setup() { #[cfg(test)] mod tests { - use shinkai_message_primitives::{shinkai_utils::utils::hash_string, schemas::{shinkai_name::ShinkaiName, agents::serialized_agent::{OpenAI, SerializedAgent, AgentAPIModel}}}; + use shinkai_message_primitives::{ + schemas::{ + agents::serialized_agent::{AgentAPIModel, OpenAI, SerializedAgent}, + shinkai_name::ShinkaiName, + }, + shinkai_utils::utils::hash_string, + }; + use shinkai_node::agent::{agent::Agent, error::AgentError}; use super::*; @@ -28,7 +37,8 @@ mod tests { // Create an instance of SerializedAgent let test_agent = SerializedAgent { id: "test_agent".to_string(), - full_identity_name: ShinkaiName::new("@@alice.shinkai/profileName/agent/myChatGPTAgent".to_string()).unwrap(), + full_identity_name: ShinkaiName::new("@@alice.shinkai/profileName/agent/myChatGPTAgent".to_string()) + .unwrap(), perform_locally: false, external_url: Some("http://localhost:8080".to_string()), api_key: Some("test_api_key".to_string()), @@ -73,7 +83,8 @@ mod tests { // Create an instance of SerializedAgent let test_agent = SerializedAgent { id: "test_agent".to_string(), - full_identity_name: ShinkaiName::new("@@alice.shinkai/profileName/agent/myChatGPTAgent".to_string()).unwrap(), + full_identity_name: ShinkaiName::new("@@alice.shinkai/profileName/agent/myChatGPTAgent".to_string()) + .unwrap(), perform_locally: false, external_url: Some("http://localhost:8080".to_string()), api_key: Some("test_api_key".to_string()), @@ -152,7 +163,8 @@ mod tests { let test_agent = SerializedAgent { id: "test_agent".to_string(), - full_identity_name: ShinkaiName::new("@@alice.shinkai/profileName/agent/myChatGPTAgent".to_string()).unwrap(), + full_identity_name: ShinkaiName::new("@@alice.shinkai/profileName/agent/myChatGPTAgent".to_string()) + .unwrap(), perform_locally: false, external_url: Some("http://localhost:8080".to_string()), api_key: Some("test_api_key".to_string()), @@ -177,4 +189,100 @@ mod tests { let toolkits = db.get_agent_toolkits_accessible(&test_agent.id).unwrap(); assert_eq!(vec!["toolkit2"], toolkits); } + + #[tokio::test] + async fn test_agent_creation() { + let (tx, mut rx) = mpsc::channel(1); + let sleep_api = SleepAPI {}; + let agent = Agent::new( + "1".to_string(), + ShinkaiName::new("@@alice.shinkai/profileName/agent/myChatGPTAgent".to_string()).unwrap(), + tx, + false, + Some("http://localhost:8000".to_string()), + Some("paramparam".to_string()), + AgentAPIModel::Sleep(sleep_api), + vec!["tk1".to_string(), "tk2".to_string()], + vec!["sb1".to_string(), "sb2".to_string()], + vec!["allowed1".to_string(), "allowed2".to_string()], + ); + + assert_eq!(agent.id, "1"); + assert_eq!( + agent.full_identity_name, + ShinkaiName::new("@@alice.shinkai/profileName/agent/myChatGPTAgent".to_string()).unwrap() + ); + assert_eq!(agent.perform_locally, false); + assert_eq!(agent.external_url, Some("http://localhost:8000".to_string())); + assert_eq!(agent.toolkit_permissions, vec!["tk1".to_string(), "tk2".to_string()]); + assert_eq!( + agent.storage_bucket_permissions, + vec!["sb1".to_string(), "sb2".to_string()] + ); + assert_eq!( + agent.allowed_message_senders, + vec!["allowed1".to_string(), "allowed2".to_string()] + ); + + let handle = tokio::spawn(async move { agent.inference("Test".to_string()).await }); + let result: Result = handle.await.unwrap(); + assert_eq!(result.unwrap(), JsonValue::Bool(true)) + } + + #[tokio::test] + async fn test_agent_call_external_api_openai() { + let mut server = Server::new(); + let _m = server + .mock("POST", "/v1/chat/completions") + .match_header("authorization", "Bearer mockapikey") + .with_status(200) + .with_header("content-type", "application/json") + .with_body( + r#"{ + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1677652288, + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "{ \"answer\": \"\\n\\nHello there, how may I assist you today?\" }" + }, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 9, + "completion_tokens": 12, + "total_tokens": 21 + } + }"#, + ) + .create(); + + let (tx, _rx) = mpsc::channel(1); + let openai = OpenAI { + model_type: "gpt-3.5-turbo".to_string(), + }; + let agent = Agent::new( + "1".to_string(), + ShinkaiName::new("@@alice.shinkai/profileName/agent/myChatGPTAgent".to_string()).unwrap(), + tx, + false, + Some(server.url()), // use the url of the mock server + Some("mockapikey".to_string()), + AgentAPIModel::OpenAI(openai), + vec!["tk1".to_string(), "tk2".to_string()], + vec!["sb1".to_string(), "sb2".to_string()], + vec!["allowed1".to_string(), "allowed2".to_string()], + ); + + let response = agent.inference("Hello!".to_string()).await; + match response { + Ok(res) => assert_eq!( + res["answer"].as_str().unwrap(), + "\n\nHello there, how may I assist you today?".to_string() + ), + Err(e) => panic!("Error when calling API: {}", e), + } + } } From ce95cf57698203768070ccf21e0cdff6ff1b95b4 Mon Sep 17 00:00:00 2001 From: Robert Kornacki <11645932+robkorn@users.noreply.github.com> Date: Fri, 15 Sep 2023 22:51:19 +0200 Subject: [PATCH 25/26] Simplified inferencing flow/surrounding code --- .../src/schemas/agents/serialized_agent.rs | 21 ++++--- .../serialized_agent_wrapper.rs | 7 ++- .../serialized_agent_conversion_tests.rs | 36 +++++++++--- src/agent/agent.rs | 56 ++++++++++--------- src/agent/agent_to_serialization.rs | 6 +- tests/agent_integration_tests.rs | 4 +- tests/db_agents_tests.rs | 16 +++--- tests/job_manager_tests.rs | 4 +- tests/node_retrying_tests.rs | 2 +- 9 files changed, 90 insertions(+), 62 deletions(-) diff --git a/shinkai-libs/shinkai-message-primitives/src/schemas/agents/serialized_agent.rs b/shinkai-libs/shinkai-message-primitives/src/schemas/agents/serialized_agent.rs index 1807c2f19..cbb502a79 100644 --- a/shinkai-libs/shinkai-message-primitives/src/schemas/agents/serialized_agent.rs +++ b/shinkai-libs/shinkai-message-primitives/src/schemas/agents/serialized_agent.rs @@ -1,5 +1,6 @@ use crate::schemas::shinkai_name::ShinkaiName; use serde::{Deserialize, Serialize}; +use std::str::FromStr; // Agent has a few fields that are not serializable, so we need to create a struct that is serializable #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] @@ -9,7 +10,7 @@ pub struct SerializedAgent { pub perform_locally: bool, pub external_url: Option, pub api_key: Option, - pub model: AgentAPIModel, + pub model: AgentLLMInterface, pub toolkit_permissions: Vec, pub storage_bucket_permissions: Vec, pub allowed_message_senders: Vec, @@ -17,31 +18,35 @@ pub struct SerializedAgent { #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(rename_all = "lowercase")] -pub enum AgentAPIModel { +pub enum AgentLLMInterface { #[serde(rename = "openai")] OpenAI(OpenAI), #[serde(rename = "sleep")] Sleep(SleepAPI), + #[serde(rename = "local-llm")] + LocalLLM(LocalLLM), } +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +pub struct LocalLLM {} + #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] pub struct OpenAI { pub model_type: String, } -use std::str::FromStr; - -impl FromStr for AgentAPIModel { +impl FromStr for AgentLLMInterface { type Err = (); fn from_str(s: &str) -> Result { if s.starts_with("openai:") { let model_type = s.strip_prefix("openai:").unwrap_or("").to_string(); - Ok(AgentAPIModel::OpenAI(OpenAI { model_type })) + Ok(AgentLLMInterface::OpenAI(OpenAI { model_type })) } else { - Ok(AgentAPIModel::Sleep(SleepAPI {})) + Ok(AgentLLMInterface::Sleep(SleepAPI {})) } } } + #[derive(Clone, Debug, Deserialize, Serialize, PartialEq)] -pub struct SleepAPI {} \ No newline at end of file +pub struct SleepAPI {} diff --git a/shinkai-libs/shinkai-message-wasm/src/shinkai_wasm_wrappers/serialized_agent_wrapper.rs b/shinkai-libs/shinkai-message-wasm/src/shinkai_wasm_wrappers/serialized_agent_wrapper.rs index 5a3dcf856..883322d8c 100644 --- a/shinkai-libs/shinkai-message-wasm/src/shinkai_wasm_wrappers/serialized_agent_wrapper.rs +++ b/shinkai-libs/shinkai-message-wasm/src/shinkai_wasm_wrappers/serialized_agent_wrapper.rs @@ -1,6 +1,9 @@ use serde::{Deserialize, Serialize}; use serde_wasm_bindgen::{from_value, to_value}; -use shinkai_message_primitives::schemas::{agents::serialized_agent::{SerializedAgent, AgentAPIModel}, shinkai_name::ShinkaiName}; +use shinkai_message_primitives::schemas::{ + agents::serialized_agent::{AgentLLMInterface, SerializedAgent}, + shinkai_name::ShinkaiName, +}; use wasm_bindgen::prelude::*; pub trait SerializedAgentJsValueConversion { @@ -66,7 +69,7 @@ impl SerializedAgentJsValueConversion for SerializedAgent { }; let api_key = if api_key.is_empty() { None } else { Some(api_key) }; let model = model - .parse::() + .parse::() .map_err(|_| JsValue::from_str("Invalid model"))?; let toolkit_permissions = if toolkit_permissions.is_empty() { Vec::new() diff --git a/shinkai-libs/shinkai-message-wasm/tests/serialized_agent_conversion_tests.rs b/shinkai-libs/shinkai-message-wasm/tests/serialized_agent_conversion_tests.rs index 32cd5caf5..b08306932 100644 --- a/shinkai-libs/shinkai-message-wasm/tests/serialized_agent_conversion_tests.rs +++ b/shinkai-libs/shinkai-message-wasm/tests/serialized_agent_conversion_tests.rs @@ -3,10 +3,10 @@ use wasm_bindgen_test::*; #[cfg(test)] mod tests { use super::*; - use shinkai_message_primitives::schemas::agents::serialized_agent::{SerializedAgent, AgentAPIModel, OpenAI}; + use serde_wasm_bindgen::from_value; + use shinkai_message_primitives::schemas::agents::serialized_agent::{AgentLLMInterface, OpenAI, SerializedAgent}; use shinkai_message_wasm::shinkai_wasm_wrappers::serialized_agent_wrapper::SerializedAgentWrapper; use wasm_bindgen::JsValue; - use serde_wasm_bindgen::from_value; #[cfg(target_arch = "wasm32")] #[wasm_bindgen_test] @@ -22,7 +22,8 @@ mod tests { "permission1,permission2".to_string(), "bucket1,bucket2".to_string(), "sender1,sender2".to_string(), - ).unwrap(); + ) + .unwrap(); // Get the inner SerializedAgent let agent_jsvalue = serialized_agent_wrapper.inner().unwrap(); @@ -30,13 +31,30 @@ mod tests { // Check that the fields are correctly converted assert_eq!(agent.id, "test_agent"); - assert_eq!(agent.full_identity_name.to_string(), "@@node.shinkai/main/agent/test_agent"); + assert_eq!( + agent.full_identity_name.to_string(), + "@@node.shinkai/main/agent/test_agent" + ); assert_eq!(agent.perform_locally, false); assert_eq!(agent.external_url, Some("http://example.com".to_string())); assert_eq!(agent.api_key, Some("123456".to_string())); - assert_eq!(agent.model, AgentAPIModel::OpenAI(OpenAI { model_type: "chatgpt3-turbo".to_string() })); - assert_eq!(agent.toolkit_permissions, vec!["permission1".to_string(), "permission2".to_string()]); - assert_eq!(agent.storage_bucket_permissions, vec!["bucket1".to_string(), "bucket2".to_string()]); - assert_eq!(agent.allowed_message_senders, vec!["sender1".to_string(), "sender2".to_string()]); + assert_eq!( + agent.model, + AgentLLMInterface::OpenAI(OpenAI { + model_type: "chatgpt3-turbo".to_string() + }) + ); + assert_eq!( + agent.toolkit_permissions, + vec!["permission1".to_string(), "permission2".to_string()] + ); + assert_eq!( + agent.storage_bucket_permissions, + vec!["bucket1".to_string(), "bucket2".to_string()] + ); + assert_eq!( + agent.allowed_message_senders, + vec!["sender1".to_string(), "sender2".to_string()] + ); } -} \ No newline at end of file +} diff --git a/src/agent/agent.rs b/src/agent/agent.rs index 46ed82af8..a9f6bea24 100644 --- a/src/agent/agent.rs +++ b/src/agent/agent.rs @@ -4,7 +4,7 @@ use reqwest::Client; use serde_json::{Map, Value as JsonValue}; use shinkai_message_primitives::{ schemas::{ - agents::serialized_agent::{AgentAPIModel, SerializedAgent}, + agents::serialized_agent::{AgentLLMInterface, SerializedAgent}, shinkai_name::ShinkaiName, }, shinkai_message::shinkai_message_schemas::{JobPreMessage, JobRecipient}, @@ -19,10 +19,10 @@ pub struct Agent { pub job_manager_sender: mpsc::Sender<(Vec, String)>, pub agent_receiver: Arc>>, pub client: Client, - pub perform_locally: bool, // flag to perform computation locally or not + pub perform_locally: bool, // Todo: Remove as not used anymore pub external_url: Option, // external API URL pub api_key: Option, - pub model: AgentAPIModel, + pub model: AgentLLMInterface, pub toolkit_permissions: Vec, // list of toolkits the agent has access to pub storage_bucket_permissions: Vec, // list of storage buckets the agent has access to pub allowed_message_senders: Vec, // list of sub-identities allowed to message the agent @@ -36,7 +36,7 @@ impl Agent { perform_locally: bool, external_url: Option, api_key: Option, - model: AgentAPIModel, + model: AgentLLMInterface, toolkit_permissions: Vec, storage_bucket_permissions: Vec, allowed_message_senders: Vec, @@ -60,23 +60,9 @@ impl Agent { } } - pub async fn call_external_api(&self, content: &str) -> Result { - match &self.model { - AgentAPIModel::OpenAI(openai) => { - openai - .call_api(&self.client, self.external_url.as_ref(), self.api_key.as_ref(), content) - .await - } - AgentAPIModel::Sleep(sleep_api) => { - sleep_api - .call_api(&self.client, self.external_url.as_ref(), self.api_key.as_ref(), content) - .await - } - } - } - - /// TODO: Probably just throw this away, and move this logic into a LocalLLM struct that implements the Provider trait - pub async fn inference_locally(&self, content: String) -> Result { + /// Inferences an LLM locally based on info held in the Agent + /// TODO: For now just mocked, eventually get around to this, and create a struct that implements the Provider trait to unify local with remote interface. + async fn inference_locally(&self, content: String) -> Result { // Here we run our GPU-intensive task on a separate thread let handle = tokio::task::spawn_blocking(move || { let mut map = Map::new(); @@ -98,12 +84,28 @@ impl Agent { /// meaning that they tell/force the LLM to always respond in JSON. We automatically /// parse the JSON object out of the response into a JsonValue, or error if no object is found. pub async fn inference(&self, content: String) -> Result { - if self.perform_locally { - // No need to spawn a new task here - return self.inference_locally(content.clone()).await; - } else { - // Call external API - return self.call_external_api(&content.clone()).await; + match &self.model { + AgentLLMInterface::OpenAI(openai) => { + openai + .call_api( + &self.client, + self.external_url.as_ref(), + self.api_key.as_ref(), + &content, + ) + .await + } + AgentLLMInterface::Sleep(sleep_api) => { + sleep_api + .call_api( + &self.client, + self.external_url.as_ref(), + self.api_key.as_ref(), + &content, + ) + .await + } + AgentLLMInterface::LocalLLM(local_llm) => self.inference_locally(content.to_string()).await, } } } diff --git a/src/agent/agent_to_serialization.rs b/src/agent/agent_to_serialization.rs index 25cd9dc77..3db3176dd 100644 --- a/src/agent/agent_to_serialization.rs +++ b/src/agent/agent_to_serialization.rs @@ -1,7 +1,7 @@ -use serde::{Serialize, Deserialize}; -use shinkai_message_primitives::schemas::{shinkai_name::ShinkaiName, agents::serialized_agent::SerializedAgent}; +use serde::{Deserialize, Serialize}; +use shinkai_message_primitives::schemas::{agents::serialized_agent::SerializedAgent, shinkai_name::ShinkaiName}; -use super::agent::{Agent}; +use super::agent::Agent; impl From for SerializedAgent { fn from(agent: Agent) -> Self { diff --git a/tests/agent_integration_tests.rs b/tests/agent_integration_tests.rs index 106c05359..184791d8c 100644 --- a/tests/agent_integration_tests.rs +++ b/tests/agent_integration_tests.rs @@ -1,5 +1,5 @@ use async_channel::{bounded, Receiver, Sender}; -use shinkai_message_primitives::schemas::agents::serialized_agent::{AgentAPIModel, OpenAI, SerializedAgent}; +use shinkai_message_primitives::schemas::agents::serialized_agent::{AgentLLMInterface, OpenAI, SerializedAgent}; use shinkai_message_primitives::schemas::inbox_name::InboxName; use shinkai_message_primitives::schemas::shinkai_name::ShinkaiName; use shinkai_message_primitives::schemas::shinkai_time::ShinkaiTime; @@ -151,7 +151,7 @@ fn node_agent_registration() { perform_locally: false, external_url: Some(server.url()), api_key: Some("mockapikey".to_string()), - model: AgentAPIModel::OpenAI(open_ai), + model: AgentLLMInterface::OpenAI(open_ai), toolkit_permissions: vec![], storage_bucket_permissions: vec![], allowed_message_senders: vec![], diff --git a/tests/db_agents_tests.rs b/tests/db_agents_tests.rs index 8b885c6be..ad14ea2d1 100644 --- a/tests/db_agents_tests.rs +++ b/tests/db_agents_tests.rs @@ -15,7 +15,7 @@ fn setup() { mod tests { use shinkai_message_primitives::{ schemas::{ - agents::serialized_agent::{AgentAPIModel, OpenAI, SerializedAgent}, + agents::serialized_agent::{AgentLLMInterface, OpenAI, SerializedAgent}, shinkai_name::ShinkaiName, }, shinkai_utils::utils::hash_string, @@ -42,7 +42,7 @@ mod tests { perform_locally: false, external_url: Some("http://localhost:8080".to_string()), api_key: Some("test_api_key".to_string()), - model: AgentAPIModel::OpenAI(open_ai), + model: AgentLLMInterface::OpenAI(open_ai), toolkit_permissions: vec!["toolkit1".to_string(), "toolkit2".to_string()], storage_bucket_permissions: vec!["storage1".to_string(), "storage2".to_string()], allowed_message_senders: vec!["sender1".to_string(), "sender2".to_string()], @@ -88,7 +88,7 @@ mod tests { perform_locally: false, external_url: Some("http://localhost:8080".to_string()), api_key: Some("test_api_key".to_string()), - model: AgentAPIModel::OpenAI(open_ai), + model: AgentLLMInterface::OpenAI(open_ai), toolkit_permissions: vec!["toolkit1".to_string(), "toolkit2".to_string()], storage_bucket_permissions: vec!["storage1".to_string(), "storage2".to_string()], allowed_message_senders: vec!["sender1".to_string(), "sender2".to_string()], @@ -132,7 +132,7 @@ mod tests { perform_locally: false, external_url: Some("http://localhost:8080".to_string()), api_key: Some("test_api_key".to_string()), - model: AgentAPIModel::OpenAI(open_ai), + model: AgentLLMInterface::OpenAI(open_ai), toolkit_permissions: vec!["toolkit1".to_string(), "toolkit2".to_string()], storage_bucket_permissions: vec!["storage1".to_string(), "storage2".to_string()], allowed_message_senders: vec!["sender1".to_string(), "sender2".to_string()], @@ -168,7 +168,7 @@ mod tests { perform_locally: false, external_url: Some("http://localhost:8080".to_string()), api_key: Some("test_api_key".to_string()), - model: AgentAPIModel::OpenAI(open_ai), + model: AgentLLMInterface::OpenAI(open_ai), toolkit_permissions: vec!["toolkit1".to_string(), "toolkit2".to_string()], storage_bucket_permissions: vec!["storage1".to_string(), "storage2".to_string()], allowed_message_senders: vec!["sender1".to_string(), "sender2".to_string()], @@ -201,7 +201,7 @@ mod tests { false, Some("http://localhost:8000".to_string()), Some("paramparam".to_string()), - AgentAPIModel::Sleep(sleep_api), + AgentLLMInterface::Sleep(sleep_api), vec!["tk1".to_string(), "tk2".to_string()], vec!["sb1".to_string(), "sb2".to_string()], vec!["allowed1".to_string(), "allowed2".to_string()], @@ -246,7 +246,7 @@ mod tests { "index": 0, "message": { "role": "assistant", - "content": "{ \"answer\": \"\\n\\nHello there, how may I assist you today?\" }" + "content": " a bunch of other text before { \"answer\": \"\\n\\nHello there, how may I assist you today?\" } and more text after to see if it fails getting the json object" }, "finish_reason": "stop" }], @@ -270,7 +270,7 @@ mod tests { false, Some(server.url()), // use the url of the mock server Some("mockapikey".to_string()), - AgentAPIModel::OpenAI(openai), + AgentLLMInterface::OpenAI(openai), vec!["tk1".to_string(), "tk2".to_string()], vec!["sb1".to_string(), "sb2".to_string()], vec!["allowed1".to_string(), "allowed2".to_string()], diff --git a/tests/job_manager_tests.rs b/tests/job_manager_tests.rs index f9fe68279..f37ddba86 100644 --- a/tests/job_manager_tests.rs +++ b/tests/job_manager_tests.rs @@ -20,7 +20,7 @@ mod tests { use mockito::Server; use shinkai_message_primitives::{ schemas::{ - agents::serialized_agent::{AgentAPIModel, OpenAI, SerializedAgent}, + agents::serialized_agent::{AgentLLMInterface, OpenAI, SerializedAgent}, inbox_name::InboxName, shinkai_name::{ShinkaiName, ShinkaiSubidentityType}, }, @@ -119,7 +119,7 @@ mod tests { perform_locally: false, external_url: Some(server.url()), api_key: Some("mockapikey".to_string()), - model: AgentAPIModel::OpenAI(openai), + model: AgentLLMInterface::OpenAI(openai), toolkit_permissions: vec!["toolkit1".to_string(), "toolkit2".to_string()], storage_bucket_permissions: vec!["storage1".to_string(), "storage2".to_string()], allowed_message_senders: vec!["sender1".to_string(), "sender2".to_string()], diff --git a/tests/node_retrying_tests.rs b/tests/node_retrying_tests.rs index 373e6194c..1e4774b93 100644 --- a/tests/node_retrying_tests.rs +++ b/tests/node_retrying_tests.rs @@ -1,5 +1,5 @@ use async_channel::{bounded, Receiver, Sender}; -use shinkai_message_primitives::schemas::agents::serialized_agent::{AgentAPIModel, OpenAI, SerializedAgent}; +use shinkai_message_primitives::schemas::agents::serialized_agent::{AgentLLMInterface, OpenAI, SerializedAgent}; use shinkai_message_primitives::schemas::inbox_name::InboxName; use shinkai_message_primitives::schemas::shinkai_name::ShinkaiName; use shinkai_message_primitives::shinkai_message::shinkai_message_schemas::{JobMessage, MessageSchemaType}; From 8f6f677bf4d8c7ff9a5b404af530a9826f0ff19e Mon Sep 17 00:00:00 2001 From: Robert Kornacki <11645932+robkorn@users.noreply.github.com> Date: Fri, 15 Sep 2023 22:55:39 +0200 Subject: [PATCH 26/26] Added comments --- src/agent/agent.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/agent/agent.rs b/src/agent/agent.rs index a9f6bea24..e37d0e0c4 100644 --- a/src/agent/agent.rs +++ b/src/agent/agent.rs @@ -23,9 +23,9 @@ pub struct Agent { pub external_url: Option, // external API URL pub api_key: Option, pub model: AgentLLMInterface, - pub toolkit_permissions: Vec, // list of toolkits the agent has access to - pub storage_bucket_permissions: Vec, // list of storage buckets the agent has access to - pub allowed_message_senders: Vec, // list of sub-identities allowed to message the agent + pub toolkit_permissions: Vec, // Todo: remove as not used + pub storage_bucket_permissions: Vec, // Todo: remove as not used + pub allowed_message_senders: Vec, // list of sub-identities allowed to message the agent } impl Agent {