diff --git a/server/src/bin/bulk-ingestion-worker.rs b/server/src/bin/bulk-ingestion-worker.rs index 685654888d..5041ba38ef 100644 --- a/server/src/bin/bulk-ingestion-worker.rs +++ b/server/src/bin/bulk-ingestion-worker.rs @@ -528,71 +528,12 @@ pub async fn bulk_upload_chunks( embedding_transaction.finish(); - let qdrant_points = tokio_stream::iter(izip!( - ingestion_data.clone(), - embedding_vectors.iter(), - splade_vectors.iter(), - )) - .then(|(chunk_data, embedding_vector, splade_vector)| async { - let qdrant_point_id = chunk_data.chunk_metadata.qdrant_point_id; - - let chunk_tags: Option>> = - if let Some(ref group_ids) = chunk_data.group_ids { - Some( - get_groups_from_group_ids_query(group_ids.clone(), web_pool.clone()) - .await? - .iter() - .filter_map(|group| group.tag_set.clone()) - .flatten() - .dedup() - .collect(), - ) - } else { - None - }; - - let payload = QdrantPayload::new( - chunk_data.chunk_metadata, - chunk_data.group_ids, - None, - chunk_tags, - ) - .into(); - - let vector_name = match embedding_vector.len() { - 384 => "384_vectors", - 512 => "512_vectors", - 768 => "768_vectors", - 1024 => "1024_vectors", - 3072 => "3072_vectors", - 1536 => "1536_vectors", - _ => { - return Err(ServiceError::BadRequest( - "Invalid embedding vector size".into(), - )) - } - }; - - let vector_payload = HashMap::from([ - ( - vector_name.to_string(), - Vector::from(embedding_vector.clone()), - ), - ( - "sparse_vectors".to_string(), - Vector::from(splade_vector.clone()), - ), - ]); - - // If qdrant_point_id does not exist, does not get written to qdrant - Ok(PointStruct::new( - qdrant_point_id.to_string(), - vector_payload, - payload, - )) - }) - .collect::>>() - .await; + let qdrant_points = match embedding_vectors { + Some(embedding_vectors) => { + get_qdrant_points(&ingestion_data, embedding_vectors, splade_vectors, web_pool).await + } + None => get_qdrant_points_from_splade_vecs(&ingestion_data, splade_vectors, web_pool).await, + }; if qdrant_points.iter().any(|point| point.is_err()) { Err(ServiceError::InternalServerError( @@ -722,7 +663,7 @@ async fn upload_chunk( num_value: payload.chunk.num_value, }; - let embedding_vector = match dataset_config.FULLTEXT_ENABLED { + let embedding_vector = match dataset_config.SEMANTIC_ENABLED { true => match payload.chunk.split_avg.unwrap_or(false) { true => { let chunks = coarse_doc_chunker(content.clone(), None, false, 20); @@ -901,14 +842,20 @@ async fn update_chunk( let chunk_metadata = payload.chunk_metadata.clone(); - let embedding_vector = create_embedding( - content.to_string(), - payload.distance_phrase, - "doc", - server_dataset_config.clone(), - ) - .await - .map_err(|err| ServiceError::BadRequest(err.to_string()))?; + let embedding_vector = match server_dataset_config.SEMANTIC_ENABLED { + true => { + let embedding = create_embedding( + content.to_string(), + payload.distance_phrase, + "doc", + server_dataset_config.clone(), + ) + .await + .map_err(|err| ServiceError::BadRequest(err.to_string()))?; + Some(embedding) + } + false => None, + }; let splade_vector = if server_dataset_config.FULLTEXT_ENABLED { let reqwest_client = reqwest::Client::new(); @@ -942,7 +889,7 @@ async fn update_chunk( update_qdrant_point_query( // If the chunk is a collision, we don't want to update the qdrant point chunk_metadata, - Some(embedding_vector), + embedding_vector, Some(chunk_group_ids), payload.dataset_id, splade_vector, @@ -963,7 +910,7 @@ async fn update_chunk( update_qdrant_point_query( // If the chunk is a collision, we don't want to update the qdrant point chunk_metadata, - Some(embedding_vector), + embedding_vector, None, payload.dataset_id, splade_vector, @@ -1058,3 +1005,124 @@ pub async fn readd_error_to_queue( Ok(()) } + +async fn get_qdrant_points( + ingestion_data: &Vec, + embedding_vectors: Vec>, + splade_vectors: Vec>, + web_pool: actix_web::web::Data, +) -> Vec> { + tokio_stream::iter(izip!( + ingestion_data.clone(), + embedding_vectors.iter(), + splade_vectors.iter(), + )) + .then(|(chunk_data, embedding_vector, splade_vector)| async { + let qdrant_point_id = chunk_data.chunk_metadata.qdrant_point_id; + + let chunk_tags: Option>> = + if let Some(ref group_ids) = chunk_data.group_ids { + Some( + get_groups_from_group_ids_query(group_ids.clone(), web_pool.clone()) + .await? + .iter() + .filter_map(|group| group.tag_set.clone()) + .flatten() + .dedup() + .collect(), + ) + } else { + None + }; + + let payload = QdrantPayload::new( + chunk_data.chunk_metadata, + chunk_data.group_ids, + None, + chunk_tags, + ) + .into(); + + let vector_name = match embedding_vector.len() { + 384 => "384_vectors", + 512 => "512_vectors", + 768 => "768_vectors", + 1024 => "1024_vectors", + 3072 => "3072_vectors", + 1536 => "1536_vectors", + _ => { + return Err(ServiceError::BadRequest( + "Invalid embedding vector size".into(), + )) + } + }; + + let vector_payload = HashMap::from([ + ( + vector_name.to_string(), + Vector::from(embedding_vector.clone()), + ), + ( + "sparse_vectors".to_string(), + Vector::from(splade_vector.clone()), + ), + ]); + + // If qdrant_point_id does not exist, does not get written to qdrant + Ok(PointStruct::new( + qdrant_point_id.to_string(), + vector_payload, + payload, + )) + }) + .collect::>>() + .await +} + +async fn get_qdrant_points_from_splade_vecs( + ingestion_data: &Vec, + splade_vectors: Vec>, + web_pool: actix_web::web::Data, +) -> Vec> { + tokio_stream::iter(izip!(ingestion_data.clone(), splade_vectors.iter(),)) + .then(|(chunk_data, splade_vector)| async { + let qdrant_point_id = chunk_data.chunk_metadata.qdrant_point_id; + + let chunk_tags: Option>> = + if let Some(ref group_ids) = chunk_data.group_ids { + Some( + get_groups_from_group_ids_query(group_ids.clone(), web_pool.clone()) + .await? + .iter() + .filter_map(|group| group.tag_set.clone()) + .flatten() + .dedup() + .collect(), + ) + } else { + None + }; + + let payload = QdrantPayload::new( + chunk_data.chunk_metadata, + chunk_data.group_ids, + None, + chunk_tags, + ) + .into(); + + let vector_payload = HashMap::from([( + "sparse_vectors".to_string(), + Vector::from(splade_vector.clone()), + )]); + + // If qdrant_point_id does not exist, does not get written to qdrant + Ok(PointStruct::new( + qdrant_point_id.to_string(), + vector_payload, + payload, + )) + }) + .collect::>>() + .await +} diff --git a/server/src/bin/ingestion-worker.rs b/server/src/bin/ingestion-worker.rs index 8a20ee1ba5..d6bf22d30f 100644 --- a/server/src/bin/ingestion-worker.rs +++ b/server/src/bin/ingestion-worker.rs @@ -509,32 +509,38 @@ pub async fn bulk_upload_chunks( "calling_create_all_embeddings", ); - // Assuming split average is false, Assume Explicit Vectors don't exist - let embedding_vectors = match create_embeddings( - content_and_boosts - .iter() - .map(|(content, _, distance_boost)| (content.clone(), distance_boost.clone())) - .collect(), - "doc", - dataset_config.clone(), - reqwest_client.clone(), - ) - .await - { - Ok(vectors) => Ok(vectors), - Err(err) => { - bulk_revert_insert_chunk_metadata_query( - inserted_chunk_metadata_ids.clone(), - web_pool.clone(), + let embedding_vectors = match dataset_config.SEMANTIC_ENABLED { + true => { + let vectors = match create_embeddings( + content_and_boosts + .iter() + .map(|(content, _, distance_boost)| (content.clone(), distance_boost.clone())) + .collect(), + "doc", + dataset_config.clone(), + reqwest_client.clone(), ) - .await?; - Err(ServiceError::InternalServerError(format!( - "Failed to create embeddings: {:?}", - err - ))) + .await + { + Ok(vectors) => Ok(vectors), + Err(err) => { + bulk_revert_insert_chunk_metadata_query( + inserted_chunk_metadata_ids.clone(), + web_pool.clone(), + ) + .await?; + Err(ServiceError::InternalServerError(format!( + "Failed to create embeddings: {:?}", + err + ))) + } + }?; + Some(vectors) } - }?; + false => None, + }; + // Assuming split average is false, Assume Explicit Vectors don't exist embedding_transaction.finish(); let embedding_transaction = transaction.start_child( @@ -573,71 +579,21 @@ pub async fn bulk_upload_chunks( embedding_transaction.finish(); - let qdrant_points = tokio_stream::iter(izip!( - inserted_chunk_metadatas.clone(), - embedding_vectors.iter(), - splade_vectors.iter(), - )) - .then(|(chunk_data, embedding_vector, splade_vector)| async { - let qdrant_point_id = chunk_data.chunk_metadata.qdrant_point_id; - - let chunk_tags: Option>> = - if let Some(ref group_ids) = chunk_data.group_ids { - Some( - get_groups_from_group_ids_query(group_ids.clone(), web_pool.clone()) - .await? - .iter() - .filter_map(|group| group.tag_set.clone()) - .flatten() - .dedup() - .collect(), - ) - } else { - None - }; - - let payload = QdrantPayload::new( - chunk_data.chunk_metadata, - chunk_data.group_ids, - None, - chunk_tags, - ) - .into(); - - let vector_name = match embedding_vector.len() { - 384 => "384_vectors", - 512 => "512_vectors", - 768 => "768_vectors", - 1024 => "1024_vectors", - 3072 => "3072_vectors", - 1536 => "1536_vectors", - _ => { - return Err(ServiceError::BadRequest( - "Invalid embedding vector size".into(), - )) - } - }; - - let vector_payload = HashMap::from([ - ( - vector_name.to_string(), - Vector::from(embedding_vector.clone()), - ), - ( - "sparse_vectors".to_string(), - Vector::from(splade_vector.clone()), - ), - ]); - - // If qdrant_point_id does not exist, does not get written to qdrant - Ok(PointStruct::new( - qdrant_point_id.to_string(), - vector_payload, - payload, - )) - }) - .collect::>>() - .await; + let qdrant_points = match embedding_vectors { + Some(embedding_vectors) => { + get_qdrant_points( + &ingestion_data, + embedding_vectors, + splade_vectors, + web_pool.clone(), + ) + .await + } + None => { + get_qdrant_points_from_splade_vecs(&ingestion_data, splade_vectors, web_pool.clone()) + .await + } + }; if qdrant_points.iter().any(|point| point.is_err()) { Err(ServiceError::InternalServerError( @@ -748,42 +704,51 @@ async fn upload_chunk( num_value: payload.chunk.num_value, }; - let embedding_vector = match payload.chunk.split_avg.unwrap_or(false) { + let embedding_vector = match dataset_config.SEMANTIC_ENABLED { true => { - let chunks = coarse_doc_chunker(content.clone(), None, false, 20); - - let embeddings = create_embeddings( - chunks - .iter() - .map(|chunk| (chunk.clone(), payload.chunk.distance_phrase.clone())) - .collect(), - "doc", - dataset_config.clone(), - reqwest_client.clone(), - ) - .await?; - - average_embeddings(embeddings)? - } - false => { - let embedding_vectors = create_embeddings( - vec![(content.clone(), payload.chunk.distance_phrase.clone())], - "doc", - dataset_config.clone(), - reqwest_client.clone(), - ) - .await - .map_err(|err| { - ServiceError::InternalServerError(format!("Failed to create embedding: {:?}", err)) - })?; - - embedding_vectors - .first() - .ok_or(ServiceError::InternalServerError( - "Failed to get first embedding".into(), - ))? - .clone() + let embedding = match payload.chunk.split_avg.unwrap_or(false) { + true => { + let chunks = coarse_doc_chunker(content.clone(), None, false, 20); + + let embeddings = create_embeddings( + chunks + .iter() + .map(|chunk| (chunk.clone(), payload.chunk.distance_phrase.clone())) + .collect(), + "doc", + dataset_config.clone(), + reqwest_client.clone(), + ) + .await?; + + average_embeddings(embeddings)? + } + false => { + let embedding_vectors = create_embeddings( + vec![(content.clone(), payload.chunk.distance_phrase.clone())], + "doc", + dataset_config.clone(), + reqwest_client.clone(), + ) + .await + .map_err(|err| { + ServiceError::InternalServerError(format!( + "Failed to create embedding: {:?}", + err + )) + })?; + + embedding_vectors + .first() + .ok_or(ServiceError::InternalServerError( + "Failed to get first embedding".into(), + ))? + .clone() + } + }; + Some(embedding) } + false => None, }; let splade_vector = if dataset_config.FULLTEXT_ENABLED { @@ -834,24 +799,32 @@ async fn upload_chunk( let payload = QdrantPayload::new(chunk_metadata, payload.chunk.group_ids, None, chunk_tags).into(); - let vector_name = match embedding_vector.len() { - 384 => "384_vectors", - 512 => "512_vectors", - 768 => "768_vectors", - 1024 => "1024_vectors", - 3072 => "3072_vectors", - 1536 => "1536_vectors", - _ => { - return Err(ServiceError::BadRequest( - "Invalid embedding vector size".into(), - )) - } + let vector_name = match &embedding_vector { + Some(embedding_vector) => match embedding_vector.len() { + 384 => Some("384_vectors"), + 512 => Some("512_vectors"), + 768 => Some("768_vectors"), + 1024 => Some("1024_vectors"), + 3072 => Some("3072_vectors"), + 1536 => Some("1536_vectors"), + _ => { + return Err(ServiceError::BadRequest( + "Invalid embedding vector size".into(), + )) + } + }, + None => None, }; - let vector_payload = HashMap::from([ - (vector_name.to_string(), Vector::from(embedding_vector)), - ("sparse_vectors".to_string(), Vector::from(splade_vector)), - ]); + let mut vector_payload = + HashMap::from([("sparse_vectors".to_string(), Vector::from(splade_vector))]); + + if embedding_vector.is_some() && vector_name.is_some() { + vector_payload.insert( + vector_name.unwrap().to_string(), + Vector::from(embedding_vector.unwrap()), + ); + } let point = PointStruct::new(qdrant_point_id.clone().to_string(), vector_payload, payload); let insert_tx = transaction.start_child( @@ -900,14 +873,20 @@ async fn update_chunk( let chunk_metadata = payload.chunk_metadata.clone(); - let embedding_vector = create_embedding( - content.to_string(), - payload.distance_phrase, - "doc", - server_dataset_config.clone(), - ) - .await - .map_err(|err| ServiceError::BadRequest(err.to_string()))?; + let embedding_vector = match server_dataset_config.SEMANTIC_ENABLED { + true => { + let embedding = create_embedding( + content.to_string(), + payload.distance_phrase, + "doc", + server_dataset_config.clone(), + ) + .await + .map_err(|err| ServiceError::BadRequest(err.to_string()))?; + Some(embedding) + } + false => None, + }; let splade_vector = if server_dataset_config.FULLTEXT_ENABLED { let reqwest_client = reqwest::Client::new(); @@ -941,7 +920,7 @@ async fn update_chunk( update_qdrant_point_query( // If the chunk is a collision, we don't want to update the qdrant point chunk_metadata, - Some(embedding_vector), + embedding_vector, Some(chunk_group_ids), payload.dataset_id, splade_vector, @@ -962,7 +941,7 @@ async fn update_chunk( update_qdrant_point_query( // If the chunk is a collision, we don't want to update the qdrant point chunk_metadata, - Some(embedding_vector), + embedding_vector, None, payload.dataset_id, splade_vector, @@ -1057,3 +1036,124 @@ pub async fn readd_error_to_queue( Ok(()) } + +async fn get_qdrant_points( + ingestion_data: &Vec, + embedding_vectors: Vec>, + splade_vectors: Vec>, + web_pool: actix_web::web::Data, +) -> Vec> { + tokio_stream::iter(izip!( + ingestion_data.clone(), + embedding_vectors.iter(), + splade_vectors.iter(), + )) + .then(|(chunk_data, embedding_vector, splade_vector)| async { + let qdrant_point_id = chunk_data.chunk_metadata.qdrant_point_id; + + let chunk_tags: Option>> = + if let Some(ref group_ids) = chunk_data.group_ids { + Some( + get_groups_from_group_ids_query(group_ids.clone(), web_pool.clone()) + .await? + .iter() + .filter_map(|group| group.tag_set.clone()) + .flatten() + .dedup() + .collect(), + ) + } else { + None + }; + + let payload = QdrantPayload::new( + chunk_data.chunk_metadata, + chunk_data.group_ids, + None, + chunk_tags, + ) + .into(); + + let vector_name = match embedding_vector.len() { + 384 => "384_vectors", + 512 => "512_vectors", + 768 => "768_vectors", + 1024 => "1024_vectors", + 3072 => "3072_vectors", + 1536 => "1536_vectors", + _ => { + return Err(ServiceError::BadRequest( + "Invalid embedding vector size".into(), + )) + } + }; + + let vector_payload = HashMap::from([ + ( + vector_name.to_string(), + Vector::from(embedding_vector.clone()), + ), + ( + "sparse_vectors".to_string(), + Vector::from(splade_vector.clone()), + ), + ]); + + // If qdrant_point_id does not exist, does not get written to qdrant + Ok(PointStruct::new( + qdrant_point_id.to_string(), + vector_payload, + payload, + )) + }) + .collect::>>() + .await +} + +async fn get_qdrant_points_from_splade_vecs( + ingestion_data: &Vec, + splade_vectors: Vec>, + web_pool: actix_web::web::Data, +) -> Vec> { + tokio_stream::iter(izip!(ingestion_data.clone(), splade_vectors.iter(),)) + .then(|(chunk_data, splade_vector)| async { + let qdrant_point_id = chunk_data.chunk_metadata.qdrant_point_id; + + let chunk_tags: Option>> = + if let Some(ref group_ids) = chunk_data.group_ids { + Some( + get_groups_from_group_ids_query(group_ids.clone(), web_pool.clone()) + .await? + .iter() + .filter_map(|group| group.tag_set.clone()) + .flatten() + .dedup() + .collect(), + ) + } else { + None + }; + + let payload = QdrantPayload::new( + chunk_data.chunk_metadata, + chunk_data.group_ids, + None, + chunk_tags, + ) + .into(); + + let vector_payload = HashMap::from([( + "sparse_vectors".to_string(), + Vector::from(splade_vector.clone()), + )]); + + // If qdrant_point_id does not exist, does not get written to qdrant + Ok(PointStruct::new( + qdrant_point_id.to_string(), + vector_payload, + payload, + )) + }) + .collect::>>() + .await +}