Skip to content

Commit

Permalink
feature: return corrected query back to the user
Browse files Browse the repository at this point in the history
  • Loading branch information
densumesh committed Aug 28, 2024
1 parent 4b6f449 commit 41eec72
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 26 deletions.
3 changes: 3 additions & 0 deletions server/src/handlers/chunk_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,7 @@ impl Default for SearchChunksReqPayload {
#[schema(title = "V1")]
pub struct SearchChunkQueryResponseBody {
pub score_chunks: Vec<ScoreChunkDTO>,
pub corrected_query: Option<String>,
pub total_chunk_pages: i64,
}

Expand All @@ -1001,6 +1002,7 @@ pub struct SearchChunkQueryResponseBody {
pub struct SearchResponseBody {
pub id: uuid::Uuid,
pub chunks: Vec<ScoreChunk>,
pub corrected_query: Option<String>,
pub total_pages: i64,
}

Expand All @@ -1022,6 +1024,7 @@ impl SearchChunkQueryResponseBody {
.into_iter()
.map(|chunk| chunk.into())
.collect(),
corrected_query: self.corrected_query,
total_pages: self.total_chunk_pages,
}
}
Expand Down
4 changes: 4 additions & 0 deletions server/src/handlers/group_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1298,6 +1298,7 @@ pub async fn get_recommended_groups(

let group_qdrant_query_result = SearchOverGroupsQueryResult {
search_results: recommended_groups_from_qdrant.clone(),
corrected_query: None,
total_chunk_pages: (recommended_groups_from_qdrant.len() as f64 / 10.0).ceil() as i64,
};

Expand Down Expand Up @@ -1439,6 +1440,7 @@ impl From<SearchWithinGroupReqPayload> for SearchChunksReqPayload {
pub struct SearchWithinGroupResults {
pub bookmarks: Vec<ScoreChunkDTO>,
pub group: ChunkGroupAndFileId,
pub corrected_query: Option<String>,
pub total_pages: i64,
}

Expand All @@ -1447,6 +1449,7 @@ pub struct SearchWithinGroupResults {
pub struct SearchWithinGroupResponseBody {
pub id: uuid::Uuid,
pub chunks: Vec<ScoreChunk>,
pub corrected_query: Option<String>,
pub total_pages: i64,
}

Expand All @@ -1468,6 +1471,7 @@ impl SearchWithinGroupResults {
.into_iter()
.map(|chunk| chunk.into())
.collect(),
corrected_query: self.corrected_query,
total_pages: self.total_pages,
}
}
Expand Down
79 changes: 63 additions & 16 deletions server/src/operators/search_operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,7 @@ pub async fn get_group_tag_set_filter_condition(
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct SearchOverGroupsQueryResult {
pub search_results: Vec<GroupSearchResults>,
pub corrected_query: Option<String>,
pub total_chunk_pages: i64,
}

Expand Down Expand Up @@ -1043,6 +1044,7 @@ pub async fn retrieve_group_qdrant_points_query(

Ok(SearchOverGroupsQueryResult {
search_results: point_ids,
corrected_query: None,
total_chunk_pages: pages,
})
}
Expand Down Expand Up @@ -1105,6 +1107,7 @@ impl From<GroupScoreChunk> for SearchOverGroupsResults {
#[schema(title = "V1")]
pub struct DeprecatedSearchOverGroupsResponseBody {
pub group_chunks: Vec<GroupScoreChunk>,
pub corrected_query: Option<String>,
pub total_chunk_pages: i64,
}

Expand All @@ -1117,6 +1120,7 @@ impl DeprecatedSearchOverGroupsResponseBody {
.into_iter()
.map(|chunk| chunk.into())
.collect(),
corrected_query: self.corrected_query,
total_pages: self.total_chunk_pages,
}
}
Expand All @@ -1127,6 +1131,7 @@ impl DeprecatedSearchOverGroupsResponseBody {
pub struct SearchOverGroupsResponseBody {
pub id: uuid::Uuid,
pub results: Vec<SearchOverGroupsResults>,
pub corrected_query: Option<String>,
pub total_pages: i64,
}

Expand Down Expand Up @@ -1289,6 +1294,7 @@ pub async fn retrieve_chunks_for_groups(

Ok(DeprecatedSearchOverGroupsResponseBody {
group_chunks,
corrected_query: None,
total_chunk_pages: search_over_groups_query_result.total_chunk_pages,
})
}
Expand Down Expand Up @@ -1530,6 +1536,7 @@ pub async fn retrieve_chunks_from_point_ids(

Ok(SearchChunkQueryResponseBody {
score_chunks,
corrected_query: None,
total_chunk_pages: search_chunk_query_results.total_chunk_pages,
})
}
Expand Down Expand Up @@ -1804,20 +1811,23 @@ pub async fn search_chunks_query(
sentry::configure_scope(|scope| scope.set_span(Some(transaction.clone())));

let mut parsed_query = parsed_query.clone();
let mut corrected_query = None;

if let Some(options) = &data.typo_options {
timer.add("start correcting query");
match parsed_query {
ParsedQueryTypes::Single(ref mut query) => {
query.query =
corrected_query =
correct_query(query.query.clone(), dataset.id, redis_pool, options).await?;
query.query = corrected_query.clone().unwrap_or(query.query.clone());
data.query = QueryTypes::Single(query.query.clone());
}
ParsedQueryTypes::Multi(ref mut queries) => {
for (query, _) in queries {
query.query =
corrected_query =
correct_query(query.query.clone(), dataset.id, redis_pool.clone(), options)
.await?;
query.query = corrected_query.clone().unwrap_or(query.query.clone());
}
}
}
Expand Down Expand Up @@ -1916,6 +1926,8 @@ pub async fn search_chunks_query(
timer.add("reranking");
transaction.finish();

result_chunks.corrected_query = corrected_query;

Ok(result_chunks)
}

Expand Down Expand Up @@ -1943,11 +1955,15 @@ pub async fn search_hybrid_chunks(
sentry::configure_scope(|scope| scope.set_span(Some(transaction.clone())));

let mut parsed_query = parsed_query.clone();
let mut corrected_query = None;

if let Some(options) = &data.typo_options {
timer.add("start correcting query");
parsed_query.query =
correct_query(parsed_query.query, dataset.id, redis_pool, options).await?;
corrected_query =
correct_query(parsed_query.query.clone(), dataset.id, redis_pool, options).await?;
parsed_query.query = corrected_query
.clone()
.unwrap_or(parsed_query.query.clone());
data.query = QueryTypes::Single(parsed_query.query.clone());

timer.add("corrected query");
Expand Down Expand Up @@ -2068,6 +2084,7 @@ pub async fn search_hybrid_chunks(

SearchChunkQueryResponseBody {
score_chunks: reranked_chunks,
corrected_query,
total_chunk_pages: result_chunks.total_chunk_pages,
}
};
Expand Down Expand Up @@ -2115,20 +2132,23 @@ pub async fn search_groups_query(
let vector = get_qdrant_vector(data.clone().search_type, parsed_query.clone(), config).await?;

let mut parsed_query = parsed_query.clone();
let mut corrected_query = None;

if let Some(options) = &data.typo_options {
timer.add("start correcting query");
match parsed_query {
ParsedQueryTypes::Single(ref mut query) => {
query.query =
corrected_query =
correct_query(query.query.clone(), dataset.id, redis_pool, options).await?;
query.query = corrected_query.clone().unwrap_or(query.query.clone());
data.query = QueryTypes::Single(query.query.clone());
}
ParsedQueryTypes::Multi(ref mut queries) => {
for (query, _) in queries {
query.query =
corrected_query =
correct_query(query.query.clone(), dataset.id, redis_pool.clone(), options)
.await?;
query.query = corrected_query.clone().unwrap_or(query.query.clone());
}
}
}
Expand Down Expand Up @@ -2219,6 +2239,7 @@ pub async fn search_groups_query(
Ok(SearchWithinGroupResults {
bookmarks: result_chunks.score_chunks,
group,
corrected_query,
total_pages: result_chunks.total_chunk_pages,
})
}
Expand All @@ -2238,11 +2259,15 @@ pub async fn search_hybrid_groups(
let dataset_config = DatasetConfiguration::from_json(dataset.server_configuration.clone());

let mut parsed_query = parsed_query.clone();
let mut corrected_query = None;

if let Some(options) = &data.typo_options {
timer.add("start correcting query");
parsed_query.query =
correct_query(parsed_query.query, dataset.id, redis_pool, options).await?;
corrected_query =
correct_query(parsed_query.query.clone(), dataset.id, redis_pool, options).await?;
parsed_query.query = corrected_query
.clone()
.unwrap_or(parsed_query.query.clone());
data.query = QueryTypes::Single(parsed_query.query.clone());

timer.add("corrected query");
Expand Down Expand Up @@ -2399,13 +2424,15 @@ pub async fn search_hybrid_groups(

SearchChunkQueryResponseBody {
score_chunks: reranked_chunks,
corrected_query: None,
total_chunk_pages: result_chunks.total_chunk_pages,
}
};

Ok(SearchWithinGroupResults {
bookmarks: reranked_chunks.score_chunks,
group,
corrected_query,
total_pages: result_chunks.total_chunk_pages,
})
}
Expand All @@ -2423,20 +2450,23 @@ pub async fn semantic_search_over_groups(
let dataset_config = DatasetConfiguration::from_json(dataset.server_configuration.clone());

let mut parsed_query = parsed_query.clone();
let mut corrected_query = None;

if let Some(options) = &data.typo_options {
timer.add("start correcting query");
match parsed_query {
ParsedQueryTypes::Single(ref mut query) => {
query.query =
corrected_query =
correct_query(query.query.clone(), dataset.id, redis_pool, options).await?;
query.query = corrected_query.clone().unwrap_or(query.query.clone());
data.query = QueryTypes::Single(query.query.clone());
}
ParsedQueryTypes::Multi(ref mut queries) => {
for (query, _) in queries {
query.query =
corrected_query =
correct_query(query.query.clone(), dataset.id, redis_pool.clone(), options)
.await?;
query.query = corrected_query.clone().unwrap_or(query.query.clone());
}
}
}
Expand Down Expand Up @@ -2493,6 +2523,7 @@ pub async fn semantic_search_over_groups(
timer.add("fetched from postgres");

//TODO: rerank for groups
result_chunks.corrected_query = corrected_query;

Ok(result_chunks)
}
Expand All @@ -2519,20 +2550,23 @@ pub async fn full_text_search_over_groups(
timer.add("computed sparse vector");

let mut parsed_query = parsed_query.clone();
let mut corrected_query = None;

if let Some(options) = &data.typo_options {
timer.add("start correcting query");
match parsed_query {
ParsedQueryTypes::Single(ref mut query) => {
query.query =
corrected_query =
correct_query(query.query.clone(), dataset.id, redis_pool, options).await?;
query.query = corrected_query.clone().unwrap_or(query.query.clone());
data.query = QueryTypes::Single(query.query.clone());
}
ParsedQueryTypes::Multi(ref mut queries) => {
for (query, _) in queries {
query.query =
corrected_query =
correct_query(query.query.clone(), dataset.id, redis_pool.clone(), options)
.await?;
query.query = corrected_query.clone().unwrap_or(query.query.clone());
}
}
}
Expand Down Expand Up @@ -2578,6 +2612,7 @@ pub async fn full_text_search_over_groups(
timer.add("fetched from postgres");

//TODO: rerank for groups
result_groups_with_chunk_hits.corrected_query = corrected_query;

Ok(result_groups_with_chunk_hits)
}
Expand Down Expand Up @@ -2655,11 +2690,15 @@ pub async fn hybrid_search_over_groups(
let dataset_config = DatasetConfiguration::from_json(dataset.server_configuration.clone());

let mut parsed_query = parsed_query.clone();
let mut corrected_query = None;

if let Some(options) = &data.typo_options {
timer.add("start correcting query");
parsed_query.query =
correct_query(parsed_query.query, dataset.id, redis_pool, options).await?;
corrected_query =
correct_query(parsed_query.query.clone(), dataset.id, redis_pool, options).await?;
parsed_query.query = corrected_query
.clone()
.unwrap_or(parsed_query.query.clone());
data.query = QueryTypes::Single(parsed_query.query.clone());

timer.add("corrected query");
Expand Down Expand Up @@ -2729,6 +2768,7 @@ pub async fn hybrid_search_over_groups(

let combined_search_chunk_query_results = SearchOverGroupsQueryResult {
search_results: combined_results,
corrected_query: None,
total_chunk_pages: semantic_results.total_chunk_pages,
};

Expand Down Expand Up @@ -2789,6 +2829,7 @@ pub async fn hybrid_search_over_groups(

let result_chunks = DeprecatedSearchOverGroupsResponseBody {
group_chunks: reranked_chunks,
corrected_query,
total_chunk_pages: combined_search_chunk_query_results.total_chunk_pages,
};

Expand All @@ -2810,11 +2851,15 @@ pub async fn autocomplete_chunks_query(
let parent_span = sentry::configure_scope(|scope| scope.get_span());

let mut parsed_query = parsed_query.clone();
let mut corrected_query = None;

if let Some(options) = &data.typo_options {
timer.add("start correcting query");
parsed_query.query =
correct_query(parsed_query.query, dataset.id, redis_pool, options).await?;
corrected_query =
correct_query(parsed_query.query.clone(), dataset.id, redis_pool, options).await?;
parsed_query.query = corrected_query
.clone()
.unwrap_or(parsed_query.query.clone());
data.query.clone_from(&parsed_query.query);

timer.add("corrected query");
Expand Down Expand Up @@ -2954,6 +2999,8 @@ pub async fn autocomplete_chunks_query(
timer.add("reranking");
transaction.finish();

result_chunks.corrected_query = corrected_query;

Ok(result_chunks)
}

Expand Down
Loading

0 comments on commit 41eec72

Please sign in to comment.