Skip to content

Commit

Permalink
checkpointing
Browse files Browse the repository at this point in the history
  • Loading branch information
nicarq committed Oct 10, 2023
1 parent b72945b commit 2f14203
Show file tree
Hide file tree
Showing 14 changed files with 189 additions and 79 deletions.
11 changes: 11 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ hex = "=0.4.3"
aes-gcm = "0.10.3"
blake3 = "1.2.0"
tiktoken-rs = "0.5.4"
crossbeam-queue = "0.3.8"

[dependencies.rocksdb]
version = "0.21.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ pub struct SerializedAgent {
pub enum AgentLLMInterface {
#[serde(rename = "openai")]
OpenAI(OpenAI),
#[serde(rename = "sleep")]
Sleep(SleepAPI),
#[serde(rename = "local-llm")]
LocalLLM(LocalLLM),
}
Expand All @@ -43,10 +41,8 @@ impl FromStr for AgentLLMInterface {
let model_type = s.strip_prefix("openai:").unwrap_or("").to_string();
Ok(AgentLLMInterface::OpenAI(OpenAI { model_type }))
} else {
Ok(AgentLLMInterface::Sleep(SleepAPI {}))
// TODO: nothing else for now
Err(())
}
}
}

#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
pub struct SleepAPI {}
5 changes: 0 additions & 5 deletions src/agent/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,6 @@ impl Agent {
.call_api(&self.client, self.external_url.as_ref(), self.api_key.as_ref(), prompt)
.await
}
AgentLLMInterface::Sleep(sleep_api) => {
sleep_api
.call_api(&self.client, self.external_url.as_ref(), self.api_key.as_ref(), prompt)
.await
}
AgentLLMInterface::LocalLLM(local_llm) => {
self.inference_locally(prompt.generate_single_output_string()?).await
}
Expand Down
7 changes: 7 additions & 0 deletions src/agent/job_execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use crate::schemas::identity::Identity;
use blake3::Hasher;
use chrono::Utc;
use ed25519_dalek::{PublicKey as SignaturePublicKey, SecretKey as SignatureStaticKey};
use serde::{Serialize, Deserialize};
use serde_json::{Map, Value as JsonValue};
use shinkai_message_primitives::shinkai_utils::encryption::unsafe_deterministic_encryption_keypair;
use shinkai_message_primitives::shinkai_utils::job_scope::LocalScopeEntry;
Expand All @@ -35,6 +36,12 @@ use std::{collections::HashMap, error::Error, sync::Arc};
use tokio::sync::{mpsc, Mutex};
use x25519_dalek::{PublicKey as EncryptionPublicKey, StaticSecret as EncryptionStaticKey};

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct JobForProcessing {
job_message: JobMessage,
profile: ShinkaiName,
}

