Skip to content

Commit

Permalink
refactor ingestion workers
Browse files Browse the repository at this point in the history
  • Loading branch information
aaryanpunia authored and cdxker committed Jul 5, 2024
1 parent 93fda63 commit 30c6e79
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 172 deletions.
118 changes: 36 additions & 82 deletions server/src/bin/bulk-ingestion-worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ pub async fn bulk_upload_chunks(
// Assuming split average is false, Assume Explicit Vectors don't exist
let embedding_vectors = match dataset_config.SEMANTIC_ENABLED {
true => {
let embedding_vecs = match create_embeddings(
let vectors = match create_embeddings(
content_and_boosts
.iter()
.map(|(content, _, distance_boost)| (content.clone(), distance_boost.clone()))
Expand All @@ -492,9 +492,9 @@ pub async fn bulk_upload_chunks(
err
))),
}?;
Some(embedding_vecs)
vectors.into_iter().map(|vec| Some(vec)).collect()
}
false => None,
false => vec![None; content_and_boosts.len()],
};

embedding_transaction.finish();
Expand Down Expand Up @@ -528,12 +528,13 @@ pub async fn bulk_upload_chunks(

embedding_transaction.finish();

let qdrant_points = match embedding_vectors {
Some(embedding_vectors) => {
get_qdrant_points(&ingestion_data, embedding_vectors, splade_vectors, web_pool).await
}
None => get_qdrant_points_from_splade_vecs(&ingestion_data, splade_vectors, web_pool).await,
};
let qdrant_points = get_qdrant_points(
&ingestion_data,
embedding_vectors,
splade_vectors,
web_pool.clone(),
)
.await;

if qdrant_points.iter().any(|point| point.is_err()) {
Err(ServiceError::InternalServerError(
Expand Down Expand Up @@ -1008,7 +1009,7 @@ pub async fn readd_error_to_queue(

async fn get_qdrant_points(
ingestion_data: &Vec<ChunkData>,
embedding_vectors: Vec<Vec<f32>>,
embedding_vectors: Vec<Option<Vec<f32>>>,
splade_vectors: Vec<Vec<(u32, f32)>>,
web_pool: actix_web::web::Data<models::Pool>,
) -> Vec<Result<PointStruct, ServiceError>> {
Expand Down Expand Up @@ -1043,30 +1044,31 @@ async fn get_qdrant_points(
)
.into();

let vector_name = match embedding_vector.len() {
384 => "384_vectors",
512 => "512_vectors",
768 => "768_vectors",
1024 => "1024_vectors",
3072 => "3072_vectors",
1536 => "1536_vectors",
_ => {
return Err(ServiceError::BadRequest(
"Invalid embedding vector size".into(),
))
}
};

let vector_payload = HashMap::from([
(
vector_name.to_string(),
Vector::from(embedding_vector.clone()),
),
(
"sparse_vectors".to_string(),
Vector::from(splade_vector.clone()),
),
]);
let mut vector_payload = HashMap::from([(
"sparse_vectors".to_string(),
Vector::from(splade_vector.clone()),
)]);

if embedding_vector.is_some() {
let vector = embedding_vector.clone().unwrap();
let vector_name = match vector.len() {
384 => "384_vectors",
512 => "512_vectors",
768 => "768_vectors",
1024 => "1024_vectors",
3072 => "3072_vectors",
1536 => "1536_vectors",
_ => {
return Err(ServiceError::BadRequest(
"Invalid embedding vector size".into(),
))
}
};
vector_payload.insert(
vector_name.to_string().clone(),
Vector::from(vector.clone()),
);
}

// If qdrant_point_id does not exist, does not get written to qdrant
Ok(PointStruct::new(
Expand All @@ -1078,51 +1080,3 @@ async fn get_qdrant_points(
.collect::<Vec<Result<PointStruct, ServiceError>>>()
.await
}

async fn get_qdrant_points_from_splade_vecs(
ingestion_data: &Vec<ChunkData>,
splade_vectors: Vec<Vec<(u32, f32)>>,
web_pool: actix_web::web::Data<models::Pool>,
) -> Vec<Result<PointStruct, ServiceError>> {
tokio_stream::iter(izip!(ingestion_data.clone(), splade_vectors.iter(),))
.then(|(chunk_data, splade_vector)| async {
let qdrant_point_id = chunk_data.chunk_metadata.qdrant_point_id;

let chunk_tags: Option<Vec<Option<String>>> =
if let Some(ref group_ids) = chunk_data.group_ids {
Some(
get_groups_from_group_ids_query(group_ids.clone(), web_pool.clone())
.await?
.iter()
.filter_map(|group| group.tag_set.clone())
.flatten()
.dedup()
.collect(),
)
} else {
None
};

let payload = QdrantPayload::new(
chunk_data.chunk_metadata,
chunk_data.group_ids,
None,
chunk_tags,
)
.into();

let vector_payload = HashMap::from([(
"sparse_vectors".to_string(),
Vector::from(splade_vector.clone()),
)]);

// If qdrant_point_id does not exist, does not get written to qdrant
Ok(PointStruct::new(
qdrant_point_id.to_string(),
vector_payload,
payload,
))
})
.collect::<Vec<Result<PointStruct, ServiceError>>>()
.await
}
125 changes: 35 additions & 90 deletions server/src/bin/ingestion-worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -535,9 +535,9 @@ pub async fn bulk_upload_chunks(
)))
}
}?;
Some(vectors)
vectors.into_iter().map(|vec| Some(vec)).collect()
}
false => None,
false => vec![None; content_and_boosts.len()],
};

// Assuming split average is false, Assume Explicit Vectors don't exist
Expand Down Expand Up @@ -579,21 +579,13 @@ pub async fn bulk_upload_chunks(

embedding_transaction.finish();

let qdrant_points = match embedding_vectors {
Some(embedding_vectors) => {
get_qdrant_points(
&ingestion_data,
embedding_vectors,
splade_vectors,
web_pool.clone(),
)
.await
}
None => {
get_qdrant_points_from_splade_vecs(&ingestion_data, splade_vectors, web_pool.clone())
.await
}
};
let qdrant_points = get_qdrant_points(
&ingestion_data,
embedding_vectors,
splade_vectors,
web_pool.clone(),
)
.await;

if qdrant_points.iter().any(|point| point.is_err()) {
Err(ServiceError::InternalServerError(
Expand Down Expand Up @@ -1039,7 +1031,7 @@ pub async fn readd_error_to_queue(

async fn get_qdrant_points(
ingestion_data: &Vec<ChunkData>,
embedding_vectors: Vec<Vec<f32>>,
embedding_vectors: Vec<Option<Vec<f32>>>,
splade_vectors: Vec<Vec<(u32, f32)>>,
web_pool: actix_web::web::Data<models::Pool>,
) -> Vec<Result<PointStruct, ServiceError>> {
Expand Down Expand Up @@ -1074,30 +1066,31 @@ async fn get_qdrant_points(
)
.into();

let vector_name = match embedding_vector.len() {
384 => "384_vectors",
512 => "512_vectors",
768 => "768_vectors",
1024 => "1024_vectors",
3072 => "3072_vectors",
1536 => "1536_vectors",
_ => {
return Err(ServiceError::BadRequest(
"Invalid embedding vector size".into(),
))
}
};

let vector_payload = HashMap::from([
(
vector_name.to_string(),
Vector::from(embedding_vector.clone()),
),
(
"sparse_vectors".to_string(),
Vector::from(splade_vector.clone()),
),
]);
let mut vector_payload = HashMap::from([(
"sparse_vectors".to_string(),
Vector::from(splade_vector.clone()),
)]);

if embedding_vector.is_some() {
let vector = embedding_vector.clone().unwrap();
let vector_name = match vector.len() {
384 => "384_vectors",
512 => "512_vectors",
768 => "768_vectors",
1024 => "1024_vectors",
3072 => "3072_vectors",
1536 => "1536_vectors",
_ => {
return Err(ServiceError::BadRequest(
"Invalid embedding vector size".into(),
))
}
};
vector_payload.insert(
vector_name.to_string().clone(),
Vector::from(vector.clone()),
);
}

// If qdrant_point_id does not exist, does not get written to qdrant
Ok(PointStruct::new(
Expand All @@ -1109,51 +1102,3 @@ async fn get_qdrant_points(
.collect::<Vec<Result<PointStruct, ServiceError>>>()
.await
}

async fn get_qdrant_points_from_splade_vecs(
ingestion_data: &Vec<ChunkData>,
splade_vectors: Vec<Vec<(u32, f32)>>,
web_pool: actix_web::web::Data<models::Pool>,
) -> Vec<Result<PointStruct, ServiceError>> {
tokio_stream::iter(izip!(ingestion_data.clone(), splade_vectors.iter(),))
.then(|(chunk_data, splade_vector)| async {
let qdrant_point_id = chunk_data.chunk_metadata.qdrant_point_id;

let chunk_tags: Option<Vec<Option<String>>> =
if let Some(ref group_ids) = chunk_data.group_ids {
Some(
get_groups_from_group_ids_query(group_ids.clone(), web_pool.clone())
.await?
.iter()
.filter_map(|group| group.tag_set.clone())
.flatten()
.dedup()
.collect(),
)
} else {
None
};

let payload = QdrantPayload::new(
chunk_data.chunk_metadata,
chunk_data.group_ids,
None,
chunk_tags,
)
.into();

let vector_payload = HashMap::from([(
"sparse_vectors".to_string(),
Vector::from(splade_vector.clone()),
)]);

// If qdrant_point_id does not exist, does not get written to qdrant
Ok(PointStruct::new(
qdrant_point_id.to_string(),
vector_payload,
payload,
))
})
.collect::<Vec<Result<PointStruct, ServiceError>>>()
.await
}

0 comments on commit 30c6e79

Please sign in to comment.