Skip to content

Commit

Permalink
add error for ollama models eg not supporting tools
Browse files Browse the repository at this point in the history
  • Loading branch information
nicarq committed Jan 9, 2025
1 parent 5387089 commit deb00a2
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
3 changes: 3 additions & 0 deletions shinkai-bin/shinkai-node/src/llm_provider/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ pub enum LLMProviderError {
AgentNotFound(String),
MessageTooLargeForLLM { max_tokens: usize, used_tokens: usize },
SomeError(String),
APIError(String),
}

impl fmt::Display for LLMProviderError {
Expand Down Expand Up @@ -176,6 +177,7 @@ impl fmt::Display for LLMProviderError {
write!(f, "Message too large for LLM: Used {} tokens, but the maximum allowed is {}.", used_tokens, max_tokens)
},
LLMProviderError::SomeError(s) => write!(f, "{}", s),
LLMProviderError::APIError(s) => write!(f, "{}", s),
}
}
}
Expand Down Expand Up @@ -256,6 +258,7 @@ impl LLMProviderError {
LLMProviderError::AgentNotFound(_) => "AgentNotFound",
LLMProviderError::MessageTooLargeForLLM { .. } => "MessageTooLargeForLLM",
LLMProviderError::SomeError(_) => "SomeError",
LLMProviderError::APIError(_) => "APIError",
};

format!("Error {} with message: {}", error_name, self)
Expand Down
26 changes: 26 additions & 0 deletions shinkai-bin/shinkai-node/src/llm_provider/providers/ollama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,19 @@ async fn process_stream(
if !previous_json_chunk.is_empty() {
chunk_str = previous_json_chunk.clone() + chunk_str.as_str();
}

// First check if it's an error response
if let Ok(error_response) = serde_json::from_str::<serde_json::Value>(&chunk_str) {
if let Some(error_msg) = error_response.get("error").and_then(|e| e.as_str()) {
shinkai_log(
ShinkaiLogOption::JobExecution,
ShinkaiLogLevel::Error,
format!("Ollama API Error: {}", error_msg).as_str(),
);
return Err(LLMProviderError::APIError(error_msg.to_string()));
}
}

let data_resp: Result<OllamaAPIStreamingResponse, _> = serde_json::from_str(&chunk_str);
match data_resp {
Ok(data) => {
Expand Down Expand Up @@ -484,6 +497,19 @@ async fn handle_non_streaming_response(
result = &mut response_future => {
let res = result?;
let response_body = res.text().await?;

// First check if it's an error response
if let Ok(error_response) = serde_json::from_str::<serde_json::Value>(&response_body) {
if let Some(error_msg) = error_response.get("error").and_then(|e| e.as_str()) {
shinkai_log(
ShinkaiLogOption::JobExecution,
ShinkaiLogLevel::Error,
format!("Ollama API Error: {}", error_msg).as_str(),
);
return Err(LLMProviderError::APIError(error_msg.to_string()));
}
}

let response_json: serde_json::Value = serde_json::from_str(&response_body)?;

if let Some(message) = response_json.get("message") {
Expand Down

0 comments on commit deb00a2

Please sign in to comment.