Skip to content

Commit

Permalink
fix: prevent deadlocks in rust node manager implementation
Browse files Browse the repository at this point in the history
Co-Authored-By: Nicolas Arqueros <[email protected]>
  • Loading branch information
devin-ai-integration[bot] and nicarq committed Dec 19, 2024
1 parent ce92e41 commit 3a95cae
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ impl ProcessHandler {
const MAX_LOGS_LENGTH: usize = 500;
const MIN_MS_ALIVE: u64 = 5000;

/// Initializes a new ShinkaiNodeManager with default or provided options
pub(crate) fn new(
app: AppHandle,
process_name: String,
Expand Down Expand Up @@ -93,52 +92,13 @@ impl ProcessHandler {
process.is_some()
}

pub async fn spawn(
&self,
env: HashMap<String, String>,
args: Vec<&str>,
current_dir: Option<PathBuf>,
) -> Result<(), String> {
{
let process = self.process.lock().await;
if process.is_some() {
log::warn!("process {} is already running", self.process_name);
return Ok(());
}
}

let mut logger = self.logger.write().await;
let shell = self.app.shell();
let (mut rx, child) = shell
.sidecar(self.process_name.clone())
.map_err(|error| {
let message = format!("failed to spawn, error: {}", error);
logger.add_log(message.clone());
message
})?
.envs(env.clone())
.current_dir(current_dir.unwrap_or_else(|| std::path::PathBuf::from("./")))
.args(args)
.spawn()
.map_err(|error| {
let message = format!("failed to spawn error: {}", error);
logger.add_log(message.clone());
message
})?;
drop(logger);

{
let mut process = self.process.lock().await;
*process = Some(child);
}

async fn handle_process_events(&self, mut rx: tokio::sync::mpsc::Receiver<CommandEvent>) {
let process_mutex = Arc::clone(&self.process);
let logger_mutex = Arc::clone(&self.logger);
let event_sender_mutex = Arc::clone(&self.event_sender);
let is_ready_mutex = Arc::new(Mutex::new(false));
let is_ready_mutex_clone = is_ready_mutex.clone();

let ready_matcher = self.ready_matcher.clone();

tauri::async_runtime::spawn(async move {
while let Some(event) = rx.recv().await {
let message = Self::command_event_to_message(event.clone());
Expand All @@ -157,37 +117,52 @@ impl ProcessHandler {
}
}
});
}

let start_time = std::time::Instant::now();
let logger_mutex = self.logger.clone();
let process_mutex = self.process.clone();
let event_sender_mutex = Arc::clone(&self.event_sender);
tauri::async_runtime::spawn(async move {
while std::time::Instant::now().duration_since(start_time)
< std::time::Duration::from_millis(Self::MIN_MS_ALIVE)
{
let process = process_mutex.lock().await;
let is_ready = is_ready_mutex_clone.lock().await;
if process.is_none() {
let event_sender = event_sender_mutex.lock().await;
let mut logger = logger_mutex.write().await;
let message = "failed to spawn shinkai-node, it crashed before min time alive"
.to_string();
let log_entry = logger.add_log(message.clone());
let _ = event_sender.send(ProcessHandlerEvent::Log(log_entry)).await;
return Err(message.to_string());
} else if *is_ready {
break;
}
std::thread::sleep(std::time::Duration::from_millis(500));
pub async fn spawn(
&self,
env: HashMap<String, String>,
args: Vec<&str>,
current_dir: Option<PathBuf>,
) -> Result<(), String> {
{
let process = self.process.lock().await;
if process.is_some() {
log::warn!("process {} is already running", self.process_name);
return Ok(());
}
Ok(())
})
.await
.unwrap()?;
}

self.emit_event(ProcessHandlerEvent::Started).await;
let child = {
let mut logger = self.logger.write().await;
let shell = self.app.shell();
let (rx, child) = shell
.sidecar(self.process_name.clone())
.map_err(|error| {
let message = format!("failed to spawn, error: {}", error);
logger.add_log(message.clone());
message
})?
.envs(env.clone())
.current_dir(current_dir.unwrap_or_else(|| std::path::PathBuf::from("./")))
.args(args)
.spawn()
.map_err(|error| {
let message = format!("failed to spawn error: {}", error);
logger.add_log(message.clone());
message
})?;

self.handle_process_events(rx);
child
};

{
let mut process = self.process.lock().await;
*process = Some(child);
}

self.emit_event(ProcessHandlerEvent::Started).await;
Ok(())
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,33 @@ impl ShinkaiNodeProcessHandler {
let _ = self.kill().await;

let env = options_to_env(&self.options.clone());
self.process_handler.spawn(env, [].to_vec(), None).await?;
if let Err(e) = self.wait_shinkai_node_server().await {
self.process_handler.kill().await;
return Err(e);

// Add timeout for spawn operation
let spawn_result = tokio::time::timeout(
Duration::from_secs(30),
self.process_handler.spawn(env, [].to_vec(), None)
).await;

match spawn_result {
Ok(Ok(_)) => {
match tokio::time::timeout(
Duration::from_millis(Self::HEALTH_TIMEOUT_MS),
self.wait_shinkai_node_server()
).await {
Ok(Ok(_)) => Ok(()),
Ok(Err(e)) => {
self.process_handler.kill().await;
Err(e)
},
Err(_) => {
self.process_handler.kill().await;
Err("Health check timeout".to_string())
}
}
},
Ok(Err(e)) => Err(e),
Err(_) => Err("Spawn timeout".to_string())
}
Ok(())
}

pub async fn get_last_n_logs(&self, n: usize) -> Vec<LogEntry> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,28 @@ impl ShinkaiNodeManager {
}

pub async fn spawn(&mut self) -> Result<(), String> {
// Add cleanup in case of previous failed state
self.kill().await;

// Start Ollama with timeout
self.emit_event(ShinkaiNodeManagerEvent::StartingOllama);
match self.ollama_process.spawn(None).await {
Ok(_) => {
match tokio::time::timeout(
Duration::from_secs(30),
self.ollama_process.spawn(None)
).await {
Ok(Ok(_)) => {
self.emit_event(ShinkaiNodeManagerEvent::OllamaStarted);
}
Err(e) => {
},
Ok(Err(e)) => {
self.kill().await;
self.emit_event(ShinkaiNodeManagerEvent::OllamaStartError { error: e.clone() });
return Err(e);
},
Err(_) => {
self.kill().await;
let error = "Ollama start timeout".to_string();
self.emit_event(ShinkaiNodeManagerEvent::OllamaStartError { error: error.clone() });
return Err(error);
}
}

Expand Down Expand Up @@ -223,19 +236,26 @@ impl ShinkaiNodeManager {
}

self.emit_event(ShinkaiNodeManagerEvent::StartingShinkaiNode);
match self.shinkai_node_process.spawn().await {
Ok(_) => {
match tokio::time::timeout(
Duration::from_secs(30),
self.shinkai_node_process.spawn()
).await {
Ok(Ok(_)) => {
self.emit_event(ShinkaiNodeManagerEvent::ShinkaiNodeStarted);
}
Err(e) => {
Ok(())
},
Ok(Err(e)) => {
self.kill().await;
self.emit_event(ShinkaiNodeManagerEvent::ShinkaiNodeStartError {
error: e.clone(),
});
return Err(e);
self.emit_event(ShinkaiNodeManagerEvent::ShinkaiNodeStartError { error: e.clone() });
Err(e)
},
Err(_) => {
self.kill().await;
let error = "Shinkai node start timeout".to_string();
self.emit_event(ShinkaiNodeManagerEvent::ShinkaiNodeStartError { error: error.clone() });
Err(error)
}
}
Ok(())
}

pub async fn kill(&mut self) {
Expand Down

0 comments on commit 3a95cae

Please sign in to comment.