Skip to content

Commit

Permalink
Merge pull request #71 from dcSpark/job-toolkit-integration
Browse files Browse the repository at this point in the history
Base Agent & Job Cleanup
  • Loading branch information
robkorn authored Sep 15, 2023
2 parents d2c65d0 + 8f6f677 commit 6add7e8
Show file tree
Hide file tree
Showing 43 changed files with 1,440 additions and 802 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::schemas::shinkai_name::ShinkaiName;
use serde::{Deserialize, Serialize};
use std::str::FromStr;

// Agent has a few fields that are not serializable, so we need to create a struct that is serializable
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
Expand All @@ -9,39 +10,43 @@ pub struct SerializedAgent {
pub perform_locally: bool,
pub external_url: Option<String>,
pub api_key: Option<String>,
pub model: AgentAPIModel,
pub model: AgentLLMInterface,
pub toolkit_permissions: Vec<String>,
pub storage_bucket_permissions: Vec<String>,
pub allowed_message_senders: Vec<String>,
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum AgentAPIModel {
pub enum AgentLLMInterface {
#[serde(rename = "openai")]
OpenAI(OpenAI),
#[serde(rename = "sleep")]
Sleep(SleepAPI),
#[serde(rename = "local-llm")]
LocalLLM(LocalLLM),
}

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct LocalLLM {}

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct OpenAI {
pub model_type: String,
}

use std::str::FromStr;

impl FromStr for AgentAPIModel {
impl FromStr for AgentLLMInterface {
type Err = ();

fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.starts_with("openai:") {
let model_type = s.strip_prefix("openai:").unwrap_or("").to_string();
Ok(AgentAPIModel::OpenAI(OpenAI { model_type }))
Ok(AgentLLMInterface::OpenAI(OpenAI { model_type }))
} else {
Ok(AgentAPIModel::Sleep(SleepAPI {}))
Ok(AgentLLMInterface::Sleep(SleepAPI {}))
}
}
}

#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
pub struct SleepAPI {}
pub struct SleepAPI {}
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use std::fmt;

use serde::{Deserialize, Serialize, Serializer, Deserializer};
use serde_json::Result;
use regex::Regex;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde_json::Result;

use crate::schemas::{inbox_name::InboxName, shinkai_name::ShinkaiName, agents::serialized_agent::SerializedAgent};
use crate::schemas::{agents::serialized_agent::SerializedAgent, inbox_name::InboxName, shinkai_name::ShinkaiName};

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub enum MessageSchemaType {
Expand All @@ -17,7 +17,7 @@ pub enum MessageSchemaType {
APIReadUpToTimeRequest,
APIAddAgentRequest,
TextContent,
Empty
Empty,
}

impl MessageSchemaType {
Expand Down Expand Up @@ -95,7 +95,7 @@ impl JobScope {
}

#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct JobCreation {
pub struct JobCreationInfo {
pub scope: JobScope,
}

Expand Down Expand Up @@ -214,9 +214,9 @@ pub struct RegistrationCodeRequest {
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum IdentityPermissions {
Admin, // can create and delete other profiles
Admin, // can create and delete other profiles
Standard, // can add / remove devices
None, // none of the above
None, // none of the above
}

impl IdentityPermissions {
Expand Down Expand Up @@ -272,7 +272,7 @@ impl Serialize for RegistrationCodeType {
RegistrationCodeType::Device(device_name) => {
let s = format!("device:{}", device_name);
serializer.serialize_str(&s)
},
}
RegistrationCodeType::Profile => serializer.serialize_str("profile"),
}
}
Expand All @@ -289,7 +289,7 @@ impl<'de> Deserialize<'de> for RegistrationCodeType {
Some(&"device") => {
let device_name = parts.get(1).unwrap_or(&"main");
Ok(RegistrationCodeType::Device(device_name.to_string()))
},
}
Some(&"profile") => Ok(RegistrationCodeType::Profile),
_ => Err(serde::de::Error::custom("Unexpected variant")),
}
Expand All @@ -303,4 +303,4 @@ impl fmt::Display for RegistrationCodeType {
RegistrationCodeType::Profile => write!(f, "profile"),
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::{
},
shinkai_message_schemas::{
APIAddAgentRequest, APIGetMessagesFromInboxRequest, APIReadUpToTimeRequest, IdentityPermissions,
JobCreation, JobMessage, JobScope, MessageSchemaType, RegistrationCodeRequest, RegistrationCodeType,
JobCreationInfo, JobMessage, JobScope, MessageSchemaType, RegistrationCodeRequest, RegistrationCodeType,
},
},
shinkai_utils::{
Expand Down Expand Up @@ -394,7 +394,7 @@ impl ShinkaiMessageBuilder {
node_receiver: ProfileName,
node_receiver_subidentity: ProfileName,
) -> Result<ShinkaiMessage, &'static str> {
let job_creation = JobCreation { scope };
let job_creation = JobCreationInfo { scope };
let body = serde_json::to_string(&job_creation).map_err(|_| "Failed to serialize job creation to JSON")?;

ShinkaiMessageBuilder::new(my_encryption_secret_key, my_signature_secret_key, receiver_public_key)
Expand Down Expand Up @@ -461,18 +461,22 @@ impl ShinkaiMessageBuilder {
// Use for placeholder. These messages *are not* encrypted so it's not required
let (placeholder_encryption_sk, placeholder_encryption_pk) = unsafe_deterministic_encryption_keypair(0);

ShinkaiMessageBuilder::new(placeholder_encryption_sk, my_signature_secret_key, placeholder_encryption_pk)
.message_raw_content(body)
.internal_metadata_with_schema(
"".to_string(),
"".to_string(),
inbox,
MessageSchemaType::JobMessageSchema,
EncryptionMethod::None,
)
.body_encryption(EncryptionMethod::None)
.external_metadata(node_receiver, node_sender)
.build()
ShinkaiMessageBuilder::new(
placeholder_encryption_sk,
my_signature_secret_key,
placeholder_encryption_pk,
)
.message_raw_content(body)
.internal_metadata_with_schema(
"".to_string(),
"".to_string(),
inbox,
MessageSchemaType::JobMessageSchema,
EncryptionMethod::None,
)
.body_encryption(EncryptionMethod::None)
.external_metadata(node_receiver, node_sender)
.build()
}

pub fn terminate_message(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use serde::{Deserialize, Serialize};
use serde_wasm_bindgen::{from_value, to_value};
use shinkai_message_primitives::schemas::{agents::serialized_agent::{SerializedAgent, AgentAPIModel}, shinkai_name::ShinkaiName};
use shinkai_message_primitives::schemas::{
agents::serialized_agent::{AgentLLMInterface, SerializedAgent},
shinkai_name::ShinkaiName,
};
use wasm_bindgen::prelude::*;

pub trait SerializedAgentJsValueConversion {
Expand Down Expand Up @@ -66,7 +69,7 @@ impl SerializedAgentJsValueConversion for SerializedAgent {
};
let api_key = if api_key.is_empty() { None } else { Some(api_key) };
let model = model
.parse::<AgentAPIModel>()
.parse::<AgentLLMInterface>()
.map_err(|_| JsValue::from_str("Invalid model"))?;
let toolkit_permissions = if toolkit_permissions.is_empty() {
Vec::new()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use serde::{Deserialize, Serialize};
use serde_wasm_bindgen;
use shinkai_message_primitives::{shinkai_message::shinkai_message_schemas::{JobScope, JobCreation, JobMessage}, schemas::inbox_name::InboxName};
use shinkai_message_primitives::{
schemas::inbox_name::InboxName,
shinkai_message::shinkai_message_schemas::{JobCreationInfo, JobMessage, JobScope},
};
use wasm_bindgen::prelude::*;

use crate::shinkai_wasm_wrappers::shinkai_wasm_error::ShinkaiWasmError;
Expand Down Expand Up @@ -36,15 +39,15 @@ impl JobScopeWrapper {
#[wasm_bindgen]
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct JobCreationWrapper {
inner: JobCreation,
inner: JobCreationInfo,
}

#[wasm_bindgen]
impl JobCreationWrapper {
#[wasm_bindgen(constructor)]
pub fn new(scope_js: &JsValue) -> Result<JobCreationWrapper, JsValue> {
let scope: JobScope = serde_wasm_bindgen::from_value(scope_js.clone())?;
let job_creation = JobCreation { scope };
let job_creation = JobCreationInfo { scope };
Ok(JobCreationWrapper { inner: job_creation })
}

Expand All @@ -66,13 +69,13 @@ impl JobCreationWrapper {

#[wasm_bindgen]
pub fn from_json_str(s: &str) -> Result<JobCreationWrapper, JsValue> {
let deserialized: JobCreation = serde_json::from_str(s).map_err(|e| JsValue::from_str(&e.to_string()))?;
let deserialized: JobCreationInfo = serde_json::from_str(s).map_err(|e| JsValue::from_str(&e.to_string()))?;
Ok(JobCreationWrapper { inner: deserialized })
}

#[wasm_bindgen]
pub fn from_jsvalue(js_value: &JsValue) -> Result<JobCreationWrapper, JsValue> {
let deserialized: JobCreation = serde_wasm_bindgen::from_value(js_value.clone())?;
let deserialized: JobCreationInfo = serde_wasm_bindgen::from_value(js_value.clone())?;
Ok(JobCreationWrapper { inner: deserialized })
}

Expand All @@ -81,7 +84,9 @@ impl JobCreationWrapper {
let buckets: Vec<InboxName> = Vec::new();
let documents: Vec<String> = Vec::new();
let job_scope = JobScope::new(Some(buckets), Some(documents));
Ok(JobCreationWrapper { inner: JobCreation { scope: job_scope } })
Ok(JobCreationWrapper {
inner: JobCreationInfo { scope: job_scope },
})
}
}

Expand Down Expand Up @@ -126,7 +131,10 @@ impl JobMessageWrapper {

#[wasm_bindgen(js_name = fromStrings)]
pub fn from_strings(job_id: &str, content: &str) -> JobMessageWrapper {
let job_message = JobMessage { job_id: job_id.to_string(), content: content.to_string() };
let job_message = JobMessage {
job_id: job_id.to_string(),
content: content.to_string(),
};
JobMessageWrapper { inner: job_message }
}
}
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,29 @@
use crate::shinkai_wasm_wrappers::{shinkai_message_wrapper::ShinkaiMessageWrapper, wasm_shinkai_message::SerdeWasmMethods, shinkai_wasm_error::{WasmErrorWrapper, ShinkaiWasmError}};
use crate::shinkai_wasm_wrappers::{
shinkai_message_wrapper::ShinkaiMessageWrapper,
shinkai_wasm_error::{ShinkaiWasmError, WasmErrorWrapper},
wasm_shinkai_message::SerdeWasmMethods,
};
use ed25519_dalek::{PublicKey as SignaturePublicKey, SecretKey as SignatureStaticKey};
use js_sys::Uint8Array;
use serde::{Deserialize, Serialize};
use shinkai_message_primitives::{shinkai_utils::{encryption::{string_to_encryption_static_key, string_to_encryption_public_key, encryption_public_key_to_string, EncryptionMethod}, signatures::{string_to_signature_secret_key, signature_public_key_to_string}, shinkai_message_builder::{ShinkaiMessageBuilder, ProfileName}}, shinkai_message::shinkai_message_schemas::{IdentityPermissions, RegistrationCodeType, RegistrationCodeRequest, MessageSchemaType, APIGetMessagesFromInboxRequest, APIAddAgentRequest, APIReadUpToTimeRequest, JobScope, JobCreation, JobMessage}, schemas::{registration_code::RegistrationCode, inbox_name::InboxName, agents::serialized_agent::SerializedAgent}};
use serde_wasm_bindgen::{from_value, to_value};
use shinkai_message_primitives::{
schemas::{agents::serialized_agent::SerializedAgent, inbox_name::InboxName, registration_code::RegistrationCode},
shinkai_message::shinkai_message_schemas::{
APIAddAgentRequest, APIGetMessagesFromInboxRequest, APIReadUpToTimeRequest, IdentityPermissions,
JobCreationInfo, JobMessage, JobScope, MessageSchemaType, RegistrationCodeRequest, RegistrationCodeType,
},
shinkai_utils::{
encryption::{
encryption_public_key_to_string, string_to_encryption_public_key, string_to_encryption_static_key,
EncryptionMethod,
},
shinkai_message_builder::{ProfileName, ShinkaiMessageBuilder},
signatures::{signature_public_key_to_string, string_to_signature_secret_key},
},
};
use wasm_bindgen::prelude::*;
use x25519_dalek::{PublicKey as EncryptionPublicKey, StaticSecret as EncryptionStaticKey};
use serde_wasm_bindgen::{from_value, to_value};

#[wasm_bindgen]
pub struct ShinkaiMessageBuilderWrapper {
Expand Down Expand Up @@ -240,7 +258,8 @@ impl ShinkaiMessageBuilderWrapper {
match builder.build() {
Ok(shinkai_message) => {
let js_value = shinkai_message.to_jsvalue().map_err(WasmErrorWrapper)?;
Ok(ShinkaiMessageWrapper::from_jsvalue(&js_value).map_err(|e| WasmErrorWrapper::new(ShinkaiWasmError::from(e)))?)
Ok(ShinkaiMessageWrapper::from_jsvalue(&js_value)
.map_err(|e| WasmErrorWrapper::new(ShinkaiWasmError::from(e)))?)
}
Err(e) => Err(JsValue::from_str(&e.to_string())),
}
Expand All @@ -255,9 +274,9 @@ impl ShinkaiMessageBuilderWrapper {
pub fn build_to_jsvalue(&mut self) -> Result<JsValue, JsValue> {
if let Some(ref builder) = self.inner {
match builder.build() {
Ok(shinkai_message) => {
shinkai_message.to_jsvalue().map_err(|e| JsValue::from_str(&e.to_string()))
}
Ok(shinkai_message) => shinkai_message
.to_jsvalue()
.map_err(|e| JsValue::from_str(&e.to_string())),
Err(e) => Err(JsValue::from_str(e)),
}
} else {
Expand Down Expand Up @@ -649,7 +668,7 @@ impl ShinkaiMessageBuilderWrapper {
) -> Result<String, JsValue> {
let scope: JobScope = serde_wasm_bindgen::from_value(scope).map_err(|e| JsValue::from_str(&e.to_string()))?;

let job_creation = JobCreation { scope };
let job_creation = JobCreationInfo { scope };
let body = serde_json::to_string(&job_creation).map_err(|e| JsValue::from_str(&e.to_string()))?;

let mut builder =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ use wasm_bindgen_test::*;
#[cfg(test)]
mod tests {
use super::*;
use shinkai_message_primitives::schemas::agents::serialized_agent::{SerializedAgent, AgentAPIModel, OpenAI};
use serde_wasm_bindgen::from_value;
use shinkai_message_primitives::schemas::agents::serialized_agent::{AgentLLMInterface, OpenAI, SerializedAgent};
use shinkai_message_wasm::shinkai_wasm_wrappers::serialized_agent_wrapper::SerializedAgentWrapper;
use wasm_bindgen::JsValue;
use serde_wasm_bindgen::from_value;

#[cfg(target_arch = "wasm32")]
#[wasm_bindgen_test]
Expand All @@ -22,21 +22,39 @@ mod tests {
"permission1,permission2".to_string(),
"bucket1,bucket2".to_string(),
"sender1,sender2".to_string(),
).unwrap();
)
.unwrap();

// Get the inner SerializedAgent
let agent_jsvalue = serialized_agent_wrapper.inner().unwrap();
let agent: SerializedAgent = from_value(agent_jsvalue).unwrap();

// Check that the fields are correctly converted
assert_eq!(agent.id, "test_agent");
assert_eq!(agent.full_identity_name.to_string(), "@@node.shinkai/main/agent/test_agent");
assert_eq!(
agent.full_identity_name.to_string(),
"@@node.shinkai/main/agent/test_agent"
);
assert_eq!(agent.perform_locally, false);
assert_eq!(agent.external_url, Some("http://example.com".to_string()));
assert_eq!(agent.api_key, Some("123456".to_string()));
assert_eq!(agent.model, AgentAPIModel::OpenAI(OpenAI { model_type: "chatgpt3-turbo".to_string() }));
assert_eq!(agent.toolkit_permissions, vec!["permission1".to_string(), "permission2".to_string()]);
assert_eq!(agent.storage_bucket_permissions, vec!["bucket1".to_string(), "bucket2".to_string()]);
assert_eq!(agent.allowed_message_senders, vec!["sender1".to_string(), "sender2".to_string()]);
assert_eq!(
agent.model,
AgentLLMInterface::OpenAI(OpenAI {
model_type: "chatgpt3-turbo".to_string()
})
);
assert_eq!(
agent.toolkit_permissions,
vec!["permission1".to_string(), "permission2".to_string()]
);
assert_eq!(
agent.storage_bucket_permissions,
vec!["bucket1".to_string(), "bucket2".to_string()]
);
assert_eq!(
agent.allowed_message_senders,
vec!["sender1".to_string(), "sender2".to_string()]
);
}
}
}
Loading

0 comments on commit 6add7e8

Please sign in to comment.