diff --git a/server/src/handlers/chunk_handler.rs b/server/src/handlers/chunk_handler.rs index f8c70e0a99..421f2e9ecd 100644 --- a/server/src/handlers/chunk_handler.rs +++ b/server/src/handlers/chunk_handler.rs @@ -993,6 +993,7 @@ impl Default for SearchChunksReqPayload { #[schema(title = "V1")] pub struct SearchChunkQueryResponseBody { pub score_chunks: Vec, + pub corrected_query: Option, pub total_chunk_pages: i64, } @@ -1001,6 +1002,7 @@ pub struct SearchChunkQueryResponseBody { pub struct SearchResponseBody { pub id: uuid::Uuid, pub chunks: Vec, + pub corrected_query: Option, pub total_pages: i64, } @@ -1022,6 +1024,7 @@ impl SearchChunkQueryResponseBody { .into_iter() .map(|chunk| chunk.into()) .collect(), + corrected_query: self.corrected_query, total_pages: self.total_chunk_pages, } } diff --git a/server/src/handlers/group_handler.rs b/server/src/handlers/group_handler.rs index e739965699..87d4b4c468 100644 --- a/server/src/handlers/group_handler.rs +++ b/server/src/handlers/group_handler.rs @@ -1298,6 +1298,7 @@ pub async fn get_recommended_groups( let group_qdrant_query_result = SearchOverGroupsQueryResult { search_results: recommended_groups_from_qdrant.clone(), + corrected_query: None, total_chunk_pages: (recommended_groups_from_qdrant.len() as f64 / 10.0).ceil() as i64, }; @@ -1439,6 +1440,7 @@ impl From for SearchChunksReqPayload { pub struct SearchWithinGroupResults { pub bookmarks: Vec, pub group: ChunkGroupAndFileId, + pub corrected_query: Option, pub total_pages: i64, } @@ -1447,6 +1449,7 @@ pub struct SearchWithinGroupResults { pub struct SearchWithinGroupResponseBody { pub id: uuid::Uuid, pub chunks: Vec, + pub corrected_query: Option, pub total_pages: i64, } @@ -1468,6 +1471,7 @@ impl SearchWithinGroupResults { .into_iter() .map(|chunk| chunk.into()) .collect(), + corrected_query: self.corrected_query, total_pages: self.total_pages, } } diff --git a/server/src/operators/search_operator.rs b/server/src/operators/search_operator.rs index b146f8fa8a..4578492e3b 100644 --- a/server/src/operators/search_operator.rs +++ b/server/src/operators/search_operator.rs @@ -986,6 +986,7 @@ pub async fn get_group_tag_set_filter_condition( #[derive(Serialize, Deserialize, Clone, Debug)] pub struct SearchOverGroupsQueryResult { pub search_results: Vec, + pub corrected_query: Option, pub total_chunk_pages: i64, } @@ -1043,6 +1044,7 @@ pub async fn retrieve_group_qdrant_points_query( Ok(SearchOverGroupsQueryResult { search_results: point_ids, + corrected_query: None, total_chunk_pages: pages, }) } @@ -1105,6 +1107,7 @@ impl From for SearchOverGroupsResults { #[schema(title = "V1")] pub struct DeprecatedSearchOverGroupsResponseBody { pub group_chunks: Vec, + pub corrected_query: Option, pub total_chunk_pages: i64, } @@ -1117,6 +1120,7 @@ impl DeprecatedSearchOverGroupsResponseBody { .into_iter() .map(|chunk| chunk.into()) .collect(), + corrected_query: self.corrected_query, total_pages: self.total_chunk_pages, } } @@ -1127,6 +1131,7 @@ impl DeprecatedSearchOverGroupsResponseBody { pub struct SearchOverGroupsResponseBody { pub id: uuid::Uuid, pub results: Vec, + pub corrected_query: Option, pub total_pages: i64, } @@ -1289,6 +1294,7 @@ pub async fn retrieve_chunks_for_groups( Ok(DeprecatedSearchOverGroupsResponseBody { group_chunks, + corrected_query: None, total_chunk_pages: search_over_groups_query_result.total_chunk_pages, }) } @@ -1530,6 +1536,7 @@ pub async fn retrieve_chunks_from_point_ids( Ok(SearchChunkQueryResponseBody { score_chunks, + corrected_query: None, total_chunk_pages: search_chunk_query_results.total_chunk_pages, }) } @@ -1804,20 +1811,23 @@ pub async fn search_chunks_query( sentry::configure_scope(|scope| scope.set_span(Some(transaction.clone()))); let mut parsed_query = parsed_query.clone(); + let mut corrected_query = None; if let Some(options) = &data.typo_options { timer.add("start correcting query"); match parsed_query { ParsedQueryTypes::Single(ref mut query) => { - query.query = + corrected_query = correct_query(query.query.clone(), dataset.id, redis_pool, options).await?; + query.query = corrected_query.clone().unwrap_or(query.query.clone()); data.query = QueryTypes::Single(query.query.clone()); } ParsedQueryTypes::Multi(ref mut queries) => { for (query, _) in queries { - query.query = + corrected_query = correct_query(query.query.clone(), dataset.id, redis_pool.clone(), options) .await?; + query.query = corrected_query.clone().unwrap_or(query.query.clone()); } } } @@ -1916,6 +1926,8 @@ pub async fn search_chunks_query( timer.add("reranking"); transaction.finish(); + result_chunks.corrected_query = corrected_query; + Ok(result_chunks) } @@ -1943,11 +1955,15 @@ pub async fn search_hybrid_chunks( sentry::configure_scope(|scope| scope.set_span(Some(transaction.clone()))); let mut parsed_query = parsed_query.clone(); + let mut corrected_query = None; if let Some(options) = &data.typo_options { timer.add("start correcting query"); - parsed_query.query = - correct_query(parsed_query.query, dataset.id, redis_pool, options).await?; + corrected_query = + correct_query(parsed_query.query.clone(), dataset.id, redis_pool, options).await?; + parsed_query.query = corrected_query + .clone() + .unwrap_or(parsed_query.query.clone()); data.query = QueryTypes::Single(parsed_query.query.clone()); timer.add("corrected query"); @@ -2068,6 +2084,7 @@ pub async fn search_hybrid_chunks( SearchChunkQueryResponseBody { score_chunks: reranked_chunks, + corrected_query, total_chunk_pages: result_chunks.total_chunk_pages, } }; @@ -2115,20 +2132,23 @@ pub async fn search_groups_query( let vector = get_qdrant_vector(data.clone().search_type, parsed_query.clone(), config).await?; let mut parsed_query = parsed_query.clone(); + let mut corrected_query = None; if let Some(options) = &data.typo_options { timer.add("start correcting query"); match parsed_query { ParsedQueryTypes::Single(ref mut query) => { - query.query = + corrected_query = correct_query(query.query.clone(), dataset.id, redis_pool, options).await?; + query.query = corrected_query.clone().unwrap_or(query.query.clone()); data.query = QueryTypes::Single(query.query.clone()); } ParsedQueryTypes::Multi(ref mut queries) => { for (query, _) in queries { - query.query = + corrected_query = correct_query(query.query.clone(), dataset.id, redis_pool.clone(), options) .await?; + query.query = corrected_query.clone().unwrap_or(query.query.clone()); } } } @@ -2219,6 +2239,7 @@ pub async fn search_groups_query( Ok(SearchWithinGroupResults { bookmarks: result_chunks.score_chunks, group, + corrected_query, total_pages: result_chunks.total_chunk_pages, }) } @@ -2238,11 +2259,15 @@ pub async fn search_hybrid_groups( let dataset_config = DatasetConfiguration::from_json(dataset.server_configuration.clone()); let mut parsed_query = parsed_query.clone(); + let mut corrected_query = None; if let Some(options) = &data.typo_options { timer.add("start correcting query"); - parsed_query.query = - correct_query(parsed_query.query, dataset.id, redis_pool, options).await?; + corrected_query = + correct_query(parsed_query.query.clone(), dataset.id, redis_pool, options).await?; + parsed_query.query = corrected_query + .clone() + .unwrap_or(parsed_query.query.clone()); data.query = QueryTypes::Single(parsed_query.query.clone()); timer.add("corrected query"); @@ -2399,6 +2424,7 @@ pub async fn search_hybrid_groups( SearchChunkQueryResponseBody { score_chunks: reranked_chunks, + corrected_query: None, total_chunk_pages: result_chunks.total_chunk_pages, } }; @@ -2406,6 +2432,7 @@ pub async fn search_hybrid_groups( Ok(SearchWithinGroupResults { bookmarks: reranked_chunks.score_chunks, group, + corrected_query, total_pages: result_chunks.total_chunk_pages, }) } @@ -2423,20 +2450,23 @@ pub async fn semantic_search_over_groups( let dataset_config = DatasetConfiguration::from_json(dataset.server_configuration.clone()); let mut parsed_query = parsed_query.clone(); + let mut corrected_query = None; if let Some(options) = &data.typo_options { timer.add("start correcting query"); match parsed_query { ParsedQueryTypes::Single(ref mut query) => { - query.query = + corrected_query = correct_query(query.query.clone(), dataset.id, redis_pool, options).await?; + query.query = corrected_query.clone().unwrap_or(query.query.clone()); data.query = QueryTypes::Single(query.query.clone()); } ParsedQueryTypes::Multi(ref mut queries) => { for (query, _) in queries { - query.query = + corrected_query = correct_query(query.query.clone(), dataset.id, redis_pool.clone(), options) .await?; + query.query = corrected_query.clone().unwrap_or(query.query.clone()); } } } @@ -2493,6 +2523,7 @@ pub async fn semantic_search_over_groups( timer.add("fetched from postgres"); //TODO: rerank for groups + result_chunks.corrected_query = corrected_query; Ok(result_chunks) } @@ -2519,20 +2550,23 @@ pub async fn full_text_search_over_groups( timer.add("computed sparse vector"); let mut parsed_query = parsed_query.clone(); + let mut corrected_query = None; if let Some(options) = &data.typo_options { timer.add("start correcting query"); match parsed_query { ParsedQueryTypes::Single(ref mut query) => { - query.query = + corrected_query = correct_query(query.query.clone(), dataset.id, redis_pool, options).await?; + query.query = corrected_query.clone().unwrap_or(query.query.clone()); data.query = QueryTypes::Single(query.query.clone()); } ParsedQueryTypes::Multi(ref mut queries) => { for (query, _) in queries { - query.query = + corrected_query = correct_query(query.query.clone(), dataset.id, redis_pool.clone(), options) .await?; + query.query = corrected_query.clone().unwrap_or(query.query.clone()); } } } @@ -2578,6 +2612,7 @@ pub async fn full_text_search_over_groups( timer.add("fetched from postgres"); //TODO: rerank for groups + result_groups_with_chunk_hits.corrected_query = corrected_query; Ok(result_groups_with_chunk_hits) } @@ -2655,11 +2690,15 @@ pub async fn hybrid_search_over_groups( let dataset_config = DatasetConfiguration::from_json(dataset.server_configuration.clone()); let mut parsed_query = parsed_query.clone(); + let mut corrected_query = None; if let Some(options) = &data.typo_options { timer.add("start correcting query"); - parsed_query.query = - correct_query(parsed_query.query, dataset.id, redis_pool, options).await?; + corrected_query = + correct_query(parsed_query.query.clone(), dataset.id, redis_pool, options).await?; + parsed_query.query = corrected_query + .clone() + .unwrap_or(parsed_query.query.clone()); data.query = QueryTypes::Single(parsed_query.query.clone()); timer.add("corrected query"); @@ -2729,6 +2768,7 @@ pub async fn hybrid_search_over_groups( let combined_search_chunk_query_results = SearchOverGroupsQueryResult { search_results: combined_results, + corrected_query: None, total_chunk_pages: semantic_results.total_chunk_pages, }; @@ -2789,6 +2829,7 @@ pub async fn hybrid_search_over_groups( let result_chunks = DeprecatedSearchOverGroupsResponseBody { group_chunks: reranked_chunks, + corrected_query, total_chunk_pages: combined_search_chunk_query_results.total_chunk_pages, }; @@ -2810,11 +2851,15 @@ pub async fn autocomplete_chunks_query( let parent_span = sentry::configure_scope(|scope| scope.get_span()); let mut parsed_query = parsed_query.clone(); + let mut corrected_query = None; if let Some(options) = &data.typo_options { timer.add("start correcting query"); - parsed_query.query = - correct_query(parsed_query.query, dataset.id, redis_pool, options).await?; + corrected_query = + correct_query(parsed_query.query.clone(), dataset.id, redis_pool, options).await?; + parsed_query.query = corrected_query + .clone() + .unwrap_or(parsed_query.query.clone()); data.query.clone_from(&parsed_query.query); timer.add("corrected query"); @@ -2954,6 +2999,8 @@ pub async fn autocomplete_chunks_query( timer.add("reranking"); transaction.finish(); + result_chunks.corrected_query = corrected_query; + Ok(result_chunks) } diff --git a/server/src/operators/words_operator.rs b/server/src/operators/words_operator.rs index f0534cd222..3bc97b0c10 100644 --- a/server/src/operators/words_operator.rs +++ b/server/src/operators/words_operator.rs @@ -469,7 +469,7 @@ impl BKTreeCache { } } -fn correct_query_helper(tree: &BkTree, query: String, options: &TypoOptions) -> String { +fn correct_query_helper(tree: &BkTree, query: String, options: &TypoOptions) -> Option { let query_split_by_whitespace = query .split_whitespace() .map(|s| s.to_string()) @@ -533,9 +533,10 @@ fn correct_query_helper(tree: &BkTree, query: String, options: &TypoOptions) -> for (og_string, correction) in query_split_to_correction { corrected_query = corrected_query.replacen(&og_string, &correction, 1); } + Some(corrected_query) + } else { + None } - - corrected_query } #[tracing::instrument(skip(redis_pool))] @@ -544,9 +545,9 @@ pub async fn correct_query( dataset_id: uuid::Uuid, redis_pool: web::Data, options: &TypoOptions, -) -> Result { +) -> Result, ServiceError> { if matches!(options.correct_typos, None | Some(false)) { - return Ok(query); + return Ok(None); } match BKTREE_CACHE.get_if_valid(&dataset_id) { @@ -581,7 +582,7 @@ pub async fn correct_query( } }; }); - Ok(query) + Ok(None) } } }