From 14adf250a87b10afe1c5dd169f6a3049d37b08ca Mon Sep 17 00:00:00 2001 From: "Nathan.fooo" <86001920+appflowy@users.noreply.github.com> Date: Mon, 11 Nov 2024 20:10:46 +0800 Subject: [PATCH] chore: update ai api endpoint & rename structs (#953) * chore: revmap loader type * chore: revamp context * chore: fix test * chore: remove unused handler * chore: create new chat api endpoint * chore: add docs * chore: clippy * chore: fix test --- Cargo.lock | 3 + libs/appflowy-ai-client/src/client.rs | 7 +- libs/appflowy-ai-client/src/dto.rs | 75 +--- .../tests/chat_test/context_test.rs | 6 +- libs/client-api/src/http_chat.rs | 2 +- libs/database-entity/Cargo.toml | 1 + libs/database-entity/src/dto.rs | 317 +-------------- libs/database-entity/src/lib.rs | 1 - libs/database/Cargo.toml | 1 + libs/database/src/chat/chat_ops.rs | 12 +- libs/database/src/index/search_ops.rs | 7 + libs/infra/Cargo.toml | 4 + libs/infra/src/lib.rs | 1 + .../src/util.rs => infra/src/validate.rs} | 4 +- libs/shared-entity/Cargo.toml | 6 +- libs/shared-entity/src/dto/ai_dto.rs | 2 +- libs/shared-entity/src/dto/chat_dto.rs | 363 ++++++++++++++++++ libs/shared-entity/src/dto/mod.rs | 1 + script/client_api_deps_check.sh | 2 +- services/appflowy-worker/Cargo.toml | 2 +- src/api/chat.rs | 189 ++++----- src/biz/chat/ops.rs | 80 +--- tests/ai_test/chat_test.rs | 27 +- tests/sql_test/chat_test.rs | 6 +- 24 files changed, 526 insertions(+), 593 deletions(-) rename libs/{database-entity/src/util.rs => infra/src/validate.rs} (59%) create mode 100644 libs/shared-entity/src/dto/chat_dto.rs diff --git a/Cargo.lock b/Cargo.lock index 200303262..7824ccceb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2981,6 +2981,7 @@ dependencies = [ "serde", "serde_json", "sha2", + "shared-entity", "sqlx", "tokio", "tonic-proto", @@ -3000,6 +3001,7 @@ dependencies = [ "bytes", "chrono", "collab-entity", + "infra", "prost", "serde", "serde_json", @@ -4197,6 +4199,7 @@ dependencies = [ "serde_json", "tokio", "tracing", + "validator", ] [[package]] diff --git a/libs/appflowy-ai-client/src/client.rs b/libs/appflowy-ai-client/src/client.rs index 1110dda74..7c57425de 100644 --- a/libs/appflowy-ai-client/src/client.rs +++ b/libs/appflowy-ai-client/src/client.rs @@ -1,5 +1,5 @@ use crate::dto::{ - AIModel, ChatAnswer, ChatQuestion, CompleteTextResponse, CompletionType, CreateTextChatContext, + AIModel, ChatAnswer, ChatQuestion, CompleteTextResponse, CompletionType, CreateChatContext, CustomPrompt, Document, EmbeddingRequest, EmbeddingResponse, LocalAIConfig, MessageData, RepeatedLocalAIPackage, RepeatedRelatedQuestion, SearchDocumentsRequest, SummarizeRowResponse, TranslateRowData, TranslateRowResponse, @@ -188,10 +188,7 @@ impl AppFlowyAIClient { .into_data() } - pub async fn create_chat_text_context( - &self, - context: CreateTextChatContext, - ) -> Result<(), AIError> { + pub async fn create_chat_text_context(&self, context: CreateChatContext) -> Result<(), AIError> { let url = format!("{}/chat/context/text", self.url); let resp = self .http_client(Method::POST, &url)? diff --git a/libs/appflowy-ai-client/src/dto.rs b/libs/appflowy-ai-client/src/dto.rs index ac01caeff..19b1dbbcd 100644 --- a/libs/appflowy-ai-client/src/dto.rs +++ b/libs/appflowy-ai-client/src/dto.rs @@ -1,4 +1,5 @@ -use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use serde::{Deserialize, Serialize, Serializer}; +use serde_json::json; use serde_repr::{Deserialize_repr, Serialize_repr}; use std::collections::HashMap; use std::fmt::{Display, Formatter}; @@ -274,87 +275,35 @@ pub struct LocalAIConfig { pub plugin: AppFlowyOfflineAI, } -#[derive(Debug, Clone)] -pub enum ChatContextLoader { - Txt, - Markdown, -} - -impl Display for ChatContextLoader { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - ChatContextLoader::Txt => write!(f, "text"), - ChatContextLoader::Markdown => write!(f, "markdown"), - } - } -} - -impl FromStr for ChatContextLoader { - type Err = anyhow::Error; - - fn from_str(s: &str) -> Result { - match s { - "text" => Ok(ChatContextLoader::Txt), - "markdown" => Ok(ChatContextLoader::Markdown), - _ => Err(anyhow::anyhow!("unknown context loader type")), - } - } -} -impl Serialize for ChatContextLoader { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - match self { - ChatContextLoader::Txt => serializer.serialize_str("text"), - ChatContextLoader::Markdown => serializer.serialize_str("markdown"), - } - } -} - -impl<'de> Deserialize<'de> for ChatContextLoader { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let s = String::deserialize(deserializer)?; - match s.as_str() { - "text" => Ok(ChatContextLoader::Txt), - "markdown" => Ok(ChatContextLoader::Markdown), - _ => Err(serde::de::Error::custom("unknown value")), - } - } -} - #[derive(Clone, Debug, Serialize, Deserialize)] -pub struct CreateTextChatContext { +pub struct CreateChatContext { pub chat_id: String, - pub context_loader: ChatContextLoader, + pub context_loader: String, pub content: String, pub chunk_size: i32, pub chunk_overlap: i32, - pub metadata: HashMap, + pub metadata: serde_json::Value, } -impl CreateTextChatContext { - pub fn new(chat_id: String, context_loader: ChatContextLoader, text: String) -> Self { - CreateTextChatContext { +impl CreateChatContext { + pub fn new(chat_id: String, context_loader: String, text: String) -> Self { + CreateChatContext { chat_id, context_loader, content: text, chunk_size: 2000, chunk_overlap: 20, - metadata: HashMap::new(), + metadata: json!({}), } } - pub fn with_metadata(mut self, metadata: HashMap) -> Self { - self.metadata = metadata; + pub fn with_metadata(mut self, metadata: T) -> Self { + self.metadata = json!(metadata); self } } -impl Display for CreateTextChatContext { +impl Display for CreateChatContext { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.write_fmt(format_args!( "Create Chat context: {{ chat_id: {}, content_type: {}, content size: {}, metadata: {:?} }}", diff --git a/libs/appflowy-ai-client/tests/chat_test/context_test.rs b/libs/appflowy-ai-client/tests/chat_test/context_test.rs index 61ba07171..d9a265725 100644 --- a/libs/appflowy-ai-client/tests/chat_test/context_test.rs +++ b/libs/appflowy-ai-client/tests/chat_test/context_test.rs @@ -1,12 +1,12 @@ use crate::appflowy_ai_client; -use appflowy_ai_client::dto::{AIModel, ChatContextLoader, CreateTextChatContext}; +use appflowy_ai_client::dto::{AIModel, CreateChatContext}; #[tokio::test] async fn create_chat_context_test() { let client = appflowy_ai_client(); let chat_id = uuid::Uuid::new_v4().to_string(); - let context = CreateTextChatContext { + let context = CreateChatContext { chat_id: chat_id.clone(), - context_loader: ChatContextLoader::Txt, + context_loader: "text".to_string(), content: "I have lived in the US for five years".to_string(), chunk_size: 1000, chunk_overlap: 20, diff --git a/libs/client-api/src/http_chat.rs b/libs/client-api/src/http_chat.rs index 1cd4e7165..06a410d86 100644 --- a/libs/client-api/src/http_chat.rs +++ b/libs/client-api/src/http_chat.rs @@ -1,7 +1,7 @@ use crate::http::log_request_id; use crate::Client; -use client_api_entity::{ +use client_api_entity::chat_dto::{ ChatMessage, CreateAnswerMessageParams, CreateChatMessageParams, CreateChatParams, MessageCursor, RepeatedChatMessage, UpdateChatMessageContentParams, }; diff --git a/libs/database-entity/Cargo.toml b/libs/database-entity/Cargo.toml index 17d6457ff..6d5c720c9 100644 --- a/libs/database-entity/Cargo.toml +++ b/libs/database-entity/Cargo.toml @@ -23,3 +23,4 @@ bincode = "1.3.3" appflowy-ai-client = { workspace = true, features = ["dto"] } bytes.workspace = true prost = "0.12" +infra.workspace = true diff --git a/libs/database-entity/src/dto.rs b/libs/database-entity/src/dto.rs index c8da00c6f..b614c565d 100644 --- a/libs/database-entity/src/dto.rs +++ b/libs/database-entity/src/dto.rs @@ -1,14 +1,14 @@ use crate::error::EntityError; use crate::error::EntityError::{DeserializationError, InvalidData}; -use crate::util::{validate_not_empty_payload, validate_not_empty_str}; -use appflowy_ai_client::dto::AIModel; + use bytes::Bytes; use chrono::{DateTime, Utc}; use collab_entity::proto; use collab_entity::CollabType; +use infra::validate::{validate_not_empty_payload, validate_not_empty_str}; use prost::Message; use serde::{Deserialize, Serialize}; -use serde_json::{json, Value}; + use serde_repr::{Deserialize_repr, Serialize_repr}; use std::cmp::Ordering; use std::collections::HashMap; @@ -718,301 +718,6 @@ impl From for AFWorkspaceInvitationStatus { } } -#[derive(Debug, Clone, Validate, Serialize, Deserialize)] -pub struct CreateChatParams { - #[validate(custom = "validate_not_empty_str")] - pub chat_id: String, - pub name: String, - pub rag_ids: Vec, -} - -#[derive(Debug, Clone, Validate, Serialize, Deserialize)] -pub struct UpdateChatParams { - #[validate(custom = "validate_not_empty_str")] - pub chat_id: String, - - #[validate(custom = "validate_not_empty_str")] - pub name: Option, - - pub metadata: Option, -} - -#[derive(Debug, Clone, Validate, Serialize, Deserialize)] -pub struct CreateChatMessageParams { - #[validate(custom = "validate_not_empty_str")] - pub content: String, - pub message_type: ChatMessageType, - - /// metadata is json array object - #[serde(skip_serializing_if = "Option::is_none")] - pub metadata: Option, -} - -/// [ChatMessageMetadata] is used when creating a new question message. -/// All the properties of [ChatMessageMetadata] except [ChatMetadataData] will be stored as a -/// metadata for specific [ChatMessage] -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ChatMessageMetadata { - pub data: ChatMetadataData, - /// The id for the metadata. It can be a file_id, view_id - pub id: String, - /// The name for the metadata. For example, @xxx, @xx.txt - pub name: String, - pub source: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub extract: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ChatMetadataData { - /// The textual content of the metadata. This field can contain raw text data from a specific - /// document or any other text content that is indexable. This content is typically used for - /// search and indexing purposes within the chat context. - pub content: String, - - /// The type of content represented by this metadata. This could indicate the format or - /// nature of the content (e.g., text, markdown, PDF). The `content_type` helps in - /// processing or rendering the content appropriately. - pub content_type: ChatMetadataContentType, - - /// The size of the content in bytes. - pub size: i64, -} - -impl ChatMetadataData { - pub fn from_text(text: String) -> Self { - let size = text.len() as i64; - Self { - content: text, - content_type: ChatMetadataContentType::Text, - size, - } - } -} - -impl ChatMetadataData { - /// Validates the `ChatMetadataData` instance. - /// - /// This method checks the validity of the data based on the content type and the presence of content or URL. - /// - If `content` is empty, the method checks if `url` is provided. If `url` is also empty, the data is invalid. - /// - For `Text` and `Markdown`, it ensures that the content length matches the specified size if content is present. - /// - For `Unknown` and `PDF`, it currently returns `false` as these types are either unsupported or - /// require additional validation logic. - /// - /// Returns `true` if the data is valid according to its content type and the presence of content or URL, otherwise `false`. - pub fn validate(&self) -> Result<(), anyhow::Error> { - match self.content_type { - ChatMetadataContentType::Text | ChatMetadataContentType::Markdown => { - if self.content.len() != self.size as usize { - return Err(anyhow::anyhow!( - "Invalid content size: content size: {}, expected size: {}", - self.content.len(), - self.size - )); - } - }, - ChatMetadataContentType::PDF => { - if self.content.is_empty() { - return Err(anyhow::anyhow!("Invalid content: content is empty")); - } - }, - ChatMetadataContentType::Unknown => { - return Err(anyhow::anyhow!( - "Unsupported content type: {:?}", - self.content_type - )); - }, - } - Ok(()) - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "lowercase")] -pub enum ChatMetadataContentType { - Unknown, - Text, - Markdown, - PDF, -} - -impl Display for ChatMetadataContentType { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - ChatMetadataContentType::Unknown => write!(f, "unknown"), - ChatMetadataContentType::Text => write!(f, "text"), - ChatMetadataContentType::Markdown => write!(f, "markdown"), - ChatMetadataContentType::PDF => write!(f, "pdf"), - } - } -} - -impl ChatMetadataData { - pub fn new_text(content: String) -> Self { - let size = content.len(); - Self { - content, - content_type: ChatMetadataContentType::Text, - size: size as i64, - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct UpdateChatMessageMetaParams { - pub message_id: i64, - pub meta_data: HashMap, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct UpdateChatMessageContentParams { - pub chat_id: String, - pub message_id: i64, - pub content: String, - #[serde(default)] - pub model: AIModel, -} - -#[derive(Debug, Clone, Default, Serialize_repr, Deserialize_repr)] -#[repr(u8)] -pub enum ChatMessageType { - System = 0, - #[default] - User = 1, -} - -impl CreateChatMessageParams { - pub fn new_system(content: T) -> Self { - Self { - content: content.to_string(), - message_type: ChatMessageType::System, - metadata: None, - } - } - - pub fn new_user(content: T) -> Self { - Self { - content: content.to_string(), - message_type: ChatMessageType::User, - metadata: None, - } - } - - pub fn with_metadata(mut self, metadata: T) -> Self { - if let Ok(metadata) = serde_json::to_value(&metadata) { - if !matches!(metadata, Value::Array(_)) { - self.metadata = Some(json!([metadata])); - } else { - self.metadata = Some(metadata); - } - } - self - } -} -#[derive(Debug, Clone, Validate, Serialize, Deserialize)] -pub struct GetChatMessageParams { - pub cursor: MessageCursor, - pub limit: u64, -} - -impl GetChatMessageParams { - pub fn offset(offset: u64, limit: u64) -> Self { - Self { - cursor: MessageCursor::Offset(offset), - limit, - } - } - - pub fn after_message_id(after_message_id: i64, limit: u64) -> Self { - Self { - cursor: MessageCursor::AfterMessageId(after_message_id), - limit, - } - } - pub fn before_message_id(before_message_id: i64, limit: u64) -> Self { - Self { - cursor: MessageCursor::BeforeMessageId(before_message_id), - limit, - } - } - - pub fn next_back(limit: u64) -> Self { - Self { - cursor: MessageCursor::NextBack, - limit, - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum MessageCursor { - Offset(u64), - AfterMessageId(i64), - BeforeMessageId(i64), - NextBack, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ChatMessage { - pub author: ChatAuthor, - pub message_id: i64, - pub content: String, - pub created_at: DateTime, - pub meta_data: serde_json::Value, - pub reply_message_id: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct QAChatMessage { - pub question: ChatMessage, - pub answer: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct RepeatedChatMessage { - pub messages: Vec, - pub has_more: bool, - pub total: i64, -} - -#[derive(Debug, Default, Clone, Serialize_repr, Deserialize_repr)] -#[repr(u8)] -pub enum ChatAuthorType { - Unknown = 0, - Human = 1, - #[default] - System = 2, - AI = 3, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ChatAuthor { - pub author_id: i64, - #[serde(default)] - pub author_type: ChatAuthorType, - #[serde(default)] - #[serde(skip_serializing_if = "Option::is_none")] - pub meta: Option, -} - -impl ChatAuthor { - pub fn new(author_id: i64, author_type: ChatAuthorType) -> Self { - Self { - author_id, - author_type, - meta: None, - } - } - - pub fn ai() -> Self { - Self { - author_id: 0, - author_type: ChatAuthorType::AI, - meta: None, - } - } -} - #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct AFCollabEmbeddingParams { pub fragment_id: String, @@ -1117,22 +822,6 @@ impl EmbeddingContentType { } } -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct UpdateChatMessageResponse { - pub answer: Option, -} - -#[derive(Debug, Clone, Validate, Serialize, Deserialize)] -pub struct CreateAnswerMessageParams { - #[validate(custom = "validate_not_empty_str")] - pub content: String, - - #[serde(skip_serializing_if = "Option::is_none")] - pub metadata: Option, - - pub question_message_id: i64, -} - #[derive(Serialize, Deserialize, Debug)] pub struct PublishCollabMetadata { pub view_id: uuid::Uuid, diff --git a/libs/database-entity/src/lib.rs b/libs/database-entity/src/lib.rs index 1c4fa4ccd..24dbf7cdf 100644 --- a/libs/database-entity/src/lib.rs +++ b/libs/database-entity/src/lib.rs @@ -1,4 +1,3 @@ pub mod dto; pub mod error; pub mod file_dto; -mod util; diff --git a/libs/database/Cargo.toml b/libs/database/Cargo.toml index 89b37ec9b..96fbafe34 100644 --- a/libs/database/Cargo.toml +++ b/libs/database/Cargo.toml @@ -11,6 +11,7 @@ collab-entity = { workspace = true } collab-rt-entity = { workspace = true } validator = { version = "0.16", features = ["validator_derive", "derive"] } database-entity.workspace = true +shared-entity.workspace = true app-error = { workspace = true, features = ["sqlx_error", "validation_error"] } tokio = { workspace = true, features = ["sync"] } diff --git a/libs/database/src/chat/chat_ops.rs b/libs/database/src/chat/chat_ops.rs index 59e9a1dfa..e2b563b8c 100644 --- a/libs/database/src/chat/chat_ops.rs +++ b/libs/database/src/chat/chat_ops.rs @@ -3,15 +3,15 @@ use crate::workspace::is_workspace_exist; use anyhow::anyhow; use app_error::AppError; use chrono::{DateTime, Utc}; -use database_entity::dto::{ - ChatAuthor, ChatMessage, CreateChatParams, GetChatMessageParams, MessageCursor, - RepeatedChatMessage, UpdateChatMessageContentParams, UpdateChatMessageMetaParams, +use shared_entity::dto::chat_dto::{ + ChatAuthor, ChatMessage, ChatMessageMetadata, CreateChatParams, GetChatMessageParams, + MessageCursor, RepeatedChatMessage, UpdateChatMessageContentParams, UpdateChatMessageMetaParams, UpdateChatParams, }; use serde_json::json; use sqlx::postgres::PgArguments; -use sqlx::types::JsonValue; + use sqlx::{Arguments, Executor, PgPool, Postgres, Transaction}; use std::ops::DerefMut; use std::str::FromStr; @@ -283,9 +283,9 @@ pub async fn insert_question_message<'a, E: Executor<'a, Database = Postgres>>( author: ChatAuthor, chat_id: &str, content: String, - metadata: Option, + metadata: Vec, ) -> Result { - let metadata = metadata.unwrap_or_else(|| json!({})); + let metadata = json!(metadata); let chat_id = Uuid::from_str(chat_id)?; let row = sqlx::query!( r#" diff --git a/libs/database/src/index/search_ops.rs b/libs/database/src/index/search_ops.rs index e7ddba84b..589ec50d0 100644 --- a/libs/database/src/index/search_ops.rs +++ b/libs/database/src/index/search_ops.rs @@ -5,6 +5,13 @@ use pgvector::Vector; use sqlx::Transaction; use uuid::Uuid; +/// Logs each search request to track usage by workspace. It either inserts a new record or updates +/// an existing one with the current date, workspace ID, request count, and token usage. This ensures +/// accurate usage tracking for billing or monitoring. +/// +/// Searches and retrieves documents based on their similarity to a given search embedding. +/// It filters by workspace, user access, and document status, and returns a limited number +/// of the most relevant documents, sorted by similarity score. pub async fn search_documents( tx: &mut Transaction<'_, sqlx::Postgres>, params: SearchDocumentParams, diff --git a/libs/infra/Cargo.toml b/libs/infra/Cargo.toml index bc2d152e9..8c3a8dab9 100644 --- a/libs/infra/Cargo.toml +++ b/libs/infra/Cargo.toml @@ -15,6 +15,10 @@ bytes = { workspace = true } tokio = { workspace = true, optional = true } pin-project.workspace = true futures = "0.3.30" +validator = { version = "0.16", features = [ + "validator_derive", + "derive", +] } [features] file_util = ["tokio/fs"] diff --git a/libs/infra/src/lib.rs b/libs/infra/src/lib.rs index 16119205a..19d24b423 100644 --- a/libs/infra/src/lib.rs +++ b/libs/infra/src/lib.rs @@ -4,3 +4,4 @@ pub mod env_util; pub mod file_util; #[cfg(feature = "request_util")] pub mod reqwest; +pub mod validate; diff --git a/libs/database-entity/src/util.rs b/libs/infra/src/validate.rs similarity index 59% rename from libs/database-entity/src/util.rs rename to libs/infra/src/validate.rs index 31604ce32..a80c248c0 100644 --- a/libs/database-entity/src/util.rs +++ b/libs/infra/src/validate.rs @@ -1,13 +1,13 @@ use validator::ValidationError; -pub(crate) fn validate_not_empty_str(s: &str) -> Result<(), ValidationError> { +pub fn validate_not_empty_str(s: &str) -> Result<(), ValidationError> { if s.is_empty() { return Err(ValidationError::new("should not be empty string")); } Ok(()) } -pub(crate) fn validate_not_empty_payload(payload: &[u8]) -> Result<(), ValidationError> { +pub fn validate_not_empty_payload(payload: &[u8]) -> Result<(), ValidationError> { if payload.is_empty() { return Err(ValidationError::new("should not be empty payload")); } diff --git a/libs/shared-entity/Cargo.toml b/libs/shared-entity/Cargo.toml index 7682077f1..cd5acd7d8 100644 --- a/libs/shared-entity/Cargo.toml +++ b/libs/shared-entity/Cargo.toml @@ -17,7 +17,7 @@ reqwest = { workspace = true, features = ["stream"] } uuid = { version = "1.6.1", features = ["v4"] } gotrue-entity = { path = "../gotrue-entity" } database-entity.workspace = true -infra.workspace = true +infra = { workspace = true, features = ["request_util"] } collab-entity = { workspace = true } app-error = { workspace = true } chrono = "0.4.31" @@ -32,7 +32,7 @@ actix-web = { version = "4.4.1", default-features = false, features = [ validator = { version = "0.16", features = [ "validator_derive", "derive", -], optional = true } +] } futures = "0.3.30" bytes = "1.6.0" log = "0.4.21" @@ -40,4 +40,4 @@ tracing = { workspace = true } [features] -cloud = ["actix-web", "validator"] +cloud = ["actix-web"] diff --git a/libs/shared-entity/src/dto/ai_dto.rs b/libs/shared-entity/src/dto/ai_dto.rs index dcb6253c0..f198894ed 100644 --- a/libs/shared-entity/src/dto/ai_dto.rs +++ b/libs/shared-entity/src/dto/ai_dto.rs @@ -1,8 +1,8 @@ use serde::{Deserialize, Serialize}; use serde_json::{Map, Value}; +use crate::dto::chat_dto::ChatMessage; pub use appflowy_ai_client::dto::*; -use database_entity::dto::ChatMessage; #[derive(Clone, Debug, Serialize, Deserialize)] pub struct SummarizeRowParams { diff --git a/libs/shared-entity/src/dto/chat_dto.rs b/libs/shared-entity/src/dto/chat_dto.rs new file mode 100644 index 000000000..aeb973d94 --- /dev/null +++ b/libs/shared-entity/src/dto/chat_dto.rs @@ -0,0 +1,363 @@ +use appflowy_ai_client::dto::AIModel; +use chrono::{DateTime, Utc}; +use infra::validate::validate_not_empty_str; +use serde::{Deserialize, Deserializer, Serialize}; + +use serde_repr::{Deserialize_repr, Serialize_repr}; +use std::collections::HashMap; +use std::fmt::Display; +use validator::Validate; + +#[derive(Debug, Clone, Validate, Serialize, Deserialize)] +pub struct CreateChatParams { + #[validate(custom = "validate_not_empty_str")] + pub chat_id: String, + pub name: String, + pub rag_ids: Vec, +} + +#[derive(Debug, Clone, Validate, Serialize, Deserialize)] +pub struct UpdateChatParams { + #[validate(custom = "validate_not_empty_str")] + pub chat_id: String, + + #[validate(custom = "validate_not_empty_str")] + pub name: Option, + + pub metadata: Option, +} + +#[derive(Debug, Clone, Validate, Serialize, Deserialize)] +pub struct CreateChatMessageParams { + #[validate(custom = "validate_not_empty_str")] + pub content: String, + pub message_type: ChatMessageType, + #[serde(deserialize_with = "deserialize_metadata")] + #[serde(default)] + #[serde(skip_serializing_if = "Vec::is_empty")] + pub metadata: Vec, +} + +#[derive(Debug, Clone, Validate, Serialize, Deserialize)] +pub struct CreateChatMessageParamsV2 { + #[validate(custom = "validate_not_empty_str")] + pub content: String, + pub message_type: ChatMessageType, + #[serde(deserialize_with = "deserialize_metadata")] + #[serde(default)] + #[serde(skip_serializing_if = "Vec::is_empty")] + pub metadata: Vec, +} + +fn deserialize_metadata<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + let raw_value = Option::::deserialize(deserializer)?; + match raw_value { + Some(serde_json::Value::Array(arr)) => { + serde_json::from_value(serde_json::Value::Array(arr)).map_err(serde::de::Error::custom) + }, + Some(_) => Err(serde::de::Error::custom( + "Expected metadata to be an array of ChatMessageMetadata.", + )), + None => Ok(vec![]), + } +} + +/// [ChatMessageMetadata] is used when creating a new question message. +/// All the properties of [ChatMessageMetadata] except [ChatRAGData] will be stored as a +/// metadata for specific [ChatMessage] +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatMessageMetadata { + pub data: ChatRAGData, + /// The id for the metadata. It can be a file_id, view_id + pub id: String, + /// The name for the metadata. For example, @xxx, @xx.txt + pub name: String, + pub source: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub extra: Option, +} + +impl ChatMessageMetadata { + pub fn split_data(self) -> (ChatRAGData, ChatMetadataDescription) { + ( + self.data, + ChatMetadataDescription { + id: self.id, + name: self.name, + source: self.source, + extra: self.extra, + }, + ) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatMetadataDescription { + pub id: String, + pub name: String, + pub source: String, + pub extra: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatRAGData { + /// The textual content of the metadata. This field can contain raw text data from a specific + /// document or any other text content that is indexable. This content is typically used for + /// search and indexing purposes within the chat context. + pub content: String, + + /// The type of content represented by this metadata. This could indicate the format or + /// nature of the content (e.g., text, markdown, PDF). The `content_type` helps in + /// processing or rendering the content appropriately. + pub content_type: ContextLoader, + + /// The size of the content in bytes. + pub size: i64, +} + +impl ChatRAGData { + pub fn from_text(text: String) -> Self { + let size = text.len() as i64; + Self { + content: text, + content_type: ContextLoader::Text, + size, + } + } +} + +impl ChatRAGData { + /// Validates the `ChatMetadataData` instance. + /// + /// This method checks the validity of the data based on the content type and the presence of content or URL. + /// - If `content` is empty, the method checks if `url` is provided. If `url` is also empty, the data is invalid. + /// - For `Text` and `Markdown`, it ensures that the content length matches the specified size if content is present. + /// - For `Unknown` and `PDF`, it currently returns `false` as these types are either unsupported or + /// require additional validation logic. + /// + /// Returns `true` if the data is valid according to its content type and the presence of content or URL, otherwise `false`. + pub fn validate(&self) -> Result<(), anyhow::Error> { + match self.content_type { + ContextLoader::Text | ContextLoader::Markdown => { + if self.content.len() != self.size as usize { + return Err(anyhow::anyhow!( + "Invalid content size: content size: {}, expected size: {}", + self.content.len(), + self.size + )); + } + }, + ContextLoader::PDF => { + if self.content.is_empty() { + return Err(anyhow::anyhow!("Invalid content: content is empty")); + } + }, + ContextLoader::Unknown => { + return Err(anyhow::anyhow!( + "Unsupported content type: {:?}", + self.content_type + )); + }, + } + Ok(()) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum ContextLoader { + Unknown, + Text, + Markdown, + PDF, +} + +impl Display for ContextLoader { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ContextLoader::Unknown => write!(f, "unknown"), + ContextLoader::Text => write!(f, "text"), + ContextLoader::Markdown => write!(f, "markdown"), + ContextLoader::PDF => write!(f, "pdf"), + } + } +} + +impl ChatRAGData { + pub fn new_text(content: String) -> Self { + let size = content.len(); + Self { + content, + content_type: ContextLoader::Text, + size: size as i64, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UpdateChatMessageMetaParams { + pub message_id: i64, + pub meta_data: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UpdateChatMessageContentParams { + pub chat_id: String, + pub message_id: i64, + pub content: String, + #[serde(default)] + pub model: AIModel, +} + +#[derive(Debug, Clone, Default, Serialize_repr, Deserialize_repr)] +#[repr(u8)] +pub enum ChatMessageType { + System = 0, + #[default] + User = 1, +} + +impl CreateChatMessageParams { + pub fn new_system(content: T) -> Self { + Self { + content: content.to_string(), + message_type: ChatMessageType::System, + metadata: vec![], + } + } + + pub fn new_user(content: T) -> Self { + Self { + content: content.to_string(), + message_type: ChatMessageType::User, + metadata: vec![], + } + } + + pub fn with_metadata(mut self, metadata: ChatMessageMetadata) -> Self { + self.metadata.push(metadata); + self + } +} +#[derive(Debug, Clone, Validate, Serialize, Deserialize)] +pub struct GetChatMessageParams { + pub cursor: MessageCursor, + pub limit: u64, +} + +impl GetChatMessageParams { + pub fn offset(offset: u64, limit: u64) -> Self { + Self { + cursor: MessageCursor::Offset(offset), + limit, + } + } + + pub fn after_message_id(after_message_id: i64, limit: u64) -> Self { + Self { + cursor: MessageCursor::AfterMessageId(after_message_id), + limit, + } + } + pub fn before_message_id(before_message_id: i64, limit: u64) -> Self { + Self { + cursor: MessageCursor::BeforeMessageId(before_message_id), + limit, + } + } + + pub fn next_back(limit: u64) -> Self { + Self { + cursor: MessageCursor::NextBack, + limit, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum MessageCursor { + Offset(u64), + AfterMessageId(i64), + BeforeMessageId(i64), + NextBack, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatMessage { + pub author: ChatAuthor, + pub message_id: i64, + pub content: String, + pub created_at: DateTime, + pub meta_data: serde_json::Value, + pub reply_message_id: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct QAChatMessage { + pub question: ChatMessage, + pub answer: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RepeatedChatMessage { + pub messages: Vec, + pub has_more: bool, + pub total: i64, +} + +#[derive(Debug, Default, Clone, Serialize_repr, Deserialize_repr)] +#[repr(u8)] +pub enum ChatAuthorType { + Unknown = 0, + Human = 1, + #[default] + System = 2, + AI = 3, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatAuthor { + pub author_id: i64, + #[serde(default)] + pub author_type: ChatAuthorType, + #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] + pub meta: Option, +} + +impl ChatAuthor { + pub fn new(author_id: i64, author_type: ChatAuthorType) -> Self { + Self { + author_id, + author_type, + meta: None, + } + } + + pub fn ai() -> Self { + Self { + author_id: 0, + author_type: ChatAuthorType::AI, + meta: None, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UpdateChatMessageResponse { + pub answer: Option, +} + +#[derive(Debug, Clone, Validate, Serialize, Deserialize)] +pub struct CreateAnswerMessageParams { + #[validate(custom = "validate_not_empty_str")] + pub content: String, + + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option, + + pub question_message_id: i64, +} diff --git a/libs/shared-entity/src/dto/mod.rs b/libs/shared-entity/src/dto/mod.rs index 0f21bd785..363b208ef 100644 --- a/libs/shared-entity/src/dto/mod.rs +++ b/libs/shared-entity/src/dto/mod.rs @@ -2,6 +2,7 @@ pub mod access_request_dto; pub mod ai_dto; pub mod auth_dto; pub mod billing_dto; +pub mod chat_dto; pub mod history_dto; pub mod import_dto; pub mod publish_dto; diff --git a/script/client_api_deps_check.sh b/script/client_api_deps_check.sh index af717b3c9..22b94b3d4 100755 --- a/script/client_api_deps_check.sh +++ b/script/client_api_deps_check.sh @@ -3,7 +3,7 @@ # Generate the current dependency list cargo tree > current_deps.txt -BASELINE_COUNT=634 +BASELINE_COUNT=637 CURRENT_COUNT=$(cat current_deps.txt | wc -l) echo "Expected dependency count (baseline): $BASELINE_COUNT" diff --git a/services/appflowy-worker/Cargo.toml b/services/appflowy-worker/Cargo.toml index f5d578a63..a83eaedce 100644 --- a/services/appflowy-worker/Cargo.toml +++ b/services/appflowy-worker/Cargo.toml @@ -32,7 +32,7 @@ thiserror = "1.0.58" tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] } serde_repr = "0.1.18" futures = "0.3.30" -infra.workspace = true +infra = { workspace = true, features = ["request_util"] } sqlx = { workspace = true, default-features = false, features = ["runtime-tokio-rustls", "macros", "postgres", "uuid", "chrono", "migrate"] } secrecy = { version = "0.8", features = ["serde"] } aws-sdk-s3 = { version = "1.36.0", features = [ diff --git a/src/api/chat.rs b/src/api/chat.rs index d69210957..62da371e1 100644 --- a/src/api/chat.rs +++ b/src/api/chat.rs @@ -1,6 +1,5 @@ use crate::biz::chat::ops::{ - create_chat, create_chat_message, create_chat_message_stream, delete_chat, - extract_chat_message_metadata, generate_chat_message_answer, get_chat_messages, + create_chat, create_chat_message, delete_chat, generate_chat_message_answer, get_chat_messages, update_chat_message, }; use crate::state::AppState; @@ -8,21 +7,21 @@ use actix_web::web::{Data, Json}; use actix_web::{web, HttpRequest, HttpResponse, Scope}; use app_error::AppError; -use appflowy_ai_client::dto::{ChatContextLoader, CreateTextChatContext, RepeatedRelatedQuestion}; +use appflowy_ai_client::dto::{CreateChatContext, RepeatedRelatedQuestion}; use authentication::jwt::UserUuid; use bytes::Bytes; -use database_entity::dto::{ - ChatAuthor, ChatMessage, CreateAnswerMessageParams, CreateChatMessageParams, CreateChatParams, - GetChatMessageParams, MessageCursor, RepeatedChatMessage, UpdateChatMessageContentParams, -}; use futures::Stream; use futures_util::stream; use futures_util::{FutureExt, TryStreamExt}; use pin_project::pin_project; +use shared_entity::dto::chat_dto::{ + ChatAuthor, ChatMessage, CreateAnswerMessageParams, CreateChatMessageParams, + CreateChatMessageParamsV2, CreateChatParams, GetChatMessageParams, MessageCursor, + RepeatedChatMessage, UpdateChatMessageContentParams, +}; use shared_entity::response::{AppResponse, JsonAppResponse}; use std::collections::HashMap; use std::pin::Pin; -use std::str::FromStr; use std::task::{Context, Poll}; use tokio::sync::oneshot; use tokio::task; @@ -32,55 +31,62 @@ use database::chat; use crate::api::util::ai_model_from_header; use database::chat::chat_ops::insert_answer_message; -use tracing::{instrument, trace, warn}; +use tracing::{error, instrument, trace}; use validator::Validate; - pub fn chat_scope() -> Scope { web::scope("/api/chat/{workspace_id}") - .service(web::resource("").route(web::post().to(create_chat_handler))) - .service( - web::resource("/{chat_id}") - .route(web::delete().to(delete_chat_handler)) - .route(web::get().to(get_chat_message_handler)), - ) - .service( - web::resource("/{chat_id}/{message_id}/related_question") - .route(web::get().to(get_related_message_handler)), - ) - .service( - web::resource("/{chat_id}/message") - // create_chat_message_handler is deprecated. No long used after frontend application v0.6.2 - .route(web::post().to(create_chat_message_handler)) - .route(web::put().to(update_chat_message_handler)), - ) - .service( - // Creating a [ChatMessage] for given content. - // When client asks a question, it will use this API to create a chat message - web::resource("/{chat_id}/message/question").route(web::post().to(create_question_handler)), - ) - // Writing the final answer for a given chat. - // After the streaming is finished, the client will use this API to save the message to disk. - .service(web::resource("/{chat_id}/message/answer").route(web::post().to(save_answer_handler))) - .service( - // Use AI to generate a response for a specified message ID. - // To generate an answer for a given question, use "/answer/stream" to receive the answer in a stream. - web::resource("/{chat_id}/{message_id}/answer").route(web::get().to(answer_handler)), - ) - // Deprecated! use "v2/answer/stream" - // Use AI to generate a response for a specified message ID. This response will be return as a stream. - .service( - web::resource("/{chat_id}/{message_id}/answer/stream") - .route(web::get().to(answer_stream_handler)), - ) + // Chat management + .service( + web::resource("") + .route(web::post().to(create_chat_handler)) + ) + .service( + web::resource("/{chat_id}") + .route(web::delete().to(delete_chat_handler)) + .route(web::get().to(get_chat_message_handler)) + ) + + // Message management + .service( + web::resource("/{chat_id}/message") + .route(web::put().to(update_question_handler)) + ) + .service( + web::resource("/{chat_id}/message/question") + .route(web::post().to(create_question_handler)) + ) + .service( + web::resource("/{chat_id}/v2/message/question") + .route(web::post().to(create_question_handler_v2)) + ) + .service( + web::resource("/{chat_id}/message/answer") + .route(web::post().to(save_answer_handler)) + ) + + // AI response generation + .service( + web::resource("/{chat_id}/{message_id}/answer") + .route(web::get().to(answer_handler)) + ) + .service( + web::resource("/{chat_id}/{message_id}/answer/stream") + .route(web::get().to(answer_stream_handler)) // Deprecated + ) .service( web::resource("/{chat_id}/{message_id}/v2/answer/stream") - .route(web::get().to(answer_stream_v2_handler)), + .route(web::get().to(answer_stream_v2_handler)) + ) + + // Additional functionality + .service( + web::resource("/{chat_id}/{message_id}/related_question") + .route(web::get().to(get_related_message_handler)) + ) + .service( + web::resource("/{chat_id}/context/text") + .route(web::post().to(create_chat_context_handler)) ) - .service( - // Create chat context for a given chat. - web::resource("/{chat_id}/context/text") - .route(web::post().to(create_chat_context_handler)) - ) } async fn create_chat_handler( path: web::Path, @@ -103,44 +109,10 @@ async fn delete_chat_handler( Ok(AppResponse::Ok().into()) } -#[instrument(level = "info", skip_all, err)] -async fn create_chat_message_handler( - state: Data, - path: web::Path<(String, String)>, - payload: Json, - uuid: UserUuid, - req: HttpRequest, -) -> actix_web::Result { - let (_workspace_id, chat_id) = path.into_inner(); - let params = payload.into_inner(); - - if let Err(err) = params.validate() { - return Ok(HttpResponse::from_error(AppError::from(err))); - } - - let ai_model = ai_model_from_header(&req); - let uid = state.user_cache.get_user_uid(&uuid).await?; - let message_stream = create_chat_message_stream( - &state.pg_pool, - uid, - chat_id, - params, - state.ai_client.clone(), - ai_model, - ) - .await; - - Ok( - HttpResponse::Ok() - .content_type("application/json") - .streaming(message_stream), - ) -} - #[instrument(level = "debug", skip_all, err)] async fn create_chat_context_handler( state: Data, - payload: Json, + payload: Json, ) -> actix_web::Result> { let params = payload.into_inner(); state @@ -151,7 +123,7 @@ async fn create_chat_context_handler( Ok(AppResponse::Ok().into()) } -async fn update_chat_message_handler( +async fn update_question_handler( state: Data, payload: Json, req: HttpRequest, @@ -185,27 +157,26 @@ async fn create_question_handler( uuid: UserUuid, ) -> actix_web::Result> { let (_workspace_id, chat_id) = path.into_inner(); - let mut params = payload.into_inner(); + let params = payload.into_inner(); // When create a question, we will extract the metadata from the question content. // metadata might include user mention file,page,or user. For example, @Get started. - for extract_context in extract_chat_message_metadata(&mut params) { - match ChatContextLoader::from_str(&extract_context.content_type) { - Ok(context_loader) => { - let context = - CreateTextChatContext::new(chat_id.clone(), context_loader, extract_context.content) - .with_metadata(extract_context.metadata); - trace!("create context for question: {}", context); - state - .ai_client - .create_chat_text_context(context) - .await - .map_err(AppError::from)?; - }, - Err(err) => { - warn!("Failed to parse chat context loader: {}", err); - }, + for metadata in params.metadata.clone() { + let (data, desc) = metadata.split_data(); + if let Err(err) = data.validate() { + error!("Failed to validate metadata: {}", err); + continue; } + + let context = + CreateChatContext::new(chat_id.clone(), data.content_type.to_string(), data.content) + .with_metadata(desc); + trace!("create context for question: {}", context); + state + .ai_client + .create_chat_text_context(context) + .await + .map_err(AppError::from)?; } let uid = state.user_cache.get_user_uid(&uuid).await?; @@ -213,6 +184,16 @@ async fn create_question_handler( Ok(AppResponse::Ok().with_data(resp).into()) } +#[instrument(level = "debug", skip_all, err)] +async fn create_question_handler_v2( + _state: Data, + _path: web::Path<(String, String)>, + _payload: Json, + _uuid: UserUuid, +) -> actix_web::Result> { + todo!() +} + async fn save_answer_handler( path: web::Path<(String, String)>, payload: Json, diff --git a/src/biz/chat/ops.rs b/src/biz/chat/ops.rs index f352ecb6e..484d5df7a 100644 --- a/src/biz/chat/ops.rs +++ b/src/biz/chat/ops.rs @@ -1,6 +1,5 @@ use actix_web::web::Bytes; use anyhow::anyhow; -use std::collections::HashMap; use app_error::AppError; use appflowy_ai_client::client::AppFlowyAIClient; @@ -11,15 +10,14 @@ use database::chat::chat_ops::{ insert_answer_message_with_transaction, insert_chat, insert_question_message, select_chat_messages, }; -use database_entity::dto::{ - ChatAuthor, ChatAuthorType, ChatMessage, ChatMessageType, ChatMetadataData, - CreateChatMessageParams, CreateChatParams, GetChatMessageParams, RepeatedChatMessage, - UpdateChatMessageContentParams, -}; use futures::stream::Stream; -use serde_json::Value; +use serde_json::json; +use shared_entity::dto::chat_dto::{ + ChatAuthor, ChatAuthorType, ChatMessage, ChatMessageType, CreateChatMessageParams, + CreateChatParams, GetChatMessageParams, RepeatedChatMessage, UpdateChatMessageContentParams, +}; use sqlx::PgPool; -use tracing::{error, info, trace}; +use tracing::{error, info}; use appflowy_ai_client::dto::AIModel; use validator::Validate; @@ -118,7 +116,6 @@ pub async fn create_chat_message( chat_id: String, params: CreateChatMessageParams, ) -> Result { - let params = params.clone(); let chat_id = chat_id.clone(); let pg_pool = pg_pool.clone(); @@ -126,74 +123,13 @@ pub async fn create_chat_message( &pg_pool, ChatAuthor::new(uid, ChatAuthorType::Human), &chat_id, - params.content.clone(), + params.content, params.metadata, ) .await?; Ok(question) } -/// Extracts the chat context from the metadata. Currently, we only support text as a context. In -/// the future, we will support other types of context. -pub(crate) struct ExtractChatMetadata { - pub(crate) content: String, - pub(crate) content_type: String, - pub(crate) metadata: HashMap, -} - -/// Removes the "content" field from the metadata if the "ty" field is equal to "text". -/// The metadata struct is shown below: -/// { -/// "data": { -/// "content": "hello world" -/// "size": 122, -/// "content_type": "text", -/// }, -/// "id": "id", -/// "name": "name" -/// } -/// -/// the root json is point to the struct [database_entity::dto::ChatMessageMetadata] -fn extract_message_metadata( - message_metadata: &mut serde_json::Value, -) -> Option { - trace!("Extracting metadata: {:?}", message_metadata); - - if let Value::Object(message_metadata) = message_metadata { - // remove the "data" field - if let Some(data) = message_metadata - .remove("data") - .and_then(|value| serde_json::from_value::(value.clone()).ok()) - { - if data.validate().is_ok() { - return Some(ExtractChatMetadata { - content: data.content, - content_type: data.content_type.to_string(), - metadata: message_metadata.clone().into_iter().collect(), - }); - } - } - } - - None -} - -pub(crate) fn extract_chat_message_metadata( - params: &mut CreateChatMessageParams, -) -> Vec { - let mut extract_metadatas = vec![]; - trace!("chat metadata: {:?}", params.metadata); - if let Some(Value::Array(ref mut list)) = params.metadata { - for metadata in list { - if let Some(extract_context) = extract_message_metadata(metadata) { - extract_metadatas.push(extract_context); - } - } - } - - extract_metadatas -} - pub async fn create_chat_message_stream( pg_pool: &PgPool, uid: i64, @@ -238,7 +174,7 @@ pub async fn create_chat_message_stream( match params.message_type { ChatMessageType::System => {} ChatMessageType::User => { - let answer = match ai_client.send_question(&chat_id, ¶ms.content, &ai_model, params.metadata).await { + let answer = match ai_client.send_question(&chat_id, ¶ms.content, &ai_model, Some(json!(params.metadata))).await { Ok(response) => response, Err(err) => { error!("Failed to send question to AI: {}", err); diff --git a/tests/ai_test/chat_test.rs b/tests/ai_test/chat_test.rs index d428ebd82..eaab06e2f 100644 --- a/tests/ai_test/chat_test.rs +++ b/tests/ai_test/chat_test.rs @@ -1,13 +1,13 @@ use crate::ai_test::util::read_text_from_asset; -use assert_json_diff::assert_json_eq; +use assert_json_diff::{assert_json_eq, assert_json_include}; use client_api::entity::{QuestionStream, QuestionStreamValue}; use client_api_test::{local_ai_test_enabled, TestClient}; -use database_entity::dto::{ - ChatMessageMetadata, ChatMetadataData, CreateChatMessageParams, CreateChatParams, MessageCursor, -}; use futures_util::StreamExt; use serde_json::json; +use shared_entity::dto::chat_dto::{ + ChatMessageMetadata, ChatRAGData, CreateChatMessageParams, CreateChatParams, MessageCursor, +}; #[tokio::test] async fn create_chat_and_create_messages_test() { @@ -120,28 +120,27 @@ async fn chat_qa_test() { let content = read_text_from_asset("my_profile.txt"); let metadata = ChatMessageMetadata { - data: ChatMetadataData::new_text(content), + data: ChatRAGData::new_text(content), id: "123".to_string(), name: "test context".to_string(), source: "user added".to_string(), - extract: Some(json!({"created_at": 123})), + extra: Some(json!({"created_at": 123})), }; - let params = - CreateChatMessageParams::new_user("Where lucas live?").with_metadata(json!(vec![metadata])); + let params = CreateChatMessageParams::new_user("Where lucas live?").with_metadata(metadata); let question = test_client .api_client .create_question(&workspace_id, &chat_id, params) .await .unwrap(); - assert_json_eq!( - question.meta_data, - json!([ + assert_json_include!( + actual: question.meta_data, + expected: json!([ { "id": "123", "name": "test context", "source": "user added", - "extract": { + "extra": { "created_at": 123 } } @@ -230,11 +229,11 @@ async fn create_chat_context_test() { let content = "Lacus have lived in the US for five years".to_string(); let metadata = ChatMessageMetadata { - data: ChatMetadataData::from_text(content), + data: ChatRAGData::from_text(content), id: chat_id.clone(), name: "".to_string(), source: "appflowy".to_string(), - extract: None, + extra: None, }; let params = CreateChatMessageParams::new_user("Where Lacus live?").with_metadata(metadata); diff --git a/tests/sql_test/chat_test.rs b/tests/sql_test/chat_test.rs index 32ec462ca..476809481 100644 --- a/tests/sql_test/chat_test.rs +++ b/tests/sql_test/chat_test.rs @@ -3,8 +3,10 @@ use database::chat::chat_ops::{ delete_chat, get_all_chat_messages, insert_chat, insert_question_message, select_chat, select_chat_messages, }; -use database_entity::dto::{ChatAuthor, ChatAuthorType, CreateChatParams, GetChatMessageParams}; use serde_json::json; +use shared_entity::dto::chat_dto::{ + ChatAuthor, ChatAuthorType, CreateChatParams, GetChatMessageParams, +}; use sqlx::PgPool; #[sqlx::test(migrations = false)] @@ -96,7 +98,7 @@ async fn chat_message_crud_test(pool: PgPool) { ChatAuthor::new(0, ChatAuthorType::System), &chat_id, format!("message {}", i), - None, + vec![], ) .await .unwrap();