From b9ec30ca6a4896d0dafb3483c1d989973cb93399 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Fri, 10 Jan 2025 21:47:47 -0800 Subject: [PATCH] update --- ee/tabby-webserver/src/service/answer.rs | 53 ++++++++++--------- .../src/service/answer/prompt_tools.rs | 22 ++++++++ 2 files changed, 51 insertions(+), 24 deletions(-) diff --git a/ee/tabby-webserver/src/service/answer.rs b/ee/tabby-webserver/src/service/answer.rs index bdd5224cf1e6..eae12ad0845e 100644 --- a/ee/tabby-webserver/src/service/answer.rs +++ b/ee/tabby-webserver/src/service/answer.rs @@ -21,8 +21,7 @@ use async_openai_alt::{ use async_stream::stream; use futures::stream::BoxStream; use prompt_tools::{ - pipeline_decide_need_codebase_commit_history, pipeline_decide_need_codebase_directory_tree, - pipeline_related_questions, + pipeline_decide_need_codebase_commit_history, pipeline_decide_need_codebase_directory_tree, pipeline_decide_need_codebase_snippet, pipeline_related_questions }; use tabby_common::{ api::{ @@ -32,7 +31,7 @@ use tabby_common::{ }, structured_doc::{DocSearch, DocSearchDocument, DocSearchError, DocSearchHit}, }, - config::AnswerConfig, + config::AnswerConfig, index::code, }; use tabby_inference::ChatCompletionStream; use tabby_schema::{ @@ -115,15 +114,6 @@ impl AnswerService { // 1. Collect relevant code if needed. if let Some(code_query) = options.code_query.as_ref() { if let Some(repository) = self.find_repository(&context_info_helper, code_query, policy.clone()).await { - let hits = self.collect_relevant_code( - &repository, - &context_info_helper, - code_query, - &self.config.code_search_params, - options.debug_options.as_ref().and_then(|x| x.code_search_params_override.as_ref()), - ).await; - attachment.code = hits.iter().map(|x| x.doc.clone().into()).collect::>(); - let need_codebase_directory_tree = pipeline_decide_need_codebase_directory_tree(self.chat.clone(), &query.content).await.unwrap_or_default(); if need_codebase_directory_tree { // List at most 300 files in the repository. @@ -148,12 +138,25 @@ impl AnswerService { } } - if !hits.is_empty() { - let hits = hits.into_iter().map(|x| x.into()).collect::>(); - yield Ok(ThreadRunItem::ThreadAssistantMessageAttachmentsCode( - ThreadAssistantMessageAttachmentsCode { code_source_id: repository.source_id, hits } - )); + let need_codebase_snippet = pipeline_decide_need_codebase_snippet(self.chat.clone(), &query.content).await.unwrap_or_default(); + if need_codebase_snippet { + let hits = self.collect_relevant_code( + &repository, + &context_info_helper, + code_query, + &self.config.code_search_params, + options.debug_options.as_ref().and_then(|x| x.code_search_params_override.as_ref()), + ).await; + attachment.code = hits.iter().map(|x| x.doc.clone().into()).collect::>(); + + if !hits.is_empty() { + let hits = hits.into_iter().map(|x| x.into()).collect::>(); + yield Ok(ThreadRunItem::ThreadAssistantMessageAttachmentsCode( + ThreadAssistantMessageAttachmentsCode { code_source_id: repository.source_id, hits } + )); + } } + }; }; @@ -476,7 +479,6 @@ fn convert_messages_to_chat_completion_request( user_attachment_input, code_file_list, ); - debug!(?user_prompt); output.push(ChatCompletionRequestMessage::User( ChatCompletionRequestUserMessage { @@ -500,11 +502,12 @@ fn build_user_prompt( .unwrap_or(true) && assistant_attachment.code.is_empty() && assistant_attachment.doc.is_empty() + && code_file_list.is_none() { return user_input.to_owned(); } - let context = { + let maybe_file_snippet_context= { let snippets: Vec = assistant_attachment .doc .iter() @@ -536,14 +539,18 @@ fn build_user_prompt( .map(|(i, snippet)| format!("[[citation:{}]]\n{}", i + 1, *snippet)) .collect(); - citations.join("\n\n") + if !citations.is_empty() { + format!("Here are set of contexts:\n\n{}\n\n", citations.join("\n\n")) + } else { + String::default() + } }; let maybe_file_list_context = code_file_list .filter(|file_list| !file_list.is_empty()) .map(|file_list| { format!( - "Here is the list of files in the workspace available for reference:\n\n{}", + "Here is the list of files in the workspace available for reference:\n\n{}\n\n", file_list.join("\n") ) }) @@ -556,9 +563,7 @@ Your answer must be correct, accurate and written by an expert using an unbiased Please cite the contexts with the reference numbers, in the format [[citation:x]]. If a sentence comes from multiple contexts, please list all applicable citations, like [[citation:3]][[citation:5]]. Other than code and specific names and citations, your answer must be written in the same language as the question. -{maybe_file_list_context}Here are the set of contexts: - -{context} +{maybe_file_list_context}{maybe_file_snippet_context} Remember, don't blindly repeat the contexts verbatim. When possible, give code snippet to demonstrate the answer. And here is the user question: diff --git a/ee/tabby-webserver/src/service/answer/prompt_tools.rs b/ee/tabby-webserver/src/service/answer/prompt_tools.rs index 8b660974ac69..38b64e37f8cd 100644 --- a/ee/tabby-webserver/src/service/answer/prompt_tools.rs +++ b/ee/tabby-webserver/src/service/answer/prompt_tools.rs @@ -76,6 +76,28 @@ fn detect_yes(content: &str) -> bool { content.to_lowercase().contains("yes") } +/// Decide whether the question requires knowledge from codebase content. +pub async fn pipeline_decide_need_codebase_snippet( + chat: Arc, + question: &str, +) -> Result { + let prompt = format!( + r#"You are a helpful assistant that helps the user to decide whether the question requires content / snippets from codebase. If it requires, return "Yes", otherwise return "No". + +Here's a few examples: +"How to implement an embedding api?" -> Yes +"Which file contains http api definitions" -> Yes +"How many python files is in the codebase?" -> No + +Here's the original question: +{question} +"# + ); + + let content = request_llm(chat, &prompt).await?; + Ok(detect_yes(&content)) +} + /// Decide whether the question requires knowledge from codebase directory structure. pub async fn pipeline_decide_need_codebase_directory_tree( chat: Arc,