Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: os path for multiplatform compatibility #789

Merged
merged 3 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions shinkai-libs/shinkai-fs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ csv = "1.1.6"
utoipa = "4.2.3"
regex = { workspace = true }

os_path = "0.8.0"

[dependencies.serde]
workspace = true
features = ["derive"]
Expand Down
18 changes: 12 additions & 6 deletions shinkai-libs/shinkai-fs/src/shinkai_file_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,10 @@ mod tests {
}

fn create_test_parsed_file(id: i64, relative_path: &str) -> ParsedFile {
let pf_relative_path = SqliteManager::normalize_path(relative_path);
ParsedFile {
id: Some(id),
relative_path: relative_path.to_string(),
relative_path: pf_relative_path.to_string(),
original_extension: None,
description: None,
source: None,
Expand Down Expand Up @@ -713,29 +714,34 @@ mod tests {
let level2_contents = level1_info.children.as_ref().unwrap();
assert_eq!(level2_contents.len(), 2); // One directory and one file

let file1_info = level2_contents.iter().find(|info| info.path == "level1/file1.txt").unwrap();
let file1_path = os_path::OsPath::from("level1/file1.txt").to_string();
let file1_info = level2_contents.iter().find(|info| info.path == file1_path).unwrap();
assert!(!file1_info.is_directory);
assert!(file1_info.has_embeddings, "File 'level1/file1.txt' should have embeddings.");

let level2_info = level2_contents.iter().find(|info| info.path == "level1/level2").unwrap();
let level2_path = os_path::OsPath::from("level1/level2").to_string();
let level2_info = level2_contents.iter().find(|info| info.path == level2_path).unwrap();
assert!(level2_info.is_directory);
assert!(level2_info.children.is_some());

let level3_contents = level2_info.children.as_ref().unwrap();
assert_eq!(level3_contents.len(), 2); // One directory and one file

let file2_info = level3_contents.iter().find(|info| info.path == "level1/level2/file2.txt").unwrap();
let file2_path = os_path::OsPath::from("level1/level2/file2.txt").to_string();
let file2_info = level3_contents.iter().find(|info| info.path == file2_path).unwrap();
assert!(!file2_info.is_directory);
assert!(file2_info.has_embeddings, "File 'level1/level2/file2.txt' should have embeddings.");

let level3_info = level3_contents.iter().find(|info| info.path == "level1/level2/level3").unwrap();
let level3_path = os_path::OsPath::from("level1/level2/level3").to_string();
let level3_info = level3_contents.iter().find(|info| info.path == level3_path).unwrap();
assert!(level3_info.is_directory);
assert!(level3_info.children.is_some());

let level3_files = level3_info.children.as_ref().unwrap();
assert_eq!(level3_files.len(), 1); // Only one file

let file3_info = level3_files.iter().find(|info| info.path == "level1/level2/level3/file3.txt").unwrap();
let file3_path = os_path::OsPath::from("level1/level2/level3/file3.txt").to_string();
let file3_info = level3_files.iter().find(|info| info.path == file3_path).unwrap();
assert!(!file3_info.is_directory);
assert!(!file3_info.has_embeddings, "File 'level1/level2/level3/file3.txt' should not have embeddings.");
}
Expand Down
2 changes: 2 additions & 0 deletions shinkai-libs/shinkai-message-primitives/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ tracing = { version = "0.1.40", optional = true }

tracing-subscriber = { version = "0.3", optional = true }

os_path = { version = "0.8.0" }

[lib]
crate-type = ["cdylib", "rlib"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ mod tests {
let deserialized: MinimalJobScope = serde_json::from_value(json_data).expect("Failed to deserialize");

assert_eq!(deserialized.vector_fs_items.len(), 2);
assert_eq!(deserialized.vector_fs_items[0].relative_path(), "path/to/file1");
assert_eq!(deserialized.vector_fs_items[1].relative_path(), "path/to/file2");
assert_eq!(deserialized.vector_fs_items[0].relative_path(), os_path::OsPath::from("path/to/file1").to_string());
assert_eq!(deserialized.vector_fs_items[1].relative_path(), os_path::OsPath::from("path/to/file2").to_string());
assert_eq!(deserialized.vector_fs_folders.len(), 1);
assert_eq!(deserialized.vector_fs_folders[0].relative_path(), "My Files (Private)");
assert_eq!(deserialized.vector_search_mode, VectorSearchMode::FillUpTo25k);
Expand All @@ -93,9 +93,9 @@ mod tests {
let deserialized: MinimalJobScope = serde_json::from_value(json_data).expect("Failed to deserialize");

assert_eq!(deserialized.vector_fs_items.len(), 1);
assert_eq!(deserialized.vector_fs_items[0].relative_path(), "path/to/file1");
assert_eq!(deserialized.vector_fs_items[0].relative_path(), os_path::OsPath::from("path/to/file1").to_string());
assert_eq!(deserialized.vector_fs_folders.len(), 1);
assert_eq!(deserialized.vector_fs_folders[0].relative_path(), "My Files (Private)");
assert_eq!(deserialized.vector_fs_folders[0].relative_path(), os_path::OsPath::from("My Files (Private)").to_string());
assert_eq!(deserialized.vector_search_mode, VectorSearchMode::FillUpTo25k); // Check default
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ impl ShinkaiPath {
/// Private helper method to create a ShinkaiPath from a &str.
pub fn new(path: &str) -> Self {
let base_path = Self::base_path();
let path_buf = PathBuf::from(path);
let path_buf = os_path::OsPath::from(path).to_pathbuf(); // PathBuf::from(path);

let final_path = if path_buf.is_absolute() {
if path_buf.starts_with(&base_path) {
Expand Down Expand Up @@ -227,7 +227,7 @@ mod tests {
env::var("NODE_STORAGE_PATH").unwrap()
))
);
assert_eq!(path.relative_path(), "word_files/christmas.docx");
assert_eq!(path.relative_path(), os_path::OsPath::from("word_files/christmas.docx").to_string());
}

#[test]
Expand All @@ -240,15 +240,15 @@ mod tests {
path.as_path(),
Path::new("storage/filesystem/word_files/christmas.docx")
);
assert_eq!(path.relative_path(), "word_files/christmas.docx");
assert_eq!(path.relative_path(), os_path::OsPath::from("word_files/christmas.docx").to_string());
}

#[test]
#[serial]
fn test_relative_path_outside_base() {
let _dir = testing_create_tempdir_and_set_env_var();
let absolute_outside = ShinkaiPath::from_string("/some/other/path".to_string());
assert_eq!(absolute_outside.relative_path(), "some/other/path");
assert_eq!(absolute_outside.relative_path(), os_path::OsPath::from("some/other/path").to_string());
}

#[test]
Expand Down Expand Up @@ -349,6 +349,8 @@ mod tests {
let serialized_path = serde_json::to_string(&path).unwrap();

// Check if the serialized output matches the expected relative path
assert_eq!(serialized_path, "\"word_files/christmas.docx\"");
let serialized_path_str = serde_json::to_string(&os_path::OsPath::from("word_files/christmas.docx").to_string()).unwrap();

assert_eq!(serialized_path, serialized_path_str);
}
}
25 changes: 22 additions & 3 deletions shinkai-libs/shinkai-sqlite/src/file_system.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@ use shinkai_message_primitives::{
};

impl SqliteManager {
// TODO: This is a temporary workaround for Windows paths. We should handle this more robustly.
pub fn normalize_path(path: &str) -> String {
let mut path = path.replace("\\\\", "/");
path = path.replace("\\", "/");
path
}

pub fn initialize_filesystem_tables(conn: &rusqlite::Connection) -> Result<(), rusqlite::Error> {
// parsed_files table
conn.execute(
Expand Down Expand Up @@ -70,21 +77,24 @@ impl SqliteManager {
let mut conn = self.get_connection()?;
let tx = conn.transaction()?;

let pf_relative_path = Self::normalize_path(&pf.relative_path);

let exists: bool = tx.query_row(
"SELECT EXISTS(SELECT 1 FROM parsed_files WHERE relative_path = ?)",
[&pf.relative_path],
[&pf_relative_path],
|row| row.get(0),
)?;
if exists {
return Err(SqliteManagerError::DataAlreadyExists);
}

let relative_path = Self::normalize_path(&pf.relative_path);
tx.execute(
"INSERT INTO parsed_files (relative_path, original_extension, description, source, embedding_model_used,
keywords, distribution_info, created_time, tags, total_tokens, total_characters)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)",
params![
pf.relative_path,
relative_path,
pf.original_extension,
pf.description,
pf.source,
Expand All @@ -104,6 +114,9 @@ impl SqliteManager {

pub fn get_parsed_file_by_rel_path(&self, rel_path: &str) -> Result<Option<ParsedFile>, SqliteManagerError> {
let conn = self.get_connection()?;

let rel_path = Self::normalize_path(rel_path);

let mut stmt = conn.prepare(
"
SELECT id, relative_path, original_extension, description, source, embedding_model_used, keywords,
Expand Down Expand Up @@ -150,13 +163,14 @@ impl SqliteManager {
return Err(SqliteManagerError::DataNotFound);
}

let relative_path = Self::normalize_path(&pf.relative_path);
tx.execute(
"UPDATE parsed_files
SET relative_path = ?1, original_extension = ?2, description = ?3, source = ?4, embedding_model_used = ?5,
keywords = ?6, distribution_info = ?7, created_time = ?8, tags = ?9, total_tokens = ?10, total_characters = ?11
WHERE id = ?12",
params![
pf.relative_path,
relative_path,
pf.original_extension,
pf.description,
pf.source,
Expand Down Expand Up @@ -448,6 +462,9 @@ impl SqliteManager {
// -------------------------

pub fn update_folder_paths(&self, old_prefix: &str, new_prefix: &str) -> Result<(), SqliteManagerError> {
let old_prefix = Self::normalize_path(old_prefix);
let new_prefix = Self::normalize_path(new_prefix);

let mut conn = self.get_connection()?;
let tx = conn.transaction()?;

Expand All @@ -469,6 +486,7 @@ impl SqliteManager {
&self,
directory_path: &str,
) -> Result<Vec<ParsedFile>, SqliteManagerError> {
let directory_path = Self::normalize_path(directory_path);
let conn = self.get_connection()?;
let mut stmt = conn.prepare(
"SELECT id, relative_path, original_extension, description, source, embedding_model_used, keywords,
Expand Down Expand Up @@ -547,6 +565,7 @@ impl SqliteManager {

/// Retrieve all parsed files whose relative paths start with the given prefix.
pub fn get_parsed_files_by_prefix(&self, prefix: &str) -> Result<Vec<ParsedFile>, SqliteManagerError> {
let prefix = Self::normalize_path(prefix);
let conn = self.get_connection()?;
let mut stmt = conn.prepare(
"SELECT id, relative_path, original_extension, description, source, embedding_model_used, keywords,
Expand Down