diff --git a/server/src/bin/bulk-ingestion-worker.rs b/server/src/bin/bulk-ingestion-worker.rs index 5041ba38ef..5125ba1beb 100644 --- a/server/src/bin/bulk-ingestion-worker.rs +++ b/server/src/bin/bulk-ingestion-worker.rs @@ -475,7 +475,7 @@ pub async fn bulk_upload_chunks( // Assuming split average is false, Assume Explicit Vectors don't exist let embedding_vectors = match dataset_config.SEMANTIC_ENABLED { true => { - let embedding_vecs = match create_embeddings( + let vectors = match create_embeddings( content_and_boosts .iter() .map(|(content, _, distance_boost)| (content.clone(), distance_boost.clone())) @@ -492,9 +492,9 @@ pub async fn bulk_upload_chunks( err ))), }?; - Some(embedding_vecs) + vectors.into_iter().map(|vec| Some(vec)).collect() } - false => None, + false => vec![None; content_and_boosts.len()], }; embedding_transaction.finish(); @@ -528,12 +528,13 @@ pub async fn bulk_upload_chunks( embedding_transaction.finish(); - let qdrant_points = match embedding_vectors { - Some(embedding_vectors) => { - get_qdrant_points(&ingestion_data, embedding_vectors, splade_vectors, web_pool).await - } - None => get_qdrant_points_from_splade_vecs(&ingestion_data, splade_vectors, web_pool).await, - }; + let qdrant_points = get_qdrant_points( + &ingestion_data, + embedding_vectors, + splade_vectors, + web_pool.clone(), + ) + .await; if qdrant_points.iter().any(|point| point.is_err()) { Err(ServiceError::InternalServerError( @@ -1008,7 +1009,7 @@ pub async fn readd_error_to_queue( async fn get_qdrant_points( ingestion_data: &Vec, - embedding_vectors: Vec>, + embedding_vectors: Vec>>, splade_vectors: Vec>, web_pool: actix_web::web::Data, ) -> Vec> { @@ -1043,30 +1044,31 @@ async fn get_qdrant_points( ) .into(); - let vector_name = match embedding_vector.len() { - 384 => "384_vectors", - 512 => "512_vectors", - 768 => "768_vectors", - 1024 => "1024_vectors", - 3072 => "3072_vectors", - 1536 => "1536_vectors", - _ => { - return Err(ServiceError::BadRequest( - "Invalid embedding vector size".into(), - )) - } - }; - - let vector_payload = HashMap::from([ - ( - vector_name.to_string(), - Vector::from(embedding_vector.clone()), - ), - ( - "sparse_vectors".to_string(), - Vector::from(splade_vector.clone()), - ), - ]); + let mut vector_payload = HashMap::from([( + "sparse_vectors".to_string(), + Vector::from(splade_vector.clone()), + )]); + + if embedding_vector.is_some() { + let vector = embedding_vector.clone().unwrap(); + let vector_name = match vector.len() { + 384 => "384_vectors", + 512 => "512_vectors", + 768 => "768_vectors", + 1024 => "1024_vectors", + 3072 => "3072_vectors", + 1536 => "1536_vectors", + _ => { + return Err(ServiceError::BadRequest( + "Invalid embedding vector size".into(), + )) + } + }; + vector_payload.insert( + vector_name.to_string().clone(), + Vector::from(vector.clone()), + ); + } // If qdrant_point_id does not exist, does not get written to qdrant Ok(PointStruct::new( @@ -1078,51 +1080,3 @@ async fn get_qdrant_points( .collect::>>() .await } - -async fn get_qdrant_points_from_splade_vecs( - ingestion_data: &Vec, - splade_vectors: Vec>, - web_pool: actix_web::web::Data, -) -> Vec> { - tokio_stream::iter(izip!(ingestion_data.clone(), splade_vectors.iter(),)) - .then(|(chunk_data, splade_vector)| async { - let qdrant_point_id = chunk_data.chunk_metadata.qdrant_point_id; - - let chunk_tags: Option>> = - if let Some(ref group_ids) = chunk_data.group_ids { - Some( - get_groups_from_group_ids_query(group_ids.clone(), web_pool.clone()) - .await? - .iter() - .filter_map(|group| group.tag_set.clone()) - .flatten() - .dedup() - .collect(), - ) - } else { - None - }; - - let payload = QdrantPayload::new( - chunk_data.chunk_metadata, - chunk_data.group_ids, - None, - chunk_tags, - ) - .into(); - - let vector_payload = HashMap::from([( - "sparse_vectors".to_string(), - Vector::from(splade_vector.clone()), - )]); - - // If qdrant_point_id does not exist, does not get written to qdrant - Ok(PointStruct::new( - qdrant_point_id.to_string(), - vector_payload, - payload, - )) - }) - .collect::>>() - .await -} diff --git a/server/src/bin/ingestion-worker.rs b/server/src/bin/ingestion-worker.rs index d6bf22d30f..38bfc8faa3 100644 --- a/server/src/bin/ingestion-worker.rs +++ b/server/src/bin/ingestion-worker.rs @@ -535,9 +535,9 @@ pub async fn bulk_upload_chunks( ))) } }?; - Some(vectors) + vectors.into_iter().map(|vec| Some(vec)).collect() } - false => None, + false => vec![None; content_and_boosts.len()], }; // Assuming split average is false, Assume Explicit Vectors don't exist @@ -579,21 +579,13 @@ pub async fn bulk_upload_chunks( embedding_transaction.finish(); - let qdrant_points = match embedding_vectors { - Some(embedding_vectors) => { - get_qdrant_points( - &ingestion_data, - embedding_vectors, - splade_vectors, - web_pool.clone(), - ) - .await - } - None => { - get_qdrant_points_from_splade_vecs(&ingestion_data, splade_vectors, web_pool.clone()) - .await - } - }; + let qdrant_points = get_qdrant_points( + &ingestion_data, + embedding_vectors, + splade_vectors, + web_pool.clone(), + ) + .await; if qdrant_points.iter().any(|point| point.is_err()) { Err(ServiceError::InternalServerError( @@ -1039,7 +1031,7 @@ pub async fn readd_error_to_queue( async fn get_qdrant_points( ingestion_data: &Vec, - embedding_vectors: Vec>, + embedding_vectors: Vec>>, splade_vectors: Vec>, web_pool: actix_web::web::Data, ) -> Vec> { @@ -1074,30 +1066,31 @@ async fn get_qdrant_points( ) .into(); - let vector_name = match embedding_vector.len() { - 384 => "384_vectors", - 512 => "512_vectors", - 768 => "768_vectors", - 1024 => "1024_vectors", - 3072 => "3072_vectors", - 1536 => "1536_vectors", - _ => { - return Err(ServiceError::BadRequest( - "Invalid embedding vector size".into(), - )) - } - }; - - let vector_payload = HashMap::from([ - ( - vector_name.to_string(), - Vector::from(embedding_vector.clone()), - ), - ( - "sparse_vectors".to_string(), - Vector::from(splade_vector.clone()), - ), - ]); + let mut vector_payload = HashMap::from([( + "sparse_vectors".to_string(), + Vector::from(splade_vector.clone()), + )]); + + if embedding_vector.is_some() { + let vector = embedding_vector.clone().unwrap(); + let vector_name = match vector.len() { + 384 => "384_vectors", + 512 => "512_vectors", + 768 => "768_vectors", + 1024 => "1024_vectors", + 3072 => "3072_vectors", + 1536 => "1536_vectors", + _ => { + return Err(ServiceError::BadRequest( + "Invalid embedding vector size".into(), + )) + } + }; + vector_payload.insert( + vector_name.to_string().clone(), + Vector::from(vector.clone()), + ); + } // If qdrant_point_id does not exist, does not get written to qdrant Ok(PointStruct::new( @@ -1109,51 +1102,3 @@ async fn get_qdrant_points( .collect::>>() .await } - -async fn get_qdrant_points_from_splade_vecs( - ingestion_data: &Vec, - splade_vectors: Vec>, - web_pool: actix_web::web::Data, -) -> Vec> { - tokio_stream::iter(izip!(ingestion_data.clone(), splade_vectors.iter(),)) - .then(|(chunk_data, splade_vector)| async { - let qdrant_point_id = chunk_data.chunk_metadata.qdrant_point_id; - - let chunk_tags: Option>> = - if let Some(ref group_ids) = chunk_data.group_ids { - Some( - get_groups_from_group_ids_query(group_ids.clone(), web_pool.clone()) - .await? - .iter() - .filter_map(|group| group.tag_set.clone()) - .flatten() - .dedup() - .collect(), - ) - } else { - None - }; - - let payload = QdrantPayload::new( - chunk_data.chunk_metadata, - chunk_data.group_ids, - None, - chunk_tags, - ) - .into(); - - let vector_payload = HashMap::from([( - "sparse_vectors".to_string(), - Vector::from(splade_vector.clone()), - )]); - - // If qdrant_point_id does not exist, does not get written to qdrant - Ok(PointStruct::new( - qdrant_point_id.to_string(), - vector_payload, - payload, - )) - }) - .collect::>>() - .await -}