Skip to content

Commit

Permalink
Fixed retrievednode resource header to be from root
Browse files Browse the repository at this point in the history
  • Loading branch information
robkorn committed Jan 16, 2024
1 parent 55bb09f commit dc87f82
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,9 @@ pub trait VectorResourceSearch: VectorResourceCore {
traversal_options: &Vec<TraversalOption>,
starting_path: Option<VRPath>,
) -> Vec<RetrievedNode> {
// Setup the root VRHeader that will be attached to all RetrievedNodes
let root_vr_header = self.generate_resource_header();

if let Some(path) = starting_path {
match self.retrieve_node_at_path(path.clone()) {
Ok(ret_node) => {
Expand All @@ -279,6 +282,7 @@ pub trait VectorResourceSearch: VectorResourceCore {
traversal_options,
vec![],
path,
root_vr_header.clone(),
);
}
}
Expand All @@ -293,6 +297,7 @@ pub trait VectorResourceSearch: VectorResourceCore {
traversal_options,
vec![],
VRPath::new(),
root_vr_header,
);

// After getting all results from the vector search, perform final filtering
Expand Down Expand Up @@ -340,7 +345,7 @@ pub trait VectorResourceSearch: VectorResourceCore {
}
}

// Check if we are using traveral method unscored all nodes
// Check if we are using traversal method unscored all nodes
if traversal_method != TraversalMethod::UnscoredAllNodes {
results.truncate(num_of_results as usize);
}
Expand All @@ -357,6 +362,7 @@ pub trait VectorResourceSearch: VectorResourceCore {
traversal_options: &Vec<TraversalOption>,
hierarchical_scores: Vec<f32>,
traversal_path: VRPath,
root_vr_header: VRHeader,
) -> Vec<RetrievedNode> {
// First we fetch the embeddings we want to score
let mut embeddings_to_score = vec![];
Expand Down Expand Up @@ -410,6 +416,7 @@ pub trait VectorResourceSearch: VectorResourceCore {
traversal_options,
hierarchical_scores,
traversal_path,
root_vr_header,
)
}

Expand All @@ -423,6 +430,7 @@ pub trait VectorResourceSearch: VectorResourceCore {
traversal_options: &Vec<TraversalOption>,
hierarchical_scores: Vec<f32>,
traversal_path: VRPath,
root_vr_header: VRHeader,
) -> Vec<RetrievedNode> {
let mut current_level_results: Vec<RetrievedNode> = vec![];
let mut vector_resource_count = 0;
Expand Down Expand Up @@ -455,7 +463,7 @@ pub trait VectorResourceSearch: VectorResourceCore {
let ret_node = RetrievedNode {
node: node.clone(),
score,
resource_header: self.generate_resource_header(),
resource_header: root_vr_header.clone(),
retrieval_path: traversal_path.clone(),
};
current_level_results.push(ret_node);
Expand Down Expand Up @@ -494,6 +502,7 @@ pub trait VectorResourceSearch: VectorResourceCore {
traversal_options,
hierarchical_scores.clone(),
traversal_path.clone(),
root_vr_header.clone(),
);
current_level_results.extend(results);
}
Expand All @@ -519,6 +528,7 @@ pub trait VectorResourceSearch: VectorResourceCore {
traversal_options: &Vec<TraversalOption>,
hierarchical_scores: Vec<f32>,
traversal_path: VRPath,
root_vr_header: VRHeader,
) -> Vec<RetrievedNode> {
let mut current_level_results: Vec<RetrievedNode> = vec![];
// Concat the current score into a new hierarchical scores Vec before moving forward
Expand All @@ -536,6 +546,7 @@ pub trait VectorResourceSearch: VectorResourceCore {
traversal_options,
new_hierarchical_scores,
new_traversal_path.clone(),
root_vr_header.clone(),
);

// If traversing with UnscoredAllNodes, include the Vector Resource
Expand All @@ -545,7 +556,7 @@ pub trait VectorResourceSearch: VectorResourceCore {
current_level_results.push(RetrievedNode {
node: node.clone(),
score,
resource_header: self.generate_resource_header(),
resource_header: root_vr_header.clone(),
retrieval_path: new_traversal_path,
});
}
Expand All @@ -563,7 +574,7 @@ pub trait VectorResourceSearch: VectorResourceCore {
current_level_results.push(RetrievedNode {
node: node.clone(),
score,
resource_header: self.generate_resource_header(),
resource_header: root_vr_header.clone(),
retrieval_path: new_traversal_path,
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ use std::fmt;
use std::hash::{Hash, Hasher};

/// A node that was retrieved from inside of a Vector Resource. Includes extra data like the retrieval path
/// and the similarity score from the vector search.
/// and the similarity score from the vector search. The resource_header is the VRHeader from the root
/// Vector Resource the RetrievedNode is from.
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
pub struct RetrievedNode {
pub node: Node,
Expand Down
7 changes: 4 additions & 3 deletions src/vector_fs/vector_fs_search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@ use shinkai_vector_resources::{
use std::collections::HashMap;

/// A retrieved node from within a Vector Resource inside of the VectorFS.
/// Includes FSItem
/// Includes the path of the FSItem in the VectorFS and the retrieved node
/// from the Vector Resource inside the FSItem's path.
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
struct FSRetrievedNode {
origin_fs_item: FSItem,
retrieved_node: RetrievedNode,
fs_item_path: VRPath,
resource_retrieved_node: RetrievedNode,
}

impl VectorFS {
Expand Down

0 comments on commit dc87f82

Please sign in to comment.