-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
189 additions
and
79 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,3 +7,4 @@ pub mod job_manager; | |
pub mod job_prompts; | ||
pub mod plan_executor; | ||
pub mod providers; | ||
pub mod queue; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
pub mod shinkai_queue; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
use crossbeam_queue::ArrayQueue; | ||
use rocksdb::DB; | ||
use bincode::{serialize, deserialize}; | ||
use serde::{Serialize, Deserialize}; | ||
use std::collections::HashMap; | ||
use std::sync::{Arc, Mutex, mpsc}; | ||
use crate::agent::job_execution::JobForProcessing; | ||
use crate::db::db_errors::ShinkaiDBError; | ||
use crate::db::{ShinkaiDB, Topic}; | ||
|
||
type Queue = Arc<ArrayQueue<JobForProcessing>>; | ||
type Subscriber = mpsc::Sender<JobForProcessing>; | ||
|
||
// Note(Nico): This | ||
|
||
#[derive(Serialize, Deserialize, Debug)] | ||
pub struct SerializedQueue { | ||
items: Vec<JobForProcessing>, | ||
} | ||
|
||
#[derive(Serialize, Deserialize)] | ||
pub struct SharedJobQueueManager { | ||
queues: HashMap<String, Queue>, | ||
subscribers: HashMap<String, Vec<Subscriber>>, | ||
db: ShinkaiDB, | ||
} | ||
|
||
impl SharedJobQueueManager { | ||
pub fn new(db: Arc<ShinkaiDB>) -> Self { | ||
SharedJobQueueManager { | ||
queues: Arc::new(Mutex::new(HashMap::new())), | ||
subscribers: Arc::new(Mutex::new(HashMap::new())), | ||
db, | ||
} | ||
} | ||
|
||
fn get_queue(&self, key: &str) -> Result<Queue, ShinkaiDBError> { | ||
let mut queues = self.queues.lock().unwrap(); | ||
if let Some(queue) = queues.get(key) { | ||
return Ok(queue.clone()); | ||
} | ||
|
||
match self.db.get_job_queues(key) { | ||
Ok(queue) => { | ||
queues.insert(key.to_string(), Arc::new(queue)); | ||
Ok(queues.get(key).unwrap().clone()) | ||
}, | ||
Err(e) => Err(e), | ||
} | ||
} | ||
|
||
pub fn push(&self, key: &str, value: JobForProcessing) -> Result<(), ShinkaiDBError> { | ||
let queue = self.get_queue(key)?; | ||
queue.push(value.clone()).unwrap(); | ||
self.db.persist_job_queues(key, &queue)?; | ||
|
||
// Notify subscribers | ||
if let Some(subs) = self.subscribers.lock().unwrap().get(key) { | ||
for sub in subs.iter() { | ||
sub.send(value.clone()).unwrap(); | ||
} | ||
} | ||
Ok(()) | ||
} | ||
|
||
pub fn pop(&self, key: &str) -> Result<Option<JobForProcessing>, ShinkaiDBError> { | ||
let queue = self.get_queue(key)?; | ||
let result = queue.pop(); | ||
if result.is_some() { | ||
self.db.persist_job_queues(key, &queue)?; | ||
} | ||
Ok(result) | ||
} | ||
|
||
pub fn subscribe(&self, key: &str) -> mpsc::Receiver<JobForProcessing> { | ||
let (tx, rx) = mpsc::channel(); | ||
self.subscribers.lock().unwrap() | ||
.entry(key.to_string()) | ||
.or_insert_with(Vec::new) | ||
.push(tx); | ||
rx | ||
} | ||
} | ||
|
||
impl Clone for SharedJobQueueManager { | ||
fn clone(&self) -> Self { | ||
SharedJobQueueManager { | ||
queues: Arc::clone(&self.queues), | ||
subscribers: Arc::clone(&self.subscribers), | ||
db: Arc::clone(&self.db), | ||
} | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
|
||
#[test] | ||
fn test_queue_manager() { | ||
let manager = SharedJobQueueManager::<String>::new("path_to_rocksdb"); | ||
|
||
// Subscribe to notifications from "my_queue" | ||
let receiver = manager.subscribe("my_queue"); | ||
let manager_clone = manager.clone(); | ||
std::thread::spawn(move || { | ||
for msg in receiver.iter() { | ||
println!("Received (from subscriber): {}", msg); | ||
|
||
// Pop from the queue inside the subscriber thread | ||
if let Some(message) = manager_clone.pop("my_queue") { | ||
println!("Popped (from subscriber): {}", message); | ||
} | ||
} | ||
}); | ||
|
||
// Push to a queue | ||
manager.push("my_queue", "Hello".to_string()); | ||
|
||
// Sleep to allow subscriber to process the message (just for this example) | ||
std::thread::sleep(std::time::Duration::from_secs(1)); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
impl ShinkaiDB { | ||
pub fn persist_job_queues<T>(&self, key: &str, queue: &Queue<T>) -> Result<(), ShinkaiDBError> | ||
where | ||
T: Serialize + Send + Sync + Clone, | ||
{ | ||
let serialized_queue = bincode::serialize(queue).map_err(|_| ShinkaiDBError::SerializationError)?; | ||
let cf_handle = self.db.cf_handle(Topic::JobQueues.as_str()).ok_or(ShinkaiDBError::ColumnFamilyNotFound(Topic::JobQueues.as_str().to_string()))?; | ||
self.db.put_cf(cf_handle, key.as_bytes(), &serialized_queue)?; | ||
Ok(()) | ||
} | ||
|
||
pub fn get_job_queues<T>(&self, key: &str) -> Result<Queue<T>, ShinkaiDBError> | ||
where | ||
T: Deserialize<'static> + Send + Sync + Clone, | ||
{ | ||
let cf_handle = self.db.cf_handle(Topic::JobQueues.as_str()).ok_or(ShinkaiDBError::ColumnFamilyNotFound(Topic::JobQueues.as_str().to_string()))?; | ||
let serialized_queue = self.db.get_cf(cf_handle, key.as_bytes())?.ok_or(ShinkaiDBError::DataNotFound)?; | ||
let queue: Queue<T> = bincode::deserialize(&serialized_queue).map_err(|_| ShinkaiDBError::DeserializationError)?; | ||
Ok(queue) | ||
} | ||
|
||
pub fn get_all_queues<T>(&self) -> Result<HashMap<String, Queue<T>>, ShinkaiDBError> | ||
where | ||
T: Deserialize<'static> + Send + Sync + Clone, | ||
{ | ||
let cf_handle = self.db.cf_handle(Topic::JobQueues.as_str()).ok_or(ShinkaiDBError::ColumnFamilyNotFound(Topic::JobQueues.as_str().to_string()))?; | ||
let mut queues = HashMap::new(); | ||
|
||
for (key, value) in self.db.iterator_cf(cf_handle, IteratorMode::Start) { | ||
let key = String::from_utf8(key.to_vec()).map_err(|_| ShinkaiDBError::KeyParseError)?; | ||
let queue: Queue<T> = bincode::deserialize(&value).map_err(|_| ShinkaiDBError::DeserializationError)?; | ||
queues.insert(key, queue); | ||
} | ||
|
||
Ok(queues) | ||
} | ||
} |