Skip to content

Commit

Permalink
Merge pull request #647 from dcSpark/nico/extend_config
Browse files Browse the repository at this point in the history
extend config with tools and system prompt
  • Loading branch information
nicarq authored Nov 11, 2024
2 parents 3263d62 + 04552db commit 8014418
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,14 @@ impl GenericInferenceChain {
);
let mut tools = vec![];
let stream = job_config.as_ref().and_then(|config| config.stream);
let tools_allowed = job_config.as_ref().and_then(|config| config.use_tools).unwrap_or(true);
let use_tools = ModelCapabilitiesManager::has_tool_capabilities_for_provider_or_agent(
llm_provider.clone(),
db.clone(),
stream,
);

if use_tools {
if use_tools && tools_allowed {
// If the llm_provider is an Agent, retrieve tools directly from the Agent struct
if let ProviderOrAgent::Agent(agent) = &llm_provider {
for tool_name in &agent.tools {
Expand Down Expand Up @@ -251,9 +252,17 @@ impl GenericInferenceChain {
}
});

let custom_system_prompt = job_config.and_then(|config| config.custom_system_prompt.clone()).or_else(|| {
if let ProviderOrAgent::Agent(agent) = &llm_provider {
agent.config.as_ref().and_then(|config| config.custom_system_prompt.clone())
} else {
None
}
});

let mut filled_prompt = JobPromptGenerator::generic_inference_prompt(
custom_system_prompt,
custom_prompt,
None, // TODO: connect later on
user_message.clone(),
image_files.clone(),
ret_nodes.clone(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,22 @@ impl JobPromptGenerator {

// Add the user question and the preference prompt for the answer
if !user_message.is_empty() {
let user_prompt = custom_user_prompt.unwrap_or_default();
let content = if user_prompt.is_empty() {
user_message.clone()
} else {
format!("{}\n {}", user_message, user_prompt)
};
let mut content = user_message.clone();

// If a custom user prompt is provided, use it as a template
if let Some(template) = custom_user_prompt {
let mut template_content = template.clone();

// Insert user_message into the template
if !template_content.contains("{{user_message}}") {
template_content.push_str(&format!("\n{}", user_message));
} else {
template_content = template_content.replace("{{user_message}}", &user_message);
}

content = template_content;
}

prompt.add_omni(content, image_files, SubPromptType::UserLastMessage, 100);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,13 +214,14 @@ impl SheetUIInferenceChain {
let job_config = full_job.config();
let mut tools = vec![];
let stream = job_config.as_ref().and_then(|config| config.stream);
let tools_allowed = job_config.as_ref().and_then(|config| config.use_tools).unwrap_or(true);
let use_tools = ModelCapabilitiesManager::has_tool_capabilities_for_provider_or_agent(
llm_provider.clone(),
db.clone(),
stream,
);

if use_tools {
if use_tools && tools_allowed {
tools.extend(SheetRustFunctions::sheet_rust_fn());

if let Some(tool_router) = &tool_router {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,7 @@ impl Node {
match db.get_job_with_options(&job_id, false, false) {
Ok(job) => {
let config = job.config().cloned().unwrap_or_else(|| JobConfig {
custom_system_prompt: None,
custom_prompt: None,
temperature: None,
seed: None,
Expand All @@ -717,6 +718,7 @@ impl Node {
stream: None,
max_tokens: None,
other_model_params: None,
use_tools: None,
});
let _ = res.send(Ok(config)).await;
Ok(())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@ use utoipa::ToSchema;

#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, ToSchema)]
pub struct JobConfig {
pub custom_system_prompt: Option<String>,
pub custom_prompt: Option<String>,
// pub custom_system_prompt: String
pub temperature: Option<f64>,
pub max_tokens: Option<u64>,
pub seed: Option<u64>,
pub top_k: Option<u64>,
pub top_p: Option<f64>,
pub stream: Option<bool>,
pub other_model_params: Option<Value>,
pub use_tools: Option<bool>,
// TODO: add ctx_...
}

Expand All @@ -21,13 +22,15 @@ impl JobConfig {
pub fn merge(&self, other: &JobConfig) -> JobConfig {
JobConfig {
// Prefer `self` (provided config) over `other` (agent's config)
custom_system_prompt: self.custom_system_prompt.clone().or_else(|| other.custom_system_prompt.clone()),
custom_prompt: self.custom_prompt.clone().or_else(|| other.custom_prompt.clone()),
temperature: self.temperature.or(other.temperature),
max_tokens: self.max_tokens.or(other.max_tokens),
seed: self.seed.or(other.seed),
top_k: self.top_k.or(other.top_k),
top_p: self.top_p.or(other.top_p),
stream: self.stream.or(other.stream),
use_tools: self.use_tools.or(other.use_tools),
other_model_params: self
.other_model_params
.clone()
Expand All @@ -38,6 +41,7 @@ impl JobConfig {
/// Creates an empty JobConfig with all fields set to None.
pub fn empty() -> JobConfig {
JobConfig {
custom_system_prompt: None,
custom_prompt: None,
temperature: None,
max_tokens: None,
Expand All @@ -46,6 +50,7 @@ impl JobConfig {
top_p: None,
stream: None,
other_model_params: None,
use_tools: None,
}
}
}

0 comments on commit 8014418

Please sign in to comment.