Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove step history and execution context #717

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ impl InferenceChain for GenericInferenceChain {
self.context.message_hash_id.clone(),
self.context.image_files.clone(),
self.context.llm_provider.clone(),
self.context.execution_context.clone(),
self.context.generator.clone(),
self.context.user_profile.clone(),
self.context.max_iterations,
Expand Down Expand Up @@ -110,7 +109,6 @@ impl GenericInferenceChain {
message_hash_id: Option<String>,
image_files: HashMap<String, String>,
llm_provider: ProviderOrAgent,
execution_context: HashMap<String, String>,
generator: RemoteEmbeddingGenerator,
user_profile: ShinkaiName,
max_iterations: u64,
Expand Down Expand Up @@ -324,7 +322,7 @@ impl GenericInferenceChain {
image_files.clone(),
ret_nodes.clone(),
summary_node_text.clone(),
Some(full_job.step_history.clone()),
full_job.prompts.clone(),
tools.clone(),
None,
);
Expand Down Expand Up @@ -378,7 +376,6 @@ impl GenericInferenceChain {
message_hash_id.clone(),
image_files.clone(),
llm_provider.clone(),
execution_context.clone(),
generator.clone(),
user_profile.clone(),
max_iterations,
Expand All @@ -394,7 +391,10 @@ impl GenericInferenceChain {

// 6) Call workflow or tooling
// Find the ShinkaiTool that has a tool with the function name
let shinkai_tool = tools.iter().find(|tool| tool.name() == function_call.name || tool.tool_router_key() == function_call.tool_router_key.clone().unwrap_or_default());
let shinkai_tool = tools.iter().find(|tool| {
tool.name() == function_call.name
|| tool.tool_router_key() == function_call.tool_router_key.clone().unwrap_or_default()
});
if shinkai_tool.is_none() {
eprintln!("Function not found: {}", function_call.name);
return Err(LLMProviderError::FunctionNotFound(function_call.name.clone()));
Expand Down Expand Up @@ -439,7 +439,7 @@ impl GenericInferenceChain {
image_files.clone(),
ret_nodes.clone(),
summary_node_text.clone(),
Some(full_job.step_history.clone()),
full_job.prompts.clone(),
tools.clone(),
Some(function_response),
);
Expand All @@ -451,7 +451,6 @@ impl GenericInferenceChain {
response.response_string,
response.tps.map(|tps| tps.to_string()),
answer_duration_ms,
execution_context.clone(),
Some(tool_calls_history.clone()),
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ impl JobPromptGenerator {
image_files: HashMap<String, String>,
ret_nodes: Vec<RetrievedNode>,
_summary_text: Option<String>,
job_step_history: Option<Vec<JobStepResult>>,
job_prompts: Vec<Prompt>,
tools: Vec<ShinkaiTool>,
function_call: Option<ToolCallFunctionResponse>,
) -> Prompt {
Expand All @@ -36,9 +36,12 @@ impl JobPromptGenerator {
let has_ret_nodes = !ret_nodes.is_empty();

// Add previous messages
// TODO: this should be full messages with assets and not just strings
if let Some(step_history) = job_step_history {
prompt.add_step_history(step_history, 97);
if !job_prompts.is_empty() {
let sub_prompts = job_prompts
.into_iter()
.flat_map(|prompt| prompt.sub_prompts.clone())
.collect();
prompt.add_sub_prompts_with_new_priority(sub_prompts, 97);
}

// Add tools if any. Decrease priority every 2 tools
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@benolt we removed the step history here but we are not adding the previous messages to the conversation 👀

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use tokio::sync::{Mutex, RwLock};
impl JobManager {
/// Chooses an inference chain based on the job message (using the agent's LLM)
/// and then starts using the chosen chain.
/// Returns the final String result from the inferencing, and a new execution context.
/// Returns the final String result from the inferencing.
#[allow(clippy::too_many_arguments)]
pub async fn inference_chain_router(
db: Arc<RwLock<SqliteManager>>,
Expand All @@ -34,7 +34,6 @@ impl JobManager {
job_message: JobMessage,
message_hash_id: Option<String>,
image_files: HashMap<String, String>,
prev_execution_context: HashMap<String, String>,
generator: RemoteEmbeddingGenerator,
user_profile: ShinkaiName,
ws_manager_trait: Option<Arc<Mutex<dyn WSUpdateHandler + Send>>>,
Expand Down Expand Up @@ -75,7 +74,6 @@ impl JobManager {
message_hash_id,
image_files,
llm_provider,
prev_execution_context,
generator,
user_profile,
3,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ pub trait InferenceChainContextTrait: Send + Sync {
fn message_hash_id(&self) -> Option<String>;
fn image_files(&self) -> &HashMap<String, String>;
fn agent(&self) -> &ProviderOrAgent;
fn execution_context(&self) -> &HashMap<String, String>;
fn generator(&self) -> &RemoteEmbeddingGenerator;
fn user_profile(&self) -> &ShinkaiName;
fn max_iterations(&self) -> u64;
Expand Down Expand Up @@ -134,10 +133,6 @@ impl InferenceChainContextTrait for InferenceChainContext {
&self.llm_provider
}

fn execution_context(&self) -> &HashMap<String, String> {
&self.execution_context
}

fn generator(&self) -> &RemoteEmbeddingGenerator {
&self.generator
}
Expand Down Expand Up @@ -207,8 +202,6 @@ pub struct InferenceChainContext {
pub message_hash_id: Option<String>,
pub image_files: HashMap<String, String>,
pub llm_provider: ProviderOrAgent,
/// Job's execution context, used to store potentially relevant data across job steps.
pub execution_context: HashMap<String, String>,
pub generator: RemoteEmbeddingGenerator,
pub user_profile: ShinkaiName,
pub max_iterations: u64,
Expand All @@ -235,7 +228,6 @@ impl InferenceChainContext {
message_hash_id: Option<String>,
image_files: HashMap<String, String>,
llm_provider: ProviderOrAgent,
execution_context: HashMap<String, String>,
generator: RemoteEmbeddingGenerator,
user_profile: ShinkaiName,
max_iterations: u64,
Expand All @@ -257,7 +249,6 @@ impl InferenceChainContext {
message_hash_id,
image_files,
llm_provider,
execution_context,
generator,
user_profile,
max_iterations,
Expand Down Expand Up @@ -296,7 +287,6 @@ impl fmt::Debug for InferenceChainContext {
.field("message_hash_id", &self.message_hash_id)
.field("image_files", &self.image_files.len())
.field("llm_provider", &self.llm_provider)
.field("execution_context", &self.execution_context)
.field("generator", &self.generator)
.field("user_profile", &self.user_profile)
.field("max_iterations", &self.max_iterations)
Expand All @@ -319,15 +309,13 @@ pub struct InferenceChainResult {
pub response: String,
pub tps: Option<String>,
pub answer_duration: Option<String>,
pub new_job_execution_context: HashMap<String, String>,
pub tool_calls: Option<Vec<FunctionCall>>,
}

impl InferenceChainResult {
pub fn new(response: String, new_job_execution_context: HashMap<String, String>) -> Self {
pub fn new(response: String) -> Self {
Self {
response,
new_job_execution_context,
tps: None,
answer_duration: None,
tool_calls: None,
Expand All @@ -338,26 +326,16 @@ impl InferenceChainResult {
response: String,
tps: Option<String>,
answer_duration_ms: Option<String>,
new_job_execution_context: HashMap<String, String>,
tool_calls: Option<Vec<FunctionCall>>,
) -> Self {
Self {
response,
tps,
answer_duration: answer_duration_ms,
new_job_execution_context,
tool_calls,
}
}

pub fn new_empty_execution_context(response: String) -> Self {
Self::new(response, HashMap::new())
}

pub fn new_empty() -> Self {
Self::new_empty_execution_context(String::new())
}

pub fn tool_calls_metadata(&self) -> Option<Vec<FunctionCallMetadata>> {
self.tool_calls
.as_ref()
Expand Down Expand Up @@ -452,10 +430,6 @@ impl InferenceChainContextTrait for Box<dyn InferenceChainContextTrait> {
(**self).agent()
}

fn execution_context(&self) -> &HashMap<String, String> {
(**self).execution_context()
}

fn generator(&self) -> &RemoteEmbeddingGenerator {
(**self).generator()
}
Expand Down Expand Up @@ -517,7 +491,6 @@ impl InferenceChainContextTrait for Box<dyn InferenceChainContextTrait> {
pub struct MockInferenceChainContext {
pub user_message: ParsedUserMessage,
pub image_files: HashMap<String, String>,
pub execution_context: HashMap<String, String>,
pub user_profile: ShinkaiName,
pub max_iterations: u64,
pub iteration_count: u64,
Expand All @@ -535,7 +508,6 @@ impl MockInferenceChainContext {
#[allow(dead_code)]
pub fn new(
user_message: ParsedUserMessage,
execution_context: HashMap<String, String>,
user_profile: ShinkaiName,
max_iterations: u64,
iteration_count: u64,
Expand All @@ -550,7 +522,6 @@ impl MockInferenceChainContext {
Self {
user_message,
image_files: HashMap::new(),
execution_context,
user_profile,
max_iterations,
iteration_count,
Expand All @@ -575,7 +546,6 @@ impl Default for MockInferenceChainContext {
Self {
user_message,
image_files: HashMap::new(),
execution_context: HashMap::new(),
user_profile,
max_iterations: 10,
iteration_count: 0,
Expand Down Expand Up @@ -635,10 +605,6 @@ impl InferenceChainContextTrait for MockInferenceChainContext {
unimplemented!()
}

fn execution_context(&self) -> &HashMap<String, String> {
&self.execution_context
}

fn generator(&self) -> &RemoteEmbeddingGenerator {
unimplemented!()
}
Expand Down Expand Up @@ -701,7 +667,6 @@ impl Clone for MockInferenceChainContext {
Self {
user_message: self.user_message.clone(),
image_files: self.image_files.clone(),
execution_context: self.execution_context.clone(),
user_profile: self.user_profile.clone(),
max_iterations: self.max_iterations,
iteration_count: self.iteration_count,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ impl InferenceChain for SheetUIInferenceChain {
self.context.message_hash_id.clone(),
self.context.image_files.clone(),
self.context.llm_provider.clone(),
self.context.execution_context.clone(),
self.context.generator.clone(),
self.context.user_profile.clone(),
self.context.max_iterations,
Expand All @@ -81,8 +80,7 @@ impl InferenceChain for SheetUIInferenceChain {
self.context.llm_stopper.clone(),
)
.await?;
let job_execution_context = self.context.execution_context.clone();
Ok(InferenceChainResult::new(response, job_execution_context))
Ok(InferenceChainResult::new(response))
}
}

Expand Down Expand Up @@ -110,7 +108,6 @@ impl SheetUIInferenceChain {
message_hash_id: Option<String>,
image_files: HashMap<String, String>,
llm_provider: ProviderOrAgent,
execution_context: HashMap<String, String>,
generator: RemoteEmbeddingGenerator,
user_profile: ShinkaiName,
max_iterations: u64,
Expand Down Expand Up @@ -280,7 +277,7 @@ impl SheetUIInferenceChain {
image_files.clone(),
ret_nodes.clone(),
summary_node_text.clone(),
Some(full_job.step_history.clone()),
full_job.prompts.clone(),
tools.clone(),
None,
);
Expand Down Expand Up @@ -380,7 +377,6 @@ impl SheetUIInferenceChain {
message_hash_id.clone(),
image_files.clone(),
llm_provider.clone(),
execution_context.clone(),
generator.clone(),
user_profile.clone(),
max_iterations,
Expand Down Expand Up @@ -416,7 +412,7 @@ impl SheetUIInferenceChain {
image_files.clone(),
ret_nodes.clone(),
summary_node_text.clone(),
Some(full_job.step_history.clone()),
full_job.prompts.clone(),
tools.clone(),
Some(function_response),
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,14 +273,6 @@ impl JobManager {
ShinkaiLogLevel::Debug,
&format!("Retrieved {} image files", image_files.len()),
);

// Setup initial data to get ready to call a specific inference chain
let prev_execution_context = full_job.execution_context.clone();
shinkai_log(
ShinkaiLogOption::JobExecution,
ShinkaiLogLevel::Debug,
&format!("Prev Execution Context: {:?}", prev_execution_context),
);
let start = Instant::now();

// Call the inference chain router to choose which chain to use, and call it
Expand All @@ -292,7 +284,6 @@ impl JobManager {
job_message.clone(),
message_hash_id,
image_files.clone(),
prev_execution_context,
generator,
user_profile.clone(),
ws_manager.clone(),
Expand All @@ -305,7 +296,6 @@ impl JobManager {
)
.await?;
let inference_response_content = inference_response.response.clone();
let new_execution_context = inference_response.new_job_execution_context.clone();

let duration = start.elapsed();
shinkai_log(
Expand Down Expand Up @@ -340,7 +330,7 @@ impl JobManager {
);

// Save response data to DB
db.write().await.add_step_history(
db.write().await.add_job_prompt(
job_message.job_id.clone(),
job_message.content,
Some(image_files),
Expand All @@ -352,9 +342,6 @@ impl JobManager {
.await
.add_message_to_job_inbox(&job_message.job_id.clone(), &shinkai_message, None, ws_manager)
.await?;
db.write()
.await
.set_job_execution_context(job_message.job_id.clone(), new_execution_context, None)?;

// Check for callbacks and add them to the JobManagerQueue if required
if let Some(callback) = &job_message.callback {
Expand Down Expand Up @@ -466,7 +453,6 @@ impl JobManager {
job_message.clone(),
message_hash_id,
empty_files,
HashMap::new(), // Assuming prev_execution_context is an empty HashMap
generator,
user_profile.clone(),
ws_manager.clone(),
Expand Down
Loading