Skip to content

Commit

Permalink
Merge pull request #561 from dcSpark/nico/fix_prompt_api
Browse files Browse the repository at this point in the history
fix new prompt api
  • Loading branch information
nicarq authored Sep 18, 2024
2 parents 8a5d239 + f72eed9 commit e9f139e
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 18 deletions.
15 changes: 8 additions & 7 deletions shinkai-bin/shinkai-node/src/lance_db/shinkai_lance_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,18 +120,18 @@ impl LanceShinkaiDb {
return Ok(());
}

self.tool_table
.create_index(&["tool_seo"], Index::FTS(FtsIndexBuilder::default()))
.execute()
.await?;

// Check the number of elements in the table
let element_count = self.tool_table.count_rows(None).await?;
if element_count < 100 {
self.tool_table
.create_index(&["tool_seo"], Index::FTS(FtsIndexBuilder::default()))
.execute()
.await?;

if element_count < 256 {
eprintln!("Not enough elements to create other indices. Skipping index creation for tool table.");
return Ok(());
}

// Create the indices
self.tool_table
.create_index(&["tool_key"], Index::Auto)
Expand All @@ -144,6 +144,7 @@ impl LanceShinkaiDb {
.await?;

self.tool_table.create_index(&["vector"], Index::Auto).execute().await?;

self.tool_table
.create_index(
&[ShinkaiToolSchema::vector_field()],
Expand Down
13 changes: 7 additions & 6 deletions shinkai-bin/shinkai-node/src/lance_db/shinkai_prompt_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,14 @@ impl LanceShinkaiDb {
return Ok(());
}

self.prompt_table
.create_index(&["prompt"], Index::FTS(FtsIndexBuilder::default()))
.execute()
.await?;

// Check the number of elements in the table
let element_count = self.prompt_table.count_rows(None).await?;
if element_count < 100 {
self.prompt_table
.create_index(&["prompt"], Index::FTS(FtsIndexBuilder::default()))
.execute()
.await?;

if element_count < 256 {
eprintln!("Not enough elements to create other indices. Skipping index creation for prompt table.");
return Ok(());
}
Expand All @@ -79,6 +79,7 @@ impl LanceShinkaiDb {
.create_index(&["vector"], Index::Auto)
.execute()
.await?;

self.prompt_table
.create_index(
&[ShinkaiPromptSchema::vector_field()],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ use std::{sync::Arc, time::Instant};

use async_channel::Sender;
use reqwest::StatusCode;
use tokio::sync::{RwLock};
use tokio::sync::RwLock;

use crate::{
db::ShinkaiDB,
lance_db::shinkai_lance_db::LanceShinkaiDb,
network::{node_api_router::APIError, node_error::NodeError, Node}, prompts::custom_prompt::CustomPrompt,
network::{node_api_router::APIError, node_error::NodeError, Node},
prompts::custom_prompt::CustomPrompt,
};

impl Node {
Expand Down Expand Up @@ -100,7 +101,16 @@ impl Node {
// Get all prompts from the LanceShinkaiDb
match lance_db.read().await.get_all_prompts().await {
Ok(prompts) => {
let _ = res.send(Ok(prompts)).await;
// Set embeddings to None before returning
let prompts_without_embeddings: Vec<CustomPrompt> = prompts
.into_iter()
.map(|mut prompt| {
prompt.embedding = None;
prompt
})
.collect();

let _ = res.send(Ok(prompts_without_embeddings)).await;
Ok(())
}
Err(err) => {
Expand Down Expand Up @@ -172,13 +182,22 @@ impl Node {
// Perform the internal search using LanceShinkaiDb
match lance_db.read().await.prompt_vector_search(&query, 5).await {
Ok(prompts) => {
// Set embeddings to None before returning
let prompts_without_embeddings: Vec<CustomPrompt> = prompts
.into_iter()
.map(|mut prompt| {
prompt.embedding = None;
prompt
})
.collect();

// Log the elapsed time if LOG_ALL is set to 1
if std::env::var("LOG_ALL").unwrap_or_default() == "1" {
let elapsed_time = start_time.elapsed();
println!("Time taken for custom prompt search: {:?}", elapsed_time);
println!("Number of custom prompt results: {}", prompts.len());
println!("Number of custom prompt results: {}", prompts_without_embeddings.len());
}
let _ = res.send(Ok(prompts)).await;
let _ = res.send(Ok(prompts_without_embeddings)).await;
Ok(())
}
Err(err) => {
Expand Down
4 changes: 4 additions & 0 deletions shinkai-bin/shinkai-node/src/network/v2_api/api_v2_router.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use crate::network::node_commands::NodeCommand;
use crate::prompts::custom_prompt;

use super::api_v2_handlers_ext_agent_offers::ext_agent_offers_routes;
use super::api_v2_handlers_jobs::job_routes;
use super::api_v2_handlers_prompts::prompt_routes;
use super::api_v2_handlers_swagger_ui::swagger_ui_routes;
use super::api_v2_handlers_vecfs::vecfs_routes;
use super::api_v2_handlers_wallets::wallet_routes;
Expand All @@ -24,6 +26,7 @@ pub fn v2_routes(
let workflows_routes = workflows_routes(node_commands_sender.clone());
let ext_agent_offers = ext_agent_offers_routes(node_commands_sender.clone());
let wallet_routes = wallet_routes(node_commands_sender.clone());
let custom_prompt = prompt_routes(node_commands_sender.clone());
let swagger_ui_routes = swagger_ui_routes();

general_routes
Expand All @@ -33,6 +36,7 @@ pub fn v2_routes(
.or(workflows_routes)
.or(ext_agent_offers)
.or(wallet_routes)
.or(custom_prompt)
.or(swagger_ui_routes)
}

Expand Down
4 changes: 4 additions & 0 deletions shinkai-bin/shinkai-node/src/tools/tool_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ impl ToolRouter {
}

self.lance_db.write().await.create_tool_indices_if_needed().await?;
self.lance_db.write().await.create_prompt_indices_if_needed().await?;

Ok(())
}
Expand All @@ -87,10 +88,13 @@ impl ToolRouter {
// Add JS tools
let _ = self.add_js_tools().await;

let _ = self.add_static_prompts(generator).await;

// Set the latest version in the database
self.set_lancedb_version(LATEST_ROUTER_DB_VERSION).await?;

self.lance_db.write().await.create_tool_indices_if_needed().await?;
self.lance_db.write().await.create_prompt_indices_if_needed().await?;

Ok(())
}
Expand Down

0 comments on commit e9f139e

Please sign in to comment.