diff --git a/ee/tabby-webserver/src/service/answer.rs b/ee/tabby-webserver/src/service/answer.rs index 5992030a5405..ea81da9550ca 100644 --- a/ee/tabby-webserver/src/service/answer.rs +++ b/ee/tabby-webserver/src/service/answer.rs @@ -1,4 +1,10 @@ -use std::sync::Arc; +use std::{ + collections::HashMap, + fs::File, + io::{BufRead, BufReader, Read}, + path::{Path, PathBuf}, + sync::Arc, +}; use anyhow::anyhow; use async_openai::{ @@ -12,7 +18,10 @@ use async_stream::stream; use futures::stream::BoxStream; use tabby_common::{ api::{ - code::{CodeSearch, CodeSearchError, CodeSearchHit, CodeSearchParams, CodeSearchQuery}, + code::{ + CodeSearch, CodeSearchError, CodeSearchHit, CodeSearchParams, CodeSearchQuery, + CodeSearchScores, + }, doc::{DocSearch, DocSearchError, DocSearchHit}, }, config::AnswerConfig, @@ -21,6 +30,7 @@ use tabby_inference::ChatCompletionStream; use tabby_schema::{ context::{ContextInfoHelper, ContextService}, policy::AccessPolicy, + repository::{Repository, RepositoryService}, thread::{ self, CodeQueryInput, CodeSearchParamsOverrideInput, DocQueryInput, MessageAttachment, ThreadAssistantMessageAttachmentsCode, ThreadAssistantMessageAttachmentsDoc, @@ -39,6 +49,7 @@ pub struct AnswerService { doc: Arc, context: Arc, serper: Option>, + repository: Arc, } impl AnswerService { @@ -49,6 +60,7 @@ impl AnswerService { doc: Arc, context: Arc, serper: Option>, + repository: Arc, ) -> Self { Self { config: config.clone(), @@ -57,6 +69,7 @@ impl AnswerService { doc, context, serper, + repository, } } @@ -92,7 +105,8 @@ impl AnswerService { &context_info_helper, code_query, &self.config.code_search_params, - options.debug_options.as_ref().and_then(|x| x.code_search_params_override.as_ref()) + options.debug_options.as_ref().and_then(|x| x.code_search_params_override.as_ref()), + policy.clone(), ).await; attachment.code = hits.iter().map(|x| x.doc.clone().into()).collect::>(); @@ -112,6 +126,8 @@ impl AnswerService { .map(|x| x.doc.clone().into()) .collect::>(); + debug!("doc content: {:?}", doc_query.content); + if !attachment.doc.is_empty() { let hits = hits.into_iter().map(|x| x.into()).collect::>(); yield Ok(ThreadRunItem::ThreadAssistantMessageAttachmentsDoc( @@ -184,10 +200,10 @@ impl AnswerService { input: &CodeQueryInput, params: &CodeSearchParams, override_params: Option<&CodeSearchParamsOverrideInput>, + policy: AccessPolicy, ) -> Vec { let source_id: Option<&str> = { if let Some(source_id) = &input.source_id { - // If source_id doesn't exist, return empty result. if helper.can_access_source_id(source_id) { Some(source_id.as_str()) } else { @@ -204,6 +220,11 @@ impl AnswerService { return vec![]; }; + let repo = match self.repository.repository_list(Some(&policy)).await { + Ok(repos) => repos.into_iter().find(|x| x.source_id == source_id), + Err(_) => return vec![], + }; + let query = CodeSearchQuery::new( input.filepath.clone(), input.language.clone(), @@ -212,12 +233,12 @@ impl AnswerService { ); let mut params = params.clone(); - override_params - .as_ref() - .inspect(|x| x.override_params(&mut params)); + if let Some(override_params) = override_params { + override_params.override_params(&mut params); + } match self.code.search_in_language(query, params).await { - Ok(docs) => docs.hits, + Ok(docs) => merge_code_snippets(repo, docs.hits).await, Err(err) => { if let CodeSearchError::NotReady = err { debug!("Code search is not ready yet"); @@ -358,8 +379,9 @@ pub fn create( doc: Arc, context: Arc, serper: Option>, + repository: Arc, ) -> AnswerService { - AnswerService::new(config, chat, code, doc, context, serper) + AnswerService::new(config, chat, code, doc, context, serper, repository) } fn convert_messages_to_chat_completion_request( @@ -490,6 +512,95 @@ Remember, don't blindly repeat the contexts verbatim. When possible, give code s ) } +/// Combine code snippets from search results rather than utilizing multiple hits: Presently, there is only one rule: if the number of lines of code (LoC) is less than 200, and there are multiple hits (number of hits > 1), include the entire file. +pub async fn merge_code_snippets( + repository: Option, + hits: Vec, +) -> Vec { + let Some(repository) = repository else { + return hits; + }; + + // group hits by filepath + let mut file_hits: HashMap> = HashMap::new(); + for hit in hits.clone().into_iter() { + let key = format!("{}-{}", repository.source_id, hit.doc.filepath); + file_hits.entry(key).or_default().push(hit); + } + + let mut result = Vec::with_capacity(file_hits.len()); + + for (_, file_hits) in file_hits { + if file_hits.len() > 1 { + // construct the full path to the file + let path: PathBuf = repository.dir.join(&file_hits[0].doc.filepath); + let file_content = match read_file_content(&path) { + Some(lines) => lines, + None => { + //cannot read the file, just extend the hits + result.extend(file_hits); + continue; + } + }; + + if !file_content.is_empty() { + let mut insert_hit = file_hits[0].clone(); + insert_hit.scores = + file_hits + .iter() + .fold(CodeSearchScores::default(), |mut acc, hit| { + acc.bm25 += hit.scores.bm25; + acc.embedding += hit.scores.embedding; + acc.rrf += hit.scores.rrf; + acc + }); + // average the scores + let num_files = file_hits.len() as f32; + insert_hit.scores.bm25 /= num_files; + insert_hit.scores.embedding /= num_files; + insert_hit.scores.rrf /= num_files; + insert_hit.doc.body = file_content; + result.push(insert_hit); + } + } else { + result.extend(file_hits); + } + } + result +} + +/// Read file content and return raw file content string, it will return nothing if the file is over 200 lines +pub fn read_file_content(path: &Path) -> Option { + if count_lines(path).ok()? > 200 { + return None; + } + + let mut file = match File::open(path) { + Ok(file) => file, + Err(e) => { + warn!("Error opening file {}: {}", path.display(), e); + return None; + } + }; + let mut content = String::new(); + match file.read_to_string(&mut content) { + Ok(_) => Some(content), + Err(e) => { + warn!("Error reading file {}: {}", path.display(), e); + None + } + } +} + +fn count_lines(path: &Path) -> std::io::Result { + let mut count = 0; + for line in BufReader::new(File::open(path)?).lines() { + line?; + count += 1; + } + Ok(count) +} + #[cfg(test)] pub mod testutils; @@ -501,7 +612,9 @@ mod tests { use juniper::ID; use tabby_common::{ api::{ - code::{CodeSearch, CodeSearchParams}, + code::{ + CodeSearch, CodeSearchDocument, CodeSearchHit, CodeSearchParams, CodeSearchScores, + }, doc::DocSearch, }, config::AnswerConfig, @@ -517,9 +630,10 @@ mod tests { }; use crate::answer::{ + merge_code_snippets, testutils::{ - FakeChatCompletionStream, FakeCodeSearch, FakeCodeSearchFail, - FakeCodeSearchFailNotReady, FakeContextService, FakeDocSearch, + make_policy, make_repository_service, FakeChatCompletionStream, FakeCodeSearch, + FakeCodeSearchFail, FakeCodeSearchFailNotReady, FakeContextService, FakeDocSearch, }, trim_bullet, AnswerService, }; @@ -694,6 +808,10 @@ mod tests { let context: Arc = Arc::new(FakeContextService); let mut serper = Some(Box::new(FakeDocSearch) as Box); let config = make_answer_config(); + + let db = DbConn::new_in_memory().await.unwrap(); + let repo = make_repository_service(db).await.unwrap(); + let mut service = AnswerService::new( &config, chat.clone(), @@ -701,6 +819,7 @@ mod tests { doc.clone(), context.clone(), serper, + repo.clone(), ); let code_query_input_could_access = make_code_query_input(Some(TEST_SOURCE_ID), Some(TEST_GIT_URL)); @@ -708,12 +827,15 @@ mod tests { let context_info_helper: ContextInfoHelper = make_context_info_helper(); debug_assert!(context_info_helper.can_access_source_id("source-1")); + let policy = make_policy().await; + service .collect_relevant_code( &context_info_helper, &code_query_input_could_access, &code_search_params, None, + policy.clone(), ) .await; @@ -724,6 +846,7 @@ mod tests { &code_query_input_not_access, &code_search_params, None, + policy.clone(), ) .await; @@ -734,6 +857,7 @@ mod tests { &code_query_input_with_only_git, &code_search_params, None, + policy.clone(), ) .await; @@ -744,6 +868,7 @@ mod tests { &code_query_input_with_only_git, &code_search_params, None, + policy.clone(), ) .await; @@ -757,6 +882,7 @@ mod tests { doc.clone(), context.clone(), serper, + repo.clone(), ); let code_fail = Arc::new(FakeCodeSearchFail); @@ -769,7 +895,8 @@ mod tests { doc.clone(), context.clone(), serper, - ) + repo.clone(), + ); } #[tokio::test] @@ -780,6 +907,9 @@ mod tests { let context: Arc = Arc::new(FakeContextService); let serper = Some(Box::new(FakeDocSearch) as Box); let config = make_answer_config(); + let db = DbConn::new_in_memory().await.unwrap(); + let repo = make_repository_service(db).await.unwrap(); + let service = AnswerService::new( &config, chat.clone(), @@ -787,6 +917,7 @@ mod tests { doc.clone(), context.clone(), serper, + repo, ); let attachment = MessageAttachment { @@ -831,6 +962,9 @@ mod tests { let context: Arc = Arc::new(FakeContextService); let serper = Some(Box::new(FakeDocSearch) as Box); let config = make_answer_config(); + let db = DbConn::new_in_memory().await.unwrap(); + let repo = make_repository_service(db).await.unwrap(); + let service = AnswerService::new( &config, chat.clone(), @@ -838,6 +972,7 @@ mod tests { doc.clone(), context.clone(), serper, + repo, ); let context_info_helper = make_context_info_helper(); @@ -894,9 +1029,15 @@ mod tests { let context: Arc = Arc::new(FakeContextService); let serper = Some(Box::new(FakeDocSearch) as Box); - let config = make_answer_config(); + let config = AnswerConfig { + code_search_params: make_code_search_params(), + presence_penalty: 0.1, + system_prompt: AnswerConfig::default_system_prompt(), + }; + let db = DbConn::new_in_memory().await.unwrap(); + let repo = make_repository_service(db).await.unwrap(); let service = Arc::new(AnswerService::new( - &config, chat, code, doc, context, serper, + &config, chat, code, doc, context, serper, repo, )); let db = DbConn::new_in_memory().await.unwrap(); @@ -944,4 +1085,62 @@ mod tests { "Expected 4 items in the result stream" ); } + #[tokio::test] + async fn test_merge_code_snippets() { + let db = DbConn::new_in_memory().await.unwrap(); + let repo_service = make_repository_service(db).await.unwrap(); + + let git_url = "https://github.com/test/repo.git".to_string(); + let id = repo_service + .git() + .create("repo".to_string(), git_url.clone()) + .await + .unwrap(); + + let policy = make_policy().await; + let repo = repo_service + .repository_list(Some(&policy)) + .await + .unwrap() + .pop(); + + let hits = vec![ + CodeSearchHit { + doc: CodeSearchDocument { + file_id: "file1".to_string(), + chunk_id: "chunk1".to_string(), + body: "fn test1() {}\nfn test2() {}".to_string(), + filepath: "test.rs".to_string(), + git_url: "https://github.com/test/repo.git".to_string(), + language: "rust".to_string(), + start_line: 1, + }, + scores: CodeSearchScores { + bm25: 0.5, + embedding: 0.7, + rrf: 0.3, + }, + }, + CodeSearchHit { + doc: CodeSearchDocument { + file_id: "file1".to_string(), + chunk_id: "chunk2".to_string(), + body: "fn test3() {}\nfn test4() {}".to_string(), + filepath: "test.rs".to_string(), + git_url: "https://github.com/test/repo.git".to_string(), + language: "rust".to_string(), + start_line: 3, + }, + scores: CodeSearchScores { + bm25: 0.6, + embedding: 0.8, + rrf: 0.4, + }, + }, + ]; + + let result = merge_code_snippets(repo, hits).await; + + assert_eq!(result.len(), 2); + } } diff --git a/ee/tabby-webserver/src/service/answer/testutils/mod.rs b/ee/tabby-webserver/src/service/answer/testutils/mod.rs index 32577d416555..287838c62691 100644 --- a/ee/tabby-webserver/src/service/answer/testutils/mod.rs +++ b/ee/tabby-webserver/src/service/answer/testutils/mod.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use async_openai::{ error::OpenAIError, types::{ @@ -7,17 +9,24 @@ use async_openai::{ }, }; use axum::async_trait; +use juniper::ID; use tabby_common::api::{ code::{CodeSearch, CodeSearchError, CodeSearchParams, CodeSearchQuery, CodeSearchResponse}, doc::{DocSearch, DocSearchDocument, DocSearchError, DocSearchHit, DocSearchResponse}, }; +use tabby_db::DbConn; use tabby_inference::ChatCompletionStream; use tabby_schema::{ context::{ContextInfo, ContextService}, + integration::IntegrationService, + job::JobService, policy::AccessPolicy, + repository::RepositoryService, Result, }; +use crate::{integration, job, repository}; + pub struct FakeChatCompletionStream; #[async_trait] impl ChatCompletionStream for FakeChatCompletionStream { @@ -200,3 +209,20 @@ impl ContextService for FakeContextService { Ok(ContextInfo { sources: vec![] }) } } +pub async fn make_repository_service(db: DbConn) -> Result> { + let job_service: Arc = Arc::new(job::create(db.clone()).await); + let integration_service: Arc = + Arc::new(integration::create(db.clone(), job_service.clone())); + Ok(repository::create( + db.clone(), + integration_service.clone(), + job_service.clone(), + )) +} +pub async fn make_policy() -> AccessPolicy { + AccessPolicy::new( + DbConn::new_in_memory().await.unwrap(), + &ID::from("nihao".to_string()), + false, + ) +} diff --git a/ee/tabby-webserver/src/service/thread.rs b/ee/tabby-webserver/src/service/thread.rs index 03e8186925c7..6bc6d09c5256 100644 --- a/ee/tabby-webserver/src/service/thread.rs +++ b/ee/tabby-webserver/src/service/thread.rs @@ -287,7 +287,8 @@ mod tests { use super::*; use crate::answer::testutils::{ - FakeChatCompletionStream, FakeCodeSearch, FakeContextService, FakeDocSearch, + make_repository_service, FakeChatCompletionStream, FakeCodeSearch, FakeContextService, + FakeDocSearch, }; #[tokio::test] @@ -510,6 +511,7 @@ mod tests { let context: Arc = Arc::new(FakeContextService); let serper = Some(Box::new(FakeDocSearch) as Box); let config = make_answer_config(); + let repo = make_repository_service(db.clone()).await.unwrap(); let answer_service = Arc::new(crate::answer::create( &config, chat.clone(), @@ -517,6 +519,7 @@ mod tests { doc.clone(), context.clone(), serper, + repo, )); let service = create(db.clone(), Some(answer_service)); diff --git a/ee/tabby-webserver/src/webserver.rs b/ee/tabby-webserver/src/webserver.rs index 3dc5f2f24689..082b90f7260c 100644 --- a/ee/tabby-webserver/src/webserver.rs +++ b/ee/tabby-webserver/src/webserver.rs @@ -95,6 +95,7 @@ impl Webserver { docsearch.clone(), context.clone(), serper, + repository.clone(), )) });