Skip to content

Commit

Permalink
chore: find question message from reply message (#1085)
Browse files Browse the repository at this point in the history
* chore: find question message from answer message id

* chore: sqlx

* test: fix tests

* test: fix test

* chore: apply code suggestions to 2 files
  • Loading branch information
richardshiue authored Dec 18, 2024
1 parent e758f18 commit ecadf8e
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 9 deletions.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 22 additions & 0 deletions libs/client-api/src/http_chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,28 @@ impl Client {
.into_data()
}

pub async fn get_question_message_from_answer_id(
&self,
workspace_id: &str,
chat_id: &str,
answer_message_id: i64,
) -> Result<Option<ChatMessage>, AppResponseError> {
let url = format!(
"{}/api/chat/{workspace_id}/{chat_id}/message/find_question",
self.base_url
);

let resp = self
.http_client_with_auth(Method::GET, &url)
.await?
.query(&[("answer_message_id", answer_message_id)])
.send()
.await?;
AppResponse::<Option<ChatMessage>>::from_response(resp)
.await?
.into_data()
}

pub async fn calculate_similarity(
&self,
params: CalculateSimilarityParams,
Expand Down
37 changes: 37 additions & 0 deletions libs/database/src/chat/chat_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -669,3 +669,40 @@ pub async fn select_chat_message_content<'a, E: Executor<'a, Database = Postgres
.await?;
Ok((row.content, row.meta_data))
}

pub async fn select_chat_message_matching_reply_message_id(
txn: &mut Transaction<'_, Postgres>,
chat_id: &str,
reply_message_id: i64,
) -> Result<Option<ChatMessage>, AppError> {
let chat_id = Uuid::from_str(chat_id)?;
let row = sqlx::query!(
r#"
SELECT message_id, content, created_at, author, meta_data, reply_message_id
FROM af_chat_messages
WHERE chat_id = $1
AND reply_message_id = $2
"#,
&chat_id,
reply_message_id
)
.fetch_one(txn.deref_mut())
.await?;

let message = match serde_json::from_value::<ChatAuthor>(row.author) {
Ok(author) => Some(ChatMessage {
author,
message_id: row.message_id,
content: row.content,
created_at: row.created_at,
meta_data: row.meta_data,
reply_message_id: row.reply_message_id,
}),
Err(err) => {
warn!("Failed to deserialize author: {}", err);
None
},
};

Ok(message)
}
23 changes: 22 additions & 1 deletion src/api/chat.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use crate::biz::chat::ops::{
create_chat, create_chat_message, delete_chat, generate_chat_message_answer, get_chat_messages,
update_chat_message,
get_question_message, update_chat_message,
};
use crate::state::AppState;
use actix_web::web::{Data, Json};
use actix_web::{web, HttpRequest, HttpResponse, Scope};
use serde::Deserialize;

use crate::api::util::ai_model_from_header;
use app_error::AppError;
Expand Down Expand Up @@ -69,6 +70,10 @@ pub fn chat_scope() -> Scope {
web::resource("/{chat_id}/message/answer")
.route(web::post().to(save_answer_handler))
)
.service(
web::resource("/{chat_id}/message/find_question")
.route(web::get().to(get_chat_question_message_handler))
)

// AI response generation
.service(
Expand Down Expand Up @@ -349,6 +354,17 @@ async fn get_chat_message_handler(
Ok(AppResponse::Ok().with_data(messages).into())
}

#[instrument(level = "debug", skip_all, err)]
async fn get_chat_question_message_handler(
path: web::Path<(String, String)>,
query: web::Query<FindQuestionParams>,
state: Data<AppState>,
) -> actix_web::Result<JsonAppResponse<Option<ChatMessage>>> {
let (_workspace_id, chat_id) = path.into_inner();
let message = get_question_message(&state.pg_pool, &chat_id, query.0.answer_message_id).await?;
Ok(AppResponse::Ok().with_data(message).into())
}

#[instrument(level = "debug", skip_all, err)]
async fn get_chat_settings_handler(
path: web::Path<(String, String)>,
Expand Down Expand Up @@ -501,3 +517,8 @@ where
}
}
}

#[derive(Debug, Deserialize)]
struct FindQuestionParams {
answer_message_id: i64,
}
14 changes: 13 additions & 1 deletion src/biz/chat/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use database::chat;
use database::chat::chat_ops::{
delete_answer_message_by_question_message_id, insert_answer_message,
insert_answer_message_with_transaction, insert_chat, insert_question_message,
select_chat_messages,
select_chat_message_matching_reply_message_id, select_chat_messages,
};
use futures::stream::Stream;
use serde_json::json;
Expand Down Expand Up @@ -232,3 +232,15 @@ pub async fn get_chat_messages(
txn.commit().await?;
Ok(messages)
}

