Skip to content

Commit

Permalink
Start of fixing scoring of VectorFS resource nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
robkorn committed Jan 17, 2024
1 parent 20e1ef5 commit 183f55e
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -685,13 +685,19 @@ impl VRPath {
new_path
}

/// Creates a cloned VRPath and removes an element from the end
/// Returns a cloned VRPath with the last id removed from the end
pub fn pop_cloned(&self) -> Self {
let mut new_path = self.clone();
new_path.pop();
new_path
}

/// Returns a VRPath which is the path prior to self (the "parent path").
/// Ie. For path "/a/b/c", this will return "/a/b".
pub fn parent_path(&self) -> Self {
self.pop_cloned()
}

/// Create a VRPath from a path string
pub fn from_string(path_string: &str) -> Result<Self, VRError> {
if !path_string.starts_with('/') {
Expand Down
33 changes: 12 additions & 21 deletions src/vector_fs/vector_fs_search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ impl VectorFS {
let items = self.vector_search_fs_item(reader, query.clone(), num_of_resources_to_search_into)?;

for item in items {
println!("Item: {:?}", item);
// Create a new reader at the path of the fs_item, and then fetch the VR from there
let new_reader = reader._new_reader_copied_data(item.path.clone(), self)?;
let resource = self.retrieve_vector_resource(&new_reader)?;
Expand All @@ -102,7 +101,6 @@ impl VectorFS {

// Perform the internal vector search into the resource itself
let results = resource.as_trait_object().vector_search(query.clone(), num_of_results);
println!("\nResults: {:?}\n\n", results);
ret_nodes.extend(results);
}

Expand Down Expand Up @@ -130,10 +128,14 @@ impl VectorFS {

let mut fs_items = vec![];
for ret_node in ret_nodes {
println!("Ret Node: {:?}", ret_node);
if let NodeContent::VRHeader(_) = ret_node.node.content {
println!(
"Merkle hash: {:?} -- Score: {}",
ret_node.node.get_merkle_hash(),
&ret_node.score
);
fs_items.push(FSItem::from_vr_header_node(
ret_node.node,
ret_node.node.clone(),
ret_node.retrieval_path,
&internals.last_read_index,
)?)
Expand All @@ -155,7 +157,8 @@ impl VectorFS {
let mut results = vec![];

for item in items {
let res_pair = self.retrieve_vr_and_source_file_map_in_folder(reader, item.name())?;
let new_reader = reader._new_reader_copied_data(item.path.parent_path(), self)?;
let res_pair = self.retrieve_vr_and_source_file_map_in_folder(&new_reader, item.name())?;
results.push(res_pair);
}
Ok(results)
Expand All @@ -173,7 +176,8 @@ impl VectorFS {
let mut results = vec![];

for item in items {
let res = self.retrieve_vector_resource_in_folder(reader, item.name())?;
let new_reader = reader._new_reader_copied_data(item.path.parent_path(), self)?;
let res = self.retrieve_vector_resource_in_folder(&new_reader, item.name())?;
results.push(res);
}
Ok(results)
Expand All @@ -191,7 +195,8 @@ impl VectorFS {
let mut results = vec![];

for item in items {
let res = self.retrieve_source_file_map_in_folder(reader, item.name())?;
let new_reader = reader._new_reader_copied_data(item.path.parent_path(), self)?;
let res = self.retrieve_source_file_map_in_folder(&new_reader, item.name())?;
results.push(res);
}
Ok(results)
Expand Down Expand Up @@ -251,11 +256,6 @@ impl VectorFS {
)),
));

println!(
"Core resource node count: {}",
internals.fs_core_resource.get_nodes().len()
);

let results = internals.fs_core_resource.vector_search_customized(
query,
num_of_results,
Expand All @@ -264,8 +264,6 @@ impl VectorFS {
Some(reader.path.clone()),
);

println!("Results: {:?}", results);

Ok(results)
}
}
Expand All @@ -275,10 +273,6 @@ impl VectorFS {
fn _permissions_validation_func(_: &Node, path: &VRPath, hashmap: HashMap<VRPath, String>) -> bool {
// If the specified path has no permissions, then the default is to now allow traversing deeper
if !hashmap.contains_key(path) {
println!(" path being checked in permissions hashmap: {}", path);
println!("doesn't contain key");

println!("Permissions hashmap: {:?}", hashmap);
return false;
}

Expand All @@ -291,12 +285,9 @@ fn _permissions_validation_func(_: &Node, path: &VRPath, hashmap: HashMap<VRPath
None => return false,
};

println!("got reader");
// Initialize the PermissionsIndex struct
let perm_index = PermissionsIndex::from_hashmap(reader.profile.clone(), hashmap);

println!("initialized perm index");

perm_index
.validate_read_permission(&reader.requester_name, path)
.is_ok()
Expand Down
6 changes: 5 additions & 1 deletion src/vector_fs/vector_fs_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,10 +263,14 @@ impl VectorFS {
let node_id = vr_header.resource_name.clone();
let resource = node.get_vector_resource_content_mut()?;
let new_vr_header_node = Node::new_vr_header(node_id, &vr_header, metadata.clone(), &vec![]);
let new_node_embedding = vr_header
.resource_embedding
.clone()
.ok_or(VRError::NoEmbeddingProvided)?;
resource.as_trait_object_mut().insert_node(
vr_header.resource_name.clone(),
new_vr_header_node,
embedding.clone(),
new_node_embedding,
Some(current_datetime),
)?;
Ok(())
Expand Down
98 changes: 90 additions & 8 deletions tests/it/vec_fs_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ use shinkai_node::vector_fs::{db::fs_db::VectorFSDB, vector_fs::VectorFS, vector
use shinkai_vector_resources::embedding_generator::{EmbeddingGenerator, RemoteEmbeddingGenerator};
use shinkai_vector_resources::model_type::{EmbeddingModelType, TextEmbeddingsInference};
use shinkai_vector_resources::vector_resource::{
BaseVectorResource, VRPath, VectorResource, VectorResourceCore, VectorResourceSearch,
BaseVectorResource, DocumentVectorResource, VRPath, VRSource, VectorResource, VectorResourceCore,
VectorResourceSearch,
};
use std::collections::HashMap;
use std::fs;
Expand Down Expand Up @@ -67,7 +68,16 @@ async fn test_vector_fs_saving_reading() {
.new_writer(default_test_profile(), path.clone(), default_test_profile())
.unwrap();
let folder_name = "first_folder";
vector_fs.create_new_folder(&writer, folder_name).unwrap();
vector_fs.create_new_folder(&writer, folder_name.clone()).unwrap();
let writer = vector_fs
.new_writer(
default_test_profile(),
path.push_cloned(folder_name.to_string()),
default_test_profile(),
)
.unwrap();
let folder_name_2 = "second_folder";
vector_fs.create_new_folder(&writer, folder_name_2).unwrap();

// Validate new folder path points to an entry at all (not empty), then specifically a folder, and finally not to an item.
let folder_path = path.push_cloned(folder_name.to_string());
Expand Down Expand Up @@ -135,6 +145,55 @@ async fn test_vector_fs_saving_reading() {
// Vector Search Tests
//

// First add a 2nd VR into the VecFS
let generator = RemoteEmbeddingGenerator::new_default();
let mut doc = DocumentVectorResource::new_empty(
"3 Animal Facts",
Some("A bunch of facts about animals and wildlife"),
VRSource::new_uri_ref("animalwildlife.com", None),
true,
);
doc.set_embedding_model_used(generator.model_type());
doc.update_resource_embedding(&generator, vec!["animal".to_string(), "wild life".to_string()])
.await
.unwrap();
let fact1 = "Dogs are creatures with 4 legs that bark.";
let fact1_embedding = generator.generate_embedding_default(fact1).await.unwrap();
let fact2 = "Camels are slow animals with large humps.";
let fact2_embedding = generator.generate_embedding_default(fact2).await.unwrap();
let fact3 = "Seals swim in the ocean.";
let fact3_embedding = generator.generate_embedding_default(fact3).await.unwrap();
doc.append_text_node(fact1.clone(), None, fact1_embedding.clone(), &vec![])
.unwrap();
doc.append_text_node(fact2.clone(), None, fact2_embedding.clone(), &vec![])
.unwrap();
doc.append_text_node(fact3.clone(), None, fact3_embedding.clone(), &vec![])
.unwrap();

let writer = vector_fs
.new_writer(default_test_profile(), folder_path.clone(), default_test_profile())
.unwrap();
vector_fs
.save_vector_resource_in_folder(
&writer,
BaseVectorResource::Document(doc),
Some(source_file_map.clone()),
DistributionOrigin::None,
)
.unwrap();

// Searching for FSItems
let reader = vector_fs
.new_reader(default_test_profile(), VRPath::root(), default_test_profile())
.unwrap();
let query_string = "Who is building Shinkai?".to_string();
let query_embedding = vector_fs
.generate_query_embedding_using_reader(query_string, &reader)
.await
.unwrap();
let res = vector_fs.vector_search_fs_item(&reader, query_embedding, 100).unwrap();
assert_eq!(res[0].name(), "shinkai_intro");

// Searching into the Vector Resources themselves in the VectorFS to acquire internal nodes
let reader = vector_fs
.new_reader(default_test_profile(), VRPath::root(), default_test_profile())
Expand All @@ -145,7 +204,7 @@ async fn test_vector_fs_saving_reading() {
.await
.unwrap();
let res = vector_fs
.vector_search_fs_retrieved_node(&reader, query_embedding, 100, 100)
.vector_search_fs_retrieved_node(&reader, query_embedding.clone(), 100, 100)
.unwrap();
assert_eq!(
"Shinkai Network Manifesto (Early Preview) Robert Kornacki [email protected] Nicolas Arqueros",
Expand All @@ -156,15 +215,38 @@ async fn test_vector_fs_saving_reading() {
.unwrap()
.to_string()
);
let res = vector_fs
.vector_search_vector_resource(&reader, query_embedding, 1)
.unwrap();
assert_eq!("shinkai_intro", res[0].as_trait_object().name());

// Searching for FSItems
let reader = vector_fs
.new_reader(default_test_profile(), VRPath::root(), default_test_profile())
// Animal facts search
let query_string = "What do you know about camels?".to_string();
let query_embedding = vector_fs
.generate_query_embedding_using_reader(query_string, &reader)
.await
.unwrap();
let query_string = "Who is building Shinkai?".to_string();
let res = vector_fs
.vector_search_fs_retrieved_node(&reader, query_embedding.clone(), 100, 100)
.unwrap();
assert_eq!(
"Camels are slow animals with large humps.",
res[0]
.resource_retrieved_node
.node
.get_text_content()
.unwrap()
.to_string()
);

let query_string = "What are popular animals?".to_string();
let query_embedding = vector_fs
.generate_query_embedding_using_reader(query_string, &reader)
.await
.unwrap();
let res = vector_fs.vector_search_fs_item(&reader, query_embedding, 100).unwrap();
let res = vector_fs
.vector_search_vector_resource(&reader, query_embedding, 100)
.unwrap();
assert_eq!("3 Animal Facts", res[1].as_trait_object().name());
assert_eq!("3 Animal Facts", res[0].as_trait_object().name());
}

0 comments on commit 183f55e

Please sign in to comment.