Skip to content

Commit

Permalink
update sqlite architecture
Browse files Browse the repository at this point in the history
  • Loading branch information
nicarq committed Dec 18, 2024
1 parent 21a36cb commit 80418c3
Show file tree
Hide file tree
Showing 72 changed files with 963 additions and 1,465 deletions.
35 changes: 17 additions & 18 deletions shinkai-bin/shinkai-node/src/cron_tasks/cron_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{
sync::{Arc, Weak},
};

use chrono::{Timelike, Utc};
use chrono::Utc;
use ed25519_dalek::SigningKey;
use futures::Future;
use shinkai_message_primitives::{
Expand All @@ -22,7 +22,7 @@ use shinkai_message_primitives::{
},
};
use shinkai_sqlite::{errors::SqliteManagerError, SqliteManager};
use tokio::sync::{Mutex, RwLock};
use tokio::sync::Mutex;
use x25519_dalek::{PublicKey as EncryptionPublicKey, StaticSecret as EncryptionStaticKey};

use crate::{
Expand Down Expand Up @@ -83,7 +83,7 @@ impl fmt::Display for CronManagerError {
}

pub struct CronManager {
pub db: Weak<RwLock<SqliteManager>>,
pub db: Weak<SqliteManager>,
pub node_profile_name: ShinkaiName,
pub identity_secret_key: SigningKey,
pub job_manager: Arc<Mutex<JobManager>>,
Expand All @@ -96,7 +96,7 @@ pub struct CronManager {

impl CronManager {
pub async fn new(
db: Weak<RwLock<SqliteManager>>,
db: Weak<SqliteManager>,
identity_secret_key: SigningKey,
node_name: ShinkaiName,
job_manager: Arc<Mutex<JobManager>>,
Expand Down Expand Up @@ -162,7 +162,7 @@ impl CronManager {

#[allow(clippy::too_many_arguments)]
pub fn process_job_queue(
db: Weak<RwLock<SqliteManager>>,
db: Weak<SqliteManager>,
node_profile_name: ShinkaiName,
identity_sk: SigningKey,
cron_time_interval: u64,
Expand All @@ -173,7 +173,7 @@ impl CronManager {
ws_manager: Option<Arc<Mutex<dyn WSUpdateHandler + Send>>>,
job_processing_fn: impl Fn(
CronTask,
Weak<RwLock<SqliteManager>>,
Weak<SqliteManager>,
SigningKey,
Arc<Mutex<JobManager>>,
Arc<Mutex<IdentityManager>>,
Expand Down Expand Up @@ -210,8 +210,8 @@ impl CronManager {
return;
}
let db_arc = db_arc.unwrap();
let db = db_arc.read().await;
db.get_all_cron_tasks()
db_arc
.get_all_cron_tasks()
.unwrap_or_default()
.into_iter()
.map(|task| (task.created_at.clone(), vec![(task.task_id.to_string(), task)]))
Expand Down Expand Up @@ -293,7 +293,7 @@ impl CronManager {
#[allow(clippy::too_many_arguments)]
pub async fn process_job_message_queued(
cron_job: CronTask,
db: Weak<RwLock<SqliteManager>>,
db: Weak<SqliteManager>,
identity_secret_key: SigningKey,
job_manager: Arc<Mutex<JobManager>>,
identity_manager: Arc<Mutex<IdentityManager>>,
Expand All @@ -313,7 +313,6 @@ impl CronManager {
// Update the last executed time
{
let current_time = Utc::now().to_rfc3339();
let db = db.read().await;
db.update_cron_task_last_executed(cron_job.task_id.into(), &current_time)?;
}

Expand All @@ -332,7 +331,7 @@ impl CronManager {
.await?;

// Update the job configuration
db.write().await.update_job_config(&job_id, config)?;
db.update_job_config(&job_id, config)?;

// Use send_job_message_with_bearer instead of ShinkaiMessageBuilder
Self::send_job_message_with_bearer(
Expand Down Expand Up @@ -407,16 +406,16 @@ impl CronManager {
result
}

async fn log_success_to_sqlite(db: &Arc<RwLock<SqliteManager>>, task_id: i64) {
async fn log_success_to_sqlite(db: &Arc<SqliteManager>, task_id: i64) {
let execution_time = chrono::Utc::now().to_rfc3339();
let db = db.write().await;
let db = db;
if let Err(err) = db.add_cron_task_execution(task_id, &execution_time, true, None) {
eprintln!("Failed to log success to SQLite: {}", err);
}
}

async fn send_job_message_with_bearer(
db: Arc<RwLock<SqliteManager>>,
db: Arc<SqliteManager>,
node_name_clone: ShinkaiName,
identity_manager_clone: Arc<Mutex<IdentityManager>>,
job_manager_clone: Arc<Mutex<JobManager>>,
Expand All @@ -427,7 +426,7 @@ impl CronManager {
task_id: i64,
) -> Result<(), NodeError> {
// Retrieve the bearer token from the database
let bearer = match db.read().await.read_api_v2_key() {
let bearer = match db.read_api_v2_key() {
Ok(Some(token)) => token,
Ok(None) => {
Self::log_error_to_sqlite(&db, task_id, "Bearer token not found").await;
Expand Down Expand Up @@ -472,9 +471,9 @@ impl CronManager {
Ok(())
}

async fn log_error_to_sqlite(db: &Arc<RwLock<SqliteManager>>, task_id: i64, error_message: &str) {
async fn log_error_to_sqlite(db: &Arc<SqliteManager>, task_id: i64, error_message: &str) {
let execution_time = chrono::Utc::now().to_rfc3339();
let db = db.write().await;
let db = db;
if let Err(err) = db.add_cron_task_execution(task_id, &execution_time, false, Some(error_message)) {
eprintln!("Failed to log error to SQLite: {}", err);
}
Expand Down Expand Up @@ -512,7 +511,7 @@ impl CronManager {
#[cfg(test)]
mod tests {
use super::*;
use chrono::Utc;
use chrono::{Timelike, Utc};
use shinkai_message_primitives::schemas::crontab::CronTaskAction;

fn create_test_cron_task(cron: &str) -> CronTask {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ impl GenericInferenceChain {

#[allow(clippy::too_many_arguments)]
pub async fn start_chain(
db: Arc<RwLock<SqliteManager>>,
db: Arc<SqliteManager>,
vector_fs: Arc<VectorFS>,
full_job: Job,
user_message: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ use shinkai_sqlite::SqliteManager;
use shinkai_vector_fs::vector_fs::vector_fs::VectorFS;
use shinkai_vector_resources::embedding_generator::RemoteEmbeddingGenerator;
use std::{collections::HashMap, sync::Arc};
use tokio::sync::{Mutex, RwLock};
use tokio::sync::Mutex;

impl JobManager {
/// Chooses an inference chain based on the job message (using the agent's LLM)
/// and then starts using the chosen chain.
/// Returns the final String result from the inferencing, and a new execution context.
#[allow(clippy::too_many_arguments)]
pub async fn inference_chain_router(
db: Arc<RwLock<SqliteManager>>,
db: Arc<SqliteManager>,
vector_fs: Arc<VectorFS>,
llm_provider_found: Option<ProviderOrAgent>,
full_job: Job,
Expand All @@ -54,8 +54,6 @@ impl JobManager {
// If it's an agent, we need to get the LLM provider from the agent
let llm_id = llm_provider.get_llm_provider_id();
let llm_provider = db
.read()
.await
.get_llm_provider(llm_id, &user_profile)
.map_err(|e| e.to_string())?
.ok_or(LLMProviderError::LLMProviderNotFound)?;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ pub trait InferenceChainContextTrait: Send + Sync {
fn update_iteration_count(&mut self, new_iteration_count: u64);
fn update_message(&mut self, new_message: ParsedUserMessage);

fn db(&self) -> Arc<RwLock<SqliteManager>>;
fn db(&self) -> Arc<SqliteManager>;
fn vector_fs(&self) -> Arc<VectorFS>;
fn full_job(&self) -> &Job;
fn user_message(&self) -> &ParsedUserMessage;
Expand Down Expand Up @@ -106,7 +106,7 @@ impl InferenceChainContextTrait for InferenceChainContext {
self.user_message = new_message;
}

fn db(&self) -> Arc<RwLock<SqliteManager>> {
fn db(&self) -> Arc<SqliteManager> {
Arc::clone(&self.db)
}

Expand Down Expand Up @@ -199,7 +199,7 @@ impl InferenceChainContextTrait for InferenceChainContext {
/// using all fields in this struct, but they are available nonetheless.
#[derive(Clone)]
pub struct InferenceChainContext {
pub db: Arc<RwLock<SqliteManager>>,
pub db: Arc<SqliteManager>,
pub vector_fs: Arc<VectorFS>,
pub full_job: Job,
pub user_message: ParsedUserMessage,
Expand Down Expand Up @@ -227,7 +227,7 @@ pub struct InferenceChainContext {
impl InferenceChainContext {
#[allow(clippy::too_many_arguments)]
pub fn new(
db: Arc<RwLock<SqliteManager>>,
db: Arc<SqliteManager>,
vector_fs: Arc<VectorFS>,
full_job: Job,
user_message: ParsedUserMessage,
Expand Down Expand Up @@ -426,7 +426,7 @@ impl InferenceChainContextTrait for Box<dyn InferenceChainContextTrait> {
(**self).update_message(new_message)
}

fn db(&self) -> Arc<RwLock<SqliteManager>> {
fn db(&self) -> Arc<SqliteManager> {
(**self).db()
}

Expand Down Expand Up @@ -525,7 +525,7 @@ pub struct MockInferenceChainContext {
pub iteration_count: u64,
pub max_tokens_in_prompt: usize,
pub raw_files: RawFiles,
pub db: Option<Arc<RwLock<SqliteManager>>>,
pub db: Option<Arc<SqliteManager>>,
pub vector_fs: Option<Arc<VectorFS>>,
pub my_agent_payments_manager: Option<Arc<Mutex<MyAgentOfferingsManager>>>,
pub ext_agent_payments_manager: Option<Arc<Mutex<ExtAgentOfferingsManager>>>,
Expand All @@ -543,7 +543,7 @@ impl MockInferenceChainContext {
iteration_count: u64,
max_tokens_in_prompt: usize,
raw_files: Option<Arc<Vec<(String, Vec<u8>)>>>,
db: Option<Arc<RwLock<SqliteManager>>>,
db: Option<Arc<SqliteManager>>,
vector_fs: Option<Arc<VectorFS>>,
my_agent_payments_manager: Option<Arc<Mutex<MyAgentOfferingsManager>>>,
ext_agent_payments_manager: Option<Arc<Mutex<ExtAgentOfferingsManager>>>,
Expand Down Expand Up @@ -609,7 +609,7 @@ impl InferenceChainContextTrait for MockInferenceChainContext {
self.user_message = new_message;
}

fn db(&self) -> Arc<RwLock<SqliteManager>> {
fn db(&self) -> Arc<SqliteManager> {
self.db.clone().expect("DB is not set")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -613,11 +613,7 @@ mod tests {
};
use shinkai_sqlite::SqliteManager;
use shinkai_vector_resources::model_type::{EmbeddingModelType, OllamaTextEmbeddingsInference};
use std::{
fs,
path::{Path, PathBuf},
sync::Arc,
};
use std::{path::PathBuf, sync::Arc};
use tempfile::NamedTempFile;
use tokio::sync::{Mutex, RwLock};

Expand Down Expand Up @@ -657,7 +653,7 @@ mod tests {
#[tokio::test]
async fn test_set_column_with_mock_job_manager() {
let db = setup_test_db();
let db = Arc::new(RwLock::new(db));
let db = Arc::new(db);
let node_name = "@@test.arb-sep-shinkai".to_string();
let node_name = ShinkaiName::new(node_name).unwrap();
let ws_manager: Option<Arc<Mutex<dyn WSUpdateHandler + Send>>> = None;
Expand Down Expand Up @@ -738,7 +734,7 @@ mod tests {
#[tokio::test]
async fn test_update_column_with_values() {
let db = setup_test_db();
let db = Arc::new(RwLock::new(db));
let db = Arc::new(db);
let node_name = "@@test.arb-sep-shinkai".to_string();
let node_name = ShinkaiName::new(node_name).unwrap();
let ws_manager: Option<Arc<Mutex<dyn WSUpdateHandler + Send>>> = None;
Expand Down Expand Up @@ -803,7 +799,7 @@ mod tests {
#[tokio::test]
async fn test_replace_value_at_position() {
let db = setup_test_db();
let db = Arc::new(RwLock::new(db));
let db = Arc::new(db);
let node_name = "@@test.arb-sep-shinkai".to_string();
let node_name = ShinkaiName::new(node_name).unwrap();
let ws_manager: Option<Arc<Mutex<dyn WSUpdateHandler + Send>>> = None;
Expand Down Expand Up @@ -868,7 +864,7 @@ mod tests {
#[tokio::test]
async fn test_create_new_columns_with_csv() {
let db = setup_test_db();
let db = Arc::new(RwLock::new(db));
let db = Arc::new(db);
let node_name = "@@test.arb-sep-shinkai".to_string();
let node_name = ShinkaiName::new(node_name).unwrap();
let ws_manager: Option<Arc<Mutex<dyn WSUpdateHandler + Send>>> = None;
Expand Down Expand Up @@ -931,7 +927,7 @@ mod tests {
#[tokio::test]
async fn test_create_new_columns_with_large_csv() {
let db = setup_test_db();
let db = Arc::new(RwLock::new(db));
let db = Arc::new(db);
let node_name = "@@test.arb-sep-shinkai".to_string();
let node_name = ShinkaiName::new(node_name).unwrap();
let ws_manager: Option<Arc<Mutex<dyn WSUpdateHandler + Send>>> = None;
Expand Down Expand Up @@ -1184,7 +1180,7 @@ mod tests {
#[tokio::test]
async fn test_create_new_columns_with_semicolon_csv() {
let db = setup_test_db();
let db = Arc::new(RwLock::new(db));
let db = Arc::new(db);
let node_name = "@@test.arb-sep-shinkai".to_string();
let node_name = ShinkaiName::new(node_name).unwrap();
let ws_manager: Option<Arc<Mutex<dyn WSUpdateHandler + Send>>> = None;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ impl SheetUIInferenceChain {
// the tool code handling in the future so we can reuse the code
#[allow(clippy::too_many_arguments)]
pub async fn start_chain(
db: Arc<RwLock<SqliteManager>>,
db: Arc<SqliteManager>,
vector_fs: Arc<VectorFS>,
full_job: Job,
user_message: String,
Expand Down
Loading

0 comments on commit 80418c3

Please sign in to comment.