Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bugfix: chunks get inserted in completion_first manner when completion_first = true #2245

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions frontends/chat/src/components/Atoms/AfMessage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,13 @@ export const AfMessage = (props: AfMessageProps) => {
const split_content = props.content.split("||");
let content = props.content;
if (split_content.length > 1) {
setChunkMetadatas(JSON.parse(split_content[0]));
content = split_content[1];
if (split_content[0].startsWith("[{")) {
setChunkMetadatas(JSON.parse(split_content[0]));
content = split_content[1];
} else {
content = split_content[0];
setChunkMetadatas(JSON.parse(split_content[1]));
}
} else if (props.content.length > 25) {
return {
content:
Expand Down
25 changes: 25 additions & 0 deletions frontends/chat/src/components/Layouts/MainLayout.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ const MainLayout = (props: LayoutProps) => {
const [concatUserMessagesQuery, setConcatUserMessagesQuery] = createSignal<
boolean | null
>(null);

const [streamCompletionsFirst, setStreamCompletionsFirst] = createSignal<
boolean | null
>(null);

const [pageSize, setPageSize] = createSignal<number | null>(null);
const [searchQuery, setSearchQuery] = createSignal<string | null>(null);
const [minScore, setMinScore] = createSignal<number | null>(null);
Expand Down Expand Up @@ -193,6 +198,9 @@ const MainLayout = (props: LayoutProps) => {
new_message_content,
topic_id: finalTopicId,
system_prompt: systemPrompt(),
llm_options: {
completion_first: streamCompletionsFirst(),
},
}),
signal: completionAbortController().signal,
});
Expand Down Expand Up @@ -304,6 +312,9 @@ const MainLayout = (props: LayoutProps) => {
new_message_content: content,
message_sort_order: idx(),
topic_id: props.selectedTopic?.id,
llm_options: {
completion_first: streamCompletionsFirst(),
},
}),
})
.then((response) => {
Expand Down Expand Up @@ -384,6 +395,20 @@ const MainLayout = (props: LayoutProps) => {
tabIndex={0}
>
<div class="flex flex-col gap-2">
<div class="flex w-full items-center gap-x-2">
<label for="stream_completion_first">
Stream Completions First
</label>
<input
type="checkbox"
id="stream_completion_first"
class="h-4 w-4 rounded-md border border-neutral-300 bg-neutral-100 p-1 dark:border-neutral-900 dark:bg-neutral-800"
checked={streamCompletionsFirst() ?? false}
onChange={(e) => {
setStreamCompletionsFirst(e.target.checked);
}}
/>
</div>
<div class="flex w-full items-center gap-x-2">
<label for="concat_user_messages">
Concatenate User Messages:
Expand Down
157 changes: 59 additions & 98 deletions server/src/handlers/message_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,12 +185,23 @@ pub async fn create_message(
.map(|message| {
let mut message = message;
if message.role == "assistant" {
message.content = message
.content
.split("||")
.last()
.unwrap_or("I give up, I can't find chunks for this message")
.to_string();
if message.content.starts_with("[{") {
// This is (chunks, content)
message.content = message
.content
.split("||")
.last()
.unwrap_or("I give up, I can't find a citation")
.to_string();
} else {
// This is (content, chunks)
message.content = message
.content
.rsplit("||")
.last()
.unwrap_or("I give up, I can't find a citation")
.to_string();
}
}
message
})
Expand Down Expand Up @@ -247,8 +258,23 @@ pub async fn get_all_topic_messages(
) -> Result<HttpResponse, actix_web::Error> {
let topic_id: uuid::Uuid = messages_topic_id.into_inner();

let messages =
get_messages_for_topic_query(topic_id, dataset_org_plan_sub.dataset.id, &pool).await?;
let messages: Vec<models::Message> =
get_messages_for_topic_query(topic_id, dataset_org_plan_sub.dataset.id, &pool)
.await?
.into_iter()
.filter_map(|mut message| {
if message.content.starts_with("||[{") {
match message.content.rsplit_once("}]") {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the format string }} = a single }. Not sure why they don't just do a } but yeah.

Some((chunks, ai_message)) => {
message.content = format!("{}{}}}]", ai_message, chunks);
}
_ => return None,
}
}

Some(message)
})
.collect();

Ok(HttpResponse::Ok().json(messages))
}
Expand Down Expand Up @@ -471,12 +497,30 @@ pub async fn regenerate_message_patch(
.map(|message| {
let mut message = message;
if message.role == "assistant" {
message.content = message
.content
.split("||")
.last()
.unwrap_or("I give up, I can't find a citation")
.to_string();
if message.content.starts_with("||[{") {
match message.content.rsplit_once("}]") {
Some((_, ai_message)) => {
message.content = ai_message.to_string();
}
_ => return message,
}
} else if message.content.starts_with("[{") {
// This is (chunks, content)
message.content = message
.content
.split("||")
.last()
.unwrap_or("I give up, I can't find a citation")
.to_string();
} else {
// This is (content, chunks)
message.content = message
.content
.rsplit("||")
.last()
.unwrap_or("I give up, I can't find a citation")
.to_string();
}
}
message
})
Expand Down Expand Up @@ -549,90 +593,7 @@ pub async fn regenerate_message(
pool: web::Data<Pool>,
event_queue: web::Data<EventQueue>,
) -> Result<HttpResponse, actix_web::Error> {
let topic_id = data.topic_id;
let dataset_config =
DatasetConfiguration::from_json(dataset_org_plan_sub.dataset.server_configuration.clone());

check_completion_param_validity(data.llm_options.clone())?;

let get_messages_pool = pool.clone();
let create_message_pool = pool.clone();
let dataset_id = dataset_org_plan_sub.dataset.id;

let mut previous_messages =
get_topic_messages(topic_id, dataset_id, &get_messages_pool).await?;

if previous_messages.len() < 2 {
return Err(
ServiceError::BadRequest("Not enough messages to regenerate".to_string()).into(),
);
}

if previous_messages.len() == 2 {
return stream_response(
previous_messages,
topic_id,
dataset_org_plan_sub.dataset,
create_message_pool,
event_queue,
dataset_config,
data.into_inner().into(),
)
.await;
}

// remove citations from the previous messages
previous_messages = previous_messages
.into_iter()
.map(|message| {
let mut message = message;
if message.role == "assistant" {
message.content = message
.content
.split("||")
.last()
.unwrap_or("I give up, I can't find a citation")
.to_string();
}
message
})
.collect::<Vec<models::Message>>();

let mut message_to_regenerate = None;
for message in previous_messages.iter().rev() {
if message.role == "assistant" {
message_to_regenerate = Some(message.clone());
break;
}
}

let message_id = match message_to_regenerate {
Some(message) => message.id,
None => {
return Err(ServiceError::BadRequest("No message to regenerate".to_string()).into());
}
};

let mut previous_messages_to_regenerate = Vec::new();
for message in previous_messages.iter() {
if message.id == message_id {
break;
}
previous_messages_to_regenerate.push(message.clone());
}

delete_message_query(message_id, topic_id, dataset_id, &pool).await?;

stream_response(
previous_messages_to_regenerate,
topic_id,
dataset_org_plan_sub.dataset,
create_message_pool,
event_queue,
dataset_config,
data.into_inner().into(),
)
.await
regenerate_message_patch(data, user, dataset_org_plan_sub, pool, event_queue).await
}

#[derive(Deserialize, Serialize, Debug, ToSchema)]
Expand Down
15 changes: 14 additions & 1 deletion server/src/operators/message_operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -611,12 +611,25 @@ pub async fn stream_response(
let (s, r) = unbounded::<String>();
let stream = client.chat().create_stream(parameters).await.unwrap();

let completion_first = create_message_req_payload
.llm_options
.as_ref()
.map(|x| x.completion_first)
.unwrap_or(Some(false))
.unwrap_or(false);

Arbiter::new().spawn(async move {
let chunk_v: Vec<String> = r.iter().collect();
let completion = chunk_v.join("");

let message_to_be_stored = if completion_first {
format!("{}{}", completion, chunk_metadatas_stringified)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes the most sense to store the chunks in the completion order the user wanted them to be streamed in. (Will likely work a lot better for the front ends to be modular.

} else {
format!("{}{}", chunk_metadatas_stringified, completion)
};

let new_message = models::Message::from_details(
format!("{}{}", chunk_metadatas_stringified, completion),
message_to_be_stored,
topic_id,
next_message_order().try_into().unwrap(),
"assistant".to_string(),
Expand Down
Loading
Loading