diff --git a/Cargo.lock b/Cargo.lock index 633b67a9e..71649811e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "Inflector" @@ -4542,6 +4542,16 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "os_path" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "360a6ecb129f544ba5ae18776ca8779cf3cf979c8133e9eefe9464ea74741f6b" +dependencies = [ + "regex", + "serde", +] + [[package]] name = "os_str_bytes" version = "6.6.1" @@ -6636,6 +6646,7 @@ dependencies = [ "futures", "keyphrases", "lazy_static", + "os_path", "rand 0.8.5", "regex", "reqwest 0.11.27", @@ -6705,6 +6716,7 @@ dependencies = [ "chrono", "ed25519-dalek", "hex", + "os_path", "rand 0.8.5", "regex", "rust_decimal", diff --git a/shinkai-libs/shinkai-fs/Cargo.toml b/shinkai-libs/shinkai-fs/Cargo.toml index 498985927..016e5258c 100644 --- a/shinkai-libs/shinkai-fs/Cargo.toml +++ b/shinkai-libs/shinkai-fs/Cargo.toml @@ -29,6 +29,8 @@ csv = "1.1.6" utoipa = "4.2.3" regex = { workspace = true } +os_path = "0.8.0" + [dependencies.serde] workspace = true features = ["derive"] diff --git a/shinkai-libs/shinkai-fs/src/shinkai_file_manager.rs b/shinkai-libs/shinkai-fs/src/shinkai_file_manager.rs index 621b4fa26..56e041093 100644 --- a/shinkai-libs/shinkai-fs/src/shinkai_file_manager.rs +++ b/shinkai-libs/shinkai-fs/src/shinkai_file_manager.rs @@ -317,9 +317,10 @@ mod tests { } fn create_test_parsed_file(id: i64, relative_path: &str) -> ParsedFile { + let pf_relative_path = SqliteManager::normalize_path(relative_path); ParsedFile { id: Some(id), - relative_path: relative_path.to_string(), + relative_path: pf_relative_path.to_string(), original_extension: None, description: None, source: None, @@ -713,29 +714,34 @@ mod tests { let level2_contents = level1_info.children.as_ref().unwrap(); assert_eq!(level2_contents.len(), 2); // One directory and one file - let file1_info = level2_contents.iter().find(|info| info.path == "level1/file1.txt").unwrap(); + let file1_path = os_path::OsPath::from("level1/file1.txt").to_string(); + let file1_info = level2_contents.iter().find(|info| info.path == file1_path).unwrap(); assert!(!file1_info.is_directory); assert!(file1_info.has_embeddings, "File 'level1/file1.txt' should have embeddings."); - let level2_info = level2_contents.iter().find(|info| info.path == "level1/level2").unwrap(); + let level2_path = os_path::OsPath::from("level1/level2").to_string(); + let level2_info = level2_contents.iter().find(|info| info.path == level2_path).unwrap(); assert!(level2_info.is_directory); assert!(level2_info.children.is_some()); let level3_contents = level2_info.children.as_ref().unwrap(); assert_eq!(level3_contents.len(), 2); // One directory and one file - let file2_info = level3_contents.iter().find(|info| info.path == "level1/level2/file2.txt").unwrap(); + let file2_path = os_path::OsPath::from("level1/level2/file2.txt").to_string(); + let file2_info = level3_contents.iter().find(|info| info.path == file2_path).unwrap(); assert!(!file2_info.is_directory); assert!(file2_info.has_embeddings, "File 'level1/level2/file2.txt' should have embeddings."); - let level3_info = level3_contents.iter().find(|info| info.path == "level1/level2/level3").unwrap(); + let level3_path = os_path::OsPath::from("level1/level2/level3").to_string(); + let level3_info = level3_contents.iter().find(|info| info.path == level3_path).unwrap(); assert!(level3_info.is_directory); assert!(level3_info.children.is_some()); let level3_files = level3_info.children.as_ref().unwrap(); assert_eq!(level3_files.len(), 1); // Only one file - let file3_info = level3_files.iter().find(|info| info.path == "level1/level2/level3/file3.txt").unwrap(); + let file3_path = os_path::OsPath::from("level1/level2/level3/file3.txt").to_string(); + let file3_info = level3_files.iter().find(|info| info.path == file3_path).unwrap(); assert!(!file3_info.is_directory); assert!(!file3_info.has_embeddings, "File 'level1/level2/level3/file3.txt' should not have embeddings."); } diff --git a/shinkai-libs/shinkai-message-primitives/Cargo.toml b/shinkai-libs/shinkai-message-primitives/Cargo.toml index 5c06caaa3..3ca9d74e0 100644 --- a/shinkai-libs/shinkai-message-primitives/Cargo.toml +++ b/shinkai-libs/shinkai-message-primitives/Cargo.toml @@ -29,6 +29,8 @@ tracing = { version = "0.1.40", optional = true } tracing-subscriber = { version = "0.3", optional = true } +os_path = { version = "0.8.0" } + [lib] crate-type = ["cdylib", "rlib"] diff --git a/shinkai-libs/shinkai-message-primitives/src/shinkai_utils/job_scope.rs b/shinkai-libs/shinkai-message-primitives/src/shinkai_utils/job_scope.rs index 94a9e2aea..4ffe965ed 100644 --- a/shinkai-libs/shinkai-message-primitives/src/shinkai_utils/job_scope.rs +++ b/shinkai-libs/shinkai-message-primitives/src/shinkai_utils/job_scope.rs @@ -75,8 +75,8 @@ mod tests { let deserialized: MinimalJobScope = serde_json::from_value(json_data).expect("Failed to deserialize"); assert_eq!(deserialized.vector_fs_items.len(), 2); - assert_eq!(deserialized.vector_fs_items[0].relative_path(), "path/to/file1"); - assert_eq!(deserialized.vector_fs_items[1].relative_path(), "path/to/file2"); + assert_eq!(deserialized.vector_fs_items[0].relative_path(), os_path::OsPath::from("path/to/file1").to_string()); + assert_eq!(deserialized.vector_fs_items[1].relative_path(), os_path::OsPath::from("path/to/file2").to_string()); assert_eq!(deserialized.vector_fs_folders.len(), 1); assert_eq!(deserialized.vector_fs_folders[0].relative_path(), "My Files (Private)"); assert_eq!(deserialized.vector_search_mode, VectorSearchMode::FillUpTo25k); @@ -93,9 +93,9 @@ mod tests { let deserialized: MinimalJobScope = serde_json::from_value(json_data).expect("Failed to deserialize"); assert_eq!(deserialized.vector_fs_items.len(), 1); - assert_eq!(deserialized.vector_fs_items[0].relative_path(), "path/to/file1"); + assert_eq!(deserialized.vector_fs_items[0].relative_path(), os_path::OsPath::from("path/to/file1").to_string()); assert_eq!(deserialized.vector_fs_folders.len(), 1); - assert_eq!(deserialized.vector_fs_folders[0].relative_path(), "My Files (Private)"); + assert_eq!(deserialized.vector_fs_folders[0].relative_path(), os_path::OsPath::from("My Files (Private)").to_string()); assert_eq!(deserialized.vector_search_mode, VectorSearchMode::FillUpTo25k); // Check default } } diff --git a/shinkai-libs/shinkai-message-primitives/src/shinkai_utils/shinkai_path.rs b/shinkai-libs/shinkai-message-primitives/src/shinkai_utils/shinkai_path.rs index fdf29932c..f6e0afc53 100644 --- a/shinkai-libs/shinkai-message-primitives/src/shinkai_utils/shinkai_path.rs +++ b/shinkai-libs/shinkai-message-primitives/src/shinkai_utils/shinkai_path.rs @@ -15,7 +15,7 @@ impl ShinkaiPath { /// Private helper method to create a ShinkaiPath from a &str. pub fn new(path: &str) -> Self { let base_path = Self::base_path(); - let path_buf = PathBuf::from(path); + let path_buf = os_path::OsPath::from(path).to_pathbuf(); // PathBuf::from(path); let final_path = if path_buf.is_absolute() { if path_buf.starts_with(&base_path) { @@ -227,7 +227,7 @@ mod tests { env::var("NODE_STORAGE_PATH").unwrap() )) ); - assert_eq!(path.relative_path(), "word_files/christmas.docx"); + assert_eq!(path.relative_path(), os_path::OsPath::from("word_files/christmas.docx").to_string()); } #[test] @@ -240,7 +240,7 @@ mod tests { path.as_path(), Path::new("storage/filesystem/word_files/christmas.docx") ); - assert_eq!(path.relative_path(), "word_files/christmas.docx"); + assert_eq!(path.relative_path(), os_path::OsPath::from("word_files/christmas.docx").to_string()); } #[test] @@ -248,7 +248,7 @@ mod tests { fn test_relative_path_outside_base() { let _dir = testing_create_tempdir_and_set_env_var(); let absolute_outside = ShinkaiPath::from_string("/some/other/path".to_string()); - assert_eq!(absolute_outside.relative_path(), "some/other/path"); + assert_eq!(absolute_outside.relative_path(), os_path::OsPath::from("some/other/path").to_string()); } #[test] @@ -349,6 +349,8 @@ mod tests { let serialized_path = serde_json::to_string(&path).unwrap(); // Check if the serialized output matches the expected relative path - assert_eq!(serialized_path, "\"word_files/christmas.docx\""); + let serialized_path_str = serde_json::to_string(&os_path::OsPath::from("word_files/christmas.docx").to_string()).unwrap(); + + assert_eq!(serialized_path, serialized_path_str); } } diff --git a/shinkai-libs/shinkai-sqlite/src/file_system.rs b/shinkai-libs/shinkai-sqlite/src/file_system.rs index 22953bbc9..79967851b 100644 --- a/shinkai-libs/shinkai-sqlite/src/file_system.rs +++ b/shinkai-libs/shinkai-sqlite/src/file_system.rs @@ -6,6 +6,13 @@ use shinkai_message_primitives::{ }; impl SqliteManager { + // TODO: This is a temporary workaround for Windows paths. We should handle this more robustly. + pub fn normalize_path(path: &str) -> String { + let mut path = path.replace("\\\\", "/"); + path = path.replace("\\", "/"); + path + } + pub fn initialize_filesystem_tables(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> { // parsed_files table conn.execute( @@ -70,21 +77,24 @@ impl SqliteManager { let mut conn = self.get_connection()?; let tx = conn.transaction()?; + let pf_relative_path = Self::normalize_path(&pf.relative_path); + let exists: bool = tx.query_row( "SELECT EXISTS(SELECT 1 FROM parsed_files WHERE relative_path = ?)", - [&pf.relative_path], + [&pf_relative_path], |row| row.get(0), )?; if exists { return Err(SqliteManagerError::DataAlreadyExists); } + let relative_path = Self::normalize_path(&pf.relative_path); tx.execute( "INSERT INTO parsed_files (relative_path, original_extension, description, source, embedding_model_used, keywords, distribution_info, created_time, tags, total_tokens, total_characters) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)", params![ - pf.relative_path, + relative_path, pf.original_extension, pf.description, pf.source, @@ -104,6 +114,9 @@ impl SqliteManager { pub fn get_parsed_file_by_rel_path(&self, rel_path: &str) -> Result, SqliteManagerError> { let conn = self.get_connection()?; + + let rel_path = Self::normalize_path(rel_path); + let mut stmt = conn.prepare( " SELECT id, relative_path, original_extension, description, source, embedding_model_used, keywords, @@ -150,13 +163,14 @@ impl SqliteManager { return Err(SqliteManagerError::DataNotFound); } + let relative_path = Self::normalize_path(&pf.relative_path); tx.execute( "UPDATE parsed_files SET relative_path = ?1, original_extension = ?2, description = ?3, source = ?4, embedding_model_used = ?5, keywords = ?6, distribution_info = ?7, created_time = ?8, tags = ?9, total_tokens = ?10, total_characters = ?11 WHERE id = ?12", params![ - pf.relative_path, + relative_path, pf.original_extension, pf.description, pf.source, @@ -448,6 +462,9 @@ impl SqliteManager { // ------------------------- pub fn update_folder_paths(&self, old_prefix: &str, new_prefix: &str) -> Result<(), SqliteManagerError> { + let old_prefix = Self::normalize_path(old_prefix); + let new_prefix = Self::normalize_path(new_prefix); + let mut conn = self.get_connection()?; let tx = conn.transaction()?; @@ -469,6 +486,7 @@ impl SqliteManager { &self, directory_path: &str, ) -> Result, SqliteManagerError> { + let directory_path = Self::normalize_path(directory_path); let conn = self.get_connection()?; let mut stmt = conn.prepare( "SELECT id, relative_path, original_extension, description, source, embedding_model_used, keywords, @@ -547,6 +565,7 @@ impl SqliteManager { /// Retrieve all parsed files whose relative paths start with the given prefix. pub fn get_parsed_files_by_prefix(&self, prefix: &str) -> Result, SqliteManagerError> { + let prefix = Self::normalize_path(prefix); let conn = self.get_connection()?; let mut stmt = conn.prepare( "SELECT id, relative_path, original_extension, description, source, embedding_model_used, keywords,