From deb00a2572dabc39acdfba16edb490c14858ad57 Mon Sep 17 00:00:00 2001 From: Nico Arqueros Date: Thu, 9 Jan 2025 14:08:11 -0600 Subject: [PATCH] add error for ollama models eg not supporting tools --- .../shinkai-node/src/llm_provider/error.rs | 3 +++ .../src/llm_provider/providers/ollama.rs | 26 +++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/shinkai-bin/shinkai-node/src/llm_provider/error.rs b/shinkai-bin/shinkai-node/src/llm_provider/error.rs index 5456b4bcf..443febf95 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/error.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/error.rs @@ -84,6 +84,7 @@ pub enum LLMProviderError { AgentNotFound(String), MessageTooLargeForLLM { max_tokens: usize, used_tokens: usize }, SomeError(String), + APIError(String), } impl fmt::Display for LLMProviderError { @@ -176,6 +177,7 @@ impl fmt::Display for LLMProviderError { write!(f, "Message too large for LLM: Used {} tokens, but the maximum allowed is {}.", used_tokens, max_tokens) }, LLMProviderError::SomeError(s) => write!(f, "{}", s), + LLMProviderError::APIError(s) => write!(f, "{}", s), } } } @@ -256,6 +258,7 @@ impl LLMProviderError { LLMProviderError::AgentNotFound(_) => "AgentNotFound", LLMProviderError::MessageTooLargeForLLM { .. } => "MessageTooLargeForLLM", LLMProviderError::SomeError(_) => "SomeError", + LLMProviderError::APIError(_) => "APIError", }; format!("Error {} with message: {}", error_name, self) diff --git a/shinkai-bin/shinkai-node/src/llm_provider/providers/ollama.rs b/shinkai-bin/shinkai-node/src/llm_provider/providers/ollama.rs index c1fe35a68..64813bfd5 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/providers/ollama.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/providers/ollama.rs @@ -290,6 +290,19 @@ async fn process_stream( if !previous_json_chunk.is_empty() { chunk_str = previous_json_chunk.clone() + chunk_str.as_str(); } + + // First check if it's an error response + if let Ok(error_response) = serde_json::from_str::(&chunk_str) { + if let Some(error_msg) = error_response.get("error").and_then(|e| e.as_str()) { + shinkai_log( + ShinkaiLogOption::JobExecution, + ShinkaiLogLevel::Error, + format!("Ollama API Error: {}", error_msg).as_str(), + ); + return Err(LLMProviderError::APIError(error_msg.to_string())); + } + } + let data_resp: Result = serde_json::from_str(&chunk_str); match data_resp { Ok(data) => { @@ -484,6 +497,19 @@ async fn handle_non_streaming_response( result = &mut response_future => { let res = result?; let response_body = res.text().await?; + + // First check if it's an error response + if let Ok(error_response) = serde_json::from_str::(&response_body) { + if let Some(error_msg) = error_response.get("error").and_then(|e| e.as_str()) { + shinkai_log( + ShinkaiLogOption::JobExecution, + ShinkaiLogLevel::Error, + format!("Ollama API Error: {}", error_msg).as_str(), + ); + return Err(LLMProviderError::APIError(error_msg.to_string())); + } + } + let response_json: serde_json::Value = serde_json::from_str(&response_body)?; if let Some(message) = response_json.get("message") {