Skip to content

Commit

Permalink
wip: modify embedding server calls
Browse files Browse the repository at this point in the history
  • Loading branch information
aaryanpunia authored and cdxker committed Jul 5, 2024
1 parent 69b3b69 commit 08f509c
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 67 deletions.
156 changes: 89 additions & 67 deletions server/src/bin/bulk-ingestion-worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ pub async fn bulk_upload_chunks(
let chunk_metadata = ChunkMetadata {
id: message.ingest_specific_chunk_metadata.id,
link: message.chunk.link.clone(),
qdrant_point_id: qdrant_point_id,
qdrant_point_id,
created_at: chrono::Utc::now().naive_local(),
updated_at: chrono::Utc::now().naive_local(),
chunk_html: message.chunk.chunk_html.clone(),
Expand Down Expand Up @@ -473,23 +473,29 @@ pub async fn bulk_upload_chunks(
);

// Assuming split average is false, Assume Explicit Vectors don't exist
let embedding_vectors = match create_embeddings(
content_and_boosts
.iter()
.map(|(content, _, distance_boost)| (content.clone(), distance_boost.clone()))
.collect(),
"doc",
dataset_config.clone(),
reqwest_client.clone(),
)
.await
{
Ok(vectors) => Ok(vectors),
Err(err) => Err(ServiceError::InternalServerError(format!(
"Failed to create embeddings: {:?}",
err
))),
}?;
let embedding_vectors = match dataset_config.SEMANTIC_ENABLED {
true => {
let embedding_vecs = match create_embeddings(
content_and_boosts
.iter()
.map(|(content, _, distance_boost)| (content.clone(), distance_boost.clone()))
.collect(),
"doc",
dataset_config.clone(),
reqwest_client.clone(),
)
.await
{
Ok(vectors) => Ok(vectors),
Err(err) => Err(ServiceError::InternalServerError(format!(
"Failed to create embeddings: {:?}",
err
))),
}?;
Some(embedding_vecs)
}
false => None,
};

embedding_transaction.finish();

Expand Down Expand Up @@ -716,42 +722,50 @@ async fn upload_chunk(
num_value: payload.chunk.num_value,
};

let embedding_vector = match payload.chunk.split_avg.unwrap_or(false) {
true => {
let chunks = coarse_doc_chunker(content.clone(), None, false, 20);
let embedding_vector = match dataset_config.FULLTEXT_ENABLED {
true => match payload.chunk.split_avg.unwrap_or(false) {
true => {
let chunks = coarse_doc_chunker(content.clone(), None, false, 20);

let embeddings = create_embeddings(
chunks
.iter()
.map(|chunk| (chunk.clone(), payload.chunk.distance_phrase.clone()))
.collect(),
"doc",
dataset_config.clone(),
reqwest_client.clone(),
)
.await?;
let embeddings = create_embeddings(
chunks
.iter()
.map(|chunk| (chunk.clone(), payload.chunk.distance_phrase.clone()))
.collect(),
"doc",
dataset_config.clone(),
reqwest_client.clone(),
)
.await?;

average_embeddings(embeddings)?
}
false => {
let embedding_vectors = create_embeddings(
vec![(content.clone(), payload.chunk.distance_phrase.clone())],
"doc",
dataset_config.clone(),
reqwest_client.clone(),
)
.await
.map_err(|err| {
ServiceError::InternalServerError(format!("Failed to create embedding: {:?}", err))
})?;
Some(average_embeddings(embeddings)?)
}
false => {
let embedding_vectors = create_embeddings(
vec![(content.clone(), payload.chunk.distance_phrase.clone())],
"doc",
dataset_config.clone(),
reqwest_client.clone(),
)
.await
.map_err(|err| {
ServiceError::InternalServerError(format!(
"Failed to create embedding: {:?}",
err
))
})?;

embedding_vectors
.first()
.ok_or(ServiceError::InternalServerError(
"Failed to get first embedding".into(),
))?
.clone()
}
Some(
embedding_vectors
.first()
.ok_or(ServiceError::InternalServerError(
"Failed to get first embedding".into(),
))?
.clone(),
)
}
},
false => None,
};

let splade_vector = if dataset_config.FULLTEXT_ENABLED {
Expand Down Expand Up @@ -797,24 +811,32 @@ async fn upload_chunk(
)
.into();

let vector_name = match embedding_vector.len() {
384 => "384_vectors",
512 => "512_vectors",
768 => "768_vectors",
1024 => "1024_vectors",
3072 => "3072_vectors",
1536 => "1536_vectors",
_ => {
return Err(ServiceError::BadRequest(
"Invalid embedding vector size".into(),
))
}
let vector_name = match &embedding_vector {
Some(embedding_vector) => match embedding_vector.len() {
384 => Some("384_vectors"),
512 => Some("512_vectors"),
768 => Some("768_vectors"),
1024 => Some("1024_vectors"),
3072 => Some("3072_vectors"),
1536 => Some("1536_vectors"),
_ => {
return Err(ServiceError::BadRequest(
"Invalid embedding vector size".into(),
))
}
},
None => None,
};

let vector_payload = HashMap::from([
(vector_name.to_string(), Vector::from(embedding_vector)),
("sparse_vectors".to_string(), Vector::from(splade_vector)),
]);
let mut vector_payload =
HashMap::from([("sparse_vectors".to_string(), Vector::from(splade_vector))]);

if embedding_vector.is_some() && vector_name.is_some() {
vector_payload.insert(
vector_name.unwrap().to_string(),
Vector::from(embedding_vector.unwrap()),
);
}

let point = PointStruct::new(
qdrant_point_id.clone().to_string(),
Expand Down
18 changes: 18 additions & 0 deletions server/src/handlers/chunk_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1093,6 +1093,12 @@ pub async fn search_chunks(
.await?
}
"semantic" => {
if !server_dataset_config.SEMANTIC_ENABLED {
return Err(ServiceError::BadRequest(
"Semantic search is not enabled for this dataset".into(),
)
.into());
}
search_semantic_chunks(
data.clone(),
parsed_query,
Expand Down Expand Up @@ -1325,6 +1331,12 @@ pub async fn autocomplete(
.await?
}
"semantic" => {
if !server_dataset_config.SEMANTIC_ENABLED {
return Err(ServiceError::BadRequest(
"Semantic search is not enabled for this dataset".into(),
)
.into());
}
autocomplete_semantic_chunks(
data.clone(),
parsed_query,
Expand Down Expand Up @@ -1519,6 +1531,12 @@ pub async fn count_chunks(
.await?
}
"semantic" => {
if !server_dataset_config.SEMANTIC_ENABLED {
return Err(ServiceError::BadRequest(
"Semantic search is not enabled for this dataset".into(),
)
.into());
}
count_semantic_chunks(
search_req_data.clone(),
parsed_query,
Expand Down
12 changes: 12 additions & 0 deletions server/src/handlers/group_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1263,6 +1263,12 @@ pub async fn search_within_group(
.await?
}
_ => {
if !server_dataset_config.SEMANTIC_ENABLED {
return Err(ServiceError::BadRequest(
"Semantic search is not enabled for this dataset".into(),
)
.into());
}
search_semantic_groups(
data.clone(),
parsed_query,
Expand Down Expand Up @@ -1405,6 +1411,12 @@ pub async fn search_over_groups(
.await?
}
_ => {
if !server_dataset_config.SEMANTIC_ENABLED {
return Err(ServiceError::BadRequest(
"Semantic search is not enabled for this dataset".into(),
)
.into());
}
semantic_search_over_groups(
data.clone(),
parsed_query,
Expand Down

0 comments on commit 08f509c

Please sign in to comment.