Skip to content

Commit

Permalink
Merge pull request #103 from dcSpark/nico/queue-feature
Browse files Browse the repository at this point in the history
Nico/queue feature
  • Loading branch information
nicarq authored Oct 11, 2023
2 parents 987d084 + 9a951d9 commit 44c44ec
Show file tree
Hide file tree
Showing 16 changed files with 402 additions and 84 deletions.
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ pub struct SerializedAgent {
pub enum AgentLLMInterface {
#[serde(rename = "openai")]
OpenAI(OpenAI),
#[serde(rename = "sleep")]
Sleep(SleepAPI),
#[serde(rename = "local-llm")]
LocalLLM(LocalLLM),
}
Expand All @@ -43,10 +41,8 @@ impl FromStr for AgentLLMInterface {
let model_type = s.strip_prefix("openai:").unwrap_or("").to_string();
Ok(AgentLLMInterface::OpenAI(OpenAI { model_type }))
} else {
Ok(AgentLLMInterface::Sleep(SleepAPI {}))
// TODO: nothing else for now
Err(())
}
}
}

#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
pub struct SleepAPI {}
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ pub struct JobCreationInfo {
pub scope: JobScope,
}

#[derive(Serialize, Deserialize, Debug, Clone)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct JobMessage {
// TODO: scope div modifications?
pub job_id: String,
Expand Down
5 changes: 0 additions & 5 deletions src/agent/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,6 @@ impl Agent {
.call_api(&self.client, self.external_url.as_ref(), self.api_key.as_ref(), prompt)
.await
}
AgentLLMInterface::Sleep(sleep_api) => {
sleep_api
.call_api(&self.client, self.external_url.as_ref(), self.api_key.as_ref(), prompt)
.await
}
AgentLLMInterface::LocalLLM(local_llm) => {
self.inference_locally(prompt.generate_single_output_string()?).await
}
Expand Down
1 change: 1 addition & 0 deletions src/agent/job_execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use async_recursion::async_recursion;
use blake3::Hasher;
use chrono::Utc;
use ed25519_dalek::{PublicKey as SignaturePublicKey, SecretKey as SignatureStaticKey};
use serde::{Serialize, Deserialize};
use serde_json::{Map, Value as JsonValue};
use shinkai_message_primitives::shinkai_utils::encryption::unsafe_deterministic_encryption_keypair;
use shinkai_message_primitives::shinkai_utils::job_scope::{JobScope, LocalScopeEntry};
Expand Down
1 change: 1 addition & 0 deletions src/agent/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ pub mod job_manager;
pub mod job_prompts;
pub mod plan_executor;
pub mod providers;
pub mod queue;
1 change: 0 additions & 1 deletion src/agent/providers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use reqwest::Client;
use serde_json::Value as JsonValue;

pub mod openai;
pub mod sleep_api;