pub async fn get_question_message(
pg_pool: &PgPool,
chat_id: &str,
answer_message_id: i64,
) -> Result<Option<ChatMessage>, AppError> {
let mut txn = pg_pool.begin().await?;
let message =
select_chat_message_matching_reply_message_id(&mut txn, chat_id, answer_message_id).await?;
txn.commit().await?;
Ok(message)
}
76 changes: 69 additions & 7 deletions tests/ai_test/chat_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use client_api_test::{ai_test_enabled, TestClient};
use futures_util::StreamExt;
use serde_json::json;
use shared_entity::dto::chat_dto::{
ChatMessageMetadata, ChatRAGData, CreateChatMessageParams, CreateChatParams, MessageCursor,
UpdateChatParams,
ChatMessageMetadata, ChatRAGData, CreateAnswerMessageParams, CreateChatMessageParams,
CreateChatParams, MessageCursor, UpdateChatParams,
};

#[tokio::test]
Expand Down Expand Up @@ -344,6 +344,10 @@ async fn create_chat_context_test() {

// #[tokio::test]
// async fn update_chat_message_test() {
// if !ai_test_enabled() {
// return;
// }

// let test_client = TestClient::new_user_without_ws_conn().await;
// let workspace_id = test_client.workspace_id().await;
// let chat_id = uuid::Uuid::new_v4().to_string();
Expand All @@ -352,13 +356,13 @@ async fn create_chat_context_test() {
// name: "my second chat".to_string(),
// rag_ids: vec![],
// };
//

// test_client
// .api_client
// .create_chat(&workspace_id, params)
// .await
// .unwrap();
//

// let params = CreateChatMessageParams::new_user("where is singapore?");
// let stream = test_client
// .api_client
Expand All @@ -367,7 +371,7 @@ async fn create_chat_context_test() {
// .unwrap();
// let messages: Vec<ChatMessage> = stream.map(|message| message.unwrap()).collect().await;
// assert_eq!(messages.len(), 2);
//

// let params = UpdateChatMessageContentParams {
// chat_id: chat_id.clone(),
// message_id: messages[0].message_id,
Expand All @@ -378,7 +382,7 @@ async fn create_chat_context_test() {
// .update_chat_message(&workspace_id, &chat_id, params)
// .await
// .unwrap();
//

// let remote_messages = test_client
// .api_client
// .get_chat_messages(&workspace_id, &chat_id, MessageCursor::NextBack, 2)
Expand All @@ -387,11 +391,69 @@ async fn create_chat_context_test() {
// .messages;
// assert_eq!(remote_messages[0].content, "where is China?");
// assert_eq!(remote_messages.len(), 2);
//

// // when the question was updated, the answer should be different
// assert_ne!(remote_messages[1].content, messages[1].content);
// }

#[tokio::test]
async fn get_question_message_test() {
if !ai_test_enabled() {
return;
}

let test_client = TestClient::new_user_without_ws_conn().await;
let workspace_id = test_client.workspace_id().await;
let chat_id = uuid::Uuid::new_v4().to_string();
let params = CreateChatParams {
chat_id: chat_id.clone(),
name: "my ai chat".to_string(),
rag_ids: vec![],
};

test_client
.api_client
.create_chat(&workspace_id, params)
.await
.unwrap();

let params = CreateChatMessageParams::new_user("where is singapore?");
let question = test_client
.api_client
.create_question(&workspace_id, &chat_id, params)
.await
.unwrap();

let answer = test_client
.api_client
.get_answer(&workspace_id, &chat_id, question.message_id)
.await
.unwrap();

test_client
.api_client
.save_answer(
&workspace_id,
&chat_id,
CreateAnswerMessageParams {
content: answer.content,
metadata: None,
question_message_id: question.message_id,
},
)
.await
.unwrap();

let find_question = test_client
.api_client
.get_question_message_from_answer_id(&workspace_id, &chat_id, answer.message_id)
.await
.unwrap()
.unwrap();

assert_eq!(find_question.reply_message_id.unwrap(), answer.message_id);
}

async fn collect_answer(mut stream: QuestionStream) -> String {
let mut answer = String::new();
while let Some(value) = stream.next().await {
Expand Down

0 comments on commit ecadf8e

Please sign in to comment.