Skip to content

Commit

Permalink
Merge pull request #585 from dcSpark/nico/fix_prompt_search
Browse files Browse the repository at this point in the history
fix prompt search
  • Loading branch information
nicarq authored Sep 30, 2024
2 parents dce2fdd + bb0517f commit e64209c
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 46 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,4 @@ storage_debug_agent_provider
agent_provider
image_0.txt
storage_debug_0_7_33.zip
shinkai-libs/shinkai-lancedb/lance_db_tests
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ impl Node {
let start_time = Instant::now();

// Perform the internal search using LanceShinkaiDb
match lance_db.read().await.prompt_vector_search(&query, 5).await {
match lance_db.read().await.prompt_vector_search(&query, 10).await {
Ok(prompts) => {
// Set embeddings to None before returning
let prompts_without_embeddings: Vec<CustomPrompt> = prompts
Expand Down
2 changes: 1 addition & 1 deletion shinkai-libs/shinkai-lancedb/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ uuid = { version = "1.6.1", features = ["v4"] }
shinkai_tools_primitives = { workspace = true }
shinkai_vector_resources = { workspace = true }
shinkai_message_primitives = { workspace = true }
shinkai_tools_runner = { version = "0.7.14", features = ["built-in-tools"] }
shinkai_tools_runner = { version = "0.7.15", features = ["built-in-tools"] }
regex = "1"
base64 = "0.22.0"
lancedb = "0.10.0"
Expand Down

Large diffs are not rendered by default.

102 changes: 59 additions & 43 deletions shinkai-libs/shinkai-lancedb/src/lance_db/shinkai_prompt_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ impl LanceShinkaiDb {
Ok(())
}

fn convert_batch_to_prompt(batch: &RecordBatch) -> Option<CustomPrompt> {
fn convert_batch_to_prompt(batch: &RecordBatch) -> Vec<CustomPrompt> {
let name_array = batch
.column_by_name(ShinkaiPromptSchema::name_field())
.unwrap()
Expand Down Expand Up @@ -134,13 +134,15 @@ impl LanceShinkaiDb {
.downcast_ref::<FixedSizeListArray>()
.unwrap();

if name_array.len() > 0 {
let embedding = if vector_array.is_null(0) {
let mut prompts = Vec::new();

for i in 0..name_array.len() {
let embedding = if vector_array.is_null(i) {
None
} else {
Some(
vector_array
.value(0)
.value(i)
.as_any()
.downcast_ref::<Float32Array>()
.unwrap()
Expand All @@ -149,18 +151,18 @@ impl LanceShinkaiDb {
)
};

Some(CustomPrompt {
name: name_array.value(0).to_string(),
prompt: prompt_array.value(0).to_string(),
is_system: is_system_array.value(0),
is_enabled: is_enabled_array.value(0),
version: version_array.value(0).to_string(),
is_favorite: is_favorite_array.value(0),
prompts.push(CustomPrompt {
name: name_array.value(i).to_string(),
prompt: prompt_array.value(i).to_string(),
is_system: is_system_array.value(i),
is_enabled: is_enabled_array.value(i),
version: version_array.value(i).to_string(),
is_favorite: is_favorite_array.value(i),
embedding,
})
} else {
None
});
}

prompts
}

pub async fn get_prompt(&self, name: &str) -> Result<Option<CustomPrompt>, ShinkaiLanceDBError> {
Expand Down Expand Up @@ -188,7 +190,8 @@ impl LanceShinkaiDb {
.map_err(|e| ShinkaiLanceDBError::DatabaseError(e.to_string()))?;

for batch in results {
if let Some(prompt) = Self::convert_batch_to_prompt(&batch) {
let prompts = Self::convert_batch_to_prompt(&batch);
for prompt in prompts {
return Ok(Some(prompt));
}
}
Expand Down Expand Up @@ -281,11 +284,8 @@ impl LanceShinkaiDb {
let mut prompts = Vec::new();
let mut res = query;
while let Some(Ok(batch)) = res.next().await {
for _i in 0..batch.num_rows() {
if let Some(prompt) = Self::convert_batch_to_prompt(&batch) {
prompts.push(prompt);
}
}
let prompts_batch = Self::convert_batch_to_prompt(&batch);
prompts.extend(prompts_batch);
}

Ok(prompts)
Expand All @@ -311,7 +311,12 @@ impl LanceShinkaiDb {
let fts_query_builder = self
.prompt_table
.query()
.full_text_search(FullTextSearchQuery::new(query.to_owned()))
.full_text_search(FullTextSearchQuery {
columns: vec!["prompt".to_string()],
query: query.to_owned(),
limit: Some(num_results as i64),
wand_factor: Some(1.0),
})
.select(Select::columns(&[
ShinkaiPromptSchema::name_field(),
ShinkaiPromptSchema::prompt_field(),
Expand Down Expand Up @@ -357,16 +362,14 @@ impl LanceShinkaiDb {

let mut fts_res = fts_query;
while let Some(Ok(batch)) = fts_res.next().await {
if let Some(prompt) = Self::convert_batch_to_prompt(&batch) {
fts_results.push(prompt);
}
let prompts = Self::convert_batch_to_prompt(&batch);
fts_results.extend(prompts);
}

let mut vector_res = vector_query;
while let Some(Ok(batch)) = vector_res.next().await {
if let Some(prompt) = Self::convert_batch_to_prompt(&batch) {
vector_results.push(prompt);
}
let prompts = Self::convert_batch_to_prompt(&batch);
vector_results.extend(prompts);
}

// Merge results using interleave and remove duplicates
Expand Down Expand Up @@ -399,6 +402,21 @@ impl LanceShinkaiDb {
}
}

// Continue to add results from the remaining iterator if needed
while combined_results.len() < num_results as usize {
if let Some(fts_item) = fts_iter.next() {
if seen.insert(fts_item.name.clone()) {
combined_results.push(fts_item);
}
} else if let Some(vector_item) = vector_iter.next() {
if seen.insert(vector_item.name.clone()) {
combined_results.push(vector_item);
}
} else {
break;
}
}

Ok(combined_results)
}

Expand All @@ -422,11 +440,8 @@ impl LanceShinkaiDb {
let mut prompts = Vec::new();
let mut res = query;
while let Some(Ok(batch)) = res.next().await {
for _i in 0..batch.num_rows() {
if let Some(prompt) = Self::convert_batch_to_prompt(&batch) {
prompts.push(prompt);
}
}
let prompts_batch = Self::convert_batch_to_prompt(&batch);
prompts.extend(prompts_batch);
}

Ok(prompts)
Expand Down Expand Up @@ -746,7 +761,7 @@ mod tests {
// Perform a vector search
let search_query = "first test prompt";
let search_results = db.prompt_vector_search(search_query, 2).await?;
assert_eq!(search_results.len(), 1, "There should be 1 search result");
assert_eq!(search_results.len(), 2, "There should be 2 search results");

Ok(())
}
Expand Down Expand Up @@ -798,6 +813,7 @@ mod tests {
prompts_json_testing.push(json!(prompt));
}

// Note: I'm commenting out the ones that are empty or don't work
// Generate prompts for production
env::set_var("IS_TESTING", "0");
let prompts = vec![
Expand All @@ -818,7 +834,7 @@ mod tests {
create_custom_prompt("Analyze Prose System", ANALYZE_PROSE_SYSTEM),
create_custom_prompt("Analyze Spiritual Text System", ANALYZE_SPIRITUAL_TEXT_SYSTEM),
create_custom_prompt("Analyze Tech Impact System", ANALYZE_TECH_IMPACT_SYSTEM),
create_custom_prompt("Analyze Threat Report System", ANALYZE_THREAT_REPORT_SYSTEM),
// create_custom_prompt("Analyze Threat Report System", ANALYZE_THREAT_REPORT_SYSTEM),
create_custom_prompt(
"Analyze Threat Report Trends System",
ANALYZE_THREAT_REPORT_TRENDS_SYSTEM,
Expand Down Expand Up @@ -867,16 +883,16 @@ mod tests {
"Create Network Threat Landscape System",
CREATE_NETWORK_THREAT_LANDSCAPE_SYSTEM,
),
create_custom_prompt(
"Create Network Threat Landscape User",
CREATE_NETWORK_THREAT_LANDSCAPE_USER,
),
// create_custom_prompt(
// "Create Network Threat Landscape User",
// CREATE_NETWORK_THREAT_LANDSCAPE_USER,
// ),
create_custom_prompt("Create NPC System", CREATE_NPC_SYSTEM),
create_custom_prompt("Create Pattern System", CREATE_PATTERN_SYSTEM),
create_custom_prompt("Create Quiz System", CREATE_QUIZ_SYSTEM),
create_custom_prompt("Create Reading Plan System", CREATE_READING_PLAN_SYSTEM),
create_custom_prompt("Create Report Finding System", CREATE_REPORT_FINDING_SYSTEM),
create_custom_prompt("Create Report Finding User", CREATE_REPORT_FINDING_USER),
// create_custom_prompt("Create Report Finding User", CREATE_REPORT_FINDING_USER),
create_custom_prompt("Create Security Update System", CREATE_SECURITY_UPDATE_SYSTEM),
create_custom_prompt("Create Show Intro System", CREATE_SHOW_INTRO_SYSTEM),
create_custom_prompt("Create Sigma Rules System", CREATE_SIGMA_RULES_SYSTEM),
Expand All @@ -888,7 +904,7 @@ mod tests {
create_custom_prompt("Create Video Chapters System", CREATE_VIDEO_CHAPTERS_SYSTEM),
create_custom_prompt("Create Visualization System", CREATE_VISUALIZATION_SYSTEM),
create_custom_prompt("Explain Code System", EXPLAIN_CODE_SYSTEM),
create_custom_prompt("Explain Code User", EXPLAIN_CODE_USER),
// create_custom_prompt("Explain Code User", EXPLAIN_CODE_USER),
create_custom_prompt("Explain Docs System", EXPLAIN_DOCS_SYSTEM),
create_custom_prompt("Explain Project System", EXPLAIN_PROJECT_SYSTEM),
create_custom_prompt("Explain Terms System", EXPLAIN_TERMS_SYSTEM),
Expand All @@ -898,7 +914,7 @@ mod tests {
EXTRACT_ALGORITHM_UPDATE_RECOMMENDATIONS_SYSTEM,
),
create_custom_prompt("Extract Article Wisdom System", EXTRACT_ARTICLE_WISDOM_SYSTEM),
create_custom_prompt("Extract Article Wisdom User", EXTRACT_ARTICLE_WISDOM_USER),
// create_custom_prompt("Extract Article Wisdom User", EXTRACT_ARTICLE_WISDOM_USER),
create_custom_prompt("Extract Book Ideas System", EXTRACT_BOOK_IDEAS_SYSTEM),
create_custom_prompt(
"Extract Book Recommendations System",
Expand Down Expand Up @@ -934,15 +950,15 @@ mod tests {
create_custom_prompt("Improve Academic Writing System", IMPROVE_ACADEMIC_WRITING_SYSTEM),
create_custom_prompt("Improve Prompt System", IMPROVE_PROMPT_SYSTEM),
create_custom_prompt("Improve Report Finding System", IMPROVE_REPORT_FINDING_SYSTEM),
create_custom_prompt("Improve Report Finding User", IMPROVE_REPORT_FINDING_USER),
// create_custom_prompt("Improve Report Finding User", IMPROVE_REPORT_FINDING_USER),
create_custom_prompt("Improve Writing System", IMPROVE_WRITING_SYSTEM),
create_custom_prompt("Label And Rate System", LABEL_AND_RATE_SYSTEM),
create_custom_prompt("Official Pattern Template System", OFFICIAL_PATTERN_TEMPLATE_SYSTEM),
create_custom_prompt("Provide Guidance System", PROVIDE_GUIDANCE_SYSTEM),
create_custom_prompt("Rate AI Response System", RATE_AI_RESPONSE_SYSTEM),
create_custom_prompt("Rate AI Result System", RATE_AI_RESULT_SYSTEM),
create_custom_prompt("Rate Content System", RATE_CONTENT_SYSTEM),
create_custom_prompt("Rate Content User", RATE_CONTENT_USER),
// create_custom_prompt("Rate Content User", RATE_CONTENT_USER),
create_custom_prompt("Rate Value System", RATE_VALUE_SYSTEM),
create_custom_prompt("Raw Query System", RAW_QUERY_SYSTEM),
create_custom_prompt("Recommend Artists System", RECOMMEND_ARTISTS_SYSTEM),
Expand Down

0 comments on commit e64209c

Please sign in to comment.