impl AgentManager {
/// Processes a job message which will trigger a job step
pub async fn process_job_step(
Expand Down
1 change: 1 addition & 0 deletions src/agent/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ pub mod job_manager;
pub mod job_prompts;
pub mod plan_executor;
pub mod providers;
pub mod queue;
1 change: 0 additions & 1 deletion src/agent/providers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use reqwest::Client;
use serde_json::Value as JsonValue;

pub mod openai;
pub mod sleep_api;

#[async_trait]
pub trait LLMProvider {
Expand Down
22 changes: 0 additions & 22 deletions src/agent/providers/sleep_api.rs

This file was deleted.

1 change: 1 addition & 0 deletions src/agent/queue/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod shinkai_queue;
123 changes: 123 additions & 0 deletions src/agent/queue/shinkai_queue.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
use crossbeam_queue::ArrayQueue;
use rocksdb::DB;
use bincode::{serialize, deserialize};
use serde::{Serialize, Deserialize};
use std::collections::HashMap;
use std::sync::{Arc, Mutex, mpsc};
use crate::agent::job_execution::JobForProcessing;
use crate::db::db_errors::ShinkaiDBError;
use crate::db::{ShinkaiDB, Topic};

type Queue = Arc<ArrayQueue<JobForProcessing>>;
type Subscriber = mpsc::Sender<JobForProcessing>;

// Note(Nico): This

#[derive(Serialize, Deserialize, Debug)]
pub struct SerializedQueue {
items: Vec<JobForProcessing>,
}

#[derive(Serialize, Deserialize)]
pub struct SharedJobQueueManager {
queues: HashMap<String, Queue>,
subscribers: HashMap<String, Vec<Subscriber>>,
db: ShinkaiDB,
}

impl SharedJobQueueManager {
pub fn new(db: Arc<ShinkaiDB>) -> Self {
SharedJobQueueManager {
queues: Arc::new(Mutex::new(HashMap::new())),
subscribers: Arc::new(Mutex::new(HashMap::new())),
db,
}
}

fn get_queue(&self, key: &str) -> Result<Queue, ShinkaiDBError> {
let mut queues = self.queues.lock().unwrap();
if let Some(queue) = queues.get(key) {
return Ok(queue.clone());
}

match self.db.get_job_queues(key) {
Ok(queue) => {
queues.insert(key.to_string(), Arc::new(queue));
Ok(queues.get(key).unwrap().clone())
},
Err(e) => Err(e),
}
}

pub fn push(&self, key: &str, value: JobForProcessing) -> Result<(), ShinkaiDBError> {
let queue = self.get_queue(key)?;
queue.push(value.clone()).unwrap();
self.db.persist_job_queues(key, &queue)?;

// Notify subscribers
if let Some(subs) = self.subscribers.lock().unwrap().get(key) {
for sub in subs.iter() {
sub.send(value.clone()).unwrap();
}
}
Ok(())
}

pub fn pop(&self, key: &str) -> Result<Option<JobForProcessing>, ShinkaiDBError> {
let queue = self.get_queue(key)?;
let result = queue.pop();
if result.is_some() {
self.db.persist_job_queues(key, &queue)?;
}
Ok(result)
}

pub fn subscribe(&self, key: &str) -> mpsc::Receiver<JobForProcessing> {
let (tx, rx) = mpsc::channel();
self.subscribers.lock().unwrap()
.entry(key.to_string())
.or_insert_with(Vec::new)
.push(tx);
rx
}
}

impl Clone for SharedJobQueueManager {
fn clone(&self) -> Self {
SharedJobQueueManager {
queues: Arc::clone(&self.queues),
subscribers: Arc::clone(&self.subscribers),
db: Arc::clone(&self.db),
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_queue_manager() {
let manager = SharedJobQueueManager::<String>::new("path_to_rocksdb");

// Subscribe to notifications from "my_queue"
let receiver = manager.subscribe("my_queue");
let manager_clone = manager.clone();
std::thread::spawn(move || {
for msg in receiver.iter() {
println!("Received (from subscriber): {}", msg);

// Pop from the queue inside the subscriber thread
if let Some(message) = manager_clone.pop("my_queue") {
println!("Popped (from subscriber): {}", message);
}
}
});

// Push to a queue
manager.push("my_queue", "Hello".to_string());

// Sleep to allow subscriber to process the message (just for this example)
std::thread::sleep(std::time::Duration::from_secs(1));
}
}
3 changes: 3 additions & 0 deletions src/db/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pub enum Topic {
MessageBoxSymmetricKeys,
MessageBoxSymmetricKeysTimes,
TempFilesInbox,
JobQueues,
}

impl Topic {
Expand Down Expand Up @@ -64,6 +65,7 @@ impl Topic {
Self::MessageBoxSymmetricKeys => "message_box_symmetric_keys",
Self::MessageBoxSymmetricKeysTimes => "message_box_symmetric_keys_times",
Self::TempFilesInbox => "temp_files_inbox",
Self::JobQueues => "job_queues",
}
}
}
Expand Down Expand Up @@ -155,6 +157,7 @@ impl ShinkaiDB {
Topic::MessageBoxSymmetricKeys.as_str().to_string(),
Topic::MessageBoxSymmetricKeysTimes.as_str().to_string(),
Topic::TempFilesInbox.as_str().to_string(),
Topic::JobQueues.as_str().to_string(),
]
};

Expand Down
3 changes: 2 additions & 1 deletion src/db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ pub mod db_resources;
pub mod db_toolkits;
pub mod db_utils;
pub mod db_retry;
pub mod db_files_transmission;
pub mod db_files_transmission;
pub mod db_job_queue;
45 changes: 1 addition & 44 deletions tests/db_agents_tests.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use mockito::Server;
use serde_json::Value as JsonValue;
use shinkai_message_primitives::schemas::agents::serialized_agent::{OpenAI, SleepAPI};
use shinkai_message_primitives::schemas::agents::serialized_agent::OpenAI;
use shinkai_node::db::{db_errors::ShinkaiDBError, ShinkaiDB};
use std::fs;
use std::path::Path;
Expand Down Expand Up @@ -190,49 +190,6 @@ mod tests {
assert_eq!(vec!["toolkit2"], toolkits);
}

#[tokio::test]
async fn test_agent_creation() {
let (tx, mut rx) = mpsc::channel(1);
let sleep_api = SleepAPI {};
let agent = Agent::new(
"1".to_string(),
ShinkaiName::new("@@alice.shinkai/profileName/agent/myChatGPTAgent".to_string()).unwrap(),
tx,
false,
Some("http://localhost:8000".to_string()),
Some("paramparam".to_string()),
AgentLLMInterface::Sleep(sleep_api),
vec!["tk1".to_string(), "tk2".to_string()],
vec!["sb1".to_string(), "sb2".to_string()],
vec!["allowed1".to_string(), "allowed2".to_string()],
);

assert_eq!(agent.id, "1");
assert_eq!(
agent.full_identity_name,
ShinkaiName::new("@@alice.shinkai/profileName/agent/myChatGPTAgent".to_string()).unwrap()
);
assert_eq!(agent.perform_locally, false);
assert_eq!(agent.external_url, Some("http://localhost:8000".to_string()));
assert_eq!(agent.toolkit_permissions, vec!["tk1".to_string(), "tk2".to_string()]);
assert_eq!(
agent.storage_bucket_permissions,
vec!["sb1".to_string(), "sb2".to_string()]
);
assert_eq!(
agent.allowed_message_senders,
vec!["allowed1".to_string(), "allowed2".to_string()]
);

let handle = tokio::spawn(async move {
agent
.inference(JobPromptGenerator::basic_instant_response_prompt("Test".to_string()))
.await
});
let result: Result<JsonValue, AgentError> = handle.await.unwrap();
assert_eq!(result.unwrap(), JsonValue::Bool(true))
}

#[tokio::test]
async fn test_agent_call_external_api_openai() {
let mut server = Server::new();
Expand Down
37 changes: 37 additions & 0 deletions tests/shinkai_queue_tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
impl ShinkaiDB {
pub fn persist_job_queues<T>(&self, key: &str, queue: &Queue<T>) -> Result<(), ShinkaiDBError>
where
T: Serialize + Send + Sync + Clone,
{
let serialized_queue = bincode::serialize(queue).map_err(|_| ShinkaiDBError::SerializationError)?;
let cf_handle = self.db.cf_handle(Topic::JobQueues.as_str()).ok_or(ShinkaiDBError::ColumnFamilyNotFound(Topic::JobQueues.as_str().to_string()))?;
self.db.put_cf(cf_handle, key.as_bytes(), &serialized_queue)?;
Ok(())
}

pub fn get_job_queues<T>(&self, key: &str) -> Result<Queue<T>, ShinkaiDBError>
where
T: Deserialize<'static> + Send + Sync + Clone,
{
let cf_handle = self.db.cf_handle(Topic::JobQueues.as_str()).ok_or(ShinkaiDBError::ColumnFamilyNotFound(Topic::JobQueues.as_str().to_string()))?;
let serialized_queue = self.db.get_cf(cf_handle, key.as_bytes())?.ok_or(ShinkaiDBError::DataNotFound)?;
let queue: Queue<T> = bincode::deserialize(&serialized_queue).map_err(|_| ShinkaiDBError::DeserializationError)?;
Ok(queue)
}

pub fn get_all_queues<T>(&self) -> Result<HashMap<String, Queue<T>>, ShinkaiDBError>
where
T: Deserialize<'static> + Send + Sync + Clone,
{
let cf_handle = self.db.cf_handle(Topic::JobQueues.as_str()).ok_or(ShinkaiDBError::ColumnFamilyNotFound(Topic::JobQueues.as_str().to_string()))?;
let mut queues = HashMap::new();

for (key, value) in self.db.iterator_cf(cf_handle, IteratorMode::Start) {
let key = String::from_utf8(key.to_vec()).map_err(|_| ShinkaiDBError::KeyParseError)?;
let queue: Queue<T> = bincode::deserialize(&value).map_err(|_| ShinkaiDBError::DeserializationError)?;
queues.insert(key, queue);
}

Ok(queues)
}
}

0 comments on commit 2f14203

Please sign in to comment.