diff --git a/server/src/handlers/chunk_handler.rs b/server/src/handlers/chunk_handler.rs index 1c087dad60..1258a0b7da 100644 --- a/server/src/handlers/chunk_handler.rs +++ b/server/src/handlers/chunk_handler.rs @@ -969,6 +969,8 @@ pub struct SearchChunksReqPayload { pub tag_weights: Option>, /// Set highlight_results to false for a slight latency improvement (1-10ms). If not specified, this defaults to true. This will add `` tags to the chunk_html of the chunks to highlight matching splits and return the highlights on each scored chunk in the response. pub highlight_results: Option, + /// Set highlight_exact_match to true to highlight exact matches from your query. + pub highlight_strategy: Option, /// Set highlight_threshold to a lower or higher value to adjust the sensitivity of the highlights applied to the chunk html. If not specified, this defaults to 0.8. The range is 0.0 to 1.0. pub highlight_threshold: Option, /// Set highlight_delimiters to a list of strings to use as delimiters for highlighting. If not specified, this defaults to ["?", ",", ".", "!"]. These are the characters that will be used to split the chunk_html into splits for highlighting. These are the characters that will be used to split the chunk_html into splits for highlighting. @@ -1003,6 +1005,7 @@ impl Default for SearchChunksReqPayload { use_weights: None, tag_weights: None, highlight_results: None, + highlight_strategy: None, highlight_threshold: None, highlight_delimiters: None, highlight_max_length: None, @@ -1277,6 +1280,8 @@ pub struct AutocompleteReqPayload { pub tag_weights: Option>, /// Set highlight_results to false for a slight latency improvement (1-10ms). If not specified, this defaults to true. This will add `` tags to the chunk_html of the chunks to highlight matching splits and return the highlights on each scored chunk in the response. pub highlight_results: Option, + /// Set highlight_exact_match to true to highlight exact matches from your query. + pub highlight_strategy: Option, /// Set highlight_threshold to a lower or higher value to adjust the sensitivity of the highlights applied to the chunk html. If not specified, this defaults to 0.8. The range is 0.0 to 1.0. pub highlight_threshold: Option, /// Set highlight_delimiters to a list of strings to use as delimiters for highlighting. If not specified, this defaults to ["?", ",", ".", "!"]. These are the characters that will be used to split the chunk_html into splits for highlighting. @@ -1312,6 +1317,7 @@ impl From for SearchChunksReqPayload { use_weights: autocomplete_data.use_weights, tag_weights: autocomplete_data.tag_weights, highlight_results: autocomplete_data.highlight_results, + highlight_strategy: autocomplete_data.highlight_strategy, highlight_threshold: autocomplete_data.highlight_threshold, highlight_delimiters: Some( autocomplete_data @@ -1590,6 +1596,7 @@ impl From for SearchChunksReqPayload { use_weights: None, tag_weights: None, highlight_results: None, + highlight_strategy: None, highlight_threshold: None, highlight_delimiters: None, highlight_max_length: None, diff --git a/server/src/handlers/group_handler.rs b/server/src/handlers/group_handler.rs index 214ddf4ffe..3a5d5f0e21 100644 --- a/server/src/handlers/group_handler.rs +++ b/server/src/handlers/group_handler.rs @@ -13,7 +13,7 @@ use crate::{ errors::ServiceError, middleware::api_version::APIVersion, operators::{ - chunk_operator::get_metadata_from_tracking_id_query, + chunk_operator::{get_metadata_from_tracking_id_query, HighlightStrategy}, clickhouse_operator::{get_latency_from_header, send_to_clickhouse, ClickHouseEvent}, group_operator::*, qdrant_operator::{ @@ -1389,6 +1389,8 @@ pub struct SearchWithinGroupReqPayload { pub tag_weights: Option>, /// Set highlight_results to false for a slight latency improvement (1-10ms). If not specified, this defaults to true. This will add `` tags to the chunk_html of the chunks to highlight matching splits and return the highlights on each scored chunk in the response. pub highlight_results: Option, + /// Set highlight_exact_match to true to highlight exact matches from your query. + pub highlight_strategy: Option, /// Set highlight_threshold to a lower or higher value to adjust the sensitivity of the highlights applied to the chunk html. If not specified, this defaults to 0.8. The range is 0.0 to 1.0. pub highlight_threshold: Option, /// Set highlight_delimiters to a list of strings to use as delimiters for highlighting. If not specified, this defaults to ["?", ",", ".", "!"]. These are the characters that will be used to split the chunk_html into splits for highlighting. @@ -1425,6 +1427,7 @@ impl From for SearchChunksReqPayload { use_weights: search_within_group_data.use_weights, tag_weights: search_within_group_data.tag_weights, highlight_results: search_within_group_data.highlight_results, + highlight_strategy: search_within_group_data.highlight_strategy, highlight_threshold: search_within_group_data.highlight_threshold, highlight_delimiters: search_within_group_data.highlight_delimiters, highlight_max_length: search_within_group_data.highlight_max_length, @@ -1613,6 +1616,8 @@ pub struct SearchOverGroupsReqPayload { pub filters: Option, /// Set highlight_results to false for a slight latency improvement (1-10ms). If not specified, this defaults to true. This will add `` tags to the chunk_html of the chunks to highlight matching splits and return the highlights on each scored chunk in the response. pub highlight_results: Option, + /// Set highlight_exact_match to true to highlight exact matches from your query. + pub highlight_strategy: Option, /// Set highlight_threshold to a lower or higher value to adjust the sensitivity of the highlights applied to the chunk html. If not specified, this defaults to 0.8. The range is 0.0 to 1.0. pub highlight_threshold: Option, /// Set highlight_delimiters to a list of strings to use as delimiters for highlighting. If not specified, this defaults to ["?", ",", ".", "!"]. These are the characters that will be used to split the chunk_html into splits for highlighting. diff --git a/server/src/lib.rs b/server/src/lib.rs index 4894610533..64d260032e 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -333,6 +333,7 @@ impl Modify for SecurityAddon { operators::analytics_operator::CTRRecommendationsWithClicksResponse, operators::analytics_operator::CTRRecommendationsWithoutClicksResponse, handlers::analytics_handler::CTRDataRequestBody, + operators::chunk_operator::HighlightStrategy, handlers::stripe_handler::CreateSetupCheckoutSessionResPayload, data::models::DateRange, data::models::FieldCondition, diff --git a/server/src/operators/chunk_operator.rs b/server/src/operators/chunk_operator.rs index 4606a4d618..29d5251be1 100644 --- a/server/src/operators/chunk_operator.rs +++ b/server/src/operators/chunk_operator.rs @@ -26,7 +26,9 @@ use diesel::upsert::excluded; use diesel_async::scoped_futures::ScopedFutureExt; use diesel_async::{AsyncConnection, RunQueryDsl}; use itertools::Itertools; +use serde::{Deserialize, Serialize}; use simsearch::{SearchOptions, SimSearch}; +use utoipa::ToSchema; #[tracing::instrument(skip(pool))] pub async fn get_chunk_metadatas_from_point_ids( @@ -1338,6 +1340,320 @@ pub fn get_slice_from_vec_string(vec: Vec, index: usize) -> Result Vec { + include_str!("../stop-words.txt") + .lines() + .map(|x| x.to_string()) + .collect() +} + +#[derive(Serialize, Deserialize, Debug, Clone, ToSchema)] +#[serde(rename_all = "lowercase")] +pub enum HighlightStrategy { + ExactMatch, + V1, +} + +pub fn get_highlights_with_exact_match( + input: ChunkMetadata, + query: String, + threshold: Option, + delimiters: Vec, + max_length: Option, + max_num: Option, + window_size: Option, +) -> Result<(ChunkMetadata, Vec), ServiceError> { + let stop_words: HashSet = HashSet::from_iter(get_stop_words().iter().cloned()); + + // remove delimiters from query except ' ' + let query = query.replace( + |c: char| delimiters.contains(&c.to_string()) && c != ' ', + "", + ); + let query_parts_split_by_stop_words: Vec = query + .split(' ') + .collect_vec() + .chunk_by(|a, b| { + !stop_words.contains(&a.to_lowercase()) && !stop_words.contains(&b.to_lowercase()) + }) + .map(|chunk| { + chunk + .iter() + .filter_map(|word| match stop_words.contains(&word.to_lowercase()) { + true => None, + false => Some(word.to_string()), + }) + .collect_vec() + }) + .filter_map(|chunk| match chunk.is_empty() { + true => None, + false => Some(chunk.join(" ")), + }) + .collect_vec(); + + let mut matching_parts: Vec = Vec::new(); + + let content = convert_html_to_text(&(input.chunk_html.clone().unwrap_or_default())); + + let (parts_in_content, parts_not_in_content): (Vec<_>, Vec<_>) = + query_parts_split_by_stop_words + .iter() + .cloned() + .map(|x| { + let lowercase_content = content.to_lowercase(); + let x_lower = x.to_lowercase(); + if let Some(found_idx) = lowercase_content.find(&x_lower) { + content[found_idx..(found_idx + x_lower.len())].to_string() + } else { + x + } + }) + .partition(|x| content.contains(x)); + + matching_parts.extend( + parts_in_content + .into_iter() + .take(max_num.unwrap_or(3) as usize), + ); + + // run old algorithm for parts not in content + + let search_options = SearchOptions::new().threshold(threshold.unwrap_or(0.8)); + let mut engine: SimSearch = SimSearch::new_with(search_options); + + let mut split_content = content + .split_inclusive(|c: char| delimiters.contains(&c.to_string())) + .flat_map(|x| { + x.to_string() + .split_inclusive(' ') + .map(|x| x.to_string()) + .collect::>() + .chunks(max_length.unwrap_or(5) as usize) + .map(|x| x.join("")) + .collect::>() + }) + .collect::>(); + + split_content.iter().enumerate().for_each(|(i, x)| { + engine.insert(i, x); + }); + + let new_output = input; + let results: Vec = parts_not_in_content + .into_iter() + .flat_map(|part| engine.search(&part)) + .collect(); + + let mut matched_idxs = vec![]; + let mut matched_idxs_set = HashSet::new(); + + for x in results.iter().take( + max_num + .unwrap_or(3) + .saturating_sub(matching_parts.len() as u32) as usize, + ) { + matched_idxs_set.insert(*x); + matched_idxs.push(*x); + } + matched_idxs.sort(); + + // sort parts by when they occur in content + matching_parts.sort_by(|a, b| { + let a_idx = content.find(a).unwrap_or(usize::MAX); + let b_idx = content.find(b).unwrap_or(usize::MAX); + a_idx.cmp(&b_idx) + }); + + // hack split_content to accept matching_parts as if it were a result. + for part in &matching_parts { + let start_index_in_content = content.find(part).ok_or( + ServiceError::InternalServerError("Part not in content".to_string()), + )?; + + let end_index_in_content = (start_index_in_content + part.len()).saturating_sub(1); + let mut idx_of_first_split_containing_part = 0; + let mut idx_of_last_split_containing_part = 0; + let mut curr_length = 0; + + for (i, split_part) in split_content.iter().enumerate() { + if start_index_in_content < curr_length + split_part.len() { + idx_of_first_split_containing_part = i; + break; + } + curr_length += split_part.len(); + } + + curr_length = 0; + for (i, split_part) in split_content.iter().enumerate() { + if end_index_in_content < curr_length + split_part.len() { + idx_of_last_split_containing_part = i; + break; + } + curr_length += split_part.len(); + } + + // part found across multiple splits, merge the splits + if idx_of_first_split_containing_part <= idx_of_last_split_containing_part { + // remove them from matched_idxs + matched_idxs.retain(|x| { + *x < idx_of_first_split_containing_part || *x > idx_of_last_split_containing_part + }); + matched_idxs_set = HashSet::from_iter(matched_idxs.clone()); + + let merged_split = split_content + [idx_of_first_split_containing_part..=idx_of_last_split_containing_part] + .join(""); + + let (prefix, suffix) = + merged_split + .split_once(part) + .ok_or(ServiceError::InternalServerError( + "Part not in split".to_string(), + ))?; + + let mut new_splits = vec![]; + if !prefix.is_empty() { + new_splits.push(prefix.to_string()); + } + new_splits.push(part.to_string()); + if !suffix.is_empty() { + new_splits.push(suffix.to_string()); + } + split_content.splice( + idx_of_first_split_containing_part..=idx_of_last_split_containing_part, + new_splits.clone().into_iter(), + ); + + // add new matched_idx + if prefix.is_empty() { + matched_idxs.push(idx_of_first_split_containing_part); + matched_idxs_set.insert(idx_of_first_split_containing_part); + } else { + matched_idxs.push(idx_of_first_split_containing_part + 1); + matched_idxs_set.insert(idx_of_first_split_containing_part + 1); + } + matched_idxs.sort(); + } + } + + let window = window_size.unwrap_or(0); + if window == 0 { + let phrases = matched_idxs + .iter() + .map(|x| split_content.get(*x)) + .filter_map(|x| x.map(|x| x.to_string())) + .collect::>(); + + return Ok(( + apply_highlights_to_html(new_output, phrases.clone()), + phrases.clone(), + )); + } + + let half_window = std::cmp::max(window / 2, 1); + // edge case 1: When the half window size is greater than the length of left or right phrase, + // we need to search further to get the correct windowed phrase + // edge case 2: When two windowed phrases overlap, we need to trim the first one. + let mut windowed_phrases = vec![]; + // Used to keep track of the number of words used in the phrase + let mut used_phrases: HashMap = HashMap::new(); + for idx in matched_idxs.clone() { + let phrase = get_slice_from_vec_string(split_content.clone(), idx)?; + let mut next_phrase = String::new(); + if idx < split_content.len() - 1 { + let mut start = idx + 1; + let mut count: usize = 0; + while (count as u32) < half_window { + if start >= split_content.len() || matched_idxs_set.contains(&start) { + break; + } + let slice = get_slice_from_vec_string(split_content.clone(), start)?; + let candidate_words = slice + .split_inclusive(' ') + .take(half_window as usize - count) + .collect::>(); + used_phrases.insert( + start, + std::cmp::min(candidate_words.len(), half_window as usize - count), + ); + count += candidate_words.len(); + next_phrase.push_str(&candidate_words.join("")); + start += 1; + } + } + let mut prev_phrase = String::new(); + if idx > 0 { + let mut start = idx - 1; + let mut count: usize = 0; + while (count as u32) < half_window { + let slice = get_slice_from_vec_string(split_content.clone(), start)?; + let split_words = slice.split_inclusive(' ').collect::>(); + if matched_idxs_set.contains(&start) { + break; + } + if used_phrases.contains_key(&start) + && split_words.len() + > *used_phrases + .get(&start) + .ok_or(ServiceError::BadRequest("Index out of bounds".to_string()))? + { + let remaining_count = half_window as usize - count; + let available_word_len = split_words.len() + - *used_phrases + .get(&start) + .ok_or(ServiceError::BadRequest("Index out of bounds".to_string()))?; + if remaining_count > available_word_len { + count += remaining_count - available_word_len; + } else { + break; + } + } + if used_phrases.contains_key(&start) + && split_words.len() + <= *used_phrases + .get(&start) + .ok_or(ServiceError::BadRequest("Index out of bounds".to_string()))? + { + break; + } + let candidate_words = split_words + .into_iter() + .rev() + .take(half_window as usize - count) + .collect::>(); + count += candidate_words.len(); + prev_phrase = format!("{}{}", candidate_words.iter().rev().join(""), prev_phrase); + if start == 0 { + break; + } + start -= 1; + } + } + let highlighted_phrase = phrase.replace( + phrase.trim(), + &format!("{}", phrase.trim()), + ); + let windowed_phrase = format!("{}{}{}", prev_phrase, highlighted_phrase, next_phrase); + windowed_phrases.push(windowed_phrase); + } + let matched_phrases = matched_idxs + .clone() + .iter() + .filter_map(|x| split_content.get(*x).cloned()) + .collect::>(); + let result_matches = if windowed_phrases.is_empty() { + matched_phrases.clone() + } else { + windowed_phrases.clone() + }; + + Ok(( + apply_highlights_to_html(new_output, matched_phrases), + result_matches, + )) +} + +#[allow(clippy::too_many_arguments)] #[tracing::instrument] pub fn get_highlights( input: ChunkMetadata, @@ -1369,30 +1685,30 @@ pub fn get_highlights( }); let new_output = input; - let results = engine.search(&query); + let results: Vec = engine.search(&query); + let mut matched_idxs = vec![]; let mut matched_idxs_set = HashSet::new(); for x in results.iter().take(max_num.unwrap_or(3) as usize) { matched_idxs_set.insert(*x); matched_idxs.push(*x); } + matched_idxs.sort(); + let window = window_size.unwrap_or(0); if window == 0 { + let phrases = matched_idxs + .iter() + .map(|x| split_content.get(*x)) + .filter_map(|x| x.map(|x| x.to_string())) + .collect::>(); return Ok(( - apply_highlights_to_html( - new_output, - matched_idxs - .iter() - .map(|x| split_content.get(*x).unwrap().clone()) - .collect(), - ), - matched_idxs - .iter() - .map(|x| split_content.get(*x).unwrap().clone()) - .collect(), + apply_highlights_to_html(new_output, phrases.clone()), + phrases.clone(), )); } + let half_window = std::cmp::max(window / 2, 1); // edge case 1: When the half window size is greater than the length of left or right phrase, // we need to search further to get the correct windowed phrase diff --git a/server/src/operators/search_operator.rs b/server/src/operators/search_operator.rs index 13f53bff11..a6ac6605cd 100644 --- a/server/src/operators/search_operator.rs +++ b/server/src/operators/search_operator.rs @@ -1,7 +1,7 @@ use super::chunk_operator::{ get_chunk_metadatas_and_collided_chunks_from_point_ids_query, - get_content_chunk_from_point_ids_query, get_highlights, get_qdrant_ids_from_chunk_ids_query, - get_slim_chunks_from_point_ids_query, + get_content_chunk_from_point_ids_query, get_highlights, get_highlights_with_exact_match, + get_qdrant_ids_from_chunk_ids_query, get_slim_chunks_from_point_ids_query, HighlightStrategy, }; use super::group_operator::{ get_group_ids_from_tracking_ids_query, get_groups_from_group_ids_query, @@ -1145,23 +1145,46 @@ pub async fn retrieve_chunks_for_groups( let mut highlights: Option> = None; if data.highlight_results.unwrap_or(true) && !data.slim_chunks.unwrap_or(false) { - let (highlighted_chunk, highlighted_snippets) = get_highlights( - chunk.clone().into(), - data.query.clone(), - data.highlight_threshold, - data.highlight_delimiters.clone().unwrap_or(vec![ - ".".to_string(), - "!".to_string(), - "?".to_string(), - "\n".to_string(), - "\t".to_string(), - ",".to_string(), - ]), - data.highlight_max_length, - data.highlight_max_num, - data.highlight_window - ) - .unwrap_or((chunk.clone().into(), vec![])); + let (highlighted_chunk, highlighted_snippets) = match data.highlight_strategy { + Some(HighlightStrategy::ExactMatch) => { + get_highlights_with_exact_match( + chunk.clone().into(), + data.query.clone(), + data.highlight_threshold, + data.highlight_delimiters.clone().unwrap_or(vec![ + ".".to_string(), + "!".to_string(), + "?".to_string(), + "\n".to_string(), + "\t".to_string(), + ",".to_string(), + ]), + data.highlight_max_length, + data.highlight_max_num, + data.highlight_window + ) + .unwrap_or((chunk.clone().into(), vec![])) + }, + _ => { + get_highlights( + chunk.clone().into(), + data.query.clone(), + data.highlight_threshold, + data.highlight_delimiters.clone().unwrap_or(vec![ + ".".to_string(), + "!".to_string(), + "?".to_string(), + "\n".to_string(), + "\t".to_string(), + ",".to_string(), + ]), + data.highlight_max_length, + data.highlight_max_num, + data.highlight_window, + ) + .unwrap_or((chunk.clone().into(), vec![])) + }, + }; highlights = Some(highlighted_snippets); @@ -1400,23 +1423,42 @@ pub async fn retrieve_chunks_from_point_ids( let mut highlights: Option> = None; if data.highlight_results.unwrap_or(true) && !data.slim_chunks.unwrap_or(false) { - let (highlighted_chunk, highlighted_snippets) = get_highlights( - chunk.clone().into(), - data.query.clone(), - data.highlight_threshold, - data.highlight_delimiters.clone().unwrap_or(vec![ - ".".to_string(), - "!".to_string(), - "?".to_string(), - "\n".to_string(), - "\t".to_string(), - ",".to_string(), - ]), - data.highlight_max_length, - data.highlight_max_num, - data.highlight_window, - ) - .unwrap_or((chunk.clone().into(), vec![])); + let (highlighted_chunk, highlighted_snippets) = match data.highlight_strategy { + Some(HighlightStrategy::ExactMatch) => get_highlights_with_exact_match( + chunk.clone().into(), + data.query.clone(), + data.highlight_threshold, + data.highlight_delimiters.clone().unwrap_or(vec![ + ".".to_string(), + "!".to_string(), + "?".to_string(), + "\n".to_string(), + "\t".to_string(), + ",".to_string(), + ]), + data.highlight_max_length, + data.highlight_max_num, + data.highlight_window, + ) + .unwrap_or((chunk.clone().into(), vec![])), + _ => get_highlights( + chunk.clone().into(), + data.query.clone(), + data.highlight_threshold, + data.highlight_delimiters.clone().unwrap_or(vec![ + ".".to_string(), + "!".to_string(), + "?".to_string(), + "\n".to_string(), + "\t".to_string(), + ",".to_string(), + ]), + data.highlight_max_length, + data.highlight_max_num, + data.highlight_window, + ) + .unwrap_or((chunk.clone().into(), vec![])), + }; highlights = Some(highlighted_snippets); diff --git a/server/src/stop-words.txt b/server/src/stop-words.txt new file mode 100644 index 0000000000..5b16b504d6 --- /dev/null +++ b/server/src/stop-words.txt @@ -0,0 +1,127 @@ +i +me +my +myself +we +our +ours +ourselves +you +your +yours +yourself +yourselves +he +him +his +himself +she +her +hers +herself +it +its +itself +they +them +their +theirs +themselves +what +which +who +whom +this +that +these +those +am +is +are +was +were +be +been +being +have +has +had +having +do +does +did +doing +a +an +the +and +but +if +or +because +as +until +while +of +at +by +for +with +about +against +between +into +through +during +before +after +above +below +to +from +up +down +in +out +on +off +over +under +again +further +then +once +here +there +when +where +why +how +all +any +both +each +few +more +most +other +some +such +no +nor +not +only +own +same +so +than +too +very +s +t +can +will +just +don +should +now