Skip to content

Commit

Permalink
added files to prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
acedward committed Dec 22, 2024
1 parent 5c32ef5 commit 8d89a06
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use crate::managers::sheet_manager::SheetManager;
use crate::managers::tool_router::{ToolCallFunctionResponse, ToolRouter};
use crate::network::agent_payments_manager::external_agent_offerings_manager::ExtAgentOfferingsManager;
use crate::network::agent_payments_manager::my_agent_offerings_manager::MyAgentOfferingsManager;
use crate::utils::environment::{fetch_node_environment, NodeEnvironment};
use async_trait::async_trait;
use shinkai_message_primitives::schemas::inbox_name::InboxName;
use shinkai_message_primitives::schemas::job::{Job, JobLike};
Expand Down Expand Up @@ -83,6 +84,7 @@ impl InferenceChain for GenericInferenceChain {
self.context.ext_agent_payments_manager.clone(),
// self.context.sqlite_logger.clone(),
self.context.llm_stopper.clone(),
fetch_node_environment(),
)
.await?;
Ok(response)
Expand Down Expand Up @@ -122,6 +124,7 @@ impl GenericInferenceChain {
ext_agent_payments_manager: Option<Arc<Mutex<ExtAgentOfferingsManager>>>,
// sqlite_logger: Option<Arc<SqliteLogger>>,
llm_stopper: Arc<LLMStopper>,
node_env: NodeEnvironment,
) -> Result<InferenceChainResult, LLMProviderError> {
shinkai_log(
ShinkaiLogOption::JobExecution,
Expand Down Expand Up @@ -327,6 +330,8 @@ impl GenericInferenceChain {
Some(full_job.step_history.clone()),
tools.clone(),
None,
full_job.job_id.clone(),
node_env.clone(),
);

let mut iteration_count = 0;
Expand Down Expand Up @@ -394,7 +399,10 @@ impl GenericInferenceChain {

// 6) Call workflow or tooling
// Find the ShinkaiTool that has a tool with the function name
let shinkai_tool = tools.iter().find(|tool| tool.name() == function_call.name || tool.tool_router_key() == function_call.tool_router_key.clone().unwrap_or_default());
let shinkai_tool = tools.iter().find(|tool| {
tool.name() == function_call.name
|| tool.tool_router_key() == function_call.tool_router_key.clone().unwrap_or_default()
});
if shinkai_tool.is_none() {
eprintln!("Function not found: {}", function_call.name);
return Err(LLMProviderError::FunctionNotFound(function_call.name.clone()));
Expand Down Expand Up @@ -443,6 +451,8 @@ impl GenericInferenceChain {
Some(full_job.step_history.clone()),
tools.clone(),
Some(function_response),
full_job.job_id.clone(),
node_env.clone(),
);
} else {
// No more function calls required, return the final response
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ use std::collections::HashMap;

use crate::llm_provider::execution::prompts::general_prompts::JobPromptGenerator;
use crate::managers::tool_router::ToolCallFunctionResponse;
use crate::network::v2_api::api_v2_commands_app_files::get_app_folder_path;
use crate::network::Node;
use crate::utils::environment::NodeEnvironment;
use serde_json::json;
use shinkai_message_primitives::schemas::job::JobStepResult;
use shinkai_message_primitives::schemas::prompts::Prompt;
Expand All @@ -23,6 +26,8 @@ impl JobPromptGenerator {
job_step_history: Option<Vec<JobStepResult>>,
tools: Vec<ShinkaiTool>,
function_call: Option<ToolCallFunctionResponse>,
job_id: String,
node_env: NodeEnvironment,
) -> Prompt {
let mut prompt = Prompt::new();

Expand Down Expand Up @@ -52,6 +57,12 @@ impl JobPromptGenerator {
priority = priority.saturating_sub(1);
}
}

let current_files = Node::v2_api_list_app_files_internal(get_app_folder_path(node_env, job_id));
if let Ok(current_files) = current_files {
let content = format!("Current files: {:?}", current_files);
prompt.add_content(content, SubPromptType::ExtraContext, 97);
}
}

// Parses the retrieved nodes as individual sub-prompts, to support priority pruning
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::managers::sheet_manager::SheetManager;
use crate::managers::tool_router::{ToolCallFunctionResponse, ToolRouter};
use crate::network::agent_payments_manager::external_agent_offerings_manager::ExtAgentOfferingsManager;
use crate::network::agent_payments_manager::my_agent_offerings_manager::MyAgentOfferingsManager;
use crate::utils::environment::{fetch_node_environment, NodeEnvironment};
use async_trait::async_trait;
use shinkai_message_primitives::schemas::inbox_name::InboxName;
use shinkai_message_primitives::schemas::job::{Job, JobLike};
Expand Down Expand Up @@ -79,6 +80,7 @@ impl InferenceChain for SheetUIInferenceChain {
self.context.ext_agent_payments_manager.clone(),
// self.context.sqlite_logger.clone(),
self.context.llm_stopper.clone(),
fetch_node_environment(),
)
.await?;
let job_execution_context = self.context.execution_context.clone();
Expand Down Expand Up @@ -123,6 +125,7 @@ impl SheetUIInferenceChain {
ext_agent_payments_manager: Option<Arc<Mutex<ExtAgentOfferingsManager>>>,
// sqlite_logger: Option<Arc<SqliteLogger>>,
llm_stopper: Arc<LLMStopper>,
node_env: NodeEnvironment,
) -> Result<String, LLMProviderError> {
shinkai_log(
ShinkaiLogOption::JobExecution,
Expand Down Expand Up @@ -283,6 +286,8 @@ impl SheetUIInferenceChain {
Some(full_job.step_history.clone()),
tools.clone(),
None,
full_job.job_id.clone(),
node_env.clone(),
);

let mut iteration_count = 0;
Expand Down Expand Up @@ -419,6 +424,8 @@ impl SheetUIInferenceChain {
Some(full_job.step_history.clone()),
tools.clone(),
Some(function_response),
full_job.job_id.clone(),
node_env.clone(),
);
} else {
// No more function calls required, return the final response
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use std::sync::Arc;

use shinkai_http_api::node_api_router::APIError;

fn get_app_folder_path(node_env: NodeEnvironment, app_id: String) -> PathBuf {
pub fn get_app_folder_path(node_env: NodeEnvironment, app_id: String) -> PathBuf {
let mut origin_path: PathBuf = PathBuf::from(node_env.node_storage_path.clone().unwrap_or_default());
origin_path.push("app_files");
origin_path.push(app_id);
Expand Down Expand Up @@ -143,26 +143,43 @@ impl Node {
return Ok(());
}
let app_folder_path = get_app_folder_path(node_env, app_id);
let result = Self::v2_api_list_app_files_internal(app_folder_path);
match result {
Ok(file_list) => {
let _ = res
.send(Ok(Value::Array(
file_list.iter().map(|file| Value::String(file.clone())).collect(),
)))
.await;
Ok(())
}
Err(err) => {
let _ = res.send(Err(err)).await;
Ok(())
}
}
}

pub fn v2_api_list_app_files_internal(app_folder_path: PathBuf) -> Result<Vec<String>, APIError> {
let files = std::fs::read_dir(&app_folder_path);
if let Err(err) = files {
let api_error = APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Internal Server Error".to_string(),
message: format!("Failed to read directory: {}", err),
};
let _ = res.send(Err(api_error)).await;
return Ok(());
return Err(api_error);
}
let mut file_list = Vec::new();

let files = files.unwrap();
for file in files {
if let Ok(file) = file {
let file_name = file.file_name().to_string_lossy().to_string();
file_list.push(Value::String(file_name));
file_list.push(file_name);
}
}
let _ = res.send(Ok(Value::Array(file_list))).await;
Ok(())
Ok(file_list)
}

pub async fn v2_api_delete_app_file(
Expand Down

0 comments on commit 8d89a06

Please sign in to comment.