diff --git a/crates/tabby/src/serve/completions/prompt.rs b/crates/tabby/src/serve/completions/prompt.rs index 6efebef06bf1..d78f4b861ee3 100644 --- a/crates/tabby/src/serve/completions/prompt.rs +++ b/crates/tabby/src/serve/completions/prompt.rs @@ -4,6 +4,7 @@ use lazy_static::lazy_static; use regex::Regex; use strfmt::strfmt; use tabby_common::languages::get_language; +use tantivy::{query::BooleanQuery, query_grammar::Occur}; use textdistance::Algorithm; use tracing::warn; @@ -106,17 +107,16 @@ fn build_prefix(language: &str, prefix: &str, snippets: &[Snippet]) -> String { fn collect_snippets(index_server: &IndexServer, language: &str, text: &str) -> Vec { let mut ret = Vec::new(); - let mut tokens = Box::new(tokenize_text(text)); + let mut tokens = tokenize_text(text); - let sanitized_text = tokens.join(" "); - let sanitized_text = sanitized_text.trim(); - if sanitized_text.is_empty() { - return ret; - } - - let query_text = format!("language:{} AND ({})", language, sanitized_text); + let language_query = index_server.language_query(language).unwrap(); + let body_query = index_server.body_query(&tokens).unwrap(); + let query = BooleanQuery::new(vec![ + (Occur::Must, language_query), + (Occur::Must, body_query), + ]); - let serp = match index_server.search(&query_text, MAX_SNIPPETS_TO_FETCH, 0) { + let serp = match index_server.search_with_query(&query, MAX_SNIPPETS_TO_FETCH, 0) { Ok(serp) => serp, Err(IndexServerError::NotReady) => { // Ignore. @@ -154,7 +154,7 @@ fn collect_snippets(index_server: &IndexServer, language: &str, text: &str) -> V // Prepend body tokens and update tokens, so future similarity calculation will consider // added snippets. body_tokens.append(&mut tokens); - *tokens = body_tokens; + tokens.append(&mut body_tokens); count_characters += body.len(); ret.push(Snippet { @@ -172,11 +172,7 @@ lazy_static! { } fn tokenize_text(text: &str) -> Vec { - TOKENIZER - .split(text) - .filter(|s| *s != "AND" && *s != "OR" && *s != "NOT" && !s.is_empty()) - .map(|x| x.to_owned()) - .collect() + TOKENIZER.split(text).map(|x| x.to_owned()).collect() } #[cfg(test)] diff --git a/crates/tabby/src/serve/search.rs b/crates/tabby/src/serve/search.rs index f8249a20d024..2be9e1579468 100644 --- a/crates/tabby/src/serve/search.rs +++ b/crates/tabby/src/serve/search.rs @@ -10,13 +10,13 @@ use serde::{Deserialize, Serialize}; use tabby_common::{index::IndexExt, path}; use tantivy::{ collector::{Count, TopDocs}, - query::QueryParser, - schema::Field, - DocAddress, Document, Index, IndexReader, + query::{QueryParser, TermQuery, TermSetQuery}, + schema::{Field, IndexRecordOption}, + DocAddress, Document, Index, IndexReader, Term, }; use thiserror::Error; use tokio::{sync::OnceCell, task, time::sleep}; -use tracing::{debug, instrument, log::info}; +use tracing::{debug, instrument, log::info, warn}; use utoipa::{IntoParams, ToSchema}; #[derive(Deserialize, IntoParams)] @@ -70,15 +70,18 @@ pub async fn search( State(state): State>, query: Query, ) -> Result, StatusCode> { - let Ok(serp) = state.search( + match state.search( &query.q, query.limit.unwrap_or(20), query.offset.unwrap_or(0), - ) else { - return Err(StatusCode::NOT_IMPLEMENTED); - }; - - Ok(Json(serp)) + ) { + Ok(serp) => Ok(Json(serp)), + Err(IndexServerError::NotReady) => Err(StatusCode::NOT_IMPLEMENTED), + Err(IndexServerError::TantivyError(err)) => { + warn!("{}", err); + Err(StatusCode::INTERNAL_SERVER_ERROR) + } + } } struct IndexServerImpl { @@ -119,17 +122,19 @@ impl IndexServerImpl { } pub fn search(&self, q: &str, limit: usize, offset: usize) -> tantivy::Result { - let query = self - .query_parser - .parse_query(q) - .expect("Parsing the query failed"); + let query = self.query_parser.parse_query(q)?; + self.search_with_query(&query, limit, offset) + } + + pub fn search_with_query( + &self, + q: &dyn tantivy::query::Query, + limit: usize, + offset: usize, + ) -> tantivy::Result { let searcher = self.reader.searcher(); - let (top_docs, num_hits) = { - searcher.search( - &query, - &(TopDocs::with_limit(limit).and_offset(offset), Count), - )? - }; + let (top_docs, num_hits) = + { searcher.search(q, &(TopDocs::with_limit(limit).and_offset(offset), Count))? }; let hits: Vec = { top_docs .iter() @@ -179,8 +184,15 @@ impl IndexServer { Self {} } - fn get_cell(&self) -> Option<&IndexServerImpl> { - IMPL.get() + fn with_impl(&self, op: F) -> Result + where + F: FnOnce(&IndexServerImpl) -> Result, + { + if let Some(imp) = IMPL.get() { + op(imp) + } else { + Err(IndexServerError::NotReady) + } } async fn worker() -> IndexServerImpl { @@ -199,17 +211,41 @@ impl IndexServer { } } + pub fn language_query(&self, language: &str) -> Result, IndexServerError> { + self.with_impl(|imp| { + Ok(Box::new(TermQuery::new( + Term::from_field_text(imp.field_language, language), + IndexRecordOption::WithFreqsAndPositions, + ))) + }) + } + + pub fn body_query(&self, tokens: &[String]) -> Result, IndexServerError> { + self.with_impl(|imp| { + Ok(Box::new(TermSetQuery::new( + tokens + .iter() + .map(|x| Term::from_field_text(imp.field_body, x)), + ))) + }) + } + pub fn search( &self, q: &str, limit: usize, offset: usize, ) -> Result { - if let Some(imp) = self.get_cell() { - Ok(imp.search(q, limit, offset)?) - } else { - Err(IndexServerError::NotReady) - } + self.with_impl(|imp| Ok(imp.search(q, limit, offset)?)) + } + + pub fn search_with_query( + &self, + q: &dyn tantivy::query::Query, + limit: usize, + offset: usize, + ) -> Result { + self.with_impl(|imp| Ok(imp.search_with_query(q, limit, offset)?)) } } @@ -218,6 +254,6 @@ pub enum IndexServerError { #[error("index not ready")] NotReady, - #[error("underlying tantivy error")] + #[error("{0}")] TantivyError(#[from] tantivy::TantivyError), }