#[async_trait]
pub trait LLMProvider {
Expand Down
22 changes: 0 additions & 22 deletions src/agent/providers/sleep_api.rs

This file was deleted.

285 changes: 285 additions & 0 deletions src/agent/queue/job_queue_manager.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
use crate::db::db_errors::ShinkaiDBError;
use crate::db::ShinkaiDB;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use shinkai_message_primitives::schemas::shinkai_name::ShinkaiName;
use shinkai_message_primitives::shinkai_message::shinkai_message_schemas::JobMessage;
use std::collections::HashMap;
use std::sync::{mpsc, Arc, Mutex};

type MutexQueue<T> = Arc<Mutex<Vec<T>>>;
type Subscriber<T> = mpsc::Sender<T>;

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct JobForProcessing {
job_message: JobMessage,
profile: ShinkaiName,
}

#[derive(Debug)]
pub struct JobQueueManager<T> {
queues: HashMap<String, MutexQueue<T>>,
subscribers: HashMap<String, Vec<Subscriber<T>>>,
db: Arc<Mutex<ShinkaiDB>>,
}

impl<T: Clone + Send + 'static + DeserializeOwned + Serialize> JobQueueManager<T> {
pub fn new(db: Arc<Mutex<ShinkaiDB>>) -> Result<Self, ShinkaiDBError> {
// Lock the db for safe access
let db_lock = db.lock().unwrap();

// Call the get_all_queues method to get all queue data from the db
match db_lock.get_all_queues() {
Ok(db_queues) => {
// Initialize the queues field with Mutex-wrapped Vecs from the db data
let manager_queues = db_queues
.into_iter()
.map(|(key, vec)| (key, Arc::new(Mutex::new(vec))))
.collect();

// Return a new SharedJobQueueManager with the loaded queue data
Ok(JobQueueManager {
queues: manager_queues,
subscribers: HashMap::new(),
db: Arc::clone(&db),
})
}
Err(e) => Err(e),
}
}

fn get_queue(&self, key: &str) -> Result<Vec<T>, ShinkaiDBError> {
let db = self.db.lock().unwrap();
db.get_job_queues(key)
}

pub fn push(&mut self, key: &str, value: T) -> Result<(), ShinkaiDBError> {
let queue = self
.queues
.entry(key.to_string())
.or_insert_with(|| Arc::new(Mutex::new(Vec::new())));

let mut guarded_queue = queue.lock().unwrap();
guarded_queue.push(value.clone());

// Persist queue to the database
let db = self.db.lock().unwrap();
db.persist_queue(key, &guarded_queue)?;

// Notify subscribers
if let Some(subs) = self.subscribers.get(key) {
for sub in subs.iter() {
sub.send(value.clone()).unwrap();
}
}
Ok(())
}

pub fn dequeue(&mut self, key: &str) -> Result<Option<T>, ShinkaiDBError> {
// Ensure the specified key exists in the queues hashmap, initializing it with an empty queue if necessary
let queue = self
.queues
.entry(key.to_string())
.or_insert_with(|| Arc::new(Mutex::new(Vec::new())));
let mut guarded_queue = queue.lock().unwrap();

// Check if there's an element to dequeue, and remove it if so
let result = if guarded_queue.get(0).is_some() {
Some(guarded_queue.remove(0))
} else {
None
};

// Persist queue to the database
let db = self.db.lock().unwrap();
db.persist_queue(key, &guarded_queue)?;

Ok(result)
}

pub fn subscribe(&mut self, key: &str) -> mpsc::Receiver<T> {
let (tx, rx) = mpsc::channel();
self.subscribers
.entry(key.to_string())
.or_insert_with(Vec::new)
.push(tx);
rx
}
}

impl<T: Clone + Send + 'static> Clone for JobQueueManager<T> {
fn clone(&self) -> Self {
JobQueueManager {
queues: self.queues.clone(),
subscribers: self.subscribers.clone(),
db: Arc::clone(&self.db),
}
}
}

#[cfg(test)]
mod tests {
use crate::agent::{error::AgentError, queue::job_queue_manager_error::JobQueueManagerError};

use super::*;
use serde_json::Value as JsonValue;
use shinkai_message_primitives::shinkai_utils::shinkai_logging::{shinkai_log, ShinkaiLogLevel, ShinkaiLogOption};
use std::{fs, path::Path};

#[test]
fn setup() {
let path = Path::new("db_tests/");
let _ = fs::remove_dir_all(&path);
}

#[test]
fn test_queue_manager() {
setup();
let db = Arc::new(Mutex::new(ShinkaiDB::new("db_tests/").unwrap()));
let mut manager = JobQueueManager::<JobForProcessing>::new(db).unwrap();

// Subscribe to notifications from "my_queue"
let receiver = manager.subscribe("my_queue");
let mut manager_clone = manager.clone();
let handle = std::thread::spawn(move || {
for msg in receiver.iter() {
println!("Received (from subscriber): {:?}", msg);

// Dequeue from the queue inside the subscriber thread
if let Ok(Some(message)) = manager_clone.dequeue("my_queue") {
println!("Dequeued (from subscriber): {:?}", message);

// Assert that the subscriber dequeued the correct message
assert_eq!(message, msg, "Dequeued message does not match received message");
}

eprintln!("Dequeued (from subscriber): {:?}", msg);
// Assert that the queue is now empty
match manager_clone.dequeue("my_queue") {
Ok(None) => (),
Ok(Some(_)) => panic!("Queue is not empty!"),
Err(e) => panic!("Failed to dequeue from queue: {:?}", e),
}
break;
}
});

// Push to a queue
let job = JobForProcessing {
job_message: JobMessage {
job_id: "job_id::123::false".to_string(),
content: "my content".to_string(),
files_inbox: "".to_string(),
},
profile: ShinkaiName::new("@@node1.shinkai/main".to_string()).unwrap(),
};
manager.push("my_queue", job.clone()).unwrap();

// Sleep to allow subscriber to process the message (just for this example)
std::thread::sleep(std::time::Duration::from_millis(500));

handle.join().unwrap();
}

#[test]
fn test_queue_manager_consistency() {
setup();
let db_path = "db_tests/";
let db = Arc::new(Mutex::new(ShinkaiDB::new(db_path).unwrap()));
let mut manager = JobQueueManager::<JobForProcessing>::new(Arc::clone(&db)).unwrap();

// Push to a queue
let job = JobForProcessing {
job_message: JobMessage {
job_id: "job_id::123::false".to_string(),
content: "my content".to_string(),
files_inbox: "".to_string(),
},
profile: ShinkaiName::new("@@node1.shinkai/main".to_string()).unwrap(),
};
let job2 = JobForProcessing {
job_message: JobMessage {
job_id: "job_id::123::false".to_string(),
content: "my content 2".to_string(),
files_inbox: "".to_string(),
},
profile: ShinkaiName::new("@@node1.shinkai/main".to_string()).unwrap(),
};
manager.push("my_queue", job.clone()).unwrap();
manager.push("my_queue", job2.clone()).unwrap();

// Sleep to allow subscriber to process the message (just for this example)
std::thread::sleep(std::time::Duration::from_millis(500));

// Create a new manager and recover the state
let mut new_manager = JobQueueManager::<JobForProcessing>::new(Arc::clone(&db)).unwrap();

// Try to pop the job from the queue using the new manager
match new_manager.dequeue("my_queue") {
Ok(Some(recovered_job)) => {
shinkai_log(
ShinkaiLogOption::Tests,
ShinkaiLogLevel::Info,
format!("Recovered job: {:?}", recovered_job).as_str(),
);
assert_eq!(recovered_job, job);
}
Ok(None) => panic!("No job found in the queue!"),
Err(e) => panic!("Failed to pop job from queue: {:?}", e),
}

match new_manager.dequeue("my_queue") {
Ok(Some(recovered_job)) => {
shinkai_log(
ShinkaiLogOption::Tests,
ShinkaiLogLevel::Info,
format!("Recovered job: {:?}", recovered_job).as_str(),
);
assert_eq!(recovered_job, job2);
}
Ok(None) => panic!("No job found in the queue!"),
Err(e) => panic!("Failed to pop job from queue: {:?}", e),
}
}

#[test]
fn test_queue_manager_with_jsonvalue() {
setup();
let db = Arc::new(Mutex::new(ShinkaiDB::new("db_tests/").unwrap()));
let mut manager = JobQueueManager::<Result<JsonValue, JobQueueManagerError>>::new(db).unwrap();

// Subscribe to notifications from "my_queue"
let receiver = manager.subscribe("my_queue");
let mut manager_clone = manager.clone();
let handle = std::thread::spawn(move || {
for msg in receiver.iter() {
println!("Received (from subscriber): {:?}", msg);

// Dequeue from the queue inside the subscriber thread
if let Ok(Some(message)) = manager_clone.dequeue("my_queue") {
println!("Dequeued (from subscriber): {:?}", message);

// Assert that the subscriber dequeued the correct message
assert_eq!(message, msg, "Dequeued message does not match received message");
}

eprintln!("Dequeued (from subscriber): {:?}", msg);
// Assert that the queue is now empty
match manager_clone.dequeue("my_queue") {
Ok(None) => (),
Ok(Some(_)) => panic!("Queue is not empty!"),
Err(e) => panic!("Failed to dequeue from queue: {:?}", e),
}
break;
}
});

// Push to a queue
let job = Ok(JsonValue::String("my content".to_string()));
manager.push("my_queue", job.clone()).unwrap();

// Sleep to allow subscriber to process the message (just for this example)
std::thread::sleep(std::time::Duration::from_millis(500));
handle.join().unwrap();
}
}
22 changes: 22 additions & 0 deletions src/agent/queue/job_queue_manager_error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use serde::{Serialize, Deserialize};
use std::error::Error;
use std::fmt;

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum JobQueueManagerError {
UrlNotSet,
ApiKeyNotSet,
ReqwestError(String),
}

impl fmt::Display for JobQueueManagerError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
JobQueueManagerError::UrlNotSet => write!(f, "URL is not set"),
JobQueueManagerError::ApiKeyNotSet => write!(f, "API Key not set"),
JobQueueManagerError::ReqwestError(err) => write!(f, "Reqwest error: {}", err),
}
}
}

impl Error for JobQueueManagerError {}
Loading

0 comments on commit 44c44ec

Please sign in to comment.