From e35e07446a074468a73279fc379fd2c17d9ba767 Mon Sep 17 00:00:00 2001 From: Guillermo Valin Date: Mon, 30 Dec 2024 17:47:15 -0300 Subject: [PATCH] Refactor tool router key serialization and deserialization --- .../src/schemas/tool_router_key.rs | 45 +++++++++++++++++++ .../src/tools/deno_tools.rs | 21 +-------- .../src/tools/python_tools.rs | 24 ++-------- 3 files changed, 50 insertions(+), 40 deletions(-) diff --git a/shinkai-libs/shinkai-message-primitives/src/schemas/tool_router_key.rs b/shinkai-libs/shinkai-message-primitives/src/schemas/tool_router_key.rs index 347864c46..21c264900 100644 --- a/shinkai-libs/shinkai-message-primitives/src/schemas/tool_router_key.rs +++ b/shinkai-libs/shinkai-message-primitives/src/schemas/tool_router_key.rs @@ -21,6 +21,51 @@ impl ToolRouterKey { } } + pub fn deserialize_tool_router_keys<'de, D>(deserializer: D) -> Result>, D::Error> + where + D: serde::Deserializer<'de>, + { + let string_vec: Option> = Option::deserialize(deserializer)?; + + match string_vec { + Some(vec) => { + let router_keys = vec + .into_iter() + .map(|s| Self::from_string(&s)) + .collect::, _>>() + .map_err(serde::de::Error::custom)?; + Ok(Some(router_keys)) + } + None => Ok(None), + } + } + + pub fn serialize_tool_router_keys( + keys: &Option>, + serializer: S + ) -> Result + where + S: serde::Serializer, + { + match keys { + Some(keys) => { + let strings: Vec = keys + .iter() + .map(|k| { + // If version is Some, use to_string_with_version() + if k.version.is_some() { + k.to_string_with_version() + } else { + k.to_string_without_version() + } + }) + .collect(); + strings.serialize(serializer) + } + None => serializer.serialize_none(), + } + } + fn sanitize(input: &str) -> String { input.chars() .map(|c| if c.is_ascii_alphanumeric() || c == '_' { c } else { '_' }) diff --git a/shinkai-libs/shinkai-tools-primitives/src/tools/deno_tools.rs b/shinkai-libs/shinkai-tools-primitives/src/tools/deno_tools.rs index 9af694fd8..d27a8894a 100644 --- a/shinkai-libs/shinkai-tools-primitives/src/tools/deno_tools.rs +++ b/shinkai-libs/shinkai-tools-primitives/src/tools/deno_tools.rs @@ -31,7 +31,8 @@ pub struct DenoTool { pub version: String, pub js_code: String, #[serde(default)] - #[serde(deserialize_with = "deserialize_tool_router_keys")] + #[serde(deserialize_with = "ToolRouterKey::deserialize_tool_router_keys")] + #[serde(serialize_with = "ToolRouterKey::serialize_tool_router_keys")] pub tools: Option>, pub config: Vec, pub description: String, @@ -48,24 +49,6 @@ pub struct DenoTool { pub assets: Option>, } -fn deserialize_tool_router_keys<'de, D>(deserializer: D) -> Result>, D::Error> -where - D: Deserializer<'de>, -{ - let string_vec: Option> = Option::deserialize(deserializer)?; - - match string_vec { - Some(vec) => { - let router_keys = vec - .into_iter() - .filter_map(|s| ToolRouterKey::from_string(&s).ok()) - .collect(); - Ok(Some(router_keys)) - } - None => Ok(None), - } -} - impl DenoTool { /// Default name of the rust toolkit pub fn toolkit_name(&self) -> String { diff --git a/shinkai-libs/shinkai-tools-primitives/src/tools/python_tools.rs b/shinkai-libs/shinkai-tools-primitives/src/tools/python_tools.rs index 04ec4feb5..d6e5d889c 100644 --- a/shinkai-libs/shinkai-tools-primitives/src/tools/python_tools.rs +++ b/shinkai-libs/shinkai-tools-primitives/src/tools/python_tools.rs @@ -4,7 +4,6 @@ use std::time::{SystemTime, UNIX_EPOCH}; use std::{env, thread}; use crate::tools::error::ToolError; -use serde::Deserialize; use shinkai_message_primitives::schemas::shinkai_name::ShinkaiName; use shinkai_message_primitives::schemas::tool_router_key::ToolRouterKey; use shinkai_tools_runner::tools::code_files::CodeFiles; @@ -30,7 +29,9 @@ pub struct PythonTool { pub name: String, pub author: String, pub py_code: String, - #[serde(deserialize_with = "deserialize_tool_router_keys")] + #[serde(default)] + #[serde(deserialize_with = "ToolRouterKey::deserialize_tool_router_keys")] + #[serde(serialize_with = "ToolRouterKey::serialize_tool_router_keys")] pub tools: Option>, pub config: Vec, pub description: String, @@ -47,25 +48,6 @@ pub struct PythonTool { pub assets: Option>, } -fn deserialize_tool_router_keys<'de, D>(deserializer: D) -> Result>, D::Error> -where - D: serde::Deserializer<'de>, -{ - let string_vec: Option> = Option::deserialize(deserializer)?; - - match string_vec { - Some(vec) => { - let router_keys = vec - .into_iter() - .map(|s| ToolRouterKey::from_string(&s)) - .collect::, _>>() - .map_err(serde::de::Error::custom)?; - Ok(Some(router_keys)) - } - None => Ok(None), - } -} - impl PythonTool { /// Default name of the rust toolkit pub fn toolkit_name(&self) -> String {