From ec75423981285e40db937d1a65c989c08257d7d9 Mon Sep 17 00:00:00 2001 From: skeptrune Date: Thu, 11 Jul 2024 17:23:17 -0700 Subject: [PATCH] cleanup: properly apply threshold when use_reranker is set to true + docs: update to explain how threshold works before weight/bias + cleanup: fully remove get_collisions --- .../search/src/components/ResultsPage.tsx | 1 - server/src/handlers/chunk_handler.rs | 27 +++---- server/src/handlers/group_handler.rs | 27 +++---- server/src/handlers/message_handler.rs | 12 +-- server/src/lib.rs | 2 +- server/src/operators/search_operator.rs | 75 ++++++++++++++----- 6 files changed, 82 insertions(+), 62 deletions(-) diff --git a/frontends/search/src/components/ResultsPage.tsx b/frontends/search/src/components/ResultsPage.tsx index 1b1b2cce29..3a7e11cc53 100644 --- a/frontends/search/src/components/ResultsPage.tsx +++ b/frontends/search/src/components/ResultsPage.tsx @@ -179,7 +179,6 @@ const ResultsPage = (props: ResultsPageProps) => { : props.search.debounced.searchType, score_threshold: props.search.debounced.scoreThreshold, recency_bias: props.search.debounced.recencyBias ?? 0.0, - get_collisions: true, slim_chunks: props.search.debounced.slimChunks ?? false, page_size: props.search.debounced.pageSize ?? 10, get_total_pages: props.search.debounced.getTotalPages ?? false, diff --git a/server/src/handlers/chunk_handler.rs b/server/src/handlers/chunk_handler.rs index 6d2bc205c7..82bc34b865 100644 --- a/server/src/handlers/chunk_handler.rs +++ b/server/src/handlers/chunk_handler.rs @@ -954,13 +954,12 @@ pub struct ChunkFilter { }, "recency_bias": 1.0, "use_weights": true, - "get_collisions": true, "highlight_results": true, "highlight_delimiters": ["?", ",", ".", "!"], "score_threshold": 0.5 }))] pub struct SearchChunksReqPayload { - /// Can be either "semantic", "fulltext", or "hybrid". If specified as "hybrid", it will pull in one page (10 chunks) of both semantic and full-text results then re-rank them using BAAI/bge-reranker-large. "semantic" will pull in one page (10 chunks) of the nearest cosine distant vectors. "fulltext" will pull in one page (10 chunks) of full-text results based on SPLADE. + /// Can be either "semantic", "fulltext", or "hybrid". If specified as "hybrid", it will pull in one page (10 chunks) of both semantic and full-text results then re-rank them using scores from a cross encoder model. "semantic" will pull in one page (10 chunks) of the nearest cosine distant vectors. "fulltext" will pull in one page (10 chunks) of full-text results based on SPLADE. pub search_type: SearchMethod, /// Query is the search query. This can be any string. The query will be used to create an embedding vector and/or SPLADE vector which will be used to find the result set. pub query: String, @@ -980,8 +979,6 @@ pub struct SearchChunksReqPayload { pub use_weights: Option, /// Tag weights is a JSON object which can be used to boost the ranking of chunks with certain tags. This is useful for when you want to be able to bias towards chunks with a certain tag on the fly. The keys are the tag names and the values are the weights. pub tag_weights: Option>, - /// Set get_collisions to true to get the collisions for each chunk. This will only apply if environment variable COLLISIONS_ENABLED is set to true. - pub get_collisions: Option, /// Set highlight_results to false for a slight latency improvement (1-10ms). If not specified, this defaults to true. This will add `` tags to the chunk_html of the chunks to highlight matching splits and return the highlights on each scored chunk in the response. pub highlight_results: Option, /// Set highlight_threshold to a lower or higher value to adjust the sensitivity of the highlights applied to the chunk html. If not specified, this defaults to 0.8. The range is 0.0 to 1.0. @@ -994,14 +991,13 @@ pub struct SearchChunksReqPayload { pub highlight_max_num: Option, /// Set highlight_window to a number to control the amount of words that are returned around the matched phrases. If not specified, this defaults to 0. This is useful for when you want to show more context around the matched words. When specified, window/2 whitespace separated words are added before and after each highlight in the response's highlights array. If an extended highlight overlaps with another highlight, the overlapping words are only included once. pub highlight_window: Option, - /// Set score_threshold to a float to filter out chunks with a score below the threshold. + /// Set score_threshold to a float to filter out chunks with a score below the threshold. This threshold applies before weight and bias modifications. If not specified, this defaults to 0.0. pub score_threshold: Option, /// Set slim_chunks to true to avoid returning the content and chunk_html of the chunks. This is useful for when you want to reduce amount of data over the wire for latency improvement (typically 10-50ms). Default is false. pub slim_chunks: Option, /// Set content_only to true to only returning the chunk_html of the chunks. This is useful for when you want to reduce amount of data over the wire for latency improvement (typically 10-50ms). Default is false. pub content_only: Option, - /// If true, chunks will be reranked using BAAI/bge-reranker-large. "hybrid" search does this - /// by default and this flag does not affect it. + /// If true, chunks will be reranked using scores from a cross encoder model. "hybrid" search will always use the reranker regardless of this setting. pub use_reranker: Option, } @@ -1018,7 +1014,6 @@ impl Default for SearchChunksReqPayload { location_bias: None, use_weights: None, tag_weights: None, - get_collisions: None, highlight_results: None, highlight_threshold: None, highlight_delimiters: None, @@ -1089,7 +1084,6 @@ pub fn parse_query(query: String) -> ParsedQuery { request_body(content = SearchChunksReqPayload, description = "JSON request payload to semantically search for chunks (chunks)", content_type = "application/json"), responses( (status = 200, description = "Chunks with embedding vectors which are similar to those in the request body", body = SearchChunkQueryResponseBody), - (status = 400, description = "Service error relating to searching", body = ErrorResponseBody), ), params( @@ -1246,7 +1240,6 @@ pub async fn search_chunks( }, "recency_bias": 1.0, "use_weights": true, - "get_collisions": true, "highlight_results": true, "highlight_delimiters": ["?", ",", ".", "!"], "score_threshold": 0.5 @@ -1282,14 +1275,13 @@ pub struct AutocompleteReqPayload { pub highlight_max_num: Option, /// Set highlight_window to a number to control the amount of words that are returned around the matched phrases. If not specified, this defaults to 0. This is useful for when you want to show more context around the matched words. When specified, window/2 whitespace separated words are added before and after each highlight in the response's highlights array. If an extended highlight overlaps with another highlight, the overlapping words are only included once. pub highlight_window: Option, - /// Set score_threshold to a float to filter out chunks with a score below the threshold. + /// Set score_threshold to a float to filter out chunks with a score below the threshold. This threshold applies before weight and bias modifications. If not specified, this defaults to 0.0. pub score_threshold: Option, /// Set slim_chunks to true to avoid returning the content and chunk_html of the chunks. This is useful for when you want to reduce amount of data over the wire for latency improvement (typically 10-50ms). Default is false. pub slim_chunks: Option, /// Set content_only to true to only returning the chunk_html of the chunks. This is useful for when you want to reduce amount of data over the wire for latency improvement (typically 10-50ms). Default is false. pub content_only: Option, - /// If true, chunks will be reranked using BAAI/bge-reranker-large. "hybrid" search does this - /// by default and this flag does not affect it. + /// If true, chunks will be reranked using scores from a cross encoder model. "hybrid" search will always use the reranker regardless of this setting. pub use_reranker: Option, } @@ -1305,7 +1297,6 @@ impl From for SearchChunksReqPayload { recency_bias: autocomplete_data.recency_bias, use_weights: autocomplete_data.use_weights, tag_weights: autocomplete_data.tag_weights, - get_collisions: None, highlight_results: autocomplete_data.highlight_results, highlight_threshold: autocomplete_data.highlight_threshold, highlight_delimiters: Some( @@ -1501,7 +1492,7 @@ pub struct CountChunksReqPayload { pub query: String, /// Filters is a JSON object which can be used to filter chunks. This is useful for when you want to filter chunks by arbitrary metadata. Unlike with tag filtering, there is a performance hit for filtering on metadata. pub filters: Option, - /// Set score_threshold to a float to filter out chunks with a score below the threshold. Will restrict the count to only chunks with a score above the threshold. + /// Set score_threshold to a float to filter out chunks with a score below the threshold. This threshold applies before weight and bias modifications. If not specified, this defaults to 0.0. pub score_threshold: Option, /// Set limit to restrict the maximum number of chunks to count. This is useful for when you want to reduce the latency of the count operation. By default the limit will be the number of chunks in the dataset. pub limit: Option, @@ -1514,7 +1505,7 @@ pub struct CountChunkQueryResponseBody { /// Count chunks above threshold /// -/// This route can be used to determine the number of chunks that match a given search criteria including filters and score threshold. It may be high latency for large datasets. Auth'ed user or api key must have an admin or owner role for the specified dataset's organization. +/// This route can be used to determine the number of chunks that match a given search criteria including filters and score threshold. It may be high latency for large datasets. Auth'ed user or api key must have an admin or owner role for the specified dataset's organization. There is a dataset configuration imposed restriction on the maximum limit value (default 10,000) which is used to prevent DDOS attacks. #[utoipa::path( post, path = "/chunk/count", @@ -1528,13 +1519,13 @@ pub struct CountChunkQueryResponseBody { ("TR-Dataset" = String, Header, description = "The dataset id to use for the request"), ), security( - ("ApiKey" = ["admin"]), + ("ApiKey" = ["readonly"]), ) )] #[tracing::instrument(skip(pool))] pub async fn count_chunks( data: web::Json, - _user: AdminOnly, + _user: LoggedUser, pool: web::Data, dataset_org_plan_sub: DatasetAndOrgWithSubAndPlan, ) -> Result { diff --git a/server/src/handlers/group_handler.rs b/server/src/handlers/group_handler.rs index e23cac1ced..5d43365c50 100644 --- a/server/src/handlers/group_handler.rs +++ b/server/src/handlers/group_handler.rs @@ -1078,13 +1078,8 @@ pub async fn get_recommended_groups( timer.add("recommend_qdrant_groups_query"); - let recommended_chunk_metadatas = get_metadata_from_groups( - group_qdrant_query_result.clone(), - Some(false), - data.slim_chunks, - pool, - ) - .await?; + let recommended_chunk_metadatas = + get_metadata_from_groups(group_qdrant_query_result.clone(), data.slim_chunks, pool).await?; let recommended_chunk_metadatas = recommended_groups_from_qdrant .into_iter() @@ -1156,7 +1151,7 @@ pub struct SearchWithinGroupData { pub group_id: Option, /// Group_tracking_id specifies the group to search within by tracking id. Results will only consist of chunks which are bookmarks within the specified group. If both group_id and group_tracking_id are provided, group_id will be used. pub group_tracking_id: Option, - /// Search_type can be either "semantic", "fulltext", or "hybrid". "hybrid" will pull in one page (10 chunks) of both semantic and full-text results then re-rank them using BAAI/bge-reranker-large. "semantic" will pull in one page (10 chunks) of the nearest cosine distant vectors. "fulltext" will pull in one page (10 chunks) of full-text results based on SPLADE. + /// Search_type can be either "semantic", "fulltext", or "hybrid". "hybrid" will pull in one page (10 chunks) of both semantic and full-text results then re-rank them using scores from a cross encoder model. "semantic" will pull in one page (10 chunks) of the nearest cosine distant vectors. "fulltext" will pull in one page (10 chunks) of full-text results based on SPLADE. pub search_type: SearchMethod, /// Location lets you rank your results by distance from a location. If not specified, this has no effect. Bias allows you to determine how much of an effect the location of chunks will have on the search results. If not specified, this defaults to 0.0. We recommend setting this to 1.0 for a gentle reranking of the results, >3.0 for a strong reranking of the results. pub location_bias: Option, @@ -1178,12 +1173,11 @@ pub struct SearchWithinGroupData { pub highlight_max_num: Option, /// Set highlight_window to a number to control the amount of words that are returned around the matched phrases. If not specified, this defaults to 0. This is useful for when you want to show more context around the matched words. When specified, window/2 whitespace separated words are added before and after each highlight in the response's highlights array. If an extended highlight overlaps with another highlight, the overlapping words are only included once. pub highlight_window: Option, - /// Set score_threshold to a float to filter out chunks with a score below the threshold. + /// Set score_threshold to a float to filter out chunks with a score below the threshold. This threshold applies before weight and bias modifications. If not specified, this defaults to 0.0. pub score_threshold: Option, /// Set slim_chunks to true to avoid returning the content and chunk_html of the chunks. This is useful for when you want to reduce amount of data over the wire for latency improvement (typicall 10-50ms). Default is false. pub slim_chunks: Option, - /// If true, chunks will be reranked using BAAI/bge-reranker-large. "hybrid" search does this - /// by default and this flag does not affect it. + /// If true, chunks will be reranked using scores from a cross encoder model. "hybrid" search will always use the reranker regardless of this setting. pub use_reranker: Option, } @@ -1200,7 +1194,6 @@ impl From for SearchChunksReqPayload { location_bias: search_within_group_data.location_bias, use_weights: search_within_group_data.use_weights, tag_weights: search_within_group_data.tag_weights, - get_collisions: Some(false), highlight_results: search_within_group_data.highlight_results, highlight_threshold: search_within_group_data.highlight_threshold, highlight_delimiters: search_within_group_data.highlight_delimiters, @@ -1224,7 +1217,7 @@ pub struct SearchWithinGroupResults { /// Search Within Group /// -/// This route allows you to search only within a group. This is useful for when you only want search results to contain chunks which are members of a specific group. If choosing hybrid search, the results will be re-ranked using BAAI/bge-reranker-large. +/// This route allows you to search only within a group. This is useful for when you only want search results to contain chunks which are members of a specific group. If choosing hybrid search, the results will be re-ranked using scores from a cross encoder model. #[utoipa::path( post, path = "/chunk_group/search", @@ -1354,7 +1347,7 @@ pub async fn search_within_group( #[derive(Serialize, Deserialize, Debug, Clone, ToSchema)] pub struct SearchOverGroupsData { - /// Can be either "semantic", "fulltext", or "hybrid". "hybrid" will pull in one page (10 chunks) of both semantic and full-text results then re-rank them using BAAI/bge-reranker-large. "semantic" will pull in one page (10 chunks) of the nearest cosine distant vectors. "fulltext" will pull in one page (10 chunks) of full-text results based on SPLADE. + /// Can be either "semantic", "fulltext", or "hybrid". "hybrid" will pull in one page (10 chunks) of both semantic and full-text results then re-rank them using scores from a cross encoder model. "semantic" will pull in one page (10 chunks) of the nearest cosine distant vectors. "fulltext" will pull in one page (10 chunks) of full-text results based on SPLADE. pub search_type: SearchMethod, /// Query is the search query. This can be any string. The query will be used to create an embedding vector and/or SPLADE vector which will be used to find the result set. pub query: String, @@ -1366,8 +1359,6 @@ pub struct SearchOverGroupsData { pub get_total_pages: Option, /// Filters is a JSON object which can be used to filter chunks. The values on each key in the object will be used to check for an exact substring match on the metadata values for each existing chunk. This is useful for when you want to filter chunks by arbitrary metadata. Unlike with tag filtering, there is a performance hit for filtering on metadata. pub filters: Option, - /// Set get_collisions to true to get the collisions for each chunk. This will only apply if environment variable COLLISIONS_ENABLED is set to true. - pub get_collisions: Option, /// Set highlight_results to false for a slight latency improvement (1-10ms). If not specified, this defaults to true. This will add `` tags to the chunk_html of the chunks to highlight matching splits and return the highlights on each scored chunk in the response. pub highlight_results: Option, /// Set highlight_threshold to a lower or higher value to adjust the sensitivity of the highlights applied to the chunk html. If not specified, this defaults to 0.8. The range is 0.0 to 1.0. @@ -1380,7 +1371,7 @@ pub struct SearchOverGroupsData { pub highlight_max_num: Option, /// Set highlight_window to a number to control the amount of words that are returned around the matched phrases. If not specified, this defaults to 0. This is useful for when you want to show more context around the matched words. When specified, window/2 whitespace separated words are added before and after each highlight in the response's highlights array. If an extended highlight overlaps with another highlight, the overlapping words are only included once. pub highlight_window: Option, - /// Set score_threshold to a float to filter out chunks with a score below the threshold. + /// Set score_threshold to a float to filter out chunks with a score below the threshold. This threshold applies before weight and bias modifications. If not specified, this defaults to 0.0. pub score_threshold: Option, /// Group_size is the number of chunks to fetch for each group. The default is 3. If a group has less than group_size chunks, all chunks will be returned. If this is set to a large number, we recommend setting slim_chunks to true to avoid returning the content and chunk_html of the chunks so as to lower the amount of time required for content download and serialization. pub group_size: Option, @@ -1390,7 +1381,7 @@ pub struct SearchOverGroupsData { /// Search Over Groups /// -/// This route allows you to get groups as results instead of chunks. Each group returned will have the matching chunks sorted by similarity within the group. This is useful for when you want to get groups of chunks which are similar to the search query. If choosing hybrid search, the results will be re-ranked using BAAI/bge-reranker-large. Compatible with semantic, fulltext, or hybrid search modes. +/// This route allows you to get groups as results instead of chunks. Each group returned will have the matching chunks sorted by similarity within the group. This is useful for when you want to get groups of chunks which are similar to the search query. If choosing hybrid search, the results will be re-ranked using scores from a cross encoder model. Compatible with semantic, fulltext, or hybrid search modes. #[utoipa::path( post, path = "/chunk_group/group_oriented_search", diff --git a/server/src/handlers/message_handler.rs b/server/src/handlers/message_handler.rs index c6c443092b..784e3bf102 100644 --- a/server/src/handlers/message_handler.rs +++ b/server/src/handlers/message_handler.rs @@ -85,7 +85,7 @@ pub struct CreateMessageReqPayload { pub highlight_results: Option, /// The delimiters to use for highlighting the citations. If this is not included, the default delimiters will be used. Default is `[".", "!", "?", "\n", "\t", ","]`. pub highlight_delimiters: Option>, - /// Search_type can be either "semantic", "fulltext", or "hybrid". "hybrid" will pull in one page (10 chunks) of both semantic and full-text results then re-rank them using BAAI/bge-reranker-large. "semantic" will pull in one page (10 chunks) of the nearest cosine distant vectors. "fulltext" will pull in one page (10 chunks) of full-text results based on SPLADE. Default is "hybrid". + /// Search_type can be either "semantic", "fulltext", or "hybrid". "hybrid" will pull in one page (10 chunks) of both semantic and full-text results then re-rank them using scores from a cross encoder model. "semantic" will pull in one page (10 chunks) of the nearest cosine distant vectors. "fulltext" will pull in one page (10 chunks) of full-text results based on SPLADE. Default is "hybrid". pub search_type: Option, /// If concat user messages query is set to true, all of the user messages in the topic will be concatenated together and used as the search query. If not specified, this defaults to false. Default is false. pub concat_user_messages_query: Option, @@ -95,7 +95,7 @@ pub struct CreateMessageReqPayload { pub page_size: Option, /// Filters is a JSON object which can be used to filter chunks. This is useful for when you want to filter chunks by arbitrary metadata. Unlike with tag filtering, there is a performance hit for filtering on metadata. pub filters: Option, - /// Set score_threshold to a float to filter out chunks with a score below the threshold. + /// Set score_threshold to a float to filter out chunks with a score below the threshold. This threshold applies before weight and bias modifications. If not specified, this defaults to 0.0. pub score_threshold: Option, /// Completion first decides whether the stream should contain the stream of the completion response or the chunks first. Default is false. Keep in mind that || is used to separate the chunks from the completion response. If || is in the completion then you may want to split on ||{ instead. pub completion_first: Option, @@ -274,7 +274,7 @@ pub struct RegenerateMessageReqPayload { pub highlight_citations: Option, /// The delimiters to use for highlighting the citations. If this is not included, the default delimiters will be used. Default is `[".", "!", "?", "\n", "\t", ","]`. pub highlight_delimiters: Option>, - /// Search_type can be either "semantic", "fulltext", or "hybrid". "hybrid" will pull in one page (10 chunks) of both semantic and full-text results then re-rank them using BAAI/bge-reranker-large. "semantic" will pull in one page (10 chunks) of the nearest cosine distant vectors. "fulltext" will pull in one page (10 chunks) of full-text results based on SPLADE. + /// Search_type can be either "semantic", "fulltext", or "hybrid". "hybrid" will pull in one page (10 chunks) of both semantic and full-text results then re-rank them using scores from a cross encoder model. "semantic" will pull in one page (10 chunks) of the nearest cosine distant vectors. "fulltext" will pull in one page (10 chunks) of full-text results based on SPLADE. pub search_type: Option, /// If concat user messages query is set to true, all of the user messages in the topic will be concatenated together and used as the search query. If not specified, this defaults to false. Default is false. pub concat_user_messages_query: Option, @@ -284,7 +284,7 @@ pub struct RegenerateMessageReqPayload { pub page_size: Option, /// Filters is a JSON object which can be used to filter chunks. This is useful for when you want to filter chunks by arbitrary metadata. Unlike with tag filtering, there is a performance hit for filtering on metadata. pub filters: Option, - /// Set score_threshold to a float to filter out chunks with a score below the threshold. + /// Set score_threshold to a float to filter out chunks with a score below the threshold. This threshold applies before weight and bias modifications. If not specified, this defaults to 0.0. pub score_threshold: Option, /// Completion first decides whether the stream should contain the stream of the completion response or the chunks first. Default is false. Keep in mind that || is used to separate the chunks from the completion response. If || is in the completion then you may want to split on ||{ instead. pub completion_first: Option, @@ -314,7 +314,7 @@ pub struct EditMessageReqPayload { pub highlight_citations: Option, /// The delimiters to use for highlighting the citations. If this is not included, the default delimiters will be used. Default is `[".", "!", "?", "\n", "\t", ","]`. pub highlight_delimiters: Option>, - /// Search_type can be either "semantic", "fulltext", or "hybrid". "hybrid" will pull in one page (10 chunks) of both semantic and full-text results then re-rank them using BAAI/bge-reranker-large. "semantic" will pull in one page (10 chunks) of the nearest cosine distant vectors. "fulltext" will pull in one page (10 chunks) of full-text results based on SPLADE. + /// Search_type can be either "semantic", "fulltext", or "hybrid". "hybrid" will pull in one page (10 chunks) of both semantic and full-text results then re-rank them using scores from a cross encoder model. "semantic" will pull in one page (10 chunks) of the nearest cosine distant vectors. "fulltext" will pull in one page (10 chunks) of full-text results based on SPLADE. pub search_type: Option, /// If concat user messages query is set to true, all of the user messages in the topic will be concatenated together and used as the search query. If not specified, this defaults to false. Default is false. pub concat_user_messages_query: Option, @@ -324,7 +324,7 @@ pub struct EditMessageReqPayload { pub page_size: Option, /// Filters is a JSON object which can be used to filter chunks. This is useful for when you want to filter chunks by arbitrary metadata. Unlike with tag filtering, there is a performance hit for filtering on metadata. pub filters: Option, - /// Set score_threshold to a float to filter out chunks with a score below the threshold. + /// Set score_threshold to a float to filter out chunks with a score below the threshold. This threshold applies before weight and bias modifications. If not specified, this defaults to 0.0. pub score_threshold: Option, /// Completion first decides whether the stream should contain the stream of the completion response or the chunks first. Default is false. Keep in mind that || is used to separate the chunks from the completion response. If || is in the completion then you may want to split on ||{ instead. pub completion_first: Option, diff --git a/server/src/lib.rs b/server/src/lib.rs index 8be6f869c4..179a6a9902 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -140,7 +140,7 @@ impl Modify for SecurityAddon { name = "BSL", url = "https://github.com/devflowinc/trieve/blob/main/LICENSE.txt", ), - version = "0.10.10", + version = "0.10.11", ), servers( (url = "https://api.trieve.ai", diff --git a/server/src/operators/search_operator.rs b/server/src/operators/search_operator.rs index a55366f41b..ed2bd68163 100644 --- a/server/src/operators/search_operator.rs +++ b/server/src/operators/search_operator.rs @@ -1031,7 +1031,6 @@ pub async fn retrieve_chunks_for_groups( #[tracing::instrument(skip(pool))] pub async fn get_metadata_from_groups( search_over_groups_query_result: SearchOverGroupsQueryResult, - get_collisions: Option, slim_chunks: Option, pool: web::Data, ) -> Result, actix_web::Error> { @@ -1430,7 +1429,11 @@ pub async fn search_semantic_chunks( let qdrant_query = RetrievePointQuery { vector: VectorType::Dense(embedding_vector), - score_threshold: data.score_threshold, + score_threshold: if data.use_reranker.unwrap_or(false) { + None + } else { + data.score_threshold + }, filter: data.filters.clone(), } .into_qdrant_query(parsed_query, dataset.id, None, pool.clone()) @@ -1455,13 +1458,19 @@ pub async fn search_semantic_chunks( let rerank_chunks_input = match data.use_reranker { Some(false) | None => result_chunks.score_chunks, Some(true) => { - cross_encoder( + let mut cross_encoder_results = cross_encoder( data.query.clone(), data.page_size.unwrap_or(10), result_chunks.score_chunks, config, ) - .await? + .await?; + + if let Some(score_threshold) = data.score_threshold { + cross_encoder_results.retain(|chunk| chunk.score >= score_threshold.into()); + } + + cross_encoder_results } }; @@ -1512,7 +1521,11 @@ pub async fn search_full_text_chunks( let qdrant_query = RetrievePointQuery { vector: VectorType::Sparse(sparse_vector), - score_threshold: data.score_threshold, + score_threshold: if data.use_reranker.unwrap_or(false) { + None + } else { + data.score_threshold + }, filter: data.filters.clone(), } .into_qdrant_query(parsed_query, dataset.id, None, pool.clone()) @@ -1537,13 +1550,19 @@ pub async fn search_full_text_chunks( let rerank_chunks_input = match data.use_reranker { Some(false) | None => result_chunks.score_chunks, Some(true) => { - cross_encoder( + let mut cross_encoder_results = cross_encoder( data.query.clone(), data.page_size.unwrap_or(10), result_chunks.score_chunks, config, ) - .await? + .await?; + + if let Some(score_threshold) = data.score_threshold { + cross_encoder_results.retain(|chunk| chunk.score >= score_threshold.into()); + } + + cross_encoder_results } }; @@ -1643,7 +1662,7 @@ pub async fn search_hybrid_chunks( let mut reranked_chunks = { let mut reranked_chunks = { - let cross_encoder_results = cross_encoder( + let mut cross_encoder_results = cross_encoder( data.query.clone(), data.page_size.unwrap_or(10), result_chunks.score_chunks, @@ -1651,6 +1670,10 @@ pub async fn search_hybrid_chunks( ) .await?; + if let Some(score_threshold) = data.score_threshold { + cross_encoder_results.retain(|chunk| chunk.score >= score_threshold.into()); + } + rerank_chunks( cross_encoder_results, data.recency_bias, @@ -1664,10 +1687,6 @@ pub async fn search_hybrid_chunks( timer.add("reranking"); - if let Some(score_threshold) = data.score_threshold { - reranked_chunks.retain(|chunk| chunk.score >= score_threshold.into()); - } - SearchChunkQueryResponseBody { score_chunks: reranked_chunks, total_chunk_pages: result_chunks.total_chunk_pages, @@ -1720,7 +1739,11 @@ pub async fn search_semantic_groups( let qdrant_query = RetrievePointQuery { vector: VectorType::Dense(embedding_vector), - score_threshold: data.score_threshold, + score_threshold: if data.use_reranker.unwrap_or(false) { + None + } else { + data.score_threshold + }, filter: data.filters.clone(), } .into_qdrant_query(parsed_query, dataset.id, Some(group.id), pool.clone()) @@ -1745,13 +1768,19 @@ pub async fn search_semantic_groups( let rerank_chunks_input = match data.use_reranker { Some(false) | None => result_chunks.score_chunks, Some(true) => { - cross_encoder( + let mut cross_encoder_results = cross_encoder( data.query.clone(), data.page_size.unwrap_or(10), result_chunks.score_chunks, config, ) - .await? + .await?; + + if let Some(score_threshold) = data.score_threshold { + cross_encoder_results.retain(|chunk| chunk.score >= score_threshold.into()); + } + + cross_encoder_results } }; @@ -1786,7 +1815,11 @@ pub async fn search_full_text_groups( let qdrant_query = RetrievePointQuery { vector: VectorType::Sparse(sparse_vector), - score_threshold: data.score_threshold, + score_threshold: if data.use_reranker.unwrap_or(false) { + None + } else { + data.score_threshold + }, filter: data.filters.clone(), } .into_qdrant_query(parsed_query, dataset.id, Some(group.id), pool.clone()) @@ -1811,13 +1844,19 @@ pub async fn search_full_text_groups( let rerank_chunks_input = match data.use_reranker { Some(false) | None => result_chunks.score_chunks, Some(true) => { - cross_encoder( + let mut cross_encoder_results = cross_encoder( data.query.clone(), data.page_size.unwrap_or(10), result_chunks.score_chunks, config, ) - .await? + .await?; + + if let Some(score_threshold) = data.score_threshold { + cross_encoder_results.retain(|chunk| chunk.score >= score_threshold.into()); + } + + cross_encoder_results } };