From 787e6beb0d880f6e1723abee2889b0d1c34d2c00 Mon Sep 17 00:00:00 2001 From: Zack Fu Zi Xiang Date: Tue, 12 Nov 2024 23:48:41 +0800 Subject: [PATCH 01/20] feat: add name filtering --- libs/client-api/src/http_collab.rs | 4 +- libs/shared-entity/src/dto/workspace_dto.rs | 13 +++- src/api/workspace.rs | 3 + src/biz/collab/ops.rs | 52 ++++++++------ tests/workspace/workspace_crud.rs | 80 ++++++++++++++------- 5 files changed, 103 insertions(+), 49 deletions(-) diff --git a/libs/client-api/src/http_collab.rs b/libs/client-api/src/http_collab.rs index bafa05405..cf4a240ea 100644 --- a/libs/client-api/src/http_collab.rs +++ b/libs/client-api/src/http_collab.rs @@ -1,7 +1,7 @@ use crate::http::log_request_id; use crate::{blocking_brotli_compress, Client}; use app_error::AppError; -use client_api_entity::workspace_dto::AFDatabase; +use client_api_entity::workspace_dto::{AFDatabase, ListDatabaseParam}; use client_api_entity::{ BatchQueryCollabParams, BatchQueryCollabResult, CreateCollabParams, DeleteCollabParams, QueryCollab, UpdateCollabWebParams, @@ -145,11 +145,13 @@ impl Client { pub async fn list_databases( &self, workspace_id: &str, + name_filter: Option, ) -> Result, AppResponseError> { let url = format!("{}/api/workspace/{}/database", self.base_url, workspace_id); let resp = self .http_client_with_auth(Method::GET, &url) .await? + .query(&ListDatabaseParam { name_filter }) .send() .await?; log_request_id(&resp); diff --git a/libs/shared-entity/src/dto/workspace_dto.rs b/libs/shared-entity/src/dto/workspace_dto.rs index ec6b4ea28..4741a1ef9 100644 --- a/libs/shared-entity/src/dto/workspace_dto.rs +++ b/libs/shared-entity/src/dto/workspace_dto.rs @@ -263,6 +263,11 @@ pub struct QueryWorkspaceParam { pub include_member_count: Option, } +#[derive(Default, Debug, Deserialize, Serialize)] +pub struct ListDatabaseParam { + pub name_filter: Option, // logic: if database name contains +} + #[derive(Default, Debug, Deserialize, Serialize)] pub struct QueryWorkspaceFolder { pub depth: Option, @@ -284,7 +289,7 @@ pub struct PublishedView { #[derive(Default, Debug, Clone, Serialize, Deserialize)] pub struct AFDatabase { pub id: String, - pub name: String, + pub names: Vec, pub fields: Vec, } @@ -293,3 +298,9 @@ pub struct AFDatabaseField { pub name: String, pub field_type: String, } + +#[derive(Default, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct AFDatabaseMeta { + pub name: String, + pub icon: String, +} diff --git a/src/api/workspace.rs b/src/api/workspace.rs index 39114eed7..58f4837b8 100644 --- a/src/api/workspace.rs +++ b/src/api/workspace.rs @@ -1801,7 +1801,9 @@ async fn list_database_handler( user_uuid: UserUuid, workspace_id: web::Path, state: Data, + query: web::Query, ) -> Result>>> { + let name_filter = query.into_inner().name_filter; let uid = state.user_cache.get_user_uid(&user_uuid).await?; let workspace_id = workspace_id.into_inner(); let dbs = biz::collab::ops::list_database( @@ -1809,6 +1811,7 @@ async fn list_database_handler( &state.collab_access_control_storage, uid, workspace_id, + name_filter, ) .await?; Ok(Json(AppResponse::Ok().with_data(dbs))) diff --git a/src/biz/collab/ops.rs b/src/biz/collab/ops.rs index 02e1707e5..adc929bd7 100644 --- a/src/biz/collab/ops.rs +++ b/src/biz/collab/ops.rs @@ -363,6 +363,7 @@ pub async fn list_database( collab_storage: &CollabAccessControlStorage, uid: i64, workspace_uuid_str: String, + name_filter: Option, ) -> Result, AppError> { let workspace_uuid: Uuid = workspace_uuid_str.as_str().parse()?; let ws_db_oid = select_workspace_database_oid(pg_pool, &workspace_uuid).await?; @@ -416,28 +417,35 @@ pub async fn list_database( Arc::new(NoPersistenceDatabaseCollabService), None, ) { - Some(db_body) => match db_body.metas.get_inline_view_id(&txn) { - Some(iid) => match db_body.views.get_view(&txn, &iid) { - Some(iview) => { - let name = iview.name; - - let db_fields = db_body.fields.get_all_fields(&txn); - let mut af_fields: Vec = Vec::with_capacity(db_fields.len()); - for db_field in db_fields { - af_fields.push(AFDatabaseField { - name: db_field.name, - field_type: format!("{:?}", FieldType::from(db_field.field_type)), - }); - } - af_databases.push(AFDatabase { - id: db_body.get_database_id(&txn), - name, - fields: af_fields, - }); - }, - None => tracing::warn!("Failed to get inline view: {}", iid), - }, - None => tracing::error!("Failed to get inline view id for database: {}", oid), + Some(db_body) => { + let db_views = db_body.views.get_all_views_meta(&txn); + let names = db_views + .iter() + .map(|v| v.name.clone()) + .filter(|name| !name.is_empty()) + .collect::>(); + + // if there exists a name filter, + // there must be at least one view name that contains the filter + if let Some(name_filter) = &name_filter { + if !names.iter().any(|name| name.contains(name_filter)) { + continue; + } + } + + let db_fields = db_body.fields.get_all_fields(&txn); + let mut af_fields: Vec = Vec::with_capacity(db_fields.len()); + for db_field in db_fields { + af_fields.push(AFDatabaseField { + name: db_field.name, + field_type: format!("{:?}", FieldType::from(db_field.field_type)), + }); + } + af_databases.push(AFDatabase { + id: db_body.get_database_id(&txn), + names, + fields: af_fields, + }); }, None => tracing::error!("Failed to create db_body from db_collab, oid: {}", oid), }, diff --git a/tests/workspace/workspace_crud.rs b/tests/workspace/workspace_crud.rs index 76e01385e..6e1835dde 100644 --- a/tests/workspace/workspace_crud.rs +++ b/tests/workspace/workspace_crud.rs @@ -8,33 +8,63 @@ use shared_entity::dto::workspace_dto::PatchWorkspaceParam; #[tokio::test] async fn workspace_list_database() { let (c, _user) = generate_unique_registered_user_client().await; - let workspace_id = c.get_workspaces().await.unwrap()[0].workspace_id; - let dbs = c.list_databases(&workspace_id.to_string()).await.unwrap(); - assert_eq!(dbs.len(), 1); + let workspace_id = c.get_workspaces().await.unwrap()[0] + .workspace_id + .to_string(); - let db = &dbs[0]; + { + let dbs = c.list_databases(&workspace_id, None).await.unwrap(); + assert_eq!(dbs.len(), 1); + + let db = &dbs[0]; + + assert_eq!(db.names.len(), 2); + assert!(db.names.contains(&String::from("Untitled"))); + assert!(db.names.contains(&String::from("Grid"))); - assert_eq!(db.name, ""); - assert!(db.fields.contains(&AFDatabaseField { - name: "Last modified".to_string(), - field_type: "LastEditedTime".to_string(), - })); - assert!(db.fields.contains(&AFDatabaseField { - name: "Multiselect".to_string(), - field_type: "MultiSelect".to_string(), - })); - assert!(db.fields.contains(&AFDatabaseField { - name: "Tasks".to_string(), - field_type: "Checklist".to_string(), - })); - assert!(db.fields.contains(&AFDatabaseField { - name: "Status".to_string(), - field_type: "SingleSelect".to_string(), - })); - assert!(db.fields.contains(&AFDatabaseField { - name: "Description".to_string(), - field_type: "RichText".to_string(), - })); + assert!(db.fields.contains(&AFDatabaseField { + name: "Last modified".to_string(), + field_type: "LastEditedTime".to_string(), + })); + assert!(db.fields.contains(&AFDatabaseField { + name: "Multiselect".to_string(), + field_type: "MultiSelect".to_string(), + })); + assert!(db.fields.contains(&AFDatabaseField { + name: "Tasks".to_string(), + field_type: "Checklist".to_string(), + })); + assert!(db.fields.contains(&AFDatabaseField { + name: "Status".to_string(), + field_type: "SingleSelect".to_string(), + })); + assert!(db.fields.contains(&AFDatabaseField { + name: "Description".to_string(), + field_type: "RichText".to_string(), + })); + } + + { + let dbs = c + .list_databases(&workspace_id, Some(String::from("nomatch"))) + .await + .unwrap(); + assert_eq!(dbs.len(), 0); + } + { + let dbs = c + .list_databases(&workspace_id, Some(String::from("ntitle"))) + .await + .unwrap(); + assert_eq!(dbs.len(), 1); + } + { + let dbs = c + .list_databases(&workspace_id, Some(String::from("rid"))) + .await + .unwrap(); + assert_eq!(dbs.len(), 1); + } } #[tokio::test] From 97f9ff3dd8ee95b932fcc9b769de9d33add77ab3 Mon Sep 17 00:00:00 2001 From: "Nathan.fooo" <86001920+appflowy@users.noreply.github.com> Date: Fri, 15 Nov 2024 13:38:08 +0800 Subject: [PATCH 02/20] fix: byte index 8000 is not a char boundary (#995) * chore: fix split text boundary error and add related tests * chore: reduce clone --- Cargo.lock | 1 + services/appflowy-collaborate/Cargo.toml | 1 + .../src/indexer/document_indexer.rs | 375 ++++++++++++++++-- 3 files changed, 341 insertions(+), 36 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b9bd93825..2009d9658 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -752,6 +752,7 @@ dependencies = [ "tokio-util", "tracing", "tracing-subscriber", + "unicode-segmentation", "uuid", "validator", "workspace-template", diff --git a/services/appflowy-collaborate/Cargo.toml b/services/appflowy-collaborate/Cargo.toml index 2e3151a8f..070547f36 100644 --- a/services/appflowy-collaborate/Cargo.toml +++ b/services/appflowy-collaborate/Cargo.toml @@ -87,6 +87,7 @@ lazy_static = "1.4.0" itertools = "0.12.0" validator = "0.16.1" rayon.workspace = true +unicode-segmentation = "1.9.0" [dev-dependencies] rand = "0.8.5" diff --git a/services/appflowy-collaborate/src/indexer/document_indexer.rs b/services/appflowy-collaborate/src/indexer/document_indexer.rs index 7cd42abb3..6dd25c3a1 100644 --- a/services/appflowy-collaborate/src/indexer/document_indexer.rs +++ b/services/appflowy-collaborate/src/indexer/document_indexer.rs @@ -13,6 +13,7 @@ use collab_document::document::DocumentBody; use collab_document::error::DocumentError; use collab_entity::CollabType; use database_entity::dto::{AFCollabEmbeddingParams, AFCollabEmbeddings, EmbeddingContentType}; +use unicode_segmentation::UnicodeSegmentation; use uuid::Uuid; use crate::indexer::{DocumentDataExt, Indexer}; @@ -45,42 +46,12 @@ impl Indexer for DocumentIndexer { match result { Ok(document_data) => { let content = document_data.to_plain_text(); - let mut result = Vec::with_capacity(1 + content.len() / Self::DOC_CONTENT_SPLIT); - - let mut slice = content.as_str(); - while slice.len() > Self::DOC_CONTENT_SPLIT { - // we should split document into multiple fragments - let (left, right) = slice.split_at(Self::DOC_CONTENT_SPLIT); - let param = AFCollabEmbeddingParams { - fragment_id: Uuid::new_v4().to_string(), - object_id: object_id.clone(), - collab_type: CollabType::Document, - content_type: EmbeddingContentType::PlainText, - content: left.to_string(), - embedding: None, - }; - result.push(param); - slice = right; - } - - let content = if slice.len() == content.len() { - content // we didn't slice the content - } else { - slice.to_string() - }; - if !content.is_empty() { - let param = AFCollabEmbeddingParams { - fragment_id: object_id.clone(), - object_id: object_id.clone(), - collab_type: CollabType::Document, - content_type: EmbeddingContentType::PlainText, - content, - embedding: None, - }; - result.push(param); - } - - Ok(result) + create_embedding_params( + object_id, + content, + CollabType::Document, + Self::DOC_CONTENT_SPLIT, + ) }, Err(err) => { if matches!(err, DocumentError::NoRequiredData) { @@ -141,3 +112,335 @@ impl Indexer for DocumentIndexer { })) } } +#[inline] +fn create_embedding_params( + object_id: String, + content: String, + collab_type: CollabType, + max_content_len: usize, +) -> Result, AppError> { + if content.is_empty() { + return Ok(vec![]); + } + + // Helper function to create AFCollabEmbeddingParams + fn create_param( + fragment_id: String, + object_id: &str, + collab_type: &CollabType, + content: String, + ) -> AFCollabEmbeddingParams { + AFCollabEmbeddingParams { + fragment_id, + object_id: object_id.to_string(), + collab_type: collab_type.clone(), + content_type: EmbeddingContentType::PlainText, + content, + embedding: None, + } + } + + if content.len() <= max_content_len { + // Content is short enough; return as a single fragment + let param = create_param(object_id.clone(), &object_id, &collab_type, content); + return Ok(vec![param]); + } + + // Content is longer than max_content_len; need to split + let mut result = Vec::with_capacity(1 + content.len() / max_content_len); + let mut fragment = String::with_capacity(max_content_len); + let mut current_len = 0; + + for grapheme in content.graphemes(true) { + let grapheme_len = grapheme.len(); + if current_len + grapheme_len > max_content_len { + if !fragment.is_empty() { + // Move the fragment to avoid cloning + result.push(create_param( + Uuid::new_v4().to_string(), + &object_id, + &collab_type, + std::mem::take(&mut fragment), + )); + } + current_len = 0; + +3 // Check if the grapheme itself is longer than max_content_len + if grapheme_len > max_content_len { + // Push the grapheme as a fragment on its own + result.push(create_param( + Uuid::new_v4().to_string(), + &object_id, + &collab_type, + grapheme.to_string(), + )); + continue; + } + } + fragment.push_str(grapheme); + current_len += grapheme_len; + } + + // Add the last fragment if it's not empty + if !fragment.is_empty() { + result.push(create_param( + object_id.clone(), + &object_id, + &collab_type, + fragment, + )); + } + + Ok(result) +} +#[cfg(test)] +mod tests { + use crate::indexer::document_indexer::create_embedding_params; + use collab_entity::CollabType; + + #[test] + fn test_split_at_non_utf8() { + let object_id = "test_object".to_string(); + let collab_type = CollabType::Document; + let max_content_len = 10; // Small number for testing + + // Content with multibyte characters (emojis) + let content = "Hello πŸ˜ƒ World 🌍! This is a test πŸš€.".to_string(); + + let params = create_embedding_params( + object_id.clone(), + content.clone(), + collab_type.clone(), + max_content_len, + ) + .unwrap(); + + // Ensure that we didn't split in the middle of a multibyte character + for param in params { + assert!(param.content.is_char_boundary(0)); + assert!(param.content.is_char_boundary(param.content.len())); + } + } + + #[test] + fn test_exact_boundary_split() { + let object_id = "test_object".to_string(); + let collab_type = CollabType::Document; + let max_content_len = 5; // Set to 5 for testing + + // Content length is exactly a multiple of max_content_len + let content = "abcdefghij".to_string(); // 10 characters + + let params = create_embedding_params( + object_id.clone(), + content.clone(), + collab_type.clone(), + max_content_len, + ) + .unwrap(); + + assert_eq!(params.len(), 2); + assert_eq!(params[0].content, "abcde"); + assert_eq!(params[1].content, "fghij"); + } + + #[test] + fn test_content_shorter_than_max_len() { + let object_id = "test_object".to_string(); + let collab_type = CollabType::Document; + let max_content_len = 100; + + let content = "Short content".to_string(); + + let params = create_embedding_params( + object_id.clone(), + content.clone(), + collab_type.clone(), + max_content_len, + ) + .unwrap(); + + assert_eq!(params.len(), 1); + assert_eq!(params[0].content, content); + } + + #[test] + fn test_empty_content() { + let object_id = "test_object".to_string(); + let collab_type = CollabType::Document; + let max_content_len = 10; + + let content = "".to_string(); + + let params = create_embedding_params( + object_id.clone(), + content.clone(), + collab_type.clone(), + max_content_len, + ) + .unwrap(); + + assert_eq!(params.len(), 0); + } + + #[test] + fn test_content_with_only_multibyte_characters() { + let object_id = "test_object".to_string(); + let collab_type = CollabType::Document; + let max_content_len = 4; // Small number for testing + + // Each emoji is 4 bytes in UTF-8 + let content = "πŸ˜€πŸ˜ƒπŸ˜„πŸ˜πŸ˜†".to_string(); + + let params = create_embedding_params( + object_id.clone(), + content.clone(), + collab_type.clone(), + max_content_len, + ) + .unwrap(); + + assert_eq!(params.len(), 5); + let expected_contents = vec!["πŸ˜€", "πŸ˜ƒ", "πŸ˜„", "😁", "πŸ˜†"]; + for (param, expected) in params.iter().zip(expected_contents.iter()) { + assert_eq!(param.content, *expected); + } + } + + #[test] + fn test_split_with_combining_characters() { + let object_id = "test_object".to_string(); + let collab_type = CollabType::Document; + let max_content_len = 5; // Small number for testing + + // String with combining characters (e.g., letters with accents) + let content = "a\u{0301}e\u{0301}i\u{0301}o\u{0301}u\u{0301}".to_string(); // "áéíóú" + + let params = create_embedding_params( + object_id.clone(), + content.clone(), + collab_type.clone(), + max_content_len, + ) + .unwrap(); + + assert_eq!(params.len(), 5); + let expected_contents = vec!["á", "é", "í", "ó", "ú"]; + for (param, expected) in params.iter().zip(expected_contents.iter()) { + assert_eq!(param.content, *expected); + } + } + + #[test] + fn test_large_content() { + let object_id = "test_object".to_string(); + let collab_type = CollabType::Document; + let max_content_len = 1000; + + // Generate a large content string + let content = "a".repeat(5000); // 5000 characters + + let params = create_embedding_params( + object_id.clone(), + content.clone(), + collab_type.clone(), + max_content_len, + ) + .unwrap(); + + assert_eq!(params.len(), 5); // 5000 / 1000 = 5 + for param in params { + assert_eq!(param.content.len(), 1000); + } + } + #[test] + fn test_non_ascii_characters() { + let object_id = "test_object".to_string(); + let collab_type = CollabType::Document; + let max_content_len = 5; + + // Non-ASCII characters: "Ñéíóú" + let content = "Ñéíóú".to_string(); + + let params = create_embedding_params( + object_id.clone(), + content.clone(), + collab_type.clone(), + max_content_len, + ) + .unwrap(); + + // Content should be split into two fragments + assert_eq!(params.len(), 3); + assert_eq!(params[0].content, "Ñé"); + assert_eq!(params[1].content, "Γ­Γ³"); + assert_eq!(params[2].content, "ΓΊ"); + } + + #[test] + fn test_content_with_leading_and_trailing_whitespace() { + let object_id = "test_object".to_string(); + let collab_type = CollabType::Document; + let max_content_len = 5; + + let content = " abcde ".to_string(); + + let params = create_embedding_params( + object_id.clone(), + content.clone(), + collab_type.clone(), + max_content_len, + ) + .unwrap(); + + // Content should include leading and trailing whitespace + assert_eq!(params.len(), 2); + assert_eq!(params[0].content, " abc"); + assert_eq!(params[1].content, "de "); + } + + #[test] + fn test_content_with_multiple_zero_width_joiners() { + let object_id = "test_object".to_string(); + let collab_type = CollabType::Document; + let max_content_len = 10; + + // Complex emoji sequence with multiple zero-width joiners + let content = "πŸ‘©β€πŸ‘©β€πŸ‘§β€πŸ‘§πŸ‘¨β€πŸ‘¨β€πŸ‘¦β€πŸ‘¦".to_string(); + + let params = create_embedding_params( + object_id.clone(), + content.clone(), + collab_type.clone(), + max_content_len, + ) + .unwrap(); + + // Each complex emoji should be treated as a single grapheme + assert_eq!(params.len(), 2); + assert_eq!(params[0].content, "πŸ‘©β€πŸ‘©β€πŸ‘§β€πŸ‘§"); + assert_eq!(params[1].content, "πŸ‘¨β€πŸ‘¨β€πŸ‘¦β€πŸ‘¦"); + } + + #[test] + fn test_content_with_long_combining_sequences() { + let object_id = "test_object".to_string(); + let collab_type = CollabType::Document; + let max_content_len = 5; + + // Character with multiple combining marks + let content = "a\u{0300}\u{0301}\u{0302}\u{0303}\u{0304}".to_string(); // a with multiple accents + + let params = create_embedding_params( + object_id.clone(), + content.clone(), + collab_type.clone(), + max_content_len, + ) + .unwrap(); + + // The entire combining sequence should be in one fragment + assert_eq!(params.len(), 1); + assert_eq!(params[0].content, content); + } +} From b10a0e78294522d0030ce087d8267559d9cda6fb Mon Sep 17 00:00:00 2001 From: nathan Date: Fri, 15 Nov 2024 13:45:25 +0800 Subject: [PATCH 03/20] chore: fix typo --- services/appflowy-collaborate/src/indexer/document_indexer.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/services/appflowy-collaborate/src/indexer/document_indexer.rs b/services/appflowy-collaborate/src/indexer/document_indexer.rs index 6dd25c3a1..23799073b 100644 --- a/services/appflowy-collaborate/src/indexer/document_indexer.rs +++ b/services/appflowy-collaborate/src/indexer/document_indexer.rs @@ -165,7 +165,7 @@ fn create_embedding_params( } current_len = 0; -3 // Check if the grapheme itself is longer than max_content_len + // Check if the grapheme itself is longer than max_content_len if grapheme_len > max_content_len { // Push the grapheme as a fragment on its own result.push(create_param( From d0c212ff10ecf758764bf209fa91db0f5a769dbf Mon Sep 17 00:00:00 2001 From: Khor Shu Heng <32997938+khorshuheng@users.noreply.github.com> Date: Fri, 15 Nov 2024 14:08:12 +0800 Subject: [PATCH 04/20] refactor: remove enforcer group (#948) --- libs/access-control/src/casbin/access.rs | 44 ++-- libs/access-control/src/casbin/collab.rs | 20 +- libs/access-control/src/casbin/enforcer.rs | 220 +++++++----------- libs/access-control/src/casbin/mod.rs | 1 - .../access-control/src/casbin/notification.rs | 80 ------- libs/access-control/src/casbin/workspace.rs | 14 +- libs/access-control/src/entity.rs | 15 ++ libs/access-control/src/request.rs | 27 --- .../appflowy-collaborate/src/application.rs | 8 - .../appflowy-collaborate/src/collab/mod.rs | 1 - .../src/collab/notification.rs | 95 -------- src/biz/pg_listener.rs | 4 - 12 files changed, 129 insertions(+), 400 deletions(-) delete mode 100644 libs/access-control/src/casbin/notification.rs delete mode 100644 services/appflowy-collaborate/src/collab/notification.rs diff --git a/libs/access-control/src/casbin/access.rs b/libs/access-control/src/casbin/access.rs index 5c1b56276..4fc5feb92 100644 --- a/libs/access-control/src/casbin/access.rs +++ b/libs/access-control/src/casbin/access.rs @@ -1,7 +1,7 @@ use super::adapter::PgAdapter; -use super::enforcer::{AFEnforcer, NoEnforceGroup}; +use super::enforcer::AFEnforcer; use crate::act::{Action, ActionVariant, Acts}; -use crate::entity::ObjectType; +use crate::entity::{ObjectType, SubjectType}; use crate::metrics::{tick_metric, AccessControlMetrics}; use anyhow::anyhow; @@ -14,15 +14,8 @@ use database_entity::dto::{AFAccessLevel, AFRole}; use sqlx::PgPool; use std::sync::Arc; -use tokio::sync::broadcast; use tracing::trace; -#[derive(Debug, Clone)] -pub enum AccessControlChange { - UpdatePolicy { uid: i64, oid: String }, - RemovePolicy { uid: i64, oid: String }, -} - /// Manages access control. /// /// Stores access control policies in the form `subject, object, role` @@ -37,10 +30,9 @@ pub enum AccessControlChange { /// according to the model defined. #[derive(Clone)] pub struct AccessControl { - enforcer: Arc>, + enforcer: Arc, #[allow(dead_code)] access_control_metrics: Arc, - change_tx: broadcast::Sender, } impl AccessControl { @@ -55,41 +47,33 @@ impl AccessControl { })?; enforcer.add_function("cmpRoleOrLevel", OperatorFunction::Arg2(cmp_role_or_level)); - let enforcer = Arc::new(AFEnforcer::new(enforcer, NoEnforceGroup).await?); + let enforcer = Arc::new(AFEnforcer::new(enforcer).await?); tick_metric( enforcer.metrics_state.clone(), access_control_metrics.clone(), ); - let (change_tx, _) = broadcast::channel(1000); Ok(Self { enforcer, access_control_metrics, - change_tx, }) } - pub fn subscribe_change(&self) -> broadcast::Receiver { - self.change_tx.subscribe() - } - pub async fn update_policy( &self, - uid: &i64, + sub: SubjectType, obj: ObjectType<'_>, act: ActionVariant<'_>, ) -> Result<(), AppError> { - let access_control_change = self.enforcer.update_policy(uid, obj, act).await?; - if let Some(change) = access_control_change { - let _ = self.change_tx.send(change); - } + self.enforcer.update_policy(sub, obj, act).await?; Ok(()) } - pub async fn remove_policy(&self, uid: &i64, obj: &ObjectType<'_>) -> Result<(), AppError> { - let access_control_change = self.enforcer.remove_policy(uid, obj).await?; - if let Some(change) = access_control_change { - let _ = self.change_tx.send(change); - } + pub async fn remove_policy( + &self, + sub: &SubjectType, + obj: &ObjectType<'_>, + ) -> Result<(), AppError> { + self.enforcer.remove_policy(sub, obj).await?; Ok(()) } @@ -169,13 +153,13 @@ r = sub, obj, act p = sub, obj, act [role_definition] -g = _, _ # role and access level rule +g = _, _ # grouping rule [policy_effect] e = some(where (p.eft == allow)) [matchers] -m = r.sub == p.sub && p.obj == r.obj && (g(p.act, r.act) || cmpRoleOrLevel(r.act, p.act)) +m = g(r.sub, p.sub) && p.obj == r.obj && (g(p.act, r.act) || cmpRoleOrLevel(r.act, p.act)) "###; pub async fn casbin_model() -> Result { diff --git a/libs/access-control/src/casbin/collab.rs b/libs/access-control/src/casbin/collab.rs index 7646599c4..f64289a46 100644 --- a/libs/access-control/src/casbin/collab.rs +++ b/libs/access-control/src/casbin/collab.rs @@ -6,7 +6,7 @@ use tracing::instrument; use crate::{ act::{Action, ActionVariant}, collab::{CollabAccessControl, RealtimeAccessControl}, - entity::ObjectType, + entity::{ObjectType, SubjectType}, }; use super::access::AccessControl; @@ -70,7 +70,7 @@ impl CollabAccessControl for CollabAccessControlImpl { self .access_control .update_policy( - uid, + SubjectType::User(*uid), ObjectType::Collab(oid), ActionVariant::FromAccessLevel(&level), ) @@ -83,7 +83,7 @@ impl CollabAccessControl for CollabAccessControlImpl { async fn remove_access_level(&self, uid: &i64, oid: &str) -> Result<(), AppError> { self .access_control - .remove_policy(uid, &ObjectType::Collab(oid)) + .remove_policy(&SubjectType::User(*uid), &ObjectType::Collab(oid)) .await?; Ok(()) } @@ -96,20 +96,6 @@ pub struct RealtimeCollabAccessControlImpl { impl RealtimeCollabAccessControlImpl { pub fn new(access_control: AccessControl) -> Self { - // let action_by_oid = Arc::new(DashMap::new()); - // let mut sub = access_control.subscribe_change(); - // let weak_action_by_oid = Arc::downgrade(&action_by_oid); - // tokio::spawn(async move { - // while let Ok(change) = sub.recv().await { - // match weak_action_by_oid.upgrade() { - // None => break, - // Some(action_by_oid) => match change { - // AccessControlChange::UpdatePolicy { uid, oid } => {}, - // AccessControlChange::RemovePolicy { uid, oid } => {}, - // }, - // } - // } - // }); Self { access_control } } diff --git a/libs/access-control/src/casbin/enforcer.rs b/libs/access-control/src/casbin/enforcer.rs index b3f885ab0..f16c6b6b8 100644 --- a/libs/access-control/src/casbin/enforcer.rs +++ b/libs/access-control/src/casbin/enforcer.rs @@ -1,41 +1,26 @@ -use super::access::{ - load_group_policies, AccessControlChange, POLICY_FIELD_INDEX_OBJECT, POLICY_FIELD_INDEX_SUBJECT, -}; +use super::access::{load_group_policies, POLICY_FIELD_INDEX_OBJECT, POLICY_FIELD_INDEX_SUBJECT}; use crate::act::ActionVariant; -use crate::entity::ObjectType; +use crate::entity::{ObjectType, SubjectType}; use crate::metrics::MetricsCalState; -use crate::request::{GroupPolicyRequest, PolicyRequest, WorkspacePolicyRequest}; +use crate::request::{PolicyRequest, WorkspacePolicyRequest}; use anyhow::anyhow; use app_error::AppError; -use async_trait::async_trait; use casbin::{CoreApi, Enforcer, MgmtApi}; use std::sync::atomic::Ordering; use tokio::sync::RwLock; use tracing::{event, instrument, trace}; -#[async_trait] -pub trait EnforcerGroup { - /// Get the group id of the user. - /// User might belong to multiple groups. So return the highest permission group id. - async fn get_enforce_group_id(&self, uid: &i64) -> Option; -} - -pub struct AFEnforcer { +pub struct AFEnforcer { enforcer: RwLock, pub(crate) metrics_state: MetricsCalState, - enforce_group: T, } -impl AFEnforcer -where - T: EnforcerGroup, -{ - pub async fn new(mut enforcer: Enforcer, enforce_group: T) -> Result { +impl AFEnforcer { + pub async fn new(mut enforcer: Enforcer) -> Result { load_group_policies(&mut enforcer).await?; Ok(Self { enforcer: RwLock::new(enforcer), metrics_state: MetricsCalState::new(), - enforce_group, }) } @@ -47,18 +32,17 @@ where #[instrument(level = "debug", skip_all, err)] pub async fn update_policy( &self, - uid: &i64, + sub: SubjectType, obj: ObjectType<'_>, act: ActionVariant<'_>, - ) -> Result, AppError> { + ) -> Result<(), AppError> { validate_obj_action(&obj, &act)?; let policies = act .policy_acts() .into_iter() - .map(|act| vec![uid.to_string(), obj.policy_object(), act.to_string()]) + .map(|act| vec![sub.policy_subject(), obj.policy_object(), act.to_string()]) .collect::>>(); - let number_of_updated_policies = policies.len(); trace!("[access control]: add policy:{:?}", policies); self @@ -69,35 +53,40 @@ where .await .map_err(|e| AppError::Internal(anyhow!("fail to add policy: {e:?}")))?; - if number_of_updated_policies > 0 { - Ok(Some(AccessControlChange::UpdatePolicy { - uid: *uid, - oid: obj.object_id().to_string(), - })) - } else { - Ok(None) - } + Ok(()) } /// Returns policies that match the filter. pub async fn remove_policy( &self, - uid: &i64, + sub: &SubjectType, object_type: &ObjectType<'_>, - ) -> Result, AppError> { + ) -> Result<(), AppError> { let mut enforcer = self.enforcer.write().await; self - .remove_with_enforcer(uid, object_type, &mut enforcer) + .remove_with_enforcer(sub, object_type, &mut enforcer) .await } + /// Add a grouping policy. + #[allow(dead_code)] + pub async fn add_grouping_policy( + &self, + sub: &SubjectType, + group_sub: &SubjectType, + ) -> Result<(), AppError> { + let mut enforcer = self.enforcer.write().await; + enforcer + .add_grouping_policy(vec![sub.policy_subject(), group_sub.policy_subject()]) + .await + .map_err(|e| AppError::Internal(anyhow!("fail to add grouping policy: {e:?}")))?; + Ok(()) + } + /// 1. **Workspace Policy**: Initially, it checks if the user has permission at the workspace level. If the user /// has permission to perform the action on the workspace, the function returns `true` without further checks. /// - /// 2. **Group Policy**: (If applicable) If the workspace policy check fails (`false`), the function will then - /// evaluate group-level policies. - /// - /// 3. **Object-Specific Policy**: If both previous checks fail, the function finally evaluates the policy + /// 2. **Object-Specific Policy**: If workspace policy check fail, the function evaluates the policy /// specific to the object itself. /// /// ## Parameters: @@ -134,20 +123,7 @@ where .enforce(policy) .map_err(|e| AppError::Internal(anyhow!("enforce: {e:?}")))?; - // 2. Fallback to group policy if workspace-level check fails. - if !result { - if let Some(guid) = self.enforce_group.get_enforce_group_id(uid).await { - let policy_request = GroupPolicyRequest::new(&guid, &obj, &act); - result = self - .enforcer - .read() - .await - .enforce(policy_request.to_policy()) - .map_err(|e| AppError::Internal(anyhow!("enforce: {e:?}")))?; - } - } - - // 3. Finally, enforce object-specific policy if previous checks fail. + // 2. Finally, enforce object-specific policy if previous checks fail. if !result { let policy_request = PolicyRequest::new(*uid, &obj, &act); let policy = policy_request.to_policy(); @@ -172,22 +148,17 @@ where #[inline] async fn remove_with_enforcer( &self, - uid: &i64, + sub: &SubjectType, object_type: &ObjectType<'_>, enforcer: &mut Enforcer, - ) -> Result, AppError> { + ) -> Result<(), AppError> { let policies_for_user_on_object = - policies_for_subject_with_given_object(uid, object_type, enforcer).await; - - // if there are no policies for the user on the object, return early. - if policies_for_user_on_object.is_empty() { - return Ok(None); - } + policies_for_subject_with_given_object(sub, object_type, enforcer).await; event!( tracing::Level::INFO, - "[access control]: remove policy:user={}, object={}, policies={:?}", - uid, + "[access control]: remove policy:subject={}, object={}, policies={:?}", + sub.policy_subject(), object_type.policy_object(), policies_for_user_on_object ); @@ -197,10 +168,7 @@ where .await .map_err(|e| AppError::Internal(anyhow!("error enforce: {e:?}")))?; - Ok(Some(AccessControlChange::RemovePolicy { - uid: *uid, - oid: object_type.object_id().to_string(), - })) + Ok(()) } } @@ -216,87 +184,73 @@ fn validate_obj_action(obj: &ObjectType<'_>, act: &ActionVariant) -> Result<(), } } #[inline] -async fn policies_for_subject_with_given_object( - subject: T, +async fn policies_for_subject_with_given_object( + subject: &SubjectType, object_type: &ObjectType<'_>, enforcer: &Enforcer, ) -> Vec> { - let subject = subject.to_string(); + let subject_id = subject.policy_subject(); let object_type_id = object_type.policy_object(); let policies_related_to_object = enforcer.get_filtered_policy(POLICY_FIELD_INDEX_OBJECT, vec![object_type_id]); policies_related_to_object .into_iter() - .filter(|p| p[POLICY_FIELD_INDEX_SUBJECT] == subject) + .filter(|p| p[POLICY_FIELD_INDEX_SUBJECT] == subject_id) .collect::>() } -pub struct NoEnforceGroup; -#[async_trait] -impl EnforcerGroup for NoEnforceGroup { - async fn get_enforce_group_id(&self, _uid: &i64) -> Option { - None - } -} - #[cfg(test)] mod tests { use crate::{ act::{Action, ActionVariant}, - casbin::{ - access::{casbin_model, cmp_role_or_level}, - enforcer::NoEnforceGroup, - }, - entity::ObjectType, + casbin::access::{casbin_model, cmp_role_or_level}, + entity::{ObjectType, SubjectType}, }; use app_error::ErrorCode; - use async_trait::async_trait; use casbin::{function_map::OperatorFunction, prelude::*}; use database_entity::dto::{AFAccessLevel, AFRole}; - use super::{AFEnforcer, EnforcerGroup}; + use super::AFEnforcer; - pub struct TestEnforceGroup { - guid: String, - } - #[async_trait] - impl EnforcerGroup for TestEnforceGroup { - async fn get_enforce_group_id(&self, _uid: &i64) -> Option { - Some(self.guid.clone()) - } - } - - async fn test_enforcer(enforce_group: T) -> AFEnforcer - where - T: EnforcerGroup, - { + async fn test_enforcer() -> AFEnforcer { let model = casbin_model().await.unwrap(); let mut enforcer = casbin::Enforcer::new(model, MemoryAdapter::default()) .await .unwrap(); enforcer.add_function("cmpRoleOrLevel", OperatorFunction::Arg2(cmp_role_or_level)); - AFEnforcer::new(enforcer, enforce_group).await.unwrap() + AFEnforcer::new(enforcer).await.unwrap() } + #[tokio::test] async fn collab_group_test() { - let enforcer = test_enforcer(NoEnforceGroup).await; + let enforcer = test_enforcer().await; let uid = 1; + let group_id = "collab_owner_group:w1"; let workspace_id = "w1"; let object_1 = "o1"; - // add user as a member of the collab + // allow workspace member to access collab enforcer .update_policy( - &uid, + SubjectType::Group(group_id.to_string()), ObjectType::Collab(object_1), ActionVariant::FromAccessLevel(&AFAccessLevel::FullAccess), ) .await .unwrap(); + // include user in the collab owner group + enforcer + .add_grouping_policy( + &SubjectType::User(uid), + &SubjectType::Group(group_id.to_string()), + ) + .await + .unwrap(); + // when the user is the owner of the collab, then the user should have access to the collab for action in [Action::Write, Action::Read] { let result = enforcer @@ -313,14 +267,14 @@ mod tests { #[tokio::test] async fn workspace_group_policy_test() { - let enforcer = test_enforcer(NoEnforceGroup).await; + let enforcer = test_enforcer().await; let uid = 1; let workspace_id = "w1"; // add user as a member of the workspace enforcer .update_policy( - &uid, + SubjectType::User(uid), ObjectType::Workspace(workspace_id), ActionVariant::FromRole(&AFRole::Member), ) @@ -343,7 +297,7 @@ mod tests { #[tokio::test] async fn workspace_owner_and_try_to_full_access_collab_test() { - let enforcer = test_enforcer(NoEnforceGroup).await; + let enforcer = test_enforcer().await; let uid = 1; let workspace_id = "w1"; @@ -352,7 +306,7 @@ mod tests { // add user as a member of the workspace enforcer .update_policy( - &uid, + SubjectType::User(uid), ObjectType::Workspace(workspace_id), ActionVariant::FromRole(&AFRole::Owner), ) @@ -374,7 +328,7 @@ mod tests { #[tokio::test] async fn workspace_member_collab_owner_try_to_full_access_collab_test() { - let enforcer = test_enforcer(NoEnforceGroup).await; + let enforcer = test_enforcer().await; let uid = 1; let workspace_id = "w1"; @@ -383,7 +337,7 @@ mod tests { // add user as a member of the workspace enforcer .update_policy( - &uid, + SubjectType::User(uid), ObjectType::Workspace(workspace_id), ActionVariant::FromRole(&AFRole::Member), ) @@ -392,7 +346,7 @@ mod tests { enforcer .update_policy( - &uid, + SubjectType::User(uid), ObjectType::Collab(object_1), ActionVariant::FromAccessLevel(&AFAccessLevel::FullAccess), ) @@ -414,7 +368,7 @@ mod tests { #[tokio::test] async fn workspace_owner_collab_member_try_to_full_access_collab_test() { - let enforcer = test_enforcer(NoEnforceGroup).await; + let enforcer = test_enforcer().await; let uid = 1; let workspace_id = "w1"; @@ -423,7 +377,7 @@ mod tests { // add user as a member of the workspace enforcer .update_policy( - &uid, + SubjectType::User(uid), ObjectType::Workspace(workspace_id), ActionVariant::FromRole(&AFRole::Owner), ) @@ -432,7 +386,7 @@ mod tests { enforcer .update_policy( - &uid, + SubjectType::User(uid), ObjectType::Collab(object_1), ActionVariant::FromAccessLevel(&AFAccessLevel::ReadAndWrite), ) @@ -454,7 +408,7 @@ mod tests { #[tokio::test] async fn workspace_member_collab_member_try_to_full_access_collab_test() { - let enforcer = test_enforcer(NoEnforceGroup).await; + let enforcer = test_enforcer().await; let uid = 1; let workspace_id = "w1"; @@ -463,7 +417,7 @@ mod tests { // add user as a member of the workspace enforcer .update_policy( - &uid, + SubjectType::User(uid), ObjectType::Workspace(workspace_id), ActionVariant::FromRole(&AFRole::Member), ) @@ -472,7 +426,7 @@ mod tests { enforcer .update_policy( - &uid, + SubjectType::User(uid), ObjectType::Collab(object_1), ActionVariant::FromAccessLevel(&AFAccessLevel::ReadAndWrite), ) @@ -506,7 +460,7 @@ mod tests { #[tokio::test] async fn workspace_member_but_not_collab_member_and_try_full_access_collab_test() { - let enforcer = test_enforcer(NoEnforceGroup).await; + let enforcer = test_enforcer().await; let uid = 1; let workspace_id = "w1"; @@ -515,7 +469,7 @@ mod tests { // add user as a member of the workspace enforcer .update_policy( - &uid, + SubjectType::User(uid), ObjectType::Workspace(workspace_id), ActionVariant::FromRole(&AFRole::Member), ) @@ -552,14 +506,14 @@ mod tests { #[tokio::test] async fn not_workspace_member_but_collab_owner_try_full_access_collab_test() { - let enforcer = test_enforcer(NoEnforceGroup).await; + let enforcer = test_enforcer().await; let uid = 1; let workspace_id = "w1"; let object_1 = "o1"; enforcer .update_policy( - &uid, + SubjectType::User(uid), ObjectType::Collab(object_1), ActionVariant::FromAccessLevel(&AFAccessLevel::FullAccess), ) @@ -581,7 +535,7 @@ mod tests { #[tokio::test] async fn not_workspace_member_not_collab_member_and_try_full_access_collab_test() { - let enforcer = test_enforcer(NoEnforceGroup).await; + let enforcer = test_enforcer().await; let uid = 1; let workspace_id = "w1"; let object_1 = "o1"; @@ -609,7 +563,7 @@ mod tests { #[tokio::test] async fn cmp_owner_role_test() { - let enforcer = test_enforcer(NoEnforceGroup).await; + let enforcer = test_enforcer().await; let uid = 1; let workspace_id = "w1"; let object_1 = "o1"; @@ -617,7 +571,7 @@ mod tests { // add user as a member of the workspace enforcer .update_policy( - &uid, + SubjectType::User(uid), ObjectType::Workspace(workspace_id), ActionVariant::FromRole(&AFRole::Owner), ) @@ -648,7 +602,7 @@ mod tests { #[tokio::test] async fn cmp_member_role_test() { - let enforcer = test_enforcer(NoEnforceGroup).await; + let enforcer = test_enforcer().await; let uid = 1; let workspace_id = "w1"; let object_1 = "o1"; @@ -656,7 +610,7 @@ mod tests { // add user as a member of the workspace enforcer .update_policy( - &uid, + SubjectType::User(uid), ObjectType::Workspace(workspace_id), ActionVariant::FromRole(&AFRole::Member), ) @@ -713,7 +667,7 @@ mod tests { #[tokio::test] async fn cmp_guest_role_test() { - let enforcer = test_enforcer(NoEnforceGroup).await; + let enforcer = test_enforcer().await; let uid = 1; let workspace_id = "w1"; let object_1 = "o1"; @@ -721,7 +675,7 @@ mod tests { // add user as a member of the workspace enforcer .update_policy( - &uid, + SubjectType::User(uid), ObjectType::Workspace(workspace_id), ActionVariant::FromRole(&AFRole::Guest), ) @@ -757,14 +711,14 @@ mod tests { #[tokio::test] async fn cmp_full_access_level_test() { - let enforcer = test_enforcer(NoEnforceGroup).await; + let enforcer = test_enforcer().await; let uid = 1; let workspace_id = "w1"; let object_1 = "o1"; enforcer .update_policy( - &uid, + SubjectType::User(uid), ObjectType::Collab(object_1), ActionVariant::FromAccessLevel(&AFAccessLevel::FullAccess), ) @@ -790,14 +744,14 @@ mod tests { #[tokio::test] async fn cmp_read_only_level_test() { - let enforcer = test_enforcer(NoEnforceGroup).await; + let enforcer = test_enforcer().await; let uid = 1; let workspace_id = "w1"; let object_1 = "o1"; enforcer .update_policy( - &uid, + SubjectType::User(uid), ObjectType::Collab(object_1), ActionVariant::FromAccessLevel(&AFAccessLevel::ReadOnly), ) diff --git a/libs/access-control/src/casbin/mod.rs b/libs/access-control/src/casbin/mod.rs index 22e3bc5e1..af149de31 100644 --- a/libs/access-control/src/casbin/mod.rs +++ b/libs/access-control/src/casbin/mod.rs @@ -2,5 +2,4 @@ pub mod access; mod adapter; pub mod collab; mod enforcer; -pub mod notification; pub mod workspace; diff --git a/libs/access-control/src/casbin/notification.rs b/libs/access-control/src/casbin/notification.rs deleted file mode 100644 index 1f87fa83e..000000000 --- a/libs/access-control/src/casbin/notification.rs +++ /dev/null @@ -1,80 +0,0 @@ -use super::access::AccessControl; -use crate::act::ActionVariant; -use crate::entity::ObjectType; -use database_entity::dto::AFRole; -use serde::Deserialize; -use tokio::sync::broadcast; -use tracing::error; -use tracing::log::warn; -use uuid::Uuid; - -pub fn spawn_listen_on_workspace_member_change( - mut listener: broadcast::Receiver, - access_control: AccessControl, -) { - tokio::spawn(async move { - while let Ok(change) = listener.recv().await { - match change.action_type { - WorkspaceMemberAction::INSERT | WorkspaceMemberAction::UPDATE => match change.new { - None => { - warn!("The workspace member change can't be None when the action is INSERT or UPDATE") - }, - Some(member_row) => { - if let Err(err) = access_control - .update_policy( - &member_row.uid, - ObjectType::Workspace(&member_row.workspace_id.to_string()), - ActionVariant::FromRole(&AFRole::from(member_row.role_id as i32)), - ) - .await - { - error!( - "Failed to update the user:{} workspace:{} access control, error: {}", - member_row.uid, member_row.workspace_id, err - ); - } - }, - }, - WorkspaceMemberAction::DELETE => match change.old { - None => warn!("The workspace member change can't be None when the action is DELETE"), - Some(member_row) => { - if let Err(err) = access_control - .remove_policy( - &member_row.uid, - &ObjectType::Workspace(&member_row.workspace_id.to_string()), - ) - .await - { - error!( - "Failed to remove the user:{} workspace: {} access control, error: {}", - member_row.uid, member_row.workspace_id, err - ); - } - }, - }, - } - } - }); -} - -#[allow(clippy::upper_case_acronyms)] -#[derive(Deserialize, Clone, Debug)] -pub enum WorkspaceMemberAction { - INSERT, - UPDATE, - DELETE, -} - -#[derive(Deserialize, Debug, Clone)] -pub struct WorkspaceMemberNotification { - pub old: Option, - pub new: Option, - pub action_type: WorkspaceMemberAction, -} - -#[derive(Deserialize, Debug, Clone)] -pub struct WorkspaceMemberRow { - pub uid: i64, - pub role_id: i64, - pub workspace_id: Uuid, -} diff --git a/libs/access-control/src/casbin/workspace.rs b/libs/access-control/src/casbin/workspace.rs index 882105fe7..d15557910 100644 --- a/libs/access-control/src/casbin/workspace.rs +++ b/libs/access-control/src/casbin/workspace.rs @@ -4,7 +4,7 @@ use uuid::Uuid; use super::access::AccessControl; use crate::act::{Action, ActionVariant}; -use crate::entity::ObjectType; +use crate::entity::{ObjectType, SubjectType}; use crate::workspace::WorkspaceAccessControl; use app_error::AppError; use database_entity::dto::AFRole; @@ -66,7 +66,7 @@ impl WorkspaceAccessControl for WorkspaceAccessControlImpl { self .access_control .update_policy( - uid, + SubjectType::User(*uid), ObjectType::Workspace(&workspace_id.to_string()), ActionVariant::FromRole(&role), ) @@ -82,12 +82,18 @@ impl WorkspaceAccessControl for WorkspaceAccessControlImpl { ) -> Result<(), AppError> { self .access_control - .remove_policy(uid, &ObjectType::Workspace(&workspace_id.to_string())) + .remove_policy( + &SubjectType::User(*uid), + &ObjectType::Workspace(&workspace_id.to_string()), + ) .await?; self .access_control - .remove_policy(uid, &ObjectType::Collab(&workspace_id.to_string())) + .remove_policy( + &SubjectType::User(*uid), + &ObjectType::Collab(&workspace_id.to_string()), + ) .await?; Ok(()) } diff --git a/libs/access-control/src/entity.rs b/libs/access-control/src/entity.rs index 1bbfa3822..f020a4714 100644 --- a/libs/access-control/src/entity.rs +++ b/libs/access-control/src/entity.rs @@ -1,3 +1,18 @@ +#[derive(Debug)] +pub enum SubjectType { + User(i64), + Group(String), +} + +impl SubjectType { + pub fn policy_subject(&self) -> String { + match self { + SubjectType::User(i) => i.to_string(), + SubjectType::Group(s) => s.clone(), + } + } +} + /// Represents the object type that is stored in the access control policy. #[derive(Debug)] pub enum ObjectType<'id> { diff --git a/libs/access-control/src/request.rs b/libs/access-control/src/request.rs index c4d0f522c..7ec329f08 100644 --- a/libs/access-control/src/request.rs +++ b/libs/access-control/src/request.rs @@ -1,33 +1,6 @@ use crate::act::ActionVariant; use crate::entity::ObjectType; -pub struct GroupPolicyRequest<'a> { - pub guid: &'a str, - pub object_type: &'a ObjectType<'a>, - pub action: &'a ActionVariant<'a>, -} - -impl GroupPolicyRequest<'_> { - pub fn new<'a>( - guid: &'a str, - object_type: &'a ObjectType<'a>, - action: &'a ActionVariant<'a>, - ) -> GroupPolicyRequest<'a> { - GroupPolicyRequest { - guid, - object_type, - action, - } - } - pub fn to_policy(&self) -> Vec { - vec![ - self.guid.to_string(), - self.object_type.policy_object(), - self.action.to_enforce_act().to_string(), - ] - } -} - pub struct WorkspacePolicyRequest<'a> { workspace_id: &'a str, uid: &'a i64, diff --git a/services/appflowy-collaborate/src/application.rs b/services/appflowy-collaborate/src/application.rs index 83126f03f..b0763df37 100644 --- a/services/appflowy-collaborate/src/application.rs +++ b/services/appflowy-collaborate/src/application.rs @@ -112,14 +112,6 @@ pub async fn init_state(config: &Config, rt_cmd_tx: CLCommandSender) -> Result, - access_control: AccessControl, -) { - tokio::spawn(async move { - while let Ok(change) = listener.recv().await { - match change.action_type { - CollabMemberAction::INSERT | CollabMemberAction::UPDATE => { - if let Some(member_row) = change.new { - let permission_row = select_permission(&pg_pool, &member_row.permission_id).await; - if let Ok(Some(row)) = permission_row { - if let Err(err) = access_control - .update_policy( - &member_row.uid, - ObjectType::Collab(&member_row.oid), - ActionVariant::FromAccessLevel(&row.access_level), - ) - .await - { - error!( - "Failed to update the user:{} collab{} access control, error: {}", - member_row.uid, member_row.oid, err - ); - } - } - } else { - error!("The new collab member is None") - } - }, - CollabMemberAction::DELETE => { - if let (Some(oid), Some(uid)) = (change.old_oid(), change.old_uid()) { - if let Err(err) = access_control - .remove_policy(uid, &ObjectType::Collab(oid)) - .await - { - warn!( - "Failed to remove the user:{} collab{} access control, error: {}", - uid, oid, err - ); - } - } else { - warn!("The oid or uid is None") - } - }, - } - } - }); -} - -#[allow(clippy::upper_case_acronyms)] -#[derive(Deserialize, Clone, Debug)] -pub enum CollabMemberAction { - INSERT, - UPDATE, - DELETE, -} - -#[derive(Deserialize, Debug, Clone)] -pub struct CollabMemberNotification { - /// The old will be None if the row does not exist before - pub old: Option, - /// The new will be None if the row is deleted - pub new: Option, - /// Represent the action of the database. Such as INSERT, UPDATE, DELETE - pub action_type: CollabMemberAction, -} - -impl CollabMemberNotification { - pub fn old_uid(&self) -> Option<&i64> { - self.old.as_ref().map(|o| &o.uid) - } - - pub fn old_oid(&self) -> Option<&str> { - self.old.as_ref().map(|o| o.oid.as_str()) - } - pub fn new_uid(&self) -> Option<&i64> { - self.new.as_ref().map(|n| &n.uid) - } - pub fn new_oid(&self) -> Option<&str> { - self.new.as_ref().map(|n| n.oid.as_str()) - } -} diff --git a/src/biz/pg_listener.rs b/src/biz/pg_listener.rs index e04ba53d0..cb26a8235 100644 --- a/src/biz/pg_listener.rs +++ b/src/biz/pg_listener.rs @@ -1,6 +1,4 @@ -use access_control::casbin::notification::WorkspaceMemberNotification; use anyhow::Error; -use appflowy_collaborate::collab::notification::CollabMemberNotification; use database::listener::PostgresDBListener; use database::pg_row::AFUserNotification; use sqlx::PgPool; @@ -31,6 +29,4 @@ impl PgListeners { } } -pub type CollabMemberListener = PostgresDBListener; pub type UserListener = PostgresDBListener; -pub type WorkspaceMemberListener = PostgresDBListener; From 187cadaa01c47f87ac61c05b40e12e15cd6a93e9 Mon Sep 17 00:00:00 2001 From: Zack Fu Zi Xiang Date: Fri, 15 Nov 2024 15:45:36 +0800 Subject: [PATCH 05/20] feat: allow underscore in publish url --- src/biz/workspace/publish.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/biz/workspace/publish.rs b/src/biz/workspace/publish.rs index e2c2d26fe..61d78df78 100644 --- a/src/biz/workspace/publish.rs +++ b/src/biz/workspace/publish.rs @@ -74,7 +74,7 @@ fn check_collab_publish_name(publish_name: &str) -> Result<(), AppError> { // Only contain alphanumeric characters and hyphens for c in publish_name.chars() { - if !c.is_alphanumeric() && c != '-' { + if !c.is_alphanumeric() && c != '-' && c != '_' { return Err(AppError::PublishNameInvalidCharacter { character: c }); } } @@ -246,8 +246,9 @@ pub async fn list_collab_publish_info( async fn check_workspace_namespace(new_namespace: &str) -> Result<(), AppError> { // Must be url safe // Only contain alphanumeric characters and hyphens + // and underscores (discouraged) for c in new_namespace.chars() { - if !c.is_alphanumeric() && c != '-' { + if !c.is_alphanumeric() && c != '-' && c != '_' { return Err(AppError::CustomNamespaceInvalidCharacter { character: c }); } } From 4703b90751f14fc222eb13fdb0f353c715870f77 Mon Sep 17 00:00:00 2001 From: Zack Fu Zi Xiang Date: Fri, 15 Nov 2024 16:37:02 +0800 Subject: [PATCH 06/20] fix: modify test case for underscore --- tests/workspace/publish.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/workspace/publish.rs b/tests/workspace/publish.rs index 5296a50b4..72f774d00 100644 --- a/tests/workspace/publish.rs +++ b/tests/workspace/publish.rs @@ -39,7 +39,7 @@ async fn test_set_publish_namespace_set() { .unwrap(); } - let new_namespace = uuid::Uuid::new_v4().to_string(); + let new_namespace = format!("namespace_{}", uuid::Uuid::new_v4()); c.set_workspace_publish_namespace(&workspace_id.to_string(), new_namespace.clone()) .await .unwrap(); @@ -156,9 +156,9 @@ async fn test_publish_doc() { assert_eq!(err.code, ErrorCode::PublishNameTooLong, "{:?}", err); } - let publish_name_1 = "publish-name-1"; + let publish_name_1 = "publish_name-1"; let view_id_1 = uuid::Uuid::new_v4(); - let publish_name_2 = "publish-name-2"; + let publish_name_2 = "publish_name-2"; let view_id_2 = uuid::Uuid::new_v4(); // User publishes two collabs From 655f13bc27ba355eb14cebb05fc845822fcf4acb Mon Sep 17 00:00:00 2001 From: "Nathan.fooo" <86001920+appflowy@users.noreply.github.com> Date: Sat, 16 Nov 2024 14:52:12 +0800 Subject: [PATCH 07/20] chore: search and chat (#999) * chore: add test for search and chat * chore: update test * chore: update test * chore: update ci * chore: fix security audio * chore: multiple core docker build * chore: multiple core docker build * chore: update ci * chore: update model setting * chore: test ci * chore: use tiktoken to calcualte token length * chore: remove env * chore: use spawn_blocking with condition * chore: docs * chore: clippy * chore: clippy * chore: docker logs * chore: pass message id * chore: clippy --- .github/workflows/integration_test.yml | 23 +- .github/workflows/push_latest_docker.yml | 1 + ...82274d9d75766bd9a5c383b96bd60e9c5c866.json | 22 + Cargo.lock | 47 +- Cargo.toml | 6 + Dockerfile | 7 +- deny.toml | 2 +- dev.env | 1 + docker-compose-ci.yml | 4 +- libs/app-error/src/lib.rs | 3 +- libs/appflowy-ai-client/src/client.rs | 34 +- libs/appflowy-ai-client/src/dto.rs | 69 ++- .../tests/chat_test/context_test.rs | 2 +- .../tests/chat_test/embedding_test.rs | 6 +- .../tests/chat_test/qa_test.rs | 13 +- libs/client-api-test/src/test_client.rs | 43 +- libs/client-api/src/http_chat.rs | 25 +- libs/database/src/chat/chat_ops.rs | 20 +- services/appflowy-collaborate/Cargo.toml | 3 +- .../src/group/persistence.rs | 2 +- .../src/indexer/document_indexer.rs | 525 +++++++++--------- .../appflowy-collaborate/src/indexer/mod.rs | 1 - .../src/indexer/provider.rs | 7 +- src/api/ai.rs | 21 +- src/api/chat.rs | 21 +- src/biz/chat/ops.rs | 21 +- src/biz/search/ops.rs | 8 +- tests/collab/collab_curd_test.rs | 18 - tests/search/asset/appflowy_values.md | 54 ++ tests/search/asset/kathryn_tennis_story.md | 54 ++ .../asset/the_five_dysfunctions_of_a_team.md | 125 +++++ tests/search/document_search.rs | 146 ++++- 32 files changed, 969 insertions(+), 365 deletions(-) create mode 100644 .sqlx/query-dbc31936b3e79632f9c8bae449182274d9d75766bd9a5c383b96bd60e9c5c866.json create mode 100644 tests/search/asset/appflowy_values.md create mode 100644 tests/search/asset/kathryn_tennis_story.md create mode 100644 tests/search/asset/the_five_dysfunctions_of_a_team.md diff --git a/.github/workflows/integration_test.yml b/.github/workflows/integration_test.yml index 87cba860c..5b9e028fc 100644 --- a/.github/workflows/integration_test.yml +++ b/.github/workflows/integration_test.yml @@ -105,6 +105,12 @@ jobs: # the wasm-pack headless tests will run on random ports, so we need to allow all origins run: sed -i 's/http:\/\/127\.0\.0\.1:8000/http:\/\/127.0.0.1/g' nginx/nginx.conf + - name: Log in to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKER_HUB_USERNAME }} + password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }} + - name: Run Docker-Compose run: | export APPFLOWY_WORKER_VERSION=${GITHUB_SHA} @@ -113,29 +119,20 @@ jobs: docker compose -f docker-compose-ci.yml up -d docker ps -a - container_id=$(docker ps --filter name=appflowy-cloud-ai-1 -q) - if [ -n "$container_id" ]; then - echo "Displaying logs for the AppFlowy-AI container..." - docker logs "$container_id" - else - echo "No running container found to display logs." - fi - - name: Install prerequisites run: | sudo apt-get update - sudo apt-get install protobuf-compiler + sudo apt-get install -y protobuf-compiler - name: Run Tests run: | echo "Running tests for ${{ matrix.test_service }} with flags: ${{ matrix.test_cmd }}" RUST_LOG="info" DISABLE_CI_TEST_LOG="true" cargo test ${{ matrix.test_cmd }} - - name: Run Tests from main branch + - name: Docker Logs + if: always() run: | - git fetch origin main - git checkout main - RUST_LOG="info" DISABLE_CI_TEST_LOG="true" cargo test ${{ matrix.test_cmd }} + docker logs appflowy-cloud-ai-1 cleanup: name: Cleanup Docker Images diff --git a/.github/workflows/push_latest_docker.yml b/.github/workflows/push_latest_docker.yml index 90a9d7e00..361867969 100644 --- a/.github/workflows/push_latest_docker.yml +++ b/.github/workflows/push_latest_docker.yml @@ -95,6 +95,7 @@ jobs: labels: ${{ steps.meta.outputs.labels }} provenance: false build-args: | + PROFILE=release FEATURES= - name: Logout from Docker Hub diff --git a/.sqlx/query-dbc31936b3e79632f9c8bae449182274d9d75766bd9a5c383b96bd60e9c5c866.json b/.sqlx/query-dbc31936b3e79632f9c8bae449182274d9d75766bd9a5c383b96bd60e9c5c866.json new file mode 100644 index 000000000..bd35ee2c0 --- /dev/null +++ b/.sqlx/query-dbc31936b3e79632f9c8bae449182274d9d75766bd9a5c383b96bd60e9c5c866.json @@ -0,0 +1,22 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT rag_ids\n FROM af_chat\n WHERE chat_id = $1 AND deleted_at IS NULL\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "rag_ids", + "type_info": "Jsonb" + } + ], + "parameters": { + "Left": [ + "Uuid" + ] + }, + "nullable": [ + false + ] + }, + "hash": "dbc31936b3e79632f9c8bae449182274d9d75766bd9a5c383b96bd60e9c5c866" +} diff --git a/Cargo.lock b/Cargo.lock index 2009d9658..f5602dfc5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -747,12 +747,13 @@ dependencies = [ "shared-entity", "sqlx", "thiserror", + "tiktoken-rs", "tokio", "tokio-stream", "tokio-util", "tracing", "tracing-subscriber", - "unicode-segmentation", + "unicode-normalization", "uuid", "validator", "workspace-template", @@ -3341,9 +3342,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" dependencies = [ "futures-core", "futures-sink", @@ -3351,9 +3352,9 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" [[package]] name = "futures-executor" @@ -3379,9 +3380,9 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" [[package]] name = "futures-lite" @@ -3398,9 +3399,9 @@ dependencies = [ [[package]] name = "futures-macro" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", @@ -3409,15 +3410,15 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" [[package]] name = "futures-task" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" [[package]] name = "futures-timer" @@ -3427,9 +3428,9 @@ checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" [[package]] name = "futures-util" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ "futures-channel", "futures-core", @@ -7170,6 +7171,22 @@ dependencies = [ "weezl", ] +[[package]] +name = "tiktoken-rs" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44075987ee2486402f0808505dd65692163d243a337fc54363d49afac41087f6" +dependencies = [ + "anyhow", + "base64 0.21.7", + "bstr", + "fancy-regex 0.13.0", + "lazy_static", + "parking_lot 0.12.3", + "regex", + "rustc-hash", +] + [[package]] name = "time" version = "0.3.36" diff --git a/Cargo.toml b/Cargo.toml index e14041757..1a1d72ed7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -301,6 +301,11 @@ codegen-units = 1 inherits = "release" debug = true +[profile.ci] +inherits = "release" +opt-level = 2 +lto = false # Disable Link-Time Optimization + [patch.crates-io] # It's diffcult to resovle different version with the same crate used in AppFlowy Frontend and the Client-API crate. # So using patch to workaround this issue. @@ -314,4 +319,5 @@ collab-importer = { git = "https://github.com/AppFlowy-IO/AppFlowy-Collab", rev [features] history = [] +# Some AI test features are not available for self-hosted AppFlowy Cloud. Therefore, AI testing is disabled by default. ai-test-enabled = ["client-api-test/ai-test-enabled"] diff --git a/Dockerfile b/Dockerfile index de8b72726..e31bc4ee1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -16,16 +16,19 @@ RUN apt update && apt install -y protobuf-compiler lld clang # Specify a default value for FEATURES; it could be an empty string if no features are enabled by default ARG FEATURES="" +ARG PROFILE="release" COPY --from=planner /app/recipe.json recipe.json # Build our project dependencies +ENV CARGO_BUILD_JOBS=4 RUN cargo chef cook --release --recipe-path recipe.json + COPY . . ENV SQLX_OFFLINE true # Build the project -RUN echo "Building with features: ${FEATURES}" -RUN cargo build --profile=release --features "${FEATURES}" --bin appflowy_cloud +RUN echo "Building with profile: ${PROFILE}, features: ${FEATURES}, " +RUN cargo build --profile=${PROFILE} --features "${FEATURES}" --bin appflowy_cloud FROM debian:bookworm-slim AS runtime WORKDIR /app diff --git a/deny.toml b/deny.toml index 852fe786b..9aaada8c7 100644 --- a/deny.toml +++ b/deny.toml @@ -1,2 +1,2 @@ [advisories] -ignore = ["RUSTSEC-2024-0370"] +ignore = ["RUSTSEC-2024-0370", "RUSTSEC-2024-0384"] diff --git a/dev.env b/dev.env index 8832a3721..fbe59882a 100644 --- a/dev.env +++ b/dev.env @@ -4,6 +4,7 @@ APPFLOWY_DATABASE_URL=postgres://postgres:password@localhost:5432/postgres APPFLOWY_ACCESS_CONTROL=true APPFLOWY_WEBSOCKET_MAILBOX_SIZE=6000 APPFLOWY_DATABASE_MAX_CONNECTIONS=40 +APPFLOWY_DOCUMENT_CONTENT_SPLIT_LEN=8000 # This file is used to set the environment variables for local development # Copy this file to .env and change the values as needed diff --git a/docker-compose-ci.yml b/docker-compose-ci.yml index 3dfcb7ade..36356a98e 100644 --- a/docker-compose-ci.yml +++ b/docker-compose-ci.yml @@ -120,6 +120,7 @@ services: dockerfile: Dockerfile args: FEATURES: "" + PROFILE: ci image: appflowyinc/appflowy_cloud:${APPFLOWY_CLOUD_VERSION:-latest} admin_frontend: @@ -138,7 +139,7 @@ services: ai: restart: on-failure - image: appflowyinc/appflowy_ai:${APPFLOWY_AI_VERSION:-latest} + image: appflowyinc/appflowy_ai_premium:${APPFLOWY_AI_VERSION:-latest} ports: - "5001:5001" environment: @@ -147,6 +148,7 @@ services: - LOCAL_AI_AWS_SECRET_ACCESS_KEY=${LOCAL_AI_AWS_SECRET_ACCESS_KEY} - APPFLOWY_AI_SERVER_PORT=${APPFLOWY_AI_SERVER_PORT} - APPFLOWY_AI_DATABASE_URL=${APPFLOWY_AI_DATABASE_URL} + - APPFLOWY_AI_REDIS_URL=${APPFLOWY_REDIS_URI} appflowy_worker: restart: on-failure diff --git a/libs/app-error/src/lib.rs b/libs/app-error/src/lib.rs index 572f82d8f..7aa803585 100644 --- a/libs/app-error/src/lib.rs +++ b/libs/app-error/src/lib.rs @@ -3,7 +3,6 @@ pub mod gotrue; #[cfg(feature = "gotrue_error")] use crate::gotrue::GoTrueError; -use std::error::Error as StdError; use std::string::FromUtf8Error; #[cfg(feature = "appflowy_ai_error")] @@ -92,7 +91,7 @@ pub enum AppError { #[error("{desc}: {err}")] SqlxArgEncodingError { desc: String, - err: Box, + err: Box, }, #[cfg(feature = "validation_error")] diff --git a/libs/appflowy-ai-client/src/client.rs b/libs/appflowy-ai-client/src/client.rs index 7c57425de..e90c008e1 100644 --- a/libs/appflowy-ai-client/src/client.rs +++ b/libs/appflowy-ai-client/src/client.rs @@ -1,8 +1,9 @@ use crate::dto::{ - AIModel, ChatAnswer, ChatQuestion, CompleteTextResponse, CompletionType, CreateChatContext, - CustomPrompt, Document, EmbeddingRequest, EmbeddingResponse, LocalAIConfig, MessageData, - RepeatedLocalAIPackage, RepeatedRelatedQuestion, SearchDocumentsRequest, SummarizeRowResponse, - TranslateRowData, TranslateRowResponse, + AIModel, CalculateSimilarityParams, ChatAnswer, ChatQuestion, CompleteTextResponse, + CompletionType, CreateChatContext, CustomPrompt, Document, EmbeddingRequest, EmbeddingResponse, + LocalAIConfig, MessageData, RepeatedLocalAIPackage, RepeatedRelatedQuestion, + SearchDocumentsRequest, SimilarityResponse, SummarizeRowResponse, TranslateRowData, + TranslateRowResponse, }; use crate::error::AIError; @@ -202,6 +203,7 @@ impl AppFlowyAIClient { pub async fn send_question( &self, chat_id: &str, + question_id: i64, content: &str, model: &AIModel, metadata: Option, @@ -211,6 +213,8 @@ impl AppFlowyAIClient { data: MessageData { content: content.to_string(), metadata, + rag_ids: vec![], + message_id: Some(question_id.to_string()), }, }; let url = format!("{}/chat/message", self.url); @@ -230,6 +234,7 @@ impl AppFlowyAIClient { chat_id: &str, content: &str, metadata: Option, + rag_ids: Vec, model: &AIModel, ) -> Result>, AIError> { let json = ChatQuestion { @@ -237,6 +242,8 @@ impl AppFlowyAIClient { data: MessageData { content: content.to_string(), metadata, + rag_ids, + message_id: None, }, }; let url = format!("{}/chat/message/stream", self.url); @@ -253,8 +260,10 @@ impl AppFlowyAIClient { pub async fn stream_question_v2( &self, chat_id: &str, + question_id: i64, content: &str, metadata: Option, + rag_ids: Vec, model: &AIModel, ) -> Result>, AIError> { let json = ChatQuestion { @@ -262,6 +271,8 @@ impl AppFlowyAIClient { data: MessageData { content: content.to_string(), metadata, + rag_ids, + message_id: Some(question_id.to_string()), }, }; let url = format!("{}/v2/chat/message/stream", self.url); @@ -323,6 +334,21 @@ impl AppFlowyAIClient { .into_data() } + pub async fn calculate_similarity( + &self, + params: CalculateSimilarityParams, + ) -> Result { + let url = format!("{}/similarity", self.url); + let resp = self + .http_client(Method::POST, &url)? + .json(¶ms) + .send() + .await?; + AIResponse::::from_response(resp) + .await? + .into_data() + } + fn http_client(&self, method: Method, url: &str) -> Result { let request_builder = self.client.request(method, url); Ok(request_builder) diff --git a/libs/appflowy-ai-client/src/dto.rs b/libs/appflowy-ai-client/src/dto.rs index 19b1dbbcd..eed948546 100644 --- a/libs/appflowy-ai-client/src/dto.rs +++ b/libs/appflowy-ai-client/src/dto.rs @@ -23,6 +23,10 @@ pub struct MessageData { pub content: String, #[serde(skip_serializing_if = "Option::is_none")] pub metadata: Option, + #[serde(default)] + pub rag_ids: Vec, + #[serde(default)] + pub message_id: Option, } #[derive(Clone, Debug, Serialize, Deserialize)] @@ -182,7 +186,7 @@ pub struct EmbeddingRequest { } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub enum EmbeddingsModel { +pub enum EmbeddingModel { #[serde(rename = "text-embedding-3-small")] TextEmbedding3Small, #[serde(rename = "text-embedding-3-large")] @@ -191,12 +195,55 @@ pub enum EmbeddingsModel { TextEmbeddingAda002, } -impl Display for EmbeddingsModel { +impl EmbeddingModel { + pub fn supported_models() -> &'static [&'static str] { + &[ + "text-embedding-ada-002", + "text-embedding-3-small", + "text-embedding-3-large", + ] + } + + pub fn max_token(&self) -> usize { + match self { + EmbeddingModel::TextEmbeddingAda002 => 8191, + EmbeddingModel::TextEmbedding3Large => 8191, + EmbeddingModel::TextEmbedding3Small => 8191, + } + } + + pub fn default_dimensions(&self) -> i32 { + match self { + EmbeddingModel::TextEmbeddingAda002 => 1536, + EmbeddingModel::TextEmbedding3Large => 3072, + EmbeddingModel::TextEmbedding3Small => 1536, + } + } + + pub fn name(&self) -> &'static str { + match self { + EmbeddingModel::TextEmbeddingAda002 => "text-embedding-ada-002", + EmbeddingModel::TextEmbedding3Large => "text-embedding-3-large", + EmbeddingModel::TextEmbedding3Small => "text-embedding-3-small", + } + } + + pub fn from_name(name: &str) -> Option { + match name { + "text-embedding-ada-002" => Some(EmbeddingModel::TextEmbeddingAda002), + "text-embedding-3-large" => Some(EmbeddingModel::TextEmbedding3Large), + "text-embedding-3-small" => Some(EmbeddingModel::TextEmbedding3Small), + _ => None, + } + } +} + +impl Display for EmbeddingModel { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { - EmbeddingsModel::TextEmbedding3Small => write!(f, "text-embedding-3-small"), - EmbeddingsModel::TextEmbedding3Large => write!(f, "text-embedding-3-large"), - EmbeddingsModel::TextEmbeddingAda002 => write!(f, "text-embedding-ada-002"), + EmbeddingModel::TextEmbedding3Small => write!(f, "text-embedding-3-small"), + EmbeddingModel::TextEmbedding3Large => write!(f, "text-embedding-3-large"), + EmbeddingModel::TextEmbeddingAda002 => write!(f, "text-embedding-ada-002"), } } } @@ -320,3 +367,15 @@ pub struct CustomPrompt { pub system: String, pub user: Option, } + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct CalculateSimilarityParams { + pub workspace_id: String, + pub input: String, + pub expected: String, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct SimilarityResponse { + pub score: f64, +} diff --git a/libs/appflowy-ai-client/tests/chat_test/context_test.rs b/libs/appflowy-ai-client/tests/chat_test/context_test.rs index d9a265725..79cfceb0f 100644 --- a/libs/appflowy-ai-client/tests/chat_test/context_test.rs +++ b/libs/appflowy-ai-client/tests/chat_test/context_test.rs @@ -14,7 +14,7 @@ async fn create_chat_context_test() { }; client.create_chat_text_context(context).await.unwrap(); let resp = client - .send_question(&chat_id, "Where I live?", &AIModel::GPT4oMini, None) + .send_question(&chat_id, 1, "Where I live?", &AIModel::GPT4oMini, None) .await .unwrap(); // response will be something like: diff --git a/libs/appflowy-ai-client/tests/chat_test/embedding_test.rs b/libs/appflowy-ai-client/tests/chat_test/embedding_test.rs index 1536a58b6..20f9aaaf7 100644 --- a/libs/appflowy-ai-client/tests/chat_test/embedding_test.rs +++ b/libs/appflowy-ai-client/tests/chat_test/embedding_test.rs @@ -1,7 +1,7 @@ use crate::appflowy_ai_client; use appflowy_ai_client::dto::{ - EmbeddingEncodingFormat, EmbeddingInput, EmbeddingRequest, EmbeddingsModel, + EmbeddingEncodingFormat, EmbeddingInput, EmbeddingModel, EmbeddingRequest, }; #[tokio::test] @@ -9,10 +9,10 @@ async fn embedding_test() { let client = appflowy_ai_client(); let request = EmbeddingRequest { input: EmbeddingInput::String("hello world".to_string()), - model: EmbeddingsModel::TextEmbedding3Small.to_string(), + model: EmbeddingModel::TextEmbedding3Small.to_string(), chunk_size: 1000, encoding_format: EmbeddingEncodingFormat::Float, - dimensions: 1536, + dimensions: EmbeddingModel::TextEmbedding3Small.default_dimensions(), }; let result = client.embeddings(request).await.unwrap(); assert!(result.total_tokens > 0); diff --git a/libs/appflowy-ai-client/tests/chat_test/qa_test.rs b/libs/appflowy-ai-client/tests/chat_test/qa_test.rs index f0f7fabf1..2aac663ae 100644 --- a/libs/appflowy-ai-client/tests/chat_test/qa_test.rs +++ b/libs/appflowy-ai-client/tests/chat_test/qa_test.rs @@ -11,7 +11,7 @@ async fn qa_test() { client.health_check().await.unwrap(); let chat_id = uuid::Uuid::new_v4().to_string(); let resp = client - .send_question(&chat_id, "I feel hungry", &AIModel::GPT4o, None) + .send_question(&chat_id, 1, "I feel hungry", &AIModel::GPT4o, None) .await .unwrap(); assert!(!resp.content.is_empty()); @@ -30,7 +30,7 @@ async fn stop_stream_test() { client.health_check().await.unwrap(); let chat_id = uuid::Uuid::new_v4().to_string(); let mut stream = client - .stream_question(&chat_id, "I feel hungry", None, &AIModel::GPT4oMini) + .stream_question(&chat_id, "I feel hungry", None, vec![], &AIModel::GPT4oMini) .await .unwrap(); @@ -52,7 +52,14 @@ async fn stream_test() { client.health_check().await.expect("Health check failed"); let chat_id = uuid::Uuid::new_v4().to_string(); let stream = client - .stream_question_v2(&chat_id, "I feel hungry", None, &AIModel::GPT4oMini) + .stream_question_v2( + &chat_id, + 1, + "I feel hungry", + None, + vec![], + &AIModel::GPT4oMini, + ) .await .expect("Failed to initiate question stream"); diff --git a/libs/client-api-test/src/test_client.rs b/libs/client-api-test/src/test_client.rs index 4149cc9ee..1970d2a65 100644 --- a/libs/client-api-test/src/test_client.rs +++ b/libs/client-api-test/src/test_client.rs @@ -31,7 +31,10 @@ use uuid::Uuid; #[cfg(feature = "collab-sync")] use client_api::collab_sync::{SinkConfig, SyncObject, SyncPlugin}; use client_api::entity::id::user_awareness_object_id; -use client_api::entity::{PublishCollabItem, PublishCollabMetadata, QueryWorkspaceMember}; +use client_api::entity::{ + PublishCollabItem, PublishCollabMetadata, QueryWorkspaceMember, QuestionStream, + QuestionStreamValue, +}; use client_api::ws::{WSClient, WSClientConfig}; use database_entity::dto::{ AFAccessLevel, AFRole, AFSnapshotMeta, AFSnapshotMetas, AFUserProfile, AFUserWorkspaceInfo, @@ -845,24 +848,21 @@ impl TestClient { #[allow(unused_variables)] pub async fn create_collab_with_data( &mut self, - object_id: String, workspace_id: &str, + object_id: &str, collab_type: CollabType, - encoded_collab_v1: Option, + encoded_collab_v1: EncodedCollab, ) -> Result<(), AppResponseError> { // Subscribe to object let origin = CollabOrigin::Client(CollabClient::new(self.uid().await, self.device_id.clone())); - let collab = match encoded_collab_v1 { - None => Collab::new_with_origin(origin.clone(), &object_id, vec![], false), - Some(data) => Collab::new_with_source( - origin.clone(), - &object_id, - DataSource::DocStateV1(data.doc_state.to_vec()), - vec![], - false, - ) - .unwrap(), - }; + let collab = Collab::new_with_source( + origin.clone(), + object_id, + DataSource::DocStateV1(encoded_collab_v1.doc_state.to_vec()), + vec![], + false, + ) + .unwrap(); let encoded_collab_v1 = collab .encode_collab_v1(|collab| collab_type.validate_require_data(collab)) @@ -873,7 +873,7 @@ impl TestClient { self .api_client .create_collab(CreateCollabParams { - object_id: object_id.clone(), + object_id: object_id.to_string(), encoded_collab_v1, collab_type: collab_type.clone(), workspace_id: workspace_id.to_string(), @@ -1167,3 +1167,16 @@ pub async fn get_collab_json_from_server( .unwrap() .to_json_value() } + +pub async fn collect_answer(mut stream: QuestionStream) -> String { + let mut answer = String::new(); + while let Some(value) = stream.next().await { + match value.unwrap() { + QuestionStreamValue::Answer { value } => { + answer.push_str(&value); + }, + QuestionStreamValue::Metadata { .. } => {}, + } + } + answer +} diff --git a/libs/client-api/src/http_chat.rs b/libs/client-api/src/http_chat.rs index 06a410d86..6f020c632 100644 --- a/libs/client-api/src/http_chat.rs +++ b/libs/client-api/src/http_chat.rs @@ -9,7 +9,10 @@ use futures_core::{ready, Stream}; use pin_project::pin_project; use reqwest::Method; use serde_json::Value; -use shared_entity::dto::ai_dto::{RepeatedRelatedQuestion, STREAM_ANSWER_KEY, STREAM_METADATA_KEY}; +use shared_entity::dto::ai_dto::{ + CalculateSimilarityParams, RepeatedRelatedQuestion, SimilarityResponse, STREAM_ANSWER_KEY, + STREAM_METADATA_KEY, +}; use shared_entity::response::{AppResponse, AppResponseError}; use std::pin::Pin; use std::task::{Context, Poll}; @@ -215,6 +218,26 @@ impl Client { .await? .into_data() } + + pub async fn calculate_similarity( + &self, + params: CalculateSimilarityParams, + ) -> Result { + let url = format!( + "{}/api/ai/{}/calculate_similarity", + self.base_url, ¶ms.workspace_id + ); + let resp = self + .http_client_with_auth(Method::POST, &url) + .await? + .json(¶ms) + .send() + .await?; + log_request_id(&resp); + AppResponse::::from_response(resp) + .await? + .into_data() + } } #[pin_project] diff --git a/libs/database/src/chat/chat_ops.rs b/libs/database/src/chat/chat_ops.rs index e2b563b8c..c0fffded7 100644 --- a/libs/database/src/chat/chat_ops.rs +++ b/libs/database/src/chat/chat_ops.rs @@ -33,7 +33,6 @@ pub async fn insert_chat( ))); } let rag_ids = json!(params.rag_ids); - sqlx::query!( r#" INSERT INTO af_chat (chat_id, name, workspace_id, rag_ids) @@ -145,6 +144,25 @@ pub async fn select_chat<'a, E: Executor<'a, Database = Postgres>>( } } +pub async fn select_chat_rag_ids<'a, E: Executor<'a, Database = Postgres>>( + executor: E, + chat_id: &str, +) -> Result, AppError> { + let chat_id = Uuid::from_str(chat_id)?; + let row = sqlx::query!( + r#" + SELECT rag_ids + FROM af_chat + WHERE chat_id = $1 AND deleted_at IS NULL + "#, + &chat_id, + ) + .fetch_one(executor) + .await?; + let rag_ids = serde_json::from_value::>(row.rag_ids).unwrap_or_default(); + Ok(rag_ids) +} + pub async fn insert_answer_message_with_transaction( transaction: &mut Transaction<'_, Postgres>, author: ChatAuthor, diff --git a/services/appflowy-collaborate/Cargo.toml b/services/appflowy-collaborate/Cargo.toml index 070547f36..61488ed70 100644 --- a/services/appflowy-collaborate/Cargo.toml +++ b/services/appflowy-collaborate/Cargo.toml @@ -87,8 +87,9 @@ lazy_static = "1.4.0" itertools = "0.12.0" validator = "0.16.1" rayon.workspace = true -unicode-segmentation = "1.9.0" +tiktoken-rs = "0.6.0" [dev-dependencies] rand = "0.8.5" workspace-template.workspace = true +unicode-normalization = "0.1.24" diff --git a/services/appflowy-collaborate/src/group/persistence.rs b/services/appflowy-collaborate/src/group/persistence.rs index 60e77d05c..0337fa3a6 100644 --- a/services/appflowy-collaborate/src/group/persistence.rs +++ b/services/appflowy-collaborate/src/group/persistence.rs @@ -134,7 +134,7 @@ where let lock = collab.read().await; if let Some(indexer) = &self.indexer { - match indexer.embedding_params(&lock) { + match indexer.embedding_params(&lock).await { Ok(embedding_params) => { drop(lock); // we no longer need the lock match indexer.embeddings(embedding_params).await { diff --git a/services/appflowy-collaborate/src/indexer/document_indexer.rs b/services/appflowy-collaborate/src/indexer/document_indexer.rs index 23799073b..2b7e2367f 100644 --- a/services/appflowy-collaborate/src/indexer/document_indexer.rs +++ b/services/appflowy-collaborate/src/indexer/document_indexer.rs @@ -4,36 +4,44 @@ use anyhow::anyhow; use async_trait::async_trait; use collab::preclude::Collab; +use crate::indexer::{DocumentDataExt, Indexer}; use app_error::AppError; use appflowy_ai_client::client::AppFlowyAIClient; use appflowy_ai_client::dto::{ - EmbeddingEncodingFormat, EmbeddingInput, EmbeddingOutput, EmbeddingRequest, EmbeddingsModel, + EmbeddingEncodingFormat, EmbeddingInput, EmbeddingModel, EmbeddingOutput, EmbeddingRequest, }; use collab_document::document::DocumentBody; use collab_document::error::DocumentError; use collab_entity::CollabType; use database_entity::dto::{AFCollabEmbeddingParams, AFCollabEmbeddings, EmbeddingContentType}; -use unicode_segmentation::UnicodeSegmentation; -use uuid::Uuid; -use crate::indexer::{DocumentDataExt, Indexer}; +use tiktoken_rs::CoreBPE; +use tracing::trace; +use uuid::Uuid; pub struct DocumentIndexer { ai_client: AppFlowyAIClient, + tokenizer: Arc, + embedding_model: EmbeddingModel, } impl DocumentIndexer { - /// We assume that every token is ~4 bytes. We're going to split document content into fragments - /// of ~2000 tokens each. - pub const DOC_CONTENT_SPLIT: usize = 8000; pub fn new(ai_client: AppFlowyAIClient) -> Arc { - Arc::new(Self { ai_client }) + let tokenizer = tiktoken_rs::cl100k_base().unwrap(); + Arc::new(Self { + ai_client, + tokenizer: Arc::new(tokenizer), + embedding_model: EmbeddingModel::TextEmbedding3Small, + }) } } #[async_trait] impl Indexer for DocumentIndexer { - fn embedding_params(&self, collab: &Collab) -> Result, AppError> { + async fn embedding_params( + &self, + collab: &Collab, + ) -> Result, AppError> { let object_id = collab.object_id().to_string(); let document = DocumentBody::from_collab(collab).ok_or_else(|| { anyhow!( @@ -46,12 +54,15 @@ impl Indexer for DocumentIndexer { match result { Ok(document_data) => { let content = document_data.to_plain_text(); - create_embedding_params( + let max_tokens = self.embedding_model.default_dimensions() as usize; + create_embedding( object_id, content, CollabType::Document, - Self::DOC_CONTENT_SPLIT, + max_tokens, + self.tokenizer.clone(), ) + .await }, Err(err) => { if matches!(err, DocumentError::NoRequiredData) { @@ -80,12 +91,17 @@ impl Indexer for DocumentIndexer { .ai_client .embeddings(EmbeddingRequest { input: EmbeddingInput::StringArray(contents), - model: EmbeddingsModel::TextEmbedding3Small.to_string(), - chunk_size: (Self::DOC_CONTENT_SPLIT / 4) as i32, + model: EmbeddingModel::TextEmbedding3Small.to_string(), + chunk_size: 2000, encoding_format: EmbeddingEncodingFormat::Float, - dimensions: 1536, + dimensions: EmbeddingModel::TextEmbedding3Small.default_dimensions(), }) .await?; + trace!( + "[Embedding] request {} embeddings, received {} embeddings", + params.len(), + resp.data.len() + ); for embedding in resp.data { let param = &mut params[embedding.index as usize]; @@ -112,335 +128,322 @@ impl Indexer for DocumentIndexer { })) } } -#[inline] -fn create_embedding_params( + +/// ## Execution Time Comparison Results +/// +/// The following results were observed when running `execution_time_comparison_tests`: +/// +/// | Content Size (chars) | Direct Time (ms) | spawn_blocking Time (ms) | +/// |-----------------------|------------------|--------------------------| +/// | 500 | 1 | 1 | +/// | 1000 | 2 | 2 | +/// | 2000 | 5 | 5 | +/// | 5000 | 11 | 11 | +/// | 20000 | 49 | 48 | +/// +/// ## Guidelines for Using `spawn_blocking` +/// +/// - **Short Tasks (< 1 ms)**: +/// Use direct execution on the async runtime. The minimal execution time has negligible impact. +/// +/// - **Moderate Tasks (1–10 ms)**: +/// - For infrequent or low-concurrency tasks, direct execution is acceptable. +/// - For frequent or high-concurrency tasks, consider using `spawn_blocking` to avoid delays. +/// +/// - **Long Tasks (> 10 ms)**: +/// Always offload to a blocking thread with `spawn_blocking` to maintain runtime efficiency and responsiveness. +/// +/// Related blog: +/// https://tokio.rs/blog/2020-04-preemption +/// https://ryhl.io/blog/async-what-is-blocking/ +async fn create_embedding( object_id: String, content: String, collab_type: CollabType, - max_content_len: usize, + max_tokens: usize, + tokenizer: Arc, ) -> Result, AppError> { + let split_contents = if content.len() < 500 { + split_text_by_max_tokens(content, max_tokens, tokenizer.as_ref())? + } else { + tokio::task::spawn_blocking(move || { + split_text_by_max_tokens(content, max_tokens, tokenizer.as_ref()) + }) + .await?? + }; + + Ok( + split_contents + .into_iter() + .map(|content| AFCollabEmbeddingParams { + fragment_id: Uuid::new_v4().to_string(), + object_id: object_id.clone(), + collab_type: collab_type.clone(), + content_type: EmbeddingContentType::PlainText, + content, + embedding: None, + }) + .collect(), + ) +} + +fn split_text_by_max_tokens( + content: String, + max_tokens: usize, + tokenizer: &CoreBPE, +) -> Result, AppError> { if content.is_empty() { return Ok(vec![]); } - // Helper function to create AFCollabEmbeddingParams - fn create_param( - fragment_id: String, - object_id: &str, - collab_type: &CollabType, - content: String, - ) -> AFCollabEmbeddingParams { - AFCollabEmbeddingParams { - fragment_id, - object_id: object_id.to_string(), - collab_type: collab_type.clone(), - content_type: EmbeddingContentType::PlainText, - content, - embedding: None, - } - } - - if content.len() <= max_content_len { - // Content is short enough; return as a single fragment - let param = create_param(object_id.clone(), &object_id, &collab_type, content); - return Ok(vec![param]); + let token_ids = tokenizer.encode_ordinary(&content); + let total_tokens = token_ids.len(); + if total_tokens <= max_tokens { + return Ok(vec![content]); } - // Content is longer than max_content_len; need to split - let mut result = Vec::with_capacity(1 + content.len() / max_content_len); - let mut fragment = String::with_capacity(max_content_len); - let mut current_len = 0; - - for grapheme in content.graphemes(true) { - let grapheme_len = grapheme.len(); - if current_len + grapheme_len > max_content_len { - if !fragment.is_empty() { - // Move the fragment to avoid cloning - result.push(create_param( - Uuid::new_v4().to_string(), - &object_id, - &collab_type, - std::mem::take(&mut fragment), - )); - } - current_len = 0; - - // Check if the grapheme itself is longer than max_content_len - if grapheme_len > max_content_len { - // Push the grapheme as a fragment on its own - result.push(create_param( - Uuid::new_v4().to_string(), - &object_id, - &collab_type, - grapheme.to_string(), - )); - continue; + let mut chunks = Vec::new(); + let mut start_idx = 0; + while start_idx < total_tokens { + let mut end_idx = (start_idx + max_tokens).min(total_tokens); + let mut decoded = false; + // Try to decode the chunk, adjust end_idx if decoding fails + while !decoded { + let token_chunk = &token_ids[start_idx..end_idx]; + // Attempt to decode the current chunk + match tokenizer.decode(token_chunk.to_vec()) { + Ok(chunk_text) => { + chunks.push(chunk_text); + start_idx = end_idx; + decoded = true; + }, + Err(_) => { + // If we can extend the chunk, do so + if end_idx < total_tokens { + end_idx += 1; + } else if start_idx + 1 < total_tokens { + // Skip the problematic token at start_idx + start_idx += 1; + end_idx = (start_idx + max_tokens).min(total_tokens); + } else { + // Cannot decode any further, break to avoid infinite loop + start_idx = total_tokens; + break; + } + }, } } - fragment.push_str(grapheme); - current_len += grapheme_len; - } - - // Add the last fragment if it's not empty - if !fragment.is_empty() { - result.push(create_param( - object_id.clone(), - &object_id, - &collab_type, - fragment, - )); } - Ok(result) + Ok(chunks) } + #[cfg(test)] mod tests { - use crate::indexer::document_indexer::create_embedding_params; - use collab_entity::CollabType; + use crate::indexer::document_indexer::split_text_by_max_tokens; + + use tiktoken_rs::cl100k_base; #[test] fn test_split_at_non_utf8() { - let object_id = "test_object".to_string(); - let collab_type = CollabType::Document; - let max_content_len = 10; // Small number for testing + let max_tokens = 10; // Small number for testing // Content with multibyte characters (emojis) let content = "Hello πŸ˜ƒ World 🌍! This is a test πŸš€.".to_string(); - - let params = create_embedding_params( - object_id.clone(), - content.clone(), - collab_type.clone(), - max_content_len, - ) - .unwrap(); + let tokenizer = cl100k_base().unwrap(); + let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap(); // Ensure that we didn't split in the middle of a multibyte character - for param in params { - assert!(param.content.is_char_boundary(0)); - assert!(param.content.is_char_boundary(param.content.len())); + for content in params { + assert!(content.is_char_boundary(0)); + assert!(content.is_char_boundary(content.len())); } } - #[test] fn test_exact_boundary_split() { - let object_id = "test_object".to_string(); - let collab_type = CollabType::Document; - let max_content_len = 5; // Set to 5 for testing - - // Content length is exactly a multiple of max_content_len - let content = "abcdefghij".to_string(); // 10 characters - - let params = create_embedding_params( - object_id.clone(), - content.clone(), - collab_type.clone(), - max_content_len, - ) - .unwrap(); - - assert_eq!(params.len(), 2); - assert_eq!(params[0].content, "abcde"); - assert_eq!(params[1].content, "fghij"); + let max_tokens = 5; // Set to 5 tokens for testing + let content = "The quick brown fox jumps over the lazy dog".to_string(); + let tokenizer = cl100k_base().unwrap(); + let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap(); + + let total_tokens = tokenizer.encode_ordinary(&content).len(); + let expected_fragments = (total_tokens + max_tokens - 1) / max_tokens; + assert_eq!(params.len(), expected_fragments); } #[test] fn test_content_shorter_than_max_len() { - let object_id = "test_object".to_string(); - let collab_type = CollabType::Document; - let max_content_len = 100; - + let max_tokens = 100; let content = "Short content".to_string(); - - let params = create_embedding_params( - object_id.clone(), - content.clone(), - collab_type.clone(), - max_content_len, - ) - .unwrap(); + let tokenizer = cl100k_base().unwrap(); + let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap(); assert_eq!(params.len(), 1); - assert_eq!(params[0].content, content); + assert_eq!(params[0], content); } #[test] fn test_empty_content() { - let object_id = "test_object".to_string(); - let collab_type = CollabType::Document; - let max_content_len = 10; - + let max_tokens = 10; let content = "".to_string(); - - let params = create_embedding_params( - object_id.clone(), - content.clone(), - collab_type.clone(), - max_content_len, - ) - .unwrap(); + let tokenizer = cl100k_base().unwrap(); + let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap(); assert_eq!(params.len(), 0); } #[test] fn test_content_with_only_multibyte_characters() { - let object_id = "test_object".to_string(); - let collab_type = CollabType::Document; - let max_content_len = 4; // Small number for testing - - // Each emoji is 4 bytes in UTF-8 + let max_tokens = 1; // Set to 1 token for testing let content = "πŸ˜€πŸ˜ƒπŸ˜„πŸ˜πŸ˜†".to_string(); + let tokenizer = cl100k_base().unwrap(); + let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap(); - let params = create_embedding_params( - object_id.clone(), - content.clone(), - collab_type.clone(), - max_content_len, - ) - .unwrap(); - - assert_eq!(params.len(), 5); - let expected_contents = vec!["πŸ˜€", "πŸ˜ƒ", "πŸ˜„", "😁", "πŸ˜†"]; - for (param, expected) in params.iter().zip(expected_contents.iter()) { - assert_eq!(param.content, *expected); + let emojis: Vec = content.chars().map(|c| c.to_string()).collect(); + for (param, emoji) in params.iter().zip(emojis.iter()) { + assert_eq!(param, emoji); } } #[test] fn test_split_with_combining_characters() { - let object_id = "test_object".to_string(); - let collab_type = CollabType::Document; - let max_content_len = 5; // Small number for testing - - // String with combining characters (e.g., letters with accents) + let max_tokens = 1; // Set to 1 token for testing let content = "a\u{0301}e\u{0301}i\u{0301}o\u{0301}u\u{0301}".to_string(); // "áéíóú" + let tokenizer = cl100k_base().unwrap(); + let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap(); - let params = create_embedding_params( - object_id.clone(), - content.clone(), - collab_type.clone(), - max_content_len, - ) - .unwrap(); - - assert_eq!(params.len(), 5); - let expected_contents = vec!["á", "é", "í", "ó", "ú"]; - for (param, expected) in params.iter().zip(expected_contents.iter()) { - assert_eq!(param.content, *expected); - } + let total_tokens = tokenizer.encode_ordinary(&content).len(); + assert_eq!(params.len(), total_tokens); + + let reconstructed_content = params.join(""); + assert_eq!(reconstructed_content, content); } #[test] fn test_large_content() { - let object_id = "test_object".to_string(); - let collab_type = CollabType::Document; - let max_content_len = 1000; - - // Generate a large content string + let max_tokens = 1000; let content = "a".repeat(5000); // 5000 characters + let tokenizer = cl100k_base().unwrap(); + let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap(); - let params = create_embedding_params( - object_id.clone(), - content.clone(), - collab_type.clone(), - max_content_len, - ) - .unwrap(); - - assert_eq!(params.len(), 5); // 5000 / 1000 = 5 - for param in params { - assert_eq!(param.content.len(), 1000); - } + let total_tokens = tokenizer.encode_ordinary(&content).len(); + let expected_fragments = (total_tokens + max_tokens - 1) / max_tokens; + assert_eq!(params.len(), expected_fragments); } + #[test] fn test_non_ascii_characters() { - let object_id = "test_object".to_string(); - let collab_type = CollabType::Document; - let max_content_len = 5; - - // Non-ASCII characters: "Ñéíóú" + let max_tokens = 2; let content = "Ñéíóú".to_string(); + let tokenizer = cl100k_base().unwrap(); + let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap(); - let params = create_embedding_params( - object_id.clone(), - content.clone(), - collab_type.clone(), - max_content_len, - ) - .unwrap(); - - // Content should be split into two fragments - assert_eq!(params.len(), 3); - assert_eq!(params[0].content, "Ñé"); - assert_eq!(params[1].content, "Γ­Γ³"); - assert_eq!(params[2].content, "ΓΊ"); + let total_tokens = tokenizer.encode_ordinary(&content).len(); + let expected_fragments = (total_tokens + max_tokens - 1) / max_tokens; + assert_eq!(params.len(), expected_fragments); + + let reconstructed_content: String = params.concat(); + assert_eq!(reconstructed_content, content); } #[test] fn test_content_with_leading_and_trailing_whitespace() { - let object_id = "test_object".to_string(); - let collab_type = CollabType::Document; - let max_content_len = 5; - + let max_tokens = 3; let content = " abcde ".to_string(); + let tokenizer = cl100k_base().unwrap(); + let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap(); - let params = create_embedding_params( - object_id.clone(), - content.clone(), - collab_type.clone(), - max_content_len, - ) - .unwrap(); - - // Content should include leading and trailing whitespace - assert_eq!(params.len(), 2); - assert_eq!(params[0].content, " abc"); - assert_eq!(params[1].content, "de "); + let total_tokens = tokenizer.encode_ordinary(&content).len(); + let expected_fragments = (total_tokens + max_tokens - 1) / max_tokens; + assert_eq!(params.len(), expected_fragments); + + let reconstructed_content: String = params.concat(); + assert_eq!(reconstructed_content, content); } #[test] fn test_content_with_multiple_zero_width_joiners() { - let object_id = "test_object".to_string(); - let collab_type = CollabType::Document; - let max_content_len = 10; - - // Complex emoji sequence with multiple zero-width joiners + let max_tokens = 1; let content = "πŸ‘©β€πŸ‘©β€πŸ‘§β€πŸ‘§πŸ‘¨β€πŸ‘¨β€πŸ‘¦β€πŸ‘¦".to_string(); + let tokenizer = cl100k_base().unwrap(); + let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap(); - let params = create_embedding_params( - object_id.clone(), - content.clone(), - collab_type.clone(), - max_content_len, - ) - .unwrap(); - - // Each complex emoji should be treated as a single grapheme - assert_eq!(params.len(), 2); - assert_eq!(params[0].content, "πŸ‘©β€πŸ‘©β€πŸ‘§β€πŸ‘§"); - assert_eq!(params[1].content, "πŸ‘¨β€πŸ‘¨β€πŸ‘¦β€πŸ‘¦"); + let reconstructed_content: String = params.concat(); + assert_eq!(reconstructed_content, content); } #[test] fn test_content_with_long_combining_sequences() { - let object_id = "test_object".to_string(); - let collab_type = CollabType::Document; - let max_content_len = 5; - - // Character with multiple combining marks - let content = "a\u{0300}\u{0301}\u{0302}\u{0303}\u{0304}".to_string(); // a with multiple accents - - let params = create_embedding_params( - object_id.clone(), - content.clone(), - collab_type.clone(), - max_content_len, - ) - .unwrap(); - - // The entire combining sequence should be in one fragment - assert_eq!(params.len(), 1); - assert_eq!(params[0].content, content); + let max_tokens = 1; + let content = "a\u{0300}\u{0301}\u{0302}\u{0303}\u{0304}".to_string(); + let tokenizer = cl100k_base().unwrap(); + let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap(); + + let reconstructed_content: String = params.concat(); + assert_eq!(reconstructed_content, content); } } + +// #[cfg(test)] +// mod execution_time_comparison_tests { +// use crate::indexer::document_indexer::split_text_by_max_tokens; +// use rand::distributions::Alphanumeric; +// use rand::{thread_rng, Rng}; +// use std::sync::Arc; +// use std::time::Instant; +// use tiktoken_rs::{cl100k_base, CoreBPE}; +// +// #[tokio::test] +// async fn test_execution_time_comparison() { +// let tokenizer = Arc::new(cl100k_base().unwrap()); +// let max_tokens = 100; +// +// let sizes = vec![500, 1000, 2000, 5000, 20000]; // Content sizes to test +// for size in sizes { +// let content = generate_random_string(size); +// +// // Measure direct execution time +// let direct_time = measure_direct_execution(content.clone(), max_tokens, &tokenizer); +// +// // Measure spawn_blocking execution time +// let spawn_blocking_time = +// measure_spawn_blocking_execution(content, max_tokens, Arc::clone(&tokenizer)).await; +// +// println!( +// "Content Size: {} | Direct Time: {}ms | spawn_blocking Time: {}ms", +// size, direct_time, spawn_blocking_time +// ); +// } +// } +// +// // Measure direct execution time +// fn measure_direct_execution(content: String, max_tokens: usize, tokenizer: &CoreBPE) -> u128 { +// let start = Instant::now(); +// split_text_by_max_tokens(content, max_tokens, tokenizer).unwrap(); +// start.elapsed().as_millis() +// } +// +// // Measure `spawn_blocking` execution time +// async fn measure_spawn_blocking_execution( +// content: String, +// max_tokens: usize, +// tokenizer: Arc, +// ) -> u128 { +// let start = Instant::now(); +// tokio::task::spawn_blocking(move || { +// split_text_by_max_tokens(content, max_tokens, tokenizer.as_ref()).unwrap() +// }) +// .await +// .unwrap(); +// start.elapsed().as_millis() +// } +// +// pub fn generate_random_string(len: usize) -> String { +// let rng = thread_rng(); +// rng +// .sample_iter(&Alphanumeric) +// .take(len) +// .map(char::from) +// .collect() +// } +// } diff --git a/services/appflowy-collaborate/src/indexer/mod.rs b/services/appflowy-collaborate/src/indexer/mod.rs index 09581298b..c9fe41a08 100644 --- a/services/appflowy-collaborate/src/indexer/mod.rs +++ b/services/appflowy-collaborate/src/indexer/mod.rs @@ -1,7 +1,6 @@ mod document_indexer; mod ext; mod provider; - pub use document_indexer::DocumentIndexer; pub use ext::DocumentDataExt; pub use provider::*; diff --git a/services/appflowy-collaborate/src/indexer/provider.rs b/services/appflowy-collaborate/src/indexer/provider.rs index 036c0ade1..f56b6a078 100644 --- a/services/appflowy-collaborate/src/indexer/provider.rs +++ b/services/appflowy-collaborate/src/indexer/provider.rs @@ -26,7 +26,10 @@ use database_entity::dto::{AFCollabEmbeddingParams, AFCollabEmbeddings, CollabPa #[async_trait] pub trait Indexer: Send + Sync { - fn embedding_params(&self, collab: &Collab) -> Result, AppError>; + async fn embedding_params( + &self, + collab: &Collab, + ) -> Result, AppError>; async fn embeddings( &self, @@ -46,7 +49,7 @@ pub trait Indexer: Send + Sync { false, ) .map_err(|err| AppError::Internal(err.into()))?; - let embedding_params = self.embedding_params(&collab)?; + let embedding_params = self.embedding_params(&collab).await?; self.embeddings(embedding_params).await } } diff --git a/src/api/ai.rs b/src/api/ai.rs index 109fc59b4..60ba4cdb9 100644 --- a/src/api/ai.rs +++ b/src/api/ai.rs @@ -5,7 +5,8 @@ use actix_web::web::{Data, Json}; use actix_web::{web, HttpRequest, HttpResponse, Scope}; use app_error::AppError; use appflowy_ai_client::dto::{ - CompleteTextResponse, LocalAIConfig, TranslateRowParams, TranslateRowResponse, + CalculateSimilarityParams, CompleteTextResponse, LocalAIConfig, SimilarityResponse, + TranslateRowParams, TranslateRowResponse, }; use futures_util::{stream, TryStreamExt}; @@ -25,6 +26,9 @@ pub fn ai_completion_scope() -> Scope { .service(web::resource("/summarize_row").route(web::post().to(summarize_row_handler))) .service(web::resource("/translate_row").route(web::post().to(translate_row_handler))) .service(web::resource("/local/config").route(web::get().to(local_ai_config_handler))) + .service( + web::resource("/calculate_similarity").route(web::post().to(calculate_similarity_handler)), + ) } async fn complete_text_handler( @@ -163,3 +167,18 @@ async fn local_ai_config_handler( .map_err(|err| AppError::AIServiceUnavailable(err.to_string()))?; Ok(AppResponse::Ok().with_data(config).into()) } + +#[instrument(level = "debug", skip_all, err)] +async fn calculate_similarity_handler( + state: web::Data, + payload: web::Json, +) -> actix_web::Result>> { + let params = payload.into_inner(); + + let response = state + .ai_client + .calculate_similarity(params) + .await + .map_err(|err| AppError::AIServiceUnavailable(err.to_string()))?; + Ok(AppResponse::Ok().with_data(response).into()) +} diff --git a/src/api/chat.rs b/src/api/chat.rs index 62da371e1..8dbc3e0db 100644 --- a/src/api/chat.rs +++ b/src/api/chat.rs @@ -95,7 +95,6 @@ async fn create_chat_handler( ) -> actix_web::Result> { let workspace_id = path.into_inner(); let params = payload.into_inner(); - trace!("create new chat: {:?}", params); create_chat(&state.pg_pool, params, &workspace_id).await?; Ok(AppResponse::Ok().into()) } @@ -242,10 +241,11 @@ async fn answer_stream_handler( let (_workspace_id, chat_id, question_id) = path.into_inner(); let (content, metadata) = chat::chat_ops::select_chat_message_content(&state.pg_pool, question_id).await?; + let rag_ids = chat::chat_ops::select_chat_rag_ids(&state.pg_pool, &chat_id).await?; let ai_model = ai_model_from_header(&req); match state .ai_client - .stream_question(&chat_id, &content, Some(metadata), &ai_model) + .stream_question(&chat_id, &content, Some(metadata), rag_ids, &ai_model) .await { Ok(answer_stream) => { @@ -275,10 +275,25 @@ async fn answer_stream_v2_handler( let (_workspace_id, chat_id, question_id) = path.into_inner(); let (content, metadata) = chat::chat_ops::select_chat_message_content(&state.pg_pool, question_id).await?; + let rag_ids = chat::chat_ops::select_chat_rag_ids(&state.pg_pool, &chat_id).await?; let ai_model = ai_model_from_header(&req); + + trace!( + "[Chat] stream answer for chat: {}, question: {}, rag_ids: {:?}", + chat_id, + content, + rag_ids + ); match state .ai_client - .stream_question_v2(&chat_id, &content, Some(metadata), &ai_model) + .stream_question_v2( + &chat_id, + question_id, + &content, + Some(metadata), + rag_ids, + &ai_model, + ) .await { Ok(answer_stream) => { diff --git a/src/biz/chat/ops.rs b/src/biz/chat/ops.rs index 484d5df7a..ef23df17a 100644 --- a/src/biz/chat/ops.rs +++ b/src/biz/chat/ops.rs @@ -17,7 +17,7 @@ use shared_entity::dto::chat_dto::{ CreateChatParams, GetChatMessageParams, RepeatedChatMessage, UpdateChatMessageContentParams, }; use sqlx::PgPool; -use tracing::{error, info}; +use tracing::{error, info, trace}; use appflowy_ai_client::dto::AIModel; use validator::Validate; @@ -28,6 +28,7 @@ pub(crate) async fn create_chat( workspace_id: &str, ) -> Result<(), AppError> { params.validate()?; + trace!("[Chat] create chat {:?}", params); let mut txn = pg_pool.begin().await?; insert_chat(&mut txn, workspace_id, params).await?; @@ -60,7 +61,13 @@ pub async fn update_chat_message( // TODO(nathan): query the metadata from the database let new_answer = ai_client - .send_question(¶ms.chat_id, ¶ms.content, &ai_model, None) + .send_question( + ¶ms.chat_id, + params.message_id, + ¶ms.content, + &ai_model, + None, + ) .await?; let _answer = insert_answer_message( pg_pool, @@ -85,7 +92,13 @@ pub async fn generate_chat_message_answer( let (content, metadata) = chat::chat_ops::select_chat_message_content(pg_pool, question_message_id).await?; let new_answer = ai_client - .send_question(chat_id, &content, &ai_model, Some(metadata)) + .send_question( + chat_id, + question_message_id, + &content, + &ai_model, + Some(metadata), + ) .await?; info!("new_answer: {:?}", new_answer); @@ -174,7 +187,7 @@ pub async fn create_chat_message_stream( match params.message_type { ChatMessageType::System => {} ChatMessageType::User => { - let answer = match ai_client.send_question(&chat_id, ¶ms.content, &ai_model, Some(json!(params.metadata))).await { + let answer = match ai_client.send_question(&chat_id,question_id, ¶ms.content, &ai_model, Some(json!(params.metadata))).await { Ok(response) => response, Err(err) => { error!("Failed to send question to AI: {}", err); diff --git a/src/biz/search/ops.rs b/src/biz/search/ops.rs index 239a42161..0938ac057 100644 --- a/src/biz/search/ops.rs +++ b/src/biz/search/ops.rs @@ -2,7 +2,7 @@ use crate::api::metrics::RequestMetrics; use app_error::ErrorCode; use appflowy_ai_client::client::AppFlowyAIClient; use appflowy_ai_client::dto::{ - EmbeddingEncodingFormat, EmbeddingInput, EmbeddingOutput, EmbeddingRequest, EmbeddingsModel, + EmbeddingEncodingFormat, EmbeddingInput, EmbeddingModel, EmbeddingOutput, EmbeddingRequest, }; use database::index::{search_documents, SearchDocumentParams}; @@ -25,10 +25,10 @@ pub async fn search_document( let embeddings = ai_client .embeddings(EmbeddingRequest { input: EmbeddingInput::String(request.query.clone()), - model: EmbeddingsModel::TextEmbedding3Small.to_string(), + model: EmbeddingModel::TextEmbedding3Small.to_string(), chunk_size: 500, encoding_format: EmbeddingEncodingFormat::Float, - dimensions: 1536, + dimensions: EmbeddingModel::TextEmbedding3Small.default_dimensions(), }) .await .map_err(|e| AppResponseError::new(ErrorCode::Internal, e.to_string()))?; @@ -64,7 +64,7 @@ pub async fn search_document( user_id: uid, workspace_id, limit: request.limit.unwrap_or(10) as i32, - preview: request.preview_size.unwrap_or(180) as i32, + preview: request.preview_size.unwrap_or(500) as i32, embedding, }, total_tokens, diff --git a/tests/collab/collab_curd_test.rs b/tests/collab/collab_curd_test.rs index fbf9bdaec..1d49f422b 100644 --- a/tests/collab/collab_curd_test.rs +++ b/tests/collab/collab_curd_test.rs @@ -37,24 +37,6 @@ async fn get_collab_response_compatible_test() { assert_eq!(collab_resp.encode_collab, encode_collab); } -#[tokio::test] -#[should_panic] -async fn create_collab_workspace_id_equal_to_object_id_test() { - let mut test_client = TestClient::new_user().await; - let workspace_id = test_client.workspace_id().await; - // Only the object with [CollabType::Folder] can have the same object_id as workspace_id. But - // it should use create workspace API - test_client - .create_collab_with_data( - workspace_id.clone(), - &workspace_id, - CollabType::Unknown, - None, - ) - .await - .unwrap() -} - #[tokio::test] async fn batch_insert_collab_with_empty_payload_test() { let mut test_client = TestClient::new_user().await; diff --git a/tests/search/asset/appflowy_values.md b/tests/search/asset/appflowy_values.md new file mode 100644 index 000000000..bcefe5ece --- /dev/null +++ b/tests/search/asset/appflowy_values.md @@ -0,0 +1,54 @@ +# AppFlowy Values + +## Mission Driven + +- Our mission is to enable everyone to unleash the potential and achieve more with secure workplace tools. +- We are true believers in open sourceβ€”a fundamentally superior approach to achieve the mission. +- We actively lead and support the AppFlowy open-source community, where a diverse group of people is empowered to + contribute to the common good. +- We think strategically, make wise decisions, and act accordingly, with an eye toward what’s sustainable in the long + run and not what’s convenient in the moment. + +## Aim High and Iterate + +1. We strive for excellence with a growth mindset. +2. We dream big, start small, and move fast. +3. We take smaller steps and ship smaller, simpler features. +4. We don’t wait, but instead iterate and work as part of the community. +5. We focus on results over process and prioritize progress over perfection. + +## Transparency + +1. We make information about AppFlowy public by default unless there is a compelling reason not to. +2. We are straightforward and kind with ourselves and each other. + +- We surface issues constructively and proactively. +- We say β€œwhy” and provide sufficient context for our actions rather than just disclosing the β€œwhat.” + +## Collaboration + +> We pride ourselves on being a great team. +> + +> We foster collaboration, value diversity and inclusion, and encourage sharing. +> + +> We thrive as individuals within the context of our team and succeed together. +> + +> We play very effectively with people of diverse backgrounds and cultures. +> + +> We make time to help each other in pursuit of our common goals. +> + +Honesty + +We are honest with ourselves. + +We admit mistakes freely and openly. + +We provide candid, helpful, timely feedback to colleagues with respect, regardless of their status or whether they +disagree with us. + +We are vulnerable in search of truth and don’t defend our point to just win over others. \ No newline at end of file diff --git a/tests/search/asset/kathryn_tennis_story.md b/tests/search/asset/kathryn_tennis_story.md new file mode 100644 index 000000000..d8fbb4dfc --- /dev/null +++ b/tests/search/asset/kathryn_tennis_story.md @@ -0,0 +1,54 @@ +Kathryn’s Journey to Becoming a Tennis Player + +Kathryn’s love for tennis began on a warm summer day when she was eight years old. She stumbled across a local park +where players were volleying back and forth. The sound of the ball hitting the racket and the sheer energy of the game +captivated her. That evening, she begged her parents for a tennis racket, and the very next weekend, she was on the +court for the first time. + +Learning the Basics + +Kathryn’s first lessons were clumsy but full of enthusiasm. She struggled with her serves, missed easy shots, and often +hit the ball over the fence. But every mistake made her more determined to improve. Her first coach, Mr. Evans, taught +her the fundamentalsβ€”how to grip the racket, the importance of footwork, and how to keep her eye on the ball. β€œTennis is +about focus and persistence,” he would say, and Kathryn took that advice to heart. + +By the time she was 12, Kathryn was playing in local junior tournaments. At first, she lost more matches than she won, +but she never let the defeats discourage her. β€œEvery loss teaches you something,” she told herself. Gradually, her +skills improved, and she started to win. + +The Turning Point + +As Kathryn entered high school, her passion for tennis only grew stronger. She spent hours after school practicing her +backhand and perfecting her serve. She joined her school’s tennis team, where she met her new coach, Ms. Carter. Unlike +her earlier coaches, Ms. Carter focused on strategy and mental toughness. + +β€œKathryn, tennis isn’t just physical. It’s a mental game too,” she said one day after a tough match. β€œYou need to stay +calm under pressure and think a few steps ahead of your opponent.” + +That advice changed everything for Kathryn. She began analyzing her matches, understanding her opponents’ patterns, and +using strategy to outplay them. By her senior year, she was the captain of her team and had won several regional +championships. + +Chasing the Dream + +After high school, Kathryn decided to pursue tennis seriously. She joined a competitive training academy, where the +practices were grueling, and the competition was fierce. There were times she doubted herself, especially after losing +matches to stronger players. But her love for the game kept her going. + +Her coaches helped her refine her technique, adding finesse to her volleys and power to her forehand. She also learned +to play smarter, conserving energy during long matches and capitalizing on her opponents’ weaknesses. + +Becoming a Player + +By the time Kathryn was in her early 20s, she was competing in national tournaments. She wasn’t the biggest name on the +court, but her hard work and persistence earned her respect. Each match was a chance to learn, grow, and prove herself. + +She eventually won her first title at a mid-level tournament, a moment she would never forget. Standing on the podium, +holding the trophy, she realized how far she had comeβ€”from the little girl who couldn’t hit a serve to a tennis player +with real potential. + +A Life of Tennis + +Today, Kathryn continues to play with the same passion she had when she first picked up a racket. She travels to +tournaments, trains every day, and inspires young players to follow their dreams. For her, tennis is more than a +sportβ€”it’s a lifelong journey of growth, persistence, and joy. \ No newline at end of file diff --git a/tests/search/asset/the_five_dysfunctions_of_a_team.md b/tests/search/asset/the_five_dysfunctions_of_a_team.md new file mode 100644 index 000000000..10ee4ad97 --- /dev/null +++ b/tests/search/asset/the_five_dysfunctions_of_a_team.md @@ -0,0 +1,125 @@ +# *The Five Dysfunctions of a Team* by Patrick Lencioni + +*The Five Dysfunctions of a Team* by Patrick Lencioni is a compelling exploration of team dynamics and the common +pitfalls that undermine successful collaboration. Through the lens of a fictional story about a Silicon Valley startup, +DecisionTech, and its CEO Kathryn Petersen, Lencioni provides a practical framework to address and resolve issues that +commonly disrupt team cohesion and performance. Below is a chapter-by-chapter look at the book’s content, capturing its +essential lessons and actionable insights. + +--- + +## Part I: Underachievement + +In this introductory section, we meet Kathryn Petersen, the newly appointed CEO of DecisionTech, a struggling Silicon +Valley startup with a dysfunctional executive team. Kathryn steps into a role where the team is plagued by poor +communication, lack of trust, and weak commitment. + +Lencioni uses this setup to introduce readers to the core problems affecting team productivity and morale. Kathryn +realizes that the team’s challenges are deeply rooted in its dynamics rather than surface-level operational issues. +Through her initial observations, she identifies that turning around the team will require addressing foundational +issues like trust, respect, and open communication. + +--- + +## Part II: Lighting the Fire + +To start addressing these issues, Kathryn organizes an offsite meeting in Napa Valley. This setting becomes a +transformative space where Kathryn pushes the team to be present, vulnerable, and engaged. Her goal is to build trust, a +critical foundation for any team. + +Kathryn leads exercises that reveal personal histories, enabling the team members to see each other beyond their +professional roles. She also introduces the idea of constructive conflict, encouraging open discussion about +disagreements and differing opinions. Despite the discomfort this causes for some team members who are used to +individualistic work styles, Kathryn emphasizes that trust and openness are crucial for effective teamwork. + +--- + +## Part III: Heavy Lifting + +With initial trust in place, Kathryn shifts her focus to accountability and responsibility. This part highlights the +challenges team members face when taking ownership of collective goals. + +Kathryn holds the team to high standards, stressing the importance of addressing issues directly instead of avoiding +them. This section also examines the role of healthy conflict as a mechanism for growth, as team members begin to hold +each other accountable for their contributions. Through challenging conversations, they tackle topics like performance +expectations and role clarity. Kathryn’s persistence helps the team understand that embracing accountability is +essential for progress, even if it leads to uncomfortable discussions. + +--- + +## Part IV: Traction + +By this stage, Kathryn reinforces the team’s commitment to shared goals. The team starts experiencing the tangible +benefits of improved trust and open conflict. Accountability has now become an expected part of their routine, and +meetings are increasingly productive. + +As they move towards achieving measurable results, the focus shifts from individual successes to collective +achievements. Kathryn ensures that each member appreciates the value of prioritizing team success over personal gain. +Through this unified approach, the team’s motivation and performance visibly improve, demonstrating the power of +cohesive collaboration. + +--- + +## The Model: Overcoming the Five Dysfunctions + +Lencioni introduces a model that identifies the five key dysfunctions of a team and provides strategies to overcome +them: + +1. **Absence of Trust** + The lack of trust prevents team members from being vulnerable and open with each other. Lencioni suggests exercises + that encourage personal sharing to build this essential foundation. + +2. **Fear of Conflict** + Teams that avoid conflict miss out on critical discussions that lead to better decision-making. Lencioni recommends + fostering a safe environment where team members feel comfortable challenging each other’s ideas without fear of + reprisal. + +3. **Lack of Commitment** + Without clarity and buy-in, team decisions become fragmented. Leaders should ensure everyone understands and agrees + on goals to achieve genuine commitment. + +4. **Avoidance of Accountability** + When team members don’t hold each other accountable, performance suffers. Regular check-ins and peer accountability + encourage responsibility and consistency. + +5. **Inattention to Results** + Prioritizing individual goals over collective outcomes dilutes team success. Aligning rewards and recognition with + team achievements helps refocus efforts on shared objectives. + +--- + +## Understanding and Overcoming Each Dysfunction + +Each dysfunction is further broken down with practical strategies: + +- **Building Trust** + Kathryn’s personal history exercise is one example of building trust. By sharing backgrounds and opening up, team + members foster a culture of vulnerability and connection. + +- **Encouraging Conflict** + Constructive conflict allows ideas to be challenged and strengthened. Kathryn’s insistence on open debate helps the + team reach better, more robust decisions. + +- **Ensuring Commitment** + Lencioni highlights the importance of clarity and alignment, which Kathryn reinforces by facilitating discussions that + ensure all team members are on the same page about their goals. + +- **Embracing Accountability** + Accountability becomes ingrained as team members regularly check in with each other, creating a culture of mutual + responsibility and high standards. + +- **Focusing on Results** + Kathryn’s focus on collective achievements over individual successes aligns with Lencioni’s advice to reward team + efforts, ensuring the entire group works toward a shared purpose. + +--- + +## Final Thoughts + +*The Five Dysfunctions of a Team* illustrates the importance of cohesive team behavior and effective leadership in +overcoming common organizational challenges. Through Kathryn’s story, Lencioni provides a practical roadmap for leaders +and teams to diagnose and address dysfunctions, ultimately fostering an environment where trust, accountability, and +shared goals drive performance. + +This book remains a valuable resource for anyone seeking to understand and improve team dynamics, with lessons that +apply well beyond the workplace. \ No newline at end of file diff --git a/tests/search/document_search.rs b/tests/search/document_search.rs index a88d4f79a..84267d9f4 100644 --- a/tests/search/document_search.rs +++ b/tests/search/document_search.rs @@ -1,13 +1,140 @@ +use std::path::PathBuf; use std::time::Duration; +use appflowy_ai_client::dto::CalculateSimilarityParams; +use client_api_test::{collect_answer, TestClient}; use collab::preclude::Collab; use collab_document::document::Document; +use collab_document::importer::md_importer::MDImporter; use collab_entity::CollabType; +use shared_entity::dto::chat_dto::{CreateChatMessageParams, CreateChatParams}; use tokio::time::sleep; - -use client_api_test::TestClient; use workspace_template::document::getting_started::getting_started_document_data; +#[tokio::test] +async fn test_embedding_when_create_document() { + let mut test_client = TestClient::new_user().await; + let workspace_id = test_client.workspace_id().await; + + let object_id_1 = uuid::Uuid::new_v4().to_string(); + let the_five_dysfunctions_of_a_team = + create_document_collab(&object_id_1, "the_five_dysfunctions_of_a_team.md").await; + let encoded_collab = the_five_dysfunctions_of_a_team.encode_collab().unwrap(); + test_client + .create_collab_with_data( + &workspace_id, + &object_id_1, + CollabType::Document, + encoded_collab, + ) + .await + .unwrap(); + + let object_id_2 = uuid::Uuid::new_v4().to_string(); + let tennis_player = create_document_collab(&object_id_2, "kathryn_tennis_story.md").await; + let encoded_collab = tennis_player.encode_collab().unwrap(); + test_client + .create_collab_with_data( + &workspace_id, + &object_id_2, + CollabType::Document, + encoded_collab, + ) + .await + .unwrap(); + + let search_resp = test_client + .api_client + .search_documents(&workspace_id, "Kathryn", 5, 100) + .await + .unwrap(); + // The number of returned documents affected by the max token size when splitting the document + // into chunks. + assert_eq!(search_resp.len(), 2); + + if ai_test_enabled() { + let previews = search_resp + .iter() + .map(|item| item.preview.clone().unwrap()) + .collect::>() + .join("\n"); + let params = CalculateSimilarityParams { + workspace_id: workspace_id.clone(), + input: previews, + expected: r#" + "Kathryn’s Journey to Becoming a Tennis Player Kathryn’s love for tennis began on a warm summer day w +yn decided to pursue tennis seriously. She joined a competitive training academy, where the +practice +mwork. Part III: Heavy Lifting With initial trust in place, Kathryn shifts her focus to accountabili +’s ideas without fear of +reprisal. Lack of Commitment Without clarity and buy-in, team decisions bec +The Five Dysfunctions of a Team by Patrick Lencioni The Five Dysfunctions of a Team by Patrick Lenci" + "# + .to_string(), + }; + let score = test_client + .api_client + .calculate_similarity(params) + .await + .unwrap() + .score; + + assert!( + score > 0.85, + "preview score should greater than 0.85, but got: {}", + score + ); + + // Create a chat to ask questions that related to the five dysfunctions of a team. + let chat_id = uuid::Uuid::new_v4().to_string(); + let params = CreateChatParams { + chat_id: chat_id.clone(), + name: "chat with the five dysfunctions of a team".to_string(), + rag_ids: vec![object_id_1], + }; + + test_client + .api_client + .create_chat(&workspace_id, params) + .await + .unwrap(); + + let params = CreateChatMessageParams::new_user("Tell me what Kathryn concisely?"); + let question = test_client + .api_client + .create_question(&workspace_id, &chat_id, params) + .await + .unwrap(); + let answer_stream = test_client + .api_client + .stream_answer_v2(&workspace_id, &chat_id, question.message_id) + .await + .unwrap(); + let answer = collect_answer(answer_stream).await; + + let params = CalculateSimilarityParams { + workspace_id, + input: answer, + expected: r#" + Kathryn Petersen is the newly appointed CEO of DecisionTech, a struggling Silicon Valley startup. + She steps into a role facing a dysfunctional executive team characterized by poor communication, + lack of trust, and weak commitment. Throughout the narrative, Kathryn focuses on addressing + foundational team issues by fostering trust, encouraging open conflict, and promoting accountability, + ultimately leading her team toward improved collaboration and performance. + "# + .to_string(), + }; + let score = test_client + .api_client + .calculate_similarity(params) + .await + .unwrap() + .score; + + assert!(score > 0.9, "score: {}", score); + } +} + #[ignore] #[tokio::test] async fn test_document_indexing_and_search() { @@ -56,3 +183,18 @@ async fn test_document_indexing_and_search() { let preview = item.preview.clone().unwrap(); assert!(preview.contains("Welcome to AppFlowy")); } + +async fn create_document_collab(document_id: &str, file_name: &str) -> Document { + let file_path = PathBuf::from(format!("tests/search/asset/{}", file_name)); + let md = std::fs::read_to_string(file_path).unwrap(); + let importer = MDImporter::new(None); + let document_data = importer.import(document_id, md).unwrap(); + Document::create(document_id, document_data).unwrap() +} + +pub fn ai_test_enabled() -> bool { + if cfg!(feature = "ai-test-enabled") { + return true; + } + false +} From dcbc84dacc81dfd1d06a303bc14daeeb7644cd8c Mon Sep 17 00:00:00 2001 From: "Nathan.fooo" <86001920+appflowy@users.noreply.github.com> Date: Sun, 17 Nov 2024 12:14:50 +0800 Subject: [PATCH 08/20] chore: recreate group if it isn't exist (#1001) --- services/appflowy-worker/src/error.rs | 3 + .../src/import_worker/worker.rs | 66 +++++++++++++------ 2 files changed, 48 insertions(+), 21 deletions(-) diff --git a/services/appflowy-worker/src/error.rs b/services/appflowy-worker/src/error.rs index daaa98b45..77985ae65 100644 --- a/services/appflowy-worker/src/error.rs +++ b/services/appflowy-worker/src/error.rs @@ -16,6 +16,9 @@ pub enum WorkerError { #[error("S3 service unavailable: {0}")] S3ServiceUnavailable(String), + #[error("Redis stream group not exist: {0}")] + StreamGroupNotExist(String), + #[error(transparent)] Internal(#[from] anyhow::Error), } diff --git a/services/appflowy-worker/src/import_worker/worker.rs b/services/appflowy-worker/src/import_worker/worker.rs index e3289807a..6d07ebe2e 100644 --- a/services/appflowy-worker/src/import_worker/worker.rs +++ b/services/appflowy-worker/src/import_worker/worker.rs @@ -80,10 +80,7 @@ pub async fn run_import_worker( tick_interval_secs: u64, ) -> Result<(), ImportError> { info!("Starting importer worker"); - if let Err(err) = ensure_consumer_group(stream_name, GROUP_NAME, &mut redis_client) - .await - .map_err(ImportError::Internal) - { + if let Err(err) = ensure_consumer_group(stream_name, GROUP_NAME, &mut redis_client).await { error!("Failed to ensure consumer group: {:?}", err); } @@ -179,6 +176,7 @@ async fn process_upcoming_tasks( loop { interval.tick().await; + let tasks: StreamReadReply = match redis_client .xread_options(&[stream_name], &[">"], &options) .await @@ -186,6 +184,17 @@ async fn process_upcoming_tasks( Ok(tasks) => tasks, Err(err) => { error!("Failed to read tasks from Redis stream: {:?}", err); + + // Use command: + // docker exec -it appflowy-cloud-redis-1 redis-cli FLUSHDB to generate the error + // NOGROUP: No such key 'import_task_stream' or consumer group 'import_task_group' in XREADGROUP with GROUP option + if let Some(code) = err.code() { + if code == "NOGROUP" { + if let Err(err) = ensure_consumer_group(stream_name, GROUP_NAME, redis_client).await { + error!("Failed to ensure consumer group: {:?}", err); + } + } + } continue; }, }; @@ -198,6 +207,7 @@ async fn process_upcoming_tasks( Ok(import_task) => { let stream_name = stream_name.to_string(); let group_name = group_name.to_string(); + let context = TaskContext { storage_dir: storage_dir.to_path_buf(), redis_client: redis_client.clone(), @@ -206,7 +216,8 @@ async fn process_upcoming_tasks( notifier: notifier.clone(), metrics: metrics.clone(), }; - task_handlers.push(spawn_local(async move { + + let handle = spawn_local(async move { consume_task( context, import_task, @@ -216,7 +227,8 @@ async fn process_upcoming_tasks( ) .await?; Ok::<(), ImportError>(()) - })); + }); + task_handlers.push(handle); }, Err(err) => { error!("Failed to deserialize task: {:?}", err); @@ -233,6 +245,7 @@ async fn process_upcoming_tasks( } } } + info!("[Import] stop reading tasks from stream"); } #[derive(Clone)] struct TaskContext { @@ -280,8 +293,12 @@ async fn consume_task( if task.last_process_at.is_none() { task.last_process_at = Some(Utc::now().timestamp()); } + process_and_ack_task(context, import_task, stream_name, group_name, &entry_id).await } else { - trace!("[Import] {} file not found, queue task", task.workspace_id); + info!( + "[Import] {} zip file not found, queue task", + task.workspace_id + ); push_task( &mut context.redis_client, stream_name, @@ -290,12 +307,12 @@ async fn consume_task( &entry_id, ) .await?; - return Ok(()); + Ok(()) } + } else { + // If the task is not a notion task, proceed directly to processing + process_and_ack_task(context, import_task, stream_name, group_name, &entry_id).await } - - // Process and acknowledge the task - process_and_ack_task(context, import_task, stream_name, group_name, &entry_id).await } async fn handle_expired_task( @@ -308,7 +325,7 @@ async fn handle_expired_task( reason: &str, ) -> Result<(), ImportError> { info!( - "[Import]: {} import is expired with reason:{}, delete workspace", + "[Import]: {} import is expired with reason:{}", task.workspace_id, reason ); @@ -323,6 +340,7 @@ async fn handle_expired_task( ImportError::Internal(e.into()) })?; remove_workspace(&import_record.workspace_id, &context.pg_pool).await; + info!("[Import]: deleted workspace {}", task.workspace_id); if let Err(err) = context.s3_client.delete_blob(task.s3_key.as_str()).await { error!( @@ -330,7 +348,12 @@ async fn handle_expired_task( task.workspace_id, err ); } - let _ = xack_task(&mut context.redis_client, stream_name, group_name, entry_id).await; + if let Err(err) = xack_task(&mut context.redis_client, stream_name, group_name, entry_id).await { + error!( + "[Import] failed to acknowledge task:{} error:{:?}", + task.workspace_id, err + ); + } notify_user( task, Err(ImportError::UploadFileExpire), @@ -388,7 +411,7 @@ fn is_task_expired(created_timestamp: i64, last_process_at: Option) -> Resu if elapsed.num_hours() >= hours { return Err(format!( - "[Import] task is expired: created_at: {}, last_process_at: {:?}, elapsed: {} hours", + "task is expired: created_at: {}, last_process_at: {:?}, elapsed: {} hours", created_at.format("%m/%d/%y %H:%M"), last_process_at, elapsed.num_hours() @@ -485,10 +508,7 @@ async fn process_task( .parse() .unwrap_or(false); - info!( - "[Import]: Processing task: {}, retry interval: {}, streaming: {}", - import_task, retry_interval, streaming - ); + info!("[Import]: Processing task: {}", import_task); match import_task { ImportTask::Notion(task) => { @@ -1285,7 +1305,7 @@ async fn ensure_consumer_group( stream_key: &str, group_name: &str, redis_client: &mut ConnectionManager, -) -> Result<(), anyhow::Error> { +) -> Result<(), WorkerError> { let result: RedisResult<()> = redis_client .xgroup_create_mkstream(stream_key, group_name, "0") .await; @@ -1293,11 +1313,15 @@ async fn ensure_consumer_group( if let Err(redis_error) = result { if let Some(code) = redis_error.code() { if code == "BUSYGROUP" { - return Ok(()); // Group already exists, considered as success. + return Ok(()); + } + + if code == "NOGROUP" { + return Err(WorkerError::StreamGroupNotExist(group_name.to_string())); } } error!("Error when creating consumer group: {:?}", redis_error); - return Err(redis_error.into()); + return Err(WorkerError::Internal(redis_error.into())); } Ok(()) From d798c81ba4ddc29317a7ecd02d44f06b7b1bb0ac Mon Sep 17 00:00:00 2001 From: "Nathan.fooo" <86001920+appflowy@users.noreply.github.com> Date: Sun, 17 Nov 2024 12:45:20 +0800 Subject: [PATCH 09/20] chore: support split by text len (#1002) * chore: support split by text len * chore: update docs * chore: update tests --- Cargo.lock | 1 + services/appflowy-collaborate/Cargo.toml | 2 + .../src/indexer/document_indexer.rs | 322 ++-------------- .../appflowy-collaborate/src/indexer/mod.rs | 2 + .../src/indexer/open_ai.rs | 361 ++++++++++++++++++ .../src/import_worker/worker.rs | 1 - tests/search/document_search.rs | 4 +- 7 files changed, 393 insertions(+), 300 deletions(-) create mode 100644 services/appflowy-collaborate/src/indexer/open_ai.rs diff --git a/Cargo.lock b/Cargo.lock index f5602dfc5..1e9082fe6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -754,6 +754,7 @@ dependencies = [ "tracing", "tracing-subscriber", "unicode-normalization", + "unicode-segmentation", "uuid", "validator", "workspace-template", diff --git a/services/appflowy-collaborate/Cargo.toml b/services/appflowy-collaborate/Cargo.toml index 61488ed70..96232c813 100644 --- a/services/appflowy-collaborate/Cargo.toml +++ b/services/appflowy-collaborate/Cargo.toml @@ -88,6 +88,8 @@ itertools = "0.12.0" validator = "0.16.1" rayon.workspace = true tiktoken-rs = "0.6.0" +unicode-segmentation = "1.9.0" + [dev-dependencies] rand = "0.8.5" diff --git a/services/appflowy-collaborate/src/indexer/document_indexer.rs b/services/appflowy-collaborate/src/indexer/document_indexer.rs index 2b7e2367f..d722598f6 100644 --- a/services/appflowy-collaborate/src/indexer/document_indexer.rs +++ b/services/appflowy-collaborate/src/indexer/document_indexer.rs @@ -15,6 +15,8 @@ use collab_document::error::DocumentError; use collab_entity::CollabType; use database_entity::dto::{AFCollabEmbeddingParams, AFCollabEmbeddings, EmbeddingContentType}; +use crate::config::get_env_var; +use crate::indexer::open_ai::{split_text_by_max_content_len, split_text_by_max_tokens}; use tiktoken_rs::CoreBPE; use tracing::trace; use uuid::Uuid; @@ -54,12 +56,11 @@ impl Indexer for DocumentIndexer { match result { Ok(document_data) => { let content = document_data.to_plain_text(); - let max_tokens = self.embedding_model.default_dimensions() as usize; create_embedding( object_id, content, CollabType::Document, - max_tokens, + &self.embedding_model, self.tokenizer.clone(), ) .await @@ -129,47 +130,35 @@ impl Indexer for DocumentIndexer { } } -/// ## Execution Time Comparison Results -/// -/// The following results were observed when running `execution_time_comparison_tests`: -/// -/// | Content Size (chars) | Direct Time (ms) | spawn_blocking Time (ms) | -/// |-----------------------|------------------|--------------------------| -/// | 500 | 1 | 1 | -/// | 1000 | 2 | 2 | -/// | 2000 | 5 | 5 | -/// | 5000 | 11 | 11 | -/// | 20000 | 49 | 48 | -/// -/// ## Guidelines for Using `spawn_blocking` -/// -/// - **Short Tasks (< 1 ms)**: -/// Use direct execution on the async runtime. The minimal execution time has negligible impact. -/// -/// - **Moderate Tasks (1–10 ms)**: -/// - For infrequent or low-concurrency tasks, direct execution is acceptable. -/// - For frequent or high-concurrency tasks, consider using `spawn_blocking` to avoid delays. -/// -/// - **Long Tasks (> 10 ms)**: -/// Always offload to a blocking thread with `spawn_blocking` to maintain runtime efficiency and responsiveness. -/// -/// Related blog: -/// https://tokio.rs/blog/2020-04-preemption -/// https://ryhl.io/blog/async-what-is-blocking/ async fn create_embedding( object_id: String, content: String, collab_type: CollabType, - max_tokens: usize, + embedding_model: &EmbeddingModel, tokenizer: Arc, ) -> Result, AppError> { - let split_contents = if content.len() < 500 { - split_text_by_max_tokens(content, max_tokens, tokenizer.as_ref())? + let use_tiktoken = get_env_var("APPFLOWY_AI_CONTENT_SPLITTER_TIKTOKEN", "false") + .parse::() + .unwrap_or(false); + + let split_contents = if use_tiktoken { + let max_tokens = embedding_model.default_dimensions() as usize; + if content.len() < 500 { + split_text_by_max_tokens(content, max_tokens, tokenizer.as_ref())? + } else { + tokio::task::spawn_blocking(move || { + split_text_by_max_tokens(content, max_tokens, tokenizer.as_ref()) + }) + .await?? + } } else { - tokio::task::spawn_blocking(move || { - split_text_by_max_tokens(content, max_tokens, tokenizer.as_ref()) - }) - .await?? + debug_assert!(matches!( + embedding_model, + EmbeddingModel::TextEmbedding3Small + )); + // We assume that every token is ~4 bytes. We're going to split document content into fragments + // of ~2000 tokens each. + split_text_by_max_content_len(content, 8000)? }; Ok( @@ -186,264 +175,3 @@ async fn create_embedding( .collect(), ) } - -fn split_text_by_max_tokens( - content: String, - max_tokens: usize, - tokenizer: &CoreBPE, -) -> Result, AppError> { - if content.is_empty() { - return Ok(vec![]); - } - - let token_ids = tokenizer.encode_ordinary(&content); - let total_tokens = token_ids.len(); - if total_tokens <= max_tokens { - return Ok(vec![content]); - } - - let mut chunks = Vec::new(); - let mut start_idx = 0; - while start_idx < total_tokens { - let mut end_idx = (start_idx + max_tokens).min(total_tokens); - let mut decoded = false; - // Try to decode the chunk, adjust end_idx if decoding fails - while !decoded { - let token_chunk = &token_ids[start_idx..end_idx]; - // Attempt to decode the current chunk - match tokenizer.decode(token_chunk.to_vec()) { - Ok(chunk_text) => { - chunks.push(chunk_text); - start_idx = end_idx; - decoded = true; - }, - Err(_) => { - // If we can extend the chunk, do so - if end_idx < total_tokens { - end_idx += 1; - } else if start_idx + 1 < total_tokens { - // Skip the problematic token at start_idx - start_idx += 1; - end_idx = (start_idx + max_tokens).min(total_tokens); - } else { - // Cannot decode any further, break to avoid infinite loop - start_idx = total_tokens; - break; - } - }, - } - } - } - - Ok(chunks) -} - -#[cfg(test)] -mod tests { - use crate::indexer::document_indexer::split_text_by_max_tokens; - - use tiktoken_rs::cl100k_base; - - #[test] - fn test_split_at_non_utf8() { - let max_tokens = 10; // Small number for testing - - // Content with multibyte characters (emojis) - let content = "Hello πŸ˜ƒ World 🌍! This is a test πŸš€.".to_string(); - let tokenizer = cl100k_base().unwrap(); - let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap(); - - // Ensure that we didn't split in the middle of a multibyte character - for content in params { - assert!(content.is_char_boundary(0)); - assert!(content.is_char_boundary(content.len())); - } - } - #[test] - fn test_exact_boundary_split() { - let max_tokens = 5; // Set to 5 tokens for testing - let content = "The quick brown fox jumps over the lazy dog".to_string(); - let tokenizer = cl100k_base().unwrap(); - let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap(); - - let total_tokens = tokenizer.encode_ordinary(&content).len(); - let expected_fragments = (total_tokens + max_tokens - 1) / max_tokens; - assert_eq!(params.len(), expected_fragments); - } - - #[test] - fn test_content_shorter_than_max_len() { - let max_tokens = 100; - let content = "Short content".to_string(); - let tokenizer = cl100k_base().unwrap(); - let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap(); - - assert_eq!(params.len(), 1); - assert_eq!(params[0], content); - } - - #[test] - fn test_empty_content() { - let max_tokens = 10; - let content = "".to_string(); - let tokenizer = cl100k_base().unwrap(); - let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap(); - - assert_eq!(params.len(), 0); - } - - #[test] - fn test_content_with_only_multibyte_characters() { - let max_tokens = 1; // Set to 1 token for testing - let content = "πŸ˜€πŸ˜ƒπŸ˜„πŸ˜πŸ˜†".to_string(); - let tokenizer = cl100k_base().unwrap(); - let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap(); - - let emojis: Vec = content.chars().map(|c| c.to_string()).collect(); - for (param, emoji) in params.iter().zip(emojis.iter()) { - assert_eq!(param, emoji); - } - } - - #[test] - fn test_split_with_combining_characters() { - let max_tokens = 1; // Set to 1 token for testing - let content = "a\u{0301}e\u{0301}i\u{0301}o\u{0301}u\u{0301}".to_string(); // "áéíóú" - let tokenizer = cl100k_base().unwrap(); - let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap(); - - let total_tokens = tokenizer.encode_ordinary(&content).len(); - assert_eq!(params.len(), total_tokens); - - let reconstructed_content = params.join(""); - assert_eq!(reconstructed_content, content); - } - - #[test] - fn test_large_content() { - let max_tokens = 1000; - let content = "a".repeat(5000); // 5000 characters - let tokenizer = cl100k_base().unwrap(); - let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap(); - - let total_tokens = tokenizer.encode_ordinary(&content).len(); - let expected_fragments = (total_tokens + max_tokens - 1) / max_tokens; - assert_eq!(params.len(), expected_fragments); - } - - #[test] - fn test_non_ascii_characters() { - let max_tokens = 2; - let content = "Ñéíóú".to_string(); - let tokenizer = cl100k_base().unwrap(); - let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap(); - - let total_tokens = tokenizer.encode_ordinary(&content).len(); - let expected_fragments = (total_tokens + max_tokens - 1) / max_tokens; - assert_eq!(params.len(), expected_fragments); - - let reconstructed_content: String = params.concat(); - assert_eq!(reconstructed_content, content); - } - - #[test] - fn test_content_with_leading_and_trailing_whitespace() { - let max_tokens = 3; - let content = " abcde ".to_string(); - let tokenizer = cl100k_base().unwrap(); - let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap(); - - let total_tokens = tokenizer.encode_ordinary(&content).len(); - let expected_fragments = (total_tokens + max_tokens - 1) / max_tokens; - assert_eq!(params.len(), expected_fragments); - - let reconstructed_content: String = params.concat(); - assert_eq!(reconstructed_content, content); - } - - #[test] - fn test_content_with_multiple_zero_width_joiners() { - let max_tokens = 1; - let content = "πŸ‘©β€πŸ‘©β€πŸ‘§β€πŸ‘§πŸ‘¨β€πŸ‘¨β€πŸ‘¦β€πŸ‘¦".to_string(); - let tokenizer = cl100k_base().unwrap(); - let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap(); - - let reconstructed_content: String = params.concat(); - assert_eq!(reconstructed_content, content); - } - - #[test] - fn test_content_with_long_combining_sequences() { - let max_tokens = 1; - let content = "a\u{0300}\u{0301}\u{0302}\u{0303}\u{0304}".to_string(); - let tokenizer = cl100k_base().unwrap(); - let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap(); - - let reconstructed_content: String = params.concat(); - assert_eq!(reconstructed_content, content); - } -} - -// #[cfg(test)] -// mod execution_time_comparison_tests { -// use crate::indexer::document_indexer::split_text_by_max_tokens; -// use rand::distributions::Alphanumeric; -// use rand::{thread_rng, Rng}; -// use std::sync::Arc; -// use std::time::Instant; -// use tiktoken_rs::{cl100k_base, CoreBPE}; -// -// #[tokio::test] -// async fn test_execution_time_comparison() { -// let tokenizer = Arc::new(cl100k_base().unwrap()); -// let max_tokens = 100; -// -// let sizes = vec![500, 1000, 2000, 5000, 20000]; // Content sizes to test -// for size in sizes { -// let content = generate_random_string(size); -// -// // Measure direct execution time -// let direct_time = measure_direct_execution(content.clone(), max_tokens, &tokenizer); -// -// // Measure spawn_blocking execution time -// let spawn_blocking_time = -// measure_spawn_blocking_execution(content, max_tokens, Arc::clone(&tokenizer)).await; -// -// println!( -// "Content Size: {} | Direct Time: {}ms | spawn_blocking Time: {}ms", -// size, direct_time, spawn_blocking_time -// ); -// } -// } -// -// // Measure direct execution time -// fn measure_direct_execution(content: String, max_tokens: usize, tokenizer: &CoreBPE) -> u128 { -// let start = Instant::now(); -// split_text_by_max_tokens(content, max_tokens, tokenizer).unwrap(); -// start.elapsed().as_millis() -// } -// -// // Measure `spawn_blocking` execution time -// async fn measure_spawn_blocking_execution( -// content: String, -// max_tokens: usize, -// tokenizer: Arc, -// ) -> u128 { -// let start = Instant::now(); -// tokio::task::spawn_blocking(move || { -// split_text_by_max_tokens(content, max_tokens, tokenizer.as_ref()).unwrap() -// }) -// .await -// .unwrap(); -// start.elapsed().as_millis() -// } -// -// pub fn generate_random_string(len: usize) -> String { -// let rng = thread_rng(); -// rng -// .sample_iter(&Alphanumeric) -// .take(len) -// .map(char::from) -// .collect() -// } -// } diff --git a/services/appflowy-collaborate/src/indexer/mod.rs b/services/appflowy-collaborate/src/indexer/mod.rs index c9fe41a08..e354f9e01 100644 --- a/services/appflowy-collaborate/src/indexer/mod.rs +++ b/services/appflowy-collaborate/src/indexer/mod.rs @@ -1,6 +1,8 @@ mod document_indexer; mod ext; +mod open_ai; mod provider; + pub use document_indexer::DocumentIndexer; pub use ext::DocumentDataExt; pub use provider::*; diff --git a/services/appflowy-collaborate/src/indexer/open_ai.rs b/services/appflowy-collaborate/src/indexer/open_ai.rs new file mode 100644 index 000000000..db5fb26a3 --- /dev/null +++ b/services/appflowy-collaborate/src/indexer/open_ai.rs @@ -0,0 +1,361 @@ +use app_error::AppError; +use tiktoken_rs::CoreBPE; +use unicode_segmentation::UnicodeSegmentation; + +/// ## Execution Time Comparison Results +/// +/// The following results were observed when running `execution_time_comparison_tests`: +/// +/// | Content Size (chars) | Direct Time (ms) | spawn_blocking Time (ms) | +/// |-----------------------|------------------|--------------------------| +/// | 500 | 1 | 1 | +/// | 1000 | 2 | 2 | +/// | 2000 | 5 | 5 | +/// | 5000 | 11 | 11 | +/// | 20000 | 49 | 48 | +/// +/// ## Guidelines for Using `spawn_blocking` +/// +/// - **Short Tasks (< 1 ms)**: +/// Use direct execution on the async runtime. The minimal execution time has negligible impact. +/// +/// - **Moderate Tasks (1–10 ms)**: +/// - For infrequent or low-concurrency tasks, direct execution is acceptable. +/// - For frequent or high-concurrency tasks, consider using `spawn_blocking` to avoid delays. +/// +/// - **Long Tasks (> 10 ms)**: +/// Always offload to a blocking thread with `spawn_blocking` to maintain runtime efficiency and responsiveness. +/// +/// Related blog: +/// https://tokio.rs/blog/2020-04-preemption +/// https://ryhl.io/blog/async-what-is-blocking/ +#[inline] +pub fn split_text_by_max_tokens( + content: String, + max_tokens: usize, + tokenizer: &CoreBPE, +) -> Result, AppError> { + if content.is_empty() { + return Ok(vec![]); + } + + let token_ids = tokenizer.encode_ordinary(&content); + let total_tokens = token_ids.len(); + if total_tokens <= max_tokens { + return Ok(vec![content]); + } + + let mut chunks = Vec::new(); + let mut start_idx = 0; + while start_idx < total_tokens { + let mut end_idx = (start_idx + max_tokens).min(total_tokens); + let mut decoded = false; + // Try to decode the chunk, adjust end_idx if decoding fails + while !decoded { + let token_chunk = &token_ids[start_idx..end_idx]; + // Attempt to decode the current chunk + match tokenizer.decode(token_chunk.to_vec()) { + Ok(chunk_text) => { + chunks.push(chunk_text); + start_idx = end_idx; + decoded = true; + }, + Err(_) => { + // If we can extend the chunk, do so + if end_idx < total_tokens { + end_idx += 1; + } else if start_idx + 1 < total_tokens { + // Skip the problematic token at start_idx + start_idx += 1; + end_idx = (start_idx + max_tokens).min(total_tokens); + } else { + // Cannot decode any further, break to avoid infinite loop + start_idx = total_tokens; + break; + } + }, + } + } + } + + Ok(chunks) +} + +#[inline] +pub fn split_text_by_max_content_len( + content: String, + max_content_len: usize, +) -> Result, AppError> { + if content.is_empty() { + return Ok(vec![]); + } + + if content.len() <= max_content_len { + return Ok(vec![content]); + } + + // Content is longer than max_content_len; need to split + let mut result = Vec::with_capacity(1 + content.len() / max_content_len); + let mut fragment = String::with_capacity(max_content_len); + let mut current_len = 0; + + for grapheme in content.graphemes(true) { + let grapheme_len = grapheme.len(); + if current_len + grapheme_len > max_content_len { + if !fragment.is_empty() { + result.push(std::mem::take(&mut fragment)); + } + current_len = 0; + + if grapheme_len > max_content_len { + // Push the grapheme as a fragment on its own + result.push(grapheme.to_string()); + continue; + } + } + fragment.push_str(grapheme); + current_len += grapheme_len; + } + + // Add the last fragment if it's not empty + if !fragment.is_empty() { + result.push(fragment); + } + Ok(result) +} + +#[cfg(test)] +mod tests { + + use crate::indexer::open_ai::{split_text_by_max_content_len, split_text_by_max_tokens}; + use tiktoken_rs::cl100k_base; + + #[test] + fn test_split_at_non_utf8() { + let max_tokens = 10; // Small number for testing + + // Content with multibyte characters (emojis) + let content = "Hello πŸ˜ƒ World 🌍! This is a test πŸš€.".to_string(); + let tokenizer = cl100k_base().unwrap(); + let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap(); + for content in params { + assert!(content.is_char_boundary(0)); + assert!(content.is_char_boundary(content.len())); + } + + let params = split_text_by_max_content_len(content.clone(), max_tokens).unwrap(); + for content in params { + assert!(content.is_char_boundary(0)); + assert!(content.is_char_boundary(content.len())); + } + } + #[test] + fn test_exact_boundary_split() { + let max_tokens = 5; // Set to 5 tokens for testing + let content = "The quick brown fox jumps over the lazy dog".to_string(); + let tokenizer = cl100k_base().unwrap(); + let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap(); + + let total_tokens = tokenizer.encode_ordinary(&content).len(); + let expected_fragments = (total_tokens + max_tokens - 1) / max_tokens; + assert_eq!(params.len(), expected_fragments); + } + + #[test] + fn test_content_shorter_than_max_len() { + let max_tokens = 100; + let content = "Short content".to_string(); + let tokenizer = cl100k_base().unwrap(); + let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap(); + + assert_eq!(params.len(), 1); + assert_eq!(params[0], content); + } + + #[test] + fn test_empty_content() { + let max_tokens = 10; + let content = "".to_string(); + let tokenizer = cl100k_base().unwrap(); + let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap(); + assert_eq!(params.len(), 0); + + let params = split_text_by_max_content_len(content.clone(), max_tokens).unwrap(); + assert_eq!(params.len(), 0); + } + + #[test] + fn test_content_with_only_multibyte_characters() { + let max_tokens = 1; // Set to 1 token for testing + let content = "πŸ˜€πŸ˜ƒπŸ˜„πŸ˜πŸ˜†".to_string(); + let tokenizer = cl100k_base().unwrap(); + let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap(); + + let emojis: Vec = content.chars().map(|c| c.to_string()).collect(); + for (param, emoji) in params.iter().zip(emojis.iter()) { + assert_eq!(param, emoji); + } + + let params = split_text_by_max_content_len(content.clone(), max_tokens).unwrap(); + for (param, emoji) in params.iter().zip(emojis.iter()) { + assert_eq!(param, emoji); + } + } + + #[test] + fn test_split_with_combining_characters() { + let max_tokens = 1; // Set to 1 token for testing + let content = "a\u{0301}e\u{0301}i\u{0301}o\u{0301}u\u{0301}".to_string(); // "áéíóú" + + let tokenizer = cl100k_base().unwrap(); + let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap(); + let total_tokens = tokenizer.encode_ordinary(&content).len(); + assert_eq!(params.len(), total_tokens); + let reconstructed_content = params.join(""); + assert_eq!(reconstructed_content, content); + + let params = split_text_by_max_content_len(content.clone(), max_tokens).unwrap(); + let reconstructed_content: String = params.concat(); + assert_eq!(reconstructed_content, content); + } + + #[test] + fn test_large_content() { + let max_tokens = 1000; + let content = "a".repeat(5000); // 5000 characters + let tokenizer = cl100k_base().unwrap(); + let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap(); + + let total_tokens = tokenizer.encode_ordinary(&content).len(); + let expected_fragments = (total_tokens + max_tokens - 1) / max_tokens; + assert_eq!(params.len(), expected_fragments); + } + + #[test] + fn test_non_ascii_characters() { + let max_tokens = 2; + let content = "Ñéíóú".to_string(); + let tokenizer = cl100k_base().unwrap(); + let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap(); + + let total_tokens = tokenizer.encode_ordinary(&content).len(); + let expected_fragments = (total_tokens + max_tokens - 1) / max_tokens; + assert_eq!(params.len(), expected_fragments); + let reconstructed_content: String = params.concat(); + assert_eq!(reconstructed_content, content); + + let params = split_text_by_max_content_len(content.clone(), max_tokens).unwrap(); + let reconstructed_content: String = params.concat(); + assert_eq!(reconstructed_content, content); + } + + #[test] + fn test_content_with_leading_and_trailing_whitespace() { + let max_tokens = 3; + let content = " abcde ".to_string(); + let tokenizer = cl100k_base().unwrap(); + let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap(); + + let total_tokens = tokenizer.encode_ordinary(&content).len(); + let expected_fragments = (total_tokens + max_tokens - 1) / max_tokens; + assert_eq!(params.len(), expected_fragments); + let reconstructed_content: String = params.concat(); + assert_eq!(reconstructed_content, content); + + let params = split_text_by_max_content_len(content.clone(), max_tokens).unwrap(); + let reconstructed_content: String = params.concat(); + assert_eq!(reconstructed_content, content); + } + + #[test] + fn test_content_with_multiple_zero_width_joiners() { + let max_tokens = 1; + let content = "πŸ‘©β€πŸ‘©β€πŸ‘§β€πŸ‘§πŸ‘¨β€πŸ‘¨β€πŸ‘¦β€πŸ‘¦".to_string(); + let tokenizer = cl100k_base().unwrap(); + let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap(); + let reconstructed_content: String = params.concat(); + assert_eq!(reconstructed_content, content); + + let params = split_text_by_max_content_len(content.clone(), max_tokens).unwrap(); + let reconstructed_content: String = params.concat(); + assert_eq!(reconstructed_content, content); + } + + #[test] + fn test_content_with_long_combining_sequences() { + let max_tokens = 1; + let content = "a\u{0300}\u{0301}\u{0302}\u{0303}\u{0304}".to_string(); + let tokenizer = cl100k_base().unwrap(); + let params = split_text_by_max_tokens(content.clone(), max_tokens, &tokenizer).unwrap(); + let reconstructed_content: String = params.concat(); + assert_eq!(reconstructed_content, content); + + let params = split_text_by_max_content_len(content.clone(), max_tokens).unwrap(); + let reconstructed_content: String = params.concat(); + assert_eq!(reconstructed_content, content); + } +} + +// #[cfg(test)] +// mod execution_time_comparison_tests { +// use crate::indexer::document_indexer::split_text_by_max_tokens; +// use rand::distributions::Alphanumeric; +// use rand::{thread_rng, Rng}; +// use std::sync::Arc; +// use std::time::Instant; +// use tiktoken_rs::{cl100k_base, CoreBPE}; +// +// #[tokio::test] +// async fn test_execution_time_comparison() { +// let tokenizer = Arc::new(cl100k_base().unwrap()); +// let max_tokens = 100; +// +// let sizes = vec![500, 1000, 2000, 5000, 20000]; // Content sizes to test +// for size in sizes { +// let content = generate_random_string(size); +// +// // Measure direct execution time +// let direct_time = measure_direct_execution(content.clone(), max_tokens, &tokenizer); +// +// // Measure spawn_blocking execution time +// let spawn_blocking_time = +// measure_spawn_blocking_execution(content, max_tokens, Arc::clone(&tokenizer)).await; +// +// println!( +// "Content Size: {} | Direct Time: {}ms | spawn_blocking Time: {}ms", +// size, direct_time, spawn_blocking_time +// ); +// } +// } +// +// // Measure direct execution time +// fn measure_direct_execution(content: String, max_tokens: usize, tokenizer: &CoreBPE) -> u128 { +// let start = Instant::now(); +// split_text_by_max_tokens(content, max_tokens, tokenizer).unwrap(); +// start.elapsed().as_millis() +// } +// +// // Measure `spawn_blocking` execution time +// async fn measure_spawn_blocking_execution( +// content: String, +// max_tokens: usize, +// tokenizer: Arc, +// ) -> u128 { +// let start = Instant::now(); +// tokio::task::spawn_blocking(move || { +// split_text_by_max_tokens(content, max_tokens, tokenizer.as_ref()).unwrap() +// }) +// .await +// .unwrap(); +// start.elapsed().as_millis() +// } +// +// pub fn generate_random_string(len: usize) -> String { +// let rng = thread_rng(); +// rng +// .sample_iter(&Alphanumeric) +// .take(len) +// .map(char::from) +// .collect() +// } +// } diff --git a/services/appflowy-worker/src/import_worker/worker.rs b/services/appflowy-worker/src/import_worker/worker.rs index 6d07ebe2e..6be6767d1 100644 --- a/services/appflowy-worker/src/import_worker/worker.rs +++ b/services/appflowy-worker/src/import_worker/worker.rs @@ -245,7 +245,6 @@ async fn process_upcoming_tasks( } } } - info!("[Import] stop reading tasks from stream"); } #[derive(Clone)] struct TaskContext { diff --git a/tests/search/document_search.rs b/tests/search/document_search.rs index 84267d9f4..e696c4369 100644 --- a/tests/search/document_search.rs +++ b/tests/search/document_search.rs @@ -114,7 +114,7 @@ The Five Dysfunctions of a Team by Patrick Lencioni The Five Dysfunctions of a T let params = CalculateSimilarityParams { workspace_id, - input: answer, + input: answer.clone(), expected: r#" Kathryn Petersen is the newly appointed CEO of DecisionTech, a struggling Silicon Valley startup. She steps into a role facing a dysfunctional executive team characterized by poor communication, @@ -131,7 +131,7 @@ The Five Dysfunctions of a Team by Patrick Lencioni The Five Dysfunctions of a T .unwrap() .score; - assert!(score > 0.9, "score: {}", score); + assert!(score > 0.9, "score: {}, input:{}", score, answer); } } From 3799966f1235bceeb376defdcce036b82ea024a2 Mon Sep 17 00:00:00 2001 From: Bartosz Sypytkowski Date: Sun, 17 Nov 2024 06:25:42 +0100 Subject: [PATCH 10/20] chore: store pending collab writes in memory (#1000) * chore: write immediatelly actually writes immediatelly * chore: fix clippy errors * chore: add metrics to new storage queue impl * chore: set collab batch write capacity to the same as on main branch --- libs/client-websocket/src/error.rs | 2 +- libs/client-websocket/src/native.rs | 2 +- .../appflowy-collaborate/src/application.rs | 1 - .../appflowy-collaborate/src/collab/mod.rs | 4 - .../appflowy-collaborate/src/collab/queue.rs | 577 ------------------ .../src/collab/queue_redis_ops.rs | 418 ------------- .../src/collab/storage.rs | 175 +++++- services/appflowy-collaborate/src/metrics.rs | 30 +- src/application.rs | 1 - tests/collab/storage_test.rs | 173 +----- 10 files changed, 168 insertions(+), 1215 deletions(-) delete mode 100644 services/appflowy-collaborate/src/collab/queue.rs delete mode 100644 services/appflowy-collaborate/src/collab/queue_redis_ops.rs diff --git a/libs/client-websocket/src/error.rs b/libs/client-websocket/src/error.rs index f5b4cde62..60b2fa72c 100644 --- a/libs/client-websocket/src/error.rs +++ b/libs/client-websocket/src/error.rs @@ -59,7 +59,7 @@ pub enum Error { #[error("URL error: {0}")] Url(#[from] UrlError), #[error("HTTP error: {}", .0.status())] - Http(Response>>), + Http(Box>>>), #[error("HTTP format error: {0}")] HttpFormat(#[from] http::Error), #[error("Parsing blobs is unsupported")] diff --git a/libs/client-websocket/src/native.rs b/libs/client-websocket/src/native.rs index 2b0b0fc84..920f63884 100644 --- a/libs/client-websocket/src/native.rs +++ b/libs/client-websocket/src/native.rs @@ -152,7 +152,7 @@ impl From for crate::Error { Error::Utf8 => crate::Error::Utf8, Error::AttackAttempt => crate::Error::AttackAttempt, Error::Url(inner) => crate::Error::Url(inner.into()), - Error::Http(inner) => crate::Error::Http(inner), + Error::Http(inner) => crate::Error::Http(inner.into()), Error::HttpFormat(inner) => crate::Error::HttpFormat(inner), } } diff --git a/services/appflowy-collaborate/src/application.rs b/services/appflowy-collaborate/src/application.rs index b0763df37..645162730 100644 --- a/services/appflowy-collaborate/src/application.rs +++ b/services/appflowy-collaborate/src/application.rs @@ -133,7 +133,6 @@ pub async fn init_state(config: &Config, rt_cmd_tx: CLCommandSender) -> Result; -#[derive(Clone)] -pub struct StorageQueue { - collab_cache: CollabCache, - connection_manager: RedisConnectionManager, - pending_write_set: PendingWriteSet, - pending_id_counter: Arc, - total_queue_collab_count: Arc, - success_queue_collab_count: Arc, -} - -pub const REDIS_PENDING_WRITE_QUEUE: &str = "collab_pending_write_queue_v0"; - -impl StorageQueue { - pub fn new( - collab_cache: CollabCache, - connection_manager: RedisConnectionManager, - queue_name: &str, - ) -> Self { - Self::new_with_metrics(collab_cache, connection_manager, queue_name, None) - } - - pub fn new_with_metrics( - collab_cache: CollabCache, - connection_manager: RedisConnectionManager, - queue_name: &str, - metrics: Option>, - ) -> Self { - let next_duration = Arc::new(Mutex::from(Duration::from_secs(1))); - let pending_id_counter = Arc::new(AtomicI64::new(0)); - let pending_write_set = Arc::new(RedisSortedSet::new(connection_manager.clone(), queue_name)); - - let total_queue_collab_count = Arc::new(AtomicI64::new(0)); - let success_queue_collab_count = Arc::new(AtomicI64::new(0)); - - // Spawns a task that periodically writes pending collaboration objects to the database. - spawn_period_write( - next_duration.clone(), - collab_cache.clone(), - connection_manager.clone(), - pending_write_set.clone(), - metrics.clone(), - total_queue_collab_count.clone(), - success_queue_collab_count.clone(), - ); - - spawn_period_check_pg_conn_count(collab_cache.pg_pool().clone(), next_duration); - - Self { - collab_cache, - connection_manager, - pending_write_set, - pending_id_counter, - total_queue_collab_count, - success_queue_collab_count, - } - } - - /// Enqueues a object for deferred processing. High priority writes are processed before low priority writes. - /// - /// adds a write task to a pending queue, which is periodically flushed by another task that batches - /// and writes the queued collaboration objects to a PostgreSQL database. - /// - /// This data is stored temporarily in the `collab_cache` and is intended for later persistent storage - /// in the database. It can also be retrieved during subsequent calls in the [CollabStorageImpl::get_encode_collab] - /// to enhance performance and reduce database reads. - /// - #[instrument(level = "trace", skip_all)] - pub async fn push( - &self, - workspace_id: &str, - uid: &i64, - params: &CollabParams, - priority: WritePriority, - ) -> Result<(), AppError> { - trace!("queuing {} object to pending write queue", params.object_id,); - self - .collab_cache - .insert_encode_collab_to_mem(params) - .await?; - - let seq = self - .pending_id_counter - .fetch_add(1, std::sync::atomic::Ordering::SeqCst); - - let pending_write = PendingWrite { - object_id: params.object_id.clone(), - seq, - data_len: params.encoded_collab_v1.len(), - priority, - }; - - let pending_meta = PendingWriteMeta { - uid: *uid, - workspace_id: workspace_id.to_string(), - object_id: params.object_id.clone(), - collab_type: params.collab_type.clone(), - embeddings: params.embeddings.clone(), - }; - - self - .total_queue_collab_count - .fetch_add(1, std::sync::atomic::Ordering::Relaxed); - // If the queueing fails, write the data to the database immediately - if let Err(err) = self - .queue_pending(params, pending_write, pending_meta) - .await - { - error!( - "Failed to queue pending write for object {}: {:?}", - params.object_id, err - ); - - let mut transaction = self - .collab_cache - .pg_pool() - .begin() - .await - .context("acquire transaction to upsert collab") - .map_err(AppError::from)?; - self - .collab_cache - .insert_encode_collab_data(workspace_id, uid, params, &mut transaction) - .await?; - transaction - .commit() - .await - .context("fail to commit the transaction to upsert collab") - .map_err(AppError::from)?; - } else { - self - .success_queue_collab_count - .fetch_add(1, std::sync::atomic::Ordering::Relaxed); - trace!( - "did queue {}:{} object for deferred writing to disk", - params.object_id, - seq - ); - } - - Ok(()) - } - - #[cfg(debug_assertions)] - pub async fn clear(&self) -> Result<(), AppError> { - self.pending_write_set.clear().await?; - crate::collab::queue_redis_ops::remove_all_pending_meta(self.connection_manager.clone()) - .await?; - Ok(()) - } - - #[inline] - async fn queue_pending( - &self, - params: &CollabParams, - pending_write: PendingWrite, - pending_write_meta: PendingWriteMeta, - ) -> Result<(), anyhow::Error> { - trace!( - "queue pending write: {}:{}", - pending_write_meta.object_id, - pending_write_meta.collab_type - ); - const MAX_RETRIES: usize = 3; - const BASE_DELAY_MS: u64 = 200; - const BACKOFF_FACTOR: u64 = 2; - - // these serialization seems very fast, so we don't need to worry about the performance and no - // need to use spawn_blocking or block_in_place - let pending_write_data = serde_json::to_vec(&pending_write)?; - let pending_write_meta_data = serde_json::to_vec(&pending_write_meta)?; - - let key = storage_cache_key(¶ms.object_id, params.encoded_collab_v1.len()); - let mut conn = self.connection_manager.clone(); - for attempt in 0..MAX_RETRIES { - let mut pipe = redis::pipe(); - // Prepare the pipeline with both commands - // 1. ZADD to add the pending write to the queue - // 2. SETEX to add the pending metadata to the cache - pipe - // .atomic() - .cmd("ZADD") - .arg(self.pending_write_set.queue_name()) - .arg(pending_write.score()) - .arg(&pending_write_data) - .ignore() - .cmd("SETEX") - .arg(&key) - .arg(PENDING_WRITE_META_EXPIRE_SECS) - .arg(&pending_write_meta_data) - .ignore(); - - match pipe.query_async::<_, ()>(&mut conn).await { - Ok(_) => return Ok(()), - Err(e) => { - if attempt == MAX_RETRIES - 1 { - return Err(e.into()); - } - - // 200ms, 400ms, 800ms - let delay = BASE_DELAY_MS * BACKOFF_FACTOR.pow(attempt as u32); - sleep(Duration::from_millis(delay)).await; - }, - } - } - Err(anyhow!("Failed to execute redis pipeline after retries")) - } -} - -/// Spawn a task that periodically checks the number of active connections in the PostgreSQL pool -/// It aims to adjust the write interval based on the number of active connections. -fn spawn_period_check_pg_conn_count(pg_pool: PgPool, next_duration: Arc>) { - let mut interval = interval(tokio::time::Duration::from_secs(10)); - tokio::spawn(async move { - loop { - interval.tick().await; - // these range values are arbitrary and can be adjusted as needed - match pg_pool.size() { - 0..=40 => { - *next_duration.lock().await = Duration::from_secs(1); - }, - _ => { - *next_duration.lock().await = Duration::from_secs(5); - }, - } - } - }); -} - -fn spawn_period_write( - next_duration: Arc>, - collab_cache: CollabCache, - connection_manager: RedisConnectionManager, - pending_write_set: PendingWriteSet, - metrics: Option>, - total_queue_collab_count: Arc, - success_queue_collab_count: Arc, -) { - let total_write_count = Arc::new(AtomicI64::new(0)); - let success_write_count = Arc::new(AtomicI64::new(0)); - tokio::spawn(async move { - loop { - // The next_duration will be changed by spawn_period_check_pg_conn_count. When the number of - // active connections is high, the interval will be longer. - let instant = Instant::now() + *next_duration.lock().await; - sleep_until(instant).await; - - if let Some(metrics) = metrics.as_ref() { - metrics.record_write_collab( - success_write_count.load(std::sync::atomic::Ordering::Relaxed), - total_write_count.load(std::sync::atomic::Ordering::Relaxed), - ); - - metrics.record_queue_collab( - success_queue_collab_count.load(std::sync::atomic::Ordering::Relaxed), - total_queue_collab_count.load(std::sync::atomic::Ordering::Relaxed), - ); - } - - let chunk_keys = consume_pending_write(&pending_write_set, 20, 5).await; - if chunk_keys.is_empty() { - continue; - } - - for keys in chunk_keys { - trace!( - "start writing {} pending collaboration data to disk", - keys.len() - ); - let cloned_collab_cache = collab_cache.clone(); - let mut cloned_connection_manager = connection_manager.clone(); - let cloned_total_write_count = total_write_count.clone(); - let cloned_total_success_write_count = success_write_count.clone(); - - if let Ok(metas) = get_pending_meta(&keys, &mut cloned_connection_manager).await { - if metas.is_empty() { - error!("the pending write keys is not empty, but metas is empty"); - return; - } - - match retry_write_pending_to_disk(&cloned_collab_cache, metas).await { - Ok(success_result) => { - #[cfg(debug_assertions)] - tracing::info!("success write pending: {:?}", keys,); - - trace!("{:?}", success_result); - cloned_total_write_count.fetch_add( - success_result.expected as i64, - std::sync::atomic::Ordering::Relaxed, - ); - cloned_total_success_write_count.fetch_add( - success_result.success as i64, - std::sync::atomic::Ordering::Relaxed, - ); - }, - Err(err) => error!("{:?}", err), - } - // Remove pending metadata from Redis even if some records fail to write to disk after retries. - // Records that fail repeatedly are considered potentially corrupt or invalid. - let _ = remove_pending_meta(&keys, &mut cloned_connection_manager).await; - } - } - } - }); -} - -async fn retry_write_pending_to_disk( - collab_cache: &CollabCache, - mut metas: Vec, -) -> Result { - const RETRY_DELAYS: [Duration; 2] = [Duration::from_secs(1), Duration::from_secs(2)]; - - let expected = metas.len(); - let mut successes = Vec::with_capacity(metas.len()); - - for &delay in RETRY_DELAYS.iter() { - match write_pending_to_disk(&metas, collab_cache).await { - Ok(success_write_objects) => { - if !success_write_objects.is_empty() { - successes.extend_from_slice(&success_write_objects); - metas.retain(|meta| !success_write_objects.contains(&meta.object_id)); - } - - // If there are no more metas to process, return the successes - if metas.is_empty() { - return Ok(WritePendingResult { - expected, - success: successes.len(), - fail: 0, - }); - } - }, - Err(err) => { - warn!( - "Error writing to disk: {:?}, retrying after {:?}", - err, delay - ); - }, - } - - // Only sleep if there are more attempts left - if !metas.is_empty() { - sleep(delay).await; - } - } - - if expected >= successes.len() { - Ok(WritePendingResult { - expected, - success: successes.len(), - fail: expected - successes.len(), - }) - } else { - Err(AppError::Internal(anyhow!( - "the len of expected is less than success" - ))) - } -} - -#[derive(Debug)] -struct WritePendingResult { - expected: usize, - success: usize, - #[allow(dead_code)] - fail: usize, -} - -async fn write_pending_to_disk( - pending_metas: &[PendingWriteMeta], - collab_cache: &CollabCache, -) -> Result, AppError> { - let mut success_write_objects = Vec::with_capacity(pending_metas.len()); - // Convert pending metadata into query parameters for batch fetching - let queries = pending_metas - .iter() - .map(QueryCollab::from) - .collect::>(); - - // Retrieve encoded collaboration data in batch - let results = collab_cache.batch_get_encode_collab(queries).await; - - // Create a mapping from object IDs to their corresponding metadata - let meta_map = pending_metas - .iter() - .map(|meta| (meta.object_id.clone(), meta)) - .collect::>(); - - // Prepare collaboration data for writing to the database - let records = results - .into_iter() - .filter_map(|(object_id, result)| { - if let QueryCollabResult::Success { encode_collab_v1 } = result { - meta_map.get(&object_id).map(|meta| PendingWriteData { - uid: meta.uid, - workspace_id: meta.workspace_id.clone(), - object_id: meta.object_id.clone(), - collab_type: meta.collab_type.clone(), - encode_collab_v1: encode_collab_v1.into(), - embeddings: meta.embeddings.clone(), - }) - } else { - None - } - }) - .collect::>(); - - // Start a database transaction - let mut transaction = collab_cache - .pg_pool() - .begin() - .await - .context("Failed to acquire transaction for writing pending collaboration data") - .map_err(AppError::from)?; - - // Insert each record into the database within the transaction context - let mut action_description = String::new(); - for (index, record) in records.into_iter().enumerate() { - let params = CollabParams { - object_id: record.object_id.clone(), - collab_type: record.collab_type, - encoded_collab_v1: record.encode_collab_v1, - embeddings: record.embeddings, - }; - action_description = format!("{}", params); - let savepoint_name = format!("sp_{}", index); - - // using savepoint to rollback the transaction if the insert fails - sqlx::query(&format!("SAVEPOINT {}", savepoint_name)) - .execute(transaction.deref_mut()) - .await?; - if let Err(_err) = collab_cache - .insert_encode_collab_to_disk(&record.workspace_id, &record.uid, params, &mut transaction) - .await - { - sqlx::query(&format!("ROLLBACK TO SAVEPOINT {}", savepoint_name)) - .execute(transaction.deref_mut()) - .await?; - } else { - success_write_objects.push(record.object_id); - } - } - - // Commit the transaction to finalize all writes - match tokio::time::timeout(Duration::from_secs(10), transaction.commit()).await { - Ok(result) => { - result.map_err(AppError::from)?; - Ok(success_write_objects) - }, - Err(_) => { - error!( - "Timeout waiting for committing the transaction for pending write:{}", - action_description - ); - Err(AppError::Internal(anyhow!( - "Timeout when committing the transaction for pending collaboration data" - ))) - }, - } -} - -const MAXIMUM_CHUNK_SIZE: usize = 5 * 1024 * 1024; -#[inline] -pub async fn consume_pending_write( - pending_write_set: &PendingWriteSet, - maximum_consume_item: usize, - num_of_item_each_chunk: usize, -) -> Vec> { - let mut chunks = Vec::new(); - let mut current_chunk = Vec::with_capacity(maximum_consume_item); - let mut current_chunk_data_size = 0; - - if let Ok(items) = pending_write_set.pop(maximum_consume_item).await { - #[cfg(debug_assertions)] - if !items.is_empty() { - trace!("Consuming {} pending write items", items.len()); - } - - for item in items { - let item_size = item.data_len; - // Check if adding this item would exceed the maximum chunk size or item limit - if current_chunk_data_size + item_size > MAXIMUM_CHUNK_SIZE - || current_chunk.len() >= num_of_item_each_chunk - { - if !current_chunk.is_empty() { - chunks.push(std::mem::take(&mut current_chunk)); - } - current_chunk_data_size = 0; - } - - // Add the item to the current batch and update the batch size - current_chunk.push(item); - current_chunk_data_size += item_size; - } - } - - if !current_chunk.is_empty() { - chunks.push(current_chunk); - } - // Convert each batch of items into a batch of keys - chunks - .into_iter() - .map(|batch| { - batch - .into_iter() - .map(|pending| storage_cache_key(&pending.object_id, pending.data_len)) - .collect() - }) - .collect() -} - -#[derive(Debug, PartialEq, Serialize, Deserialize)] -pub struct PendingWriteMeta { - pub uid: i64, - pub workspace_id: String, - pub object_id: String, - pub collab_type: CollabType, - #[serde(default)] - pub embeddings: Option, -} - -impl From<&PendingWriteMeta> for QueryCollab { - fn from(meta: &PendingWriteMeta) -> Self { - QueryCollab { - object_id: meta.object_id.clone(), - collab_type: meta.collab_type.clone(), - } - } -} - -#[derive(PartialEq, Debug)] -pub struct PendingWriteData { - pub uid: i64, - pub workspace_id: String, - pub object_id: String, - pub collab_type: CollabType, - pub encode_collab_v1: Bytes, - pub embeddings: Option, -} - -impl From for CollabParams { - fn from(data: PendingWriteData) -> Self { - CollabParams { - object_id: data.object_id, - collab_type: data.collab_type, - encoded_collab_v1: data.encode_collab_v1, - embeddings: data.embeddings, - } - } -} diff --git a/services/appflowy-collaborate/src/collab/queue_redis_ops.rs b/services/appflowy-collaborate/src/collab/queue_redis_ops.rs deleted file mode 100644 index 59c4c544e..000000000 --- a/services/appflowy-collaborate/src/collab/queue_redis_ops.rs +++ /dev/null @@ -1,418 +0,0 @@ -use crate::collab::queue::PendingWriteMeta; -use crate::state::RedisConnectionManager; -use app_error::AppError; -use futures_util::StreamExt; -use redis::{AsyncCommands, AsyncIter, Script}; -use serde::{Deserialize, Serialize}; -use serde_repr::{Deserialize_repr, Serialize_repr}; - -pub(crate) const PENDING_WRITE_META_EXPIRE_SECS: u64 = 604800; // 7 days in seconds - -#[allow(dead_code)] -pub(crate) async fn remove_all_pending_meta( - mut connection_manager: RedisConnectionManager, -) -> Result<(), AppError> { - let pattern = format!("{}*", QUEUE_COLLAB_PREFIX); - let iter: AsyncIter = connection_manager - .scan_match(pattern) - .await - .map_err(|err| AppError::Internal(err.into()))?; - let keys: Vec<_> = iter.collect().await; - - if keys.is_empty() { - return Ok(()); - } - connection_manager - .del(keys) - .await - .map_err(|err| AppError::Internal(err.into()))?; - Ok(()) -} - -#[inline] -pub(crate) async fn get_pending_meta( - keys: &[String], - connection_manager: &mut RedisConnectionManager, -) -> Result, AppError> { - let results: Vec>> = connection_manager - .get(keys) - .await - .map_err(|err| AppError::Internal(err.into()))?; - - let metas = results - .into_iter() - .filter_map(|value| value.and_then(|data| serde_json::from_slice(&data).ok())) - .collect::>(); - - Ok(metas) -} - -#[inline] -pub(crate) async fn remove_pending_meta( - keys: &[String], - connection_manager: &mut RedisConnectionManager, -) -> Result<(), AppError> { - connection_manager - .del(keys) - .await - .map_err(|err| AppError::Internal(err.into()))?; - Ok(()) -} - -pub(crate) const QUEUE_COLLAB_PREFIX: &str = "storage_pending_meta_v0:"; - -#[inline] -pub(crate) fn storage_cache_key(object_id: &str, data_len: usize) -> String { - format!("{}{}:{}", QUEUE_COLLAB_PREFIX, object_id, data_len) -} - -#[derive(Clone)] -pub struct RedisSortedSet { - conn: RedisConnectionManager, - name: String, -} - -impl RedisSortedSet { - pub fn new(conn: RedisConnectionManager, name: &str) -> Self { - Self { - conn, - name: name.to_string(), - } - } - - pub async fn push(&self, item: PendingWrite) -> Result<(), anyhow::Error> { - let data = serde_json::to_vec(&item)?; - redis::cmd("ZADD") - .arg(&self.name) - .arg(item.score()) - .arg(data) - .query_async(&mut self.conn.clone()) - .await?; - Ok(()) - } - - pub async fn push_with_conn( - &self, - item: PendingWrite, - conn: &mut RedisConnectionManager, - ) -> Result<(), anyhow::Error> { - let data = serde_json::to_vec(&item)?; - redis::cmd("ZADD") - .arg(&self.name) - .arg(item.score()) - .arg(data) - .query_async(conn) - .await?; - Ok(()) - } - - pub fn queue_name(&self) -> &str { - &self.name - } - - /// Pops items from a Redis sorted set. - /// - /// This asynchronous function retrieves and removes the top `len` items from a Redis sorted set specified by `self.name`. - /// It uses a Lua script to atomically perform the operation to maintain data integrity during concurrent access. - /// - /// # Parameters - /// - `len`: The number of items to pop from the sorted set. If `len` is 0, the function returns an empty vector. - /// - pub async fn pop(&self, len: usize) -> Result, anyhow::Error> { - if len == 0 { - return Ok(vec![]); - } - - let script = Script::new( - r#" - local items = redis.call('ZRANGE', KEYS[1], 0, ARGV[1], 'WITHSCORES') - if #items > 0 then - redis.call('ZREMRANGEBYRANK', KEYS[1], 0, #items / 2 - 1) - end - return items - "#, - ); - let mut conn = self.conn.clone(); - let items: Vec<(String, f64)> = script - .key(&self.name) - .arg(len - 1) - .invoke_async(&mut conn) - .await?; - - let results = items - .iter() - .map(|(data, _score)| serde_json::from_str::(data).map_err(|e| e.into())) - .collect::, anyhow::Error>>()?; - - Ok(results) - } - - pub async fn peek(&self, n: usize) -> Result, anyhow::Error> { - let mut conn = self.conn.clone(); - let items: Vec<(String, f64)> = redis::cmd("ZREVRANGE") - .arg(&self.name) - .arg(0) - .arg(n - 1) - .arg("WITHSCORES") - .query_async(&mut conn) - .await?; - - let results = items - .iter() - .map(|(data, _score)| serde_json::from_str::(data).map_err(|e| e.into())) - .collect::, anyhow::Error>>()?; - - Ok(results) - } - pub async fn remove_items>( - &self, - items_to_remove: Vec, - ) -> Result<(), anyhow::Error> { - let mut conn = self.conn.clone(); - let mut pipe = redis::pipe(); - for item in items_to_remove { - pipe.cmd("ZREM").arg(&self.name).arg(item.as_ref()).ignore(); - } - pipe.query_async::<_, ()>(&mut conn).await?; - Ok(()) - } - - pub async fn clear(&self) -> Result<(), anyhow::Error> { - let mut conn = self.conn.clone(); - conn.del(&self.name).await?; - Ok(()) - } -} - -#[derive(Clone, Serialize, Deserialize, Debug)] -pub struct PendingWrite { - pub object_id: String, - pub seq: i64, - pub data_len: usize, - pub priority: WritePriority, -} - -impl PendingWrite { - pub fn score(&self) -> i64 { - match self.priority { - WritePriority::High => 0, - WritePriority::Low => self.seq + 1, - } - } -} - -#[derive(Clone, Serialize_repr, Deserialize_repr, Debug)] -#[repr(u8)] -pub enum WritePriority { - High = 0, - Low = 1, -} - -impl Eq for PendingWrite {} -impl PartialEq for PendingWrite { - fn eq(&self, other: &Self) -> bool { - self.object_id == other.object_id - } -} - -impl Ord for PendingWrite { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - match (&self.priority, &other.priority) { - (WritePriority::High, WritePriority::Low) => std::cmp::Ordering::Greater, - (WritePriority::Low, WritePriority::High) => std::cmp::Ordering::Less, - _ => { - // Assuming lower seq is higher priority - other.seq.cmp(&self.seq) - }, - } - } -} - -impl PartialOrd for PendingWrite { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -#[cfg(test)] -mod tests { - use crate::collab::{PendingWrite, RedisSortedSet, WritePriority}; - use anyhow::Context; - use std::time::Duration; - - #[tokio::test] - async fn pending_write_sorted_set_test() { - let conn = redis_client().await.get_connection_manager().await.unwrap(); - let set_name = uuid::Uuid::new_v4().to_string(); - let sorted_set = RedisSortedSet::new(conn.clone(), &set_name); - - let pending_writes = vec![ - PendingWrite { - object_id: "o1".to_string(), - seq: 1, - data_len: 0, - priority: WritePriority::Low, - }, - PendingWrite { - object_id: "o2".to_string(), - seq: 2, - data_len: 0, - priority: WritePriority::Low, - }, - PendingWrite { - object_id: "o3".to_string(), - seq: 0, - data_len: 0, - priority: WritePriority::High, - }, - ]; - - for item in &pending_writes { - sorted_set.push(item.clone()).await.unwrap(); - } - - let pending_writes_from_sorted_set = sorted_set.pop(3).await.unwrap(); - assert_eq!(pending_writes_from_sorted_set[0].object_id, "o3"); - assert_eq!(pending_writes_from_sorted_set[1].object_id, "o1"); - assert_eq!(pending_writes_from_sorted_set[2].object_id, "o2"); - - let items = sorted_set.pop(2).await.unwrap(); - assert!(items.is_empty()); - } - - #[tokio::test] - async fn sorted_set_consume_partial_items_test() { - let conn = redis_client().await.get_connection_manager().await.unwrap(); - let set_name = uuid::Uuid::new_v4().to_string(); - let sorted_set_1 = RedisSortedSet::new(conn.clone(), &set_name); - - let pending_writes = vec![ - PendingWrite { - object_id: "o1".to_string(), - seq: 1, - data_len: 0, - priority: WritePriority::Low, - }, - PendingWrite { - object_id: "o1".to_string(), - seq: 1, - data_len: 0, - priority: WritePriority::Low, - }, - PendingWrite { - object_id: "o2".to_string(), - seq: 2, - data_len: 0, - priority: WritePriority::Low, - }, - PendingWrite { - object_id: "o3".to_string(), - seq: 0, - data_len: 0, - priority: WritePriority::High, - }, - ]; - - for item in &pending_writes { - sorted_set_1.push(item.clone()).await.unwrap(); - } - - let pending_writes_from_sorted_set = sorted_set_1.pop(1).await.unwrap(); - assert_eq!(pending_writes_from_sorted_set[0].object_id, "o3"); - - let sorted_set_2 = RedisSortedSet::new(conn.clone(), &set_name); - let pending_writes_from_sorted_set = sorted_set_2.pop(10).await.unwrap(); - assert_eq!(pending_writes_from_sorted_set.len(), 2); - assert_eq!(pending_writes_from_sorted_set[0].object_id, "o1"); - assert_eq!(pending_writes_from_sorted_set[1].object_id, "o2"); - - assert!(sorted_set_1.pop(10).await.unwrap().is_empty()); - assert!(sorted_set_2.pop(10).await.unwrap().is_empty()); - } - - #[tokio::test] - async fn large_num_set_test() { - let conn = redis_client().await.get_connection_manager().await.unwrap(); - let set_name = uuid::Uuid::new_v4().to_string(); - let sorted_set = RedisSortedSet::new(conn.clone(), &set_name); - assert!(sorted_set.pop(10).await.unwrap().is_empty()); - - for i in 0..100 { - let pending_write = PendingWrite { - object_id: format!("o{}", i), - seq: i, - data_len: 0, - priority: WritePriority::Low, - }; - sorted_set.push(pending_write).await.unwrap(); - } - - let set_1 = sorted_set.pop(20).await.unwrap(); - assert_eq!(set_1.len(), 20); - assert_eq!(set_1[19].object_id, "o19"); - - let set_2 = sorted_set.pop(30).await.unwrap(); - assert_eq!(set_2.len(), 30); - assert_eq!(set_2[0].object_id, "o20"); - assert_eq!(set_2[29].object_id, "o49"); - - let set_3 = sorted_set.pop(1).await.unwrap(); - assert_eq!(set_3.len(), 1); - assert_eq!(set_3[0].object_id, "o50"); - - let set_4 = sorted_set.pop(200).await.unwrap(); - assert_eq!(set_4.len(), 49); - } - - #[tokio::test] - async fn multi_threads_sorted_set_test() { - let conn = redis_client().await.get_connection_manager().await.unwrap(); - let set_name = uuid::Uuid::new_v4().to_string(); - let sorted_set = RedisSortedSet::new(conn.clone(), &set_name); - - let mut handles = vec![]; - for i in 0..100 { - let cloned_sorted_set = sorted_set.clone(); - let handle = tokio::spawn(async move { - let pending_write = PendingWrite { - object_id: format!("o{}", i), - seq: i, - data_len: 0, - priority: WritePriority::Low, - }; - tokio::time::sleep(Duration::from_millis(rand::random::() % 100)).await; - cloned_sorted_set - .push(pending_write) - .await - .expect("Failed to push data") - }); - handles.push(handle); - } - futures::future::join_all(handles).await; - - let mut handles = vec![]; - for _ in 0..10 { - let cloned_sorted_set = sorted_set.clone(); - let handle = tokio::spawn(async move { - tokio::time::sleep(Duration::from_millis(rand::random::() % 100)).await; - let items = cloned_sorted_set - .pop(10) - .await - .expect("Failed to pop items"); - assert_eq!(items.len(), 10, "Expected exactly 10 items to be popped"); - }); - handles.push(handle); - } - let results = futures::future::join_all(handles).await; - for result in results { - result.expect("A thread panicked or errored out"); - } - } - - async fn redis_client() -> redis::Client { - let redis_uri = "redis://localhost:6379"; - redis::Client::open(redis_uri) - .context("failed to connect to redis") - .unwrap() - } -} diff --git a/services/appflowy-collaborate/src/collab/storage.rs b/services/appflowy-collaborate/src/collab/storage.rs index 00f0701a6..3a0c6c55c 100644 --- a/services/appflowy-collaborate/src/collab/storage.rs +++ b/services/appflowy-collaborate/src/collab/storage.rs @@ -1,7 +1,6 @@ -use std::collections::HashMap; -use std::sync::Arc; -use std::time::Duration; +#![allow(unused_imports)] +use anyhow::{anyhow, Context}; use async_trait::async_trait; use collab::entity::EncodedCollab; use collab_entity::CollabType; @@ -10,10 +9,15 @@ use database::collab::cache::CollabCache; use itertools::{Either, Itertools}; use rayon::iter::{IntoParallelIterator, ParallelIterator}; use sqlx::Transaction; - +use std::collections::HashMap; +use std::ops::DerefMut; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::time::timeout; use tracing::warn; use tracing::{error, instrument, trace}; +use uuid::Uuid; use validator::Validate; use crate::command::{CLCommandSender, CollaborationCommand}; @@ -28,15 +32,28 @@ use database_entity::dto::{ }; use crate::collab::access_control::CollabStorageAccessControlImpl; -use crate::collab::queue::{StorageQueue, REDIS_PENDING_WRITE_QUEUE}; -use crate::collab::queue_redis_ops::WritePriority; use crate::collab::validator::CollabValidator; use crate::metrics::CollabMetrics; use crate::snapshot::SnapshotControl; -use crate::state::RedisConnectionManager; pub type CollabAccessControlStorage = CollabStorageImpl; +struct PendingCollabWrite { + workspace_id: String, + uid: i64, + params: CollabParams, +} + +impl PendingCollabWrite { + fn new(workspace_id: String, uid: i64, params: CollabParams) -> Self { + PendingCollabWrite { + workspace_id, + uid, + params, + } + } +} + /// A wrapper around the actual storage implementation that provides access control and caching. #[derive(Clone)] pub struct CollabStorageImpl { @@ -45,7 +62,8 @@ pub struct CollabStorageImpl { access_control: AC, snapshot_control: SnapshotControl, rt_cmd_sender: CLCommandSender, - queue: Arc, + queue: Sender, + metrics: Arc, } impl CollabStorageImpl @@ -57,14 +75,13 @@ where access_control: AC, snapshot_control: SnapshotControl, rt_cmd_sender: CLCommandSender, - redis_conn_manager: RedisConnectionManager, metrics: Arc, ) -> Self { - let queue = Arc::new(StorageQueue::new_with_metrics( + let (queue, reader) = channel(1000); + tokio::spawn(Self::periodic_write_task( cache.clone(), - redis_conn_manager, - REDIS_PENDING_WRITE_QUEUE, - Some(metrics), + metrics.clone(), + reader, )); Self { cache, @@ -72,9 +89,117 @@ where snapshot_control, rt_cmd_sender, queue, + metrics, } } + const PENDING_WRITE_BUF_CAPACITY: usize = 20; + async fn periodic_write_task( + cache: CollabCache, + metrics: Arc, + mut reader: Receiver, + ) { + let mut buf = Vec::with_capacity(Self::PENDING_WRITE_BUF_CAPACITY); + loop { + let n = reader + .recv_many(&mut buf, Self::PENDING_WRITE_BUF_CAPACITY) + .await; + if n == 0 { + break; + } + let pending = buf.drain(..n); + if let Err(e) = Self::persist(&cache, &metrics, pending).await { + tracing::error!("failed to persist {} collabs: {}", n, e); + } + } + } + + async fn persist( + cache: &CollabCache, + metrics: &CollabMetrics, + records: impl ExactSizeIterator, + ) -> Result<(), AppError> { + // Start a database transaction + let mut transaction = cache + .pg_pool() + .begin() + .await + .context("Failed to acquire transaction for writing pending collaboration data") + .map_err(AppError::from)?; + + let total_records = records.len(); + let mut successful_writes = 0; + // Insert each record into the database within the transaction context + let mut action_description = String::new(); + for (index, record) in records.into_iter().enumerate() { + let params = record.params; + action_description = format!("{}", params); + let savepoint_name = format!("sp_{}", index); + + // using savepoint to rollback the transaction if the insert fails + sqlx::query(&format!("SAVEPOINT {}", savepoint_name)) + .execute(transaction.deref_mut()) + .await?; + if let Err(_err) = cache + .insert_encode_collab_to_disk(&record.workspace_id, &record.uid, params, &mut transaction) + .await + { + sqlx::query(&format!("ROLLBACK TO SAVEPOINT {}", savepoint_name)) + .execute(transaction.deref_mut()) + .await?; + } else { + successful_writes += 1; + } + } + + metrics.record_write_collab(successful_writes, total_records as _); + + // Commit the transaction to finalize all writes + match tokio::time::timeout(Duration::from_secs(10), transaction.commit()).await { + Ok(result) => { + result.map_err(AppError::from)?; + Ok(()) + }, + Err(_) => { + error!( + "Timeout waiting for committing the transaction for pending write:{}", + action_description + ); + Err(AppError::Internal(anyhow!( + "Timeout when committing the transaction for pending collaboration data" + ))) + }, + } + } + + async fn insert_collab( + &self, + workspace_id: &str, + uid: &i64, + params: CollabParams, + ) -> AppResult<()> { + // Start a database transaction + let mut transaction = self + .cache + .pg_pool() + .begin() + .await + .context("Failed to acquire transaction for writing pending collaboration data") + .map_err(AppError::from)?; + self + .cache + .insert_encode_collab_to_disk(workspace_id, uid, params, &mut transaction) + .await?; + tokio::time::timeout(Duration::from_secs(10), transaction.commit()) + .await + .map_err(|_| { + AppError::Internal(anyhow!( + "Timeout when committing the transaction for pending collaboration data" + )) + })??; + Ok(()) + } + async fn check_write_workspace_permission( &self, workspace_id: &str, @@ -178,7 +303,6 @@ where workspace_id: &str, uid: &i64, params: CollabParams, - priority: WritePriority, ) -> Result<(), AppError> { trace!( "Queue insert collab:{}:{}", @@ -192,11 +316,13 @@ where ))); } - self - .queue - .push(workspace_id, uid, ¶ms, priority) - .await - .map_err(AppError::from) + let pending = PendingCollabWrite::new(workspace_id.into(), *uid, params); + if let Err(e) = self.queue.send(pending).await { + tracing::error!("Failed to queue insert collab doc state: {}", e); + } else { + self.metrics.record_queue_collab(1); + } + Ok(()) } async fn batch_insert_collabs( @@ -259,14 +385,11 @@ where .update_policy(uid, ¶ms.object_id, AFAccessLevel::FullAccess) .await?; } - let priority = if write_immediately { - WritePriority::High + if write_immediately { + self.insert_collab(workspace_id, uid, params).await?; } else { - WritePriority::Low - }; - self - .queue_insert_collab(workspace_id, uid, params, priority) - .await?; + self.queue_insert_collab(workspace_id, uid, params).await?; + } Ok(()) } diff --git a/services/appflowy-collaborate/src/metrics.rs b/services/appflowy-collaborate/src/metrics.rs index 3a37c40c1..33b21cd33 100644 --- a/services/appflowy-collaborate/src/metrics.rs +++ b/services/appflowy-collaborate/src/metrics.rs @@ -1,9 +1,9 @@ -use std::sync::Arc; -use std::time::Duration; - +use prometheus_client::metrics::counter::Counter; use prometheus_client::metrics::gauge::Gauge; use prometheus_client::metrics::histogram::Histogram; use prometheus_client::registry::Registry; +use std::sync::Arc; +use std::time::Duration; use tokio::time::interval; use database::collab::CollabStorage; @@ -146,10 +146,9 @@ where pub struct CollabMetrics { success_write_snapshot_count: Gauge, total_write_snapshot_count: Gauge, - success_write_collab_count: Gauge, - total_write_collab_count: Gauge, - total_queue_collab_count: Gauge, - success_queue_collab_count: Gauge, + success_write_collab_count: Counter, + total_write_collab_count: Counter, + success_queue_collab_count: Counter, } impl CollabMetrics { @@ -159,7 +158,6 @@ impl CollabMetrics { total_write_snapshot_count: Default::default(), success_write_collab_count: Default::default(), total_write_collab_count: Default::default(), - total_queue_collab_count: Default::default(), success_queue_collab_count: Default::default(), } } @@ -192,11 +190,6 @@ impl CollabMetrics { "success queue collab", metrics.success_queue_collab_count.clone(), ); - realtime_registry.register( - "total_queue_collab_count", - "total queue pending collab", - metrics.total_queue_collab_count.clone(), - ); metrics } @@ -206,13 +199,12 @@ impl CollabMetrics { self.total_write_snapshot_count.set(total_attempt); } - pub fn record_write_collab(&self, success_attempt: i64, total_attempt: i64) { - self.success_write_collab_count.set(success_attempt); - self.total_write_collab_count.set(total_attempt); + pub fn record_write_collab(&self, success_attempt: u64, total_attempt: u64) { + self.success_write_collab_count.inc_by(success_attempt); + self.total_write_collab_count.inc_by(total_attempt); } - pub fn record_queue_collab(&self, success_attempt: i64, total_attempt: i64) { - self.success_queue_collab_count.set(success_attempt); - self.total_queue_collab_count.set(total_attempt); + pub fn record_queue_collab(&self, attempt: u64) { + self.success_queue_collab_count.inc_by(attempt); } } diff --git a/src/application.rs b/src/application.rs index 311f2c82a..f91947c79 100644 --- a/src/application.rs +++ b/src/application.rs @@ -299,7 +299,6 @@ pub async fn init_state(config: &Config, rt_cmd_tx: CLCommandSender) -> Result Date: Sun, 17 Nov 2024 14:26:59 +0800 Subject: [PATCH 11/20] chore: adjust redis env (#1003) --- .github/workflows/integration_test.yml | 1 + deploy.env | 1 + dev.env | 1 + docker-compose-ci.yml | 2 +- docker-compose-dev.yml | 1 + docker-compose.yml | 1 + .../src/indexer/document_indexer.rs | 12 ++++++++---- 7 files changed, 14 insertions(+), 5 deletions(-) diff --git a/.github/workflows/integration_test.yml b/.github/workflows/integration_test.yml index 5b9e028fc..079111e96 100644 --- a/.github/workflows/integration_test.yml +++ b/.github/workflows/integration_test.yml @@ -24,6 +24,7 @@ env: LOCALHOST_URL: http://localhost LOCALHOST_WS: ws://localhost/ws/v1 APPFLOWY_REDIS_URI: redis://redis:6379 + APPFLOWY_AI_REDIS_URL: redis://redis:6379 LOCALHOST_GOTRUE: http://localhost/gotrue POSTGRES_PASSWORD: password DATABASE_URL: postgres://postgres:password@localhost:5432/postgres diff --git a/deploy.env b/deploy.env index c1799af87..afae1ebfa 100644 --- a/deploy.env +++ b/deploy.env @@ -146,6 +146,7 @@ APPFLOWY_AI_OPENAI_API_KEY= APPFLOWY_AI_SERVER_PORT=5001 APPFLOWY_AI_SERVER_HOST=ai APPFLOWY_AI_DATABASE_URL=postgresql+psycopg://${POSTGRES_USER}:${POSTGRES_PASSWORD}@${POSTGRES_HOST}:${POSTGRES_PORT}/${POSTGRES_DB} +APPFLOWY_AI_REDIS_URL=redis://${REDIS_HOST}:${REDIS_PORT} APPFLOWY_LOCAL_AI_TEST_ENABLED=false # AppFlowy Indexer diff --git a/dev.env b/dev.env index fbe59882a..57c52fcf2 100644 --- a/dev.env +++ b/dev.env @@ -113,6 +113,7 @@ APPFLOWY_AI_OPENAI_API_KEY= APPFLOWY_AI_SERVER_PORT=5001 APPFLOWY_AI_SERVER_HOST=localhost APPFLOWY_AI_DATABASE_URL=postgresql+psycopg://postgres:password@postgres:5432/postgres +APPFLOWY_AI_REDIS_URL=redis://redis:6379 APPFLOWY_LOCAL_AI_TEST_ENABLED=false # AppFlowy Indexer diff --git a/docker-compose-ci.yml b/docker-compose-ci.yml index 36356a98e..b56851b98 100644 --- a/docker-compose-ci.yml +++ b/docker-compose-ci.yml @@ -148,7 +148,7 @@ services: - LOCAL_AI_AWS_SECRET_ACCESS_KEY=${LOCAL_AI_AWS_SECRET_ACCESS_KEY} - APPFLOWY_AI_SERVER_PORT=${APPFLOWY_AI_SERVER_PORT} - APPFLOWY_AI_DATABASE_URL=${APPFLOWY_AI_DATABASE_URL} - - APPFLOWY_AI_REDIS_URL=${APPFLOWY_REDIS_URI} + - APPFLOWY_AI_REDIS_URL=${APPFLOWY_AI_REDIS_URL} appflowy_worker: restart: on-failure diff --git a/docker-compose-dev.yml b/docker-compose-dev.yml index 33dd18f41..1f314aae4 100644 --- a/docker-compose-dev.yml +++ b/docker-compose-dev.yml @@ -122,6 +122,7 @@ services: - OPENAI_API_KEY=${APPFLOWY_AI_OPENAI_API_KEY} - APPFLOWY_AI_SERVER_PORT=${APPFLOWY_AI_SERVER_PORT} - APPFLOWY_AI_DATABASE_URL=${APPFLOWY_AI_DATABASE_URL} + - APPFLOWY_AI_REDIS_URL=${APPFLOWY_AI_REDIS_URL} volumes: postgres_data: diff --git a/docker-compose.yml b/docker-compose.yml index 54ac4faff..b39056cd8 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -140,6 +140,7 @@ services: - OPENAI_API_KEY=${APPFLOWY_AI_OPENAI_API_KEY} - APPFLOWY_AI_SERVER_PORT=${APPFLOWY_AI_SERVER_PORT} - APPFLOWY_AI_DATABASE_URL=${APPFLOWY_AI_DATABASE_URL} + - APPFLOWY_AI_REDIS_URL=${APPFLOWY_AI_REDIS_URL} appflowy_worker: restart: on-failure diff --git a/services/appflowy-collaborate/src/indexer/document_indexer.rs b/services/appflowy-collaborate/src/indexer/document_indexer.rs index d722598f6..d4d5e392b 100644 --- a/services/appflowy-collaborate/src/indexer/document_indexer.rs +++ b/services/appflowy-collaborate/src/indexer/document_indexer.rs @@ -25,15 +25,21 @@ pub struct DocumentIndexer { ai_client: AppFlowyAIClient, tokenizer: Arc, embedding_model: EmbeddingModel, + use_tiktoken: bool, } impl DocumentIndexer { pub fn new(ai_client: AppFlowyAIClient) -> Arc { let tokenizer = tiktoken_rs::cl100k_base().unwrap(); + let use_tiktoken = get_env_var("APPFLOWY_AI_CONTENT_SPLITTER_TIKTOKEN", "false") + .parse::() + .unwrap_or(false); + Arc::new(Self { ai_client, tokenizer: Arc::new(tokenizer), embedding_model: EmbeddingModel::TextEmbedding3Small, + use_tiktoken, }) } } @@ -62,6 +68,7 @@ impl Indexer for DocumentIndexer { CollabType::Document, &self.embedding_model, self.tokenizer.clone(), + self.use_tiktoken, ) .await }, @@ -136,11 +143,8 @@ async fn create_embedding( collab_type: CollabType, embedding_model: &EmbeddingModel, tokenizer: Arc, + use_tiktoken: bool, ) -> Result, AppError> { - let use_tiktoken = get_env_var("APPFLOWY_AI_CONTENT_SPLITTER_TIKTOKEN", "false") - .parse::() - .unwrap_or(false); - let split_contents = if use_tiktoken { let max_tokens = embedding_model.default_dimensions() as usize; if content.len() < 500 { From 51bd650644766c52fdc447322ff8b33b07cdaf1a Mon Sep 17 00:00:00 2001 From: Khor Shu Heng <32997938+khorshuheng@users.noreply.github.com> Date: Mon, 18 Nov 2024 09:42:57 +0800 Subject: [PATCH 12/20] chore: fix flaky test for get section items (#1004) --- tests/workspace/workspace_folder.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/workspace/workspace_folder.rs b/tests/workspace/workspace_folder.rs index e5478b680..14831749f 100644 --- a/tests/workspace/workspace_folder.rs +++ b/tests/workspace/workspace_folder.rs @@ -1,7 +1,10 @@ +use std::time::Duration; + use client_api::entity::{CreateCollabParams, QueryCollabParams}; use client_api_test::generate_unique_registered_user_client; use collab::core::origin::CollabClient; use collab_folder::{CollabOrigin, Folder}; +use tokio::time::sleep; #[tokio::test] async fn get_workpace_folder() { @@ -83,6 +86,8 @@ async fn get_section_items() { }) .await .unwrap(); + // Collab update is performed asynchronously via a queue + sleep(Duration::from_secs(1)).await; let favorite_section_items = c.get_workspace_favorite(&workspace_id).await.unwrap(); assert_eq!(favorite_section_items.views.len(), 1); assert_eq!( From 2647d41f3a9f3ef53861788f8aef5bb9cc5486ce Mon Sep 17 00:00:00 2001 From: Khor Shu Heng <32997938+khorshuheng@users.noreply.github.com> Date: Tue, 19 Nov 2024 11:46:23 +0800 Subject: [PATCH 13/20] feat: allow create page API to accept a view name (#1005) --- libs/shared-entity/src/dto/workspace_dto.rs | 1 + src/api/workspace.rs | 1 + src/biz/workspace/page_view.rs | 17 ++++- tests/workspace/page_view.rs | 4 +- tests/workspace/workspace_folder.rs | 73 --------------------- 5 files changed, 20 insertions(+), 76 deletions(-) diff --git a/libs/shared-entity/src/dto/workspace_dto.rs b/libs/shared-entity/src/dto/workspace_dto.rs index af56fc2e2..67d46ffed 100644 --- a/libs/shared-entity/src/dto/workspace_dto.rs +++ b/libs/shared-entity/src/dto/workspace_dto.rs @@ -132,6 +132,7 @@ pub struct Page { pub struct CreatePageParams { pub parent_view_id: String, pub layout: ViewLayout, + pub name: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/src/api/workspace.rs b/src/api/workspace.rs index be368858b..850603458 100644 --- a/src/api/workspace.rs +++ b/src/api/workspace.rs @@ -916,6 +916,7 @@ async fn post_page_view_handler( workspace_uuid, &payload.parent_view_id, &payload.layout, + payload.name.as_deref(), ) .await?; Ok(Json(AppResponse::Ok().with_data(page))) diff --git a/src/biz/workspace/page_view.rs b/src/biz/workspace/page_view.rs index b6e1fd9ae..5c837174c 100644 --- a/src/biz/workspace/page_view.rs +++ b/src/biz/workspace/page_view.rs @@ -49,13 +49,22 @@ pub async fn create_page( workspace_id: Uuid, parent_view_id: &str, view_layout: &ViewLayout, + name: Option<&str>, ) -> Result { if *view_layout != ViewLayout::Document { return Err(AppError::InvalidRequest( "Only document layout is supported for page creation".to_string(), )); } - create_document_page(pg_pool, collab_storage, uid, workspace_id, parent_view_id).await + create_document_page( + pg_pool, + collab_storage, + uid, + workspace_id, + parent_view_id, + name, + ) + .await } fn prepare_default_document_collab_param() -> Result { @@ -80,10 +89,12 @@ async fn add_new_view_to_folder( parent_view_id: &str, view_id: &str, folder: &mut Folder, + name: Option<&str>, ) -> Result { let encoded_update = { let view = NestedChildViewBuilder::new(uid, parent_view_id.to_string()) .with_view_id(view_id) + .with_name(name.unwrap_or_default()) .build() .view; let mut txn = folder.collab.transact_mut(); @@ -234,13 +245,15 @@ async fn create_document_page( uid: i64, workspace_id: Uuid, parent_view_id: &str, + name: Option<&str>, ) -> Result { let default_document_collab_params = prepare_default_document_collab_param()?; let view_id = default_document_collab_params.object_id.clone(); let collab_origin = GetCollabOrigin::User { uid }; let mut folder = get_latest_collab_folder(collab_storage, collab_origin, &workspace_id.to_string()).await?; - let folder_update = add_new_view_to_folder(uid, parent_view_id, &view_id, &mut folder).await?; + let folder_update = + add_new_view_to_folder(uid, parent_view_id, &view_id, &mut folder, name).await?; let mut transaction = pg_pool.begin().await?; let action = format!("Create new collab: {}", view_id); collab_storage diff --git a/tests/workspace/page_view.rs b/tests/workspace/page_view.rs index d4e922395..461c83d61 100644 --- a/tests/workspace/page_view.rs +++ b/tests/workspace/page_view.rs @@ -100,6 +100,7 @@ async fn create_new_document_page() { &CreatePageParams { parent_view_id: general_space.view_id.clone(), layout: ViewLayout::Document, + name: Some("New document".to_string()), }, ) .await @@ -114,11 +115,12 @@ async fn create_new_document_page() { .into_iter() .find(|v| v.name == "General") .unwrap(); - general_space + let view = general_space .children .iter() .find(|v| v.view_id == page.view_id) .unwrap(); + assert_eq!(view.name, "New document"); c.get_collab(QueryCollabParams { workspace_id: workspace_id.to_string(), inner: QueryCollab { diff --git a/tests/workspace/workspace_folder.rs b/tests/workspace/workspace_folder.rs index 14831749f..8b3738c13 100644 --- a/tests/workspace/workspace_folder.rs +++ b/tests/workspace/workspace_folder.rs @@ -1,10 +1,4 @@ -use std::time::Duration; - -use client_api::entity::{CreateCollabParams, QueryCollabParams}; use client_api_test::generate_unique_registered_user_client; -use collab::core::origin::CollabClient; -use collab_folder::{CollabOrigin, Folder}; -use tokio::time::sleep; #[tokio::test] async fn get_workpace_folder() { @@ -37,70 +31,3 @@ async fn get_workpace_folder() { .unwrap(); assert_eq!(folder_view.children.len(), 2); } - -#[tokio::test] -async fn get_section_items() { - let (c, _user) = generate_unique_registered_user_client().await; - let user_workspace_info = c.get_user_workspace_info().await.unwrap(); - let workspaces = c.get_workspaces().await.unwrap(); - assert_eq!(workspaces.len(), 1); - let workspace_id = workspaces[0].workspace_id.to_string(); - let folder_collab = c - .get_collab(QueryCollabParams::new( - workspace_id.clone(), - collab_entity::CollabType::Folder, - workspace_id.clone(), - )) - .await - .unwrap() - .encode_collab; - let uid = user_workspace_info.user_profile.uid; - let mut folder = Folder::from_collab_doc_state( - uid, - CollabOrigin::Client(CollabClient::new(uid, c.device_id.clone())), - folder_collab.into(), - &workspace_id, - vec![], - ) - .unwrap(); - let views = folder.get_views_belong_to(&workspace_id); - let new_favorite_id = views[0].children[0].id.clone(); - let to_be_deleted_favorite_id = views[0].children[1].id.clone(); - folder.add_favorite_view_ids(vec![ - new_favorite_id.clone(), - to_be_deleted_favorite_id.clone(), - ]); - folder.add_trash_view_ids(vec![to_be_deleted_favorite_id.clone()]); - let recent_id = folder.get_views_belong_to(&new_favorite_id)[0].id.clone(); - folder.add_recent_view_ids(vec![recent_id.clone()]); - let collab_type = collab_entity::CollabType::Folder; - c.update_collab(CreateCollabParams { - workspace_id: workspace_id.clone(), - collab_type: collab_type.clone(), - object_id: workspace_id.clone(), - encoded_collab_v1: folder - .encode_collab_v1(|collab| collab_type.validate_require_data(collab)) - .unwrap() - .encode_to_bytes() - .unwrap(), - }) - .await - .unwrap(); - // Collab update is performed asynchronously via a queue - sleep(Duration::from_secs(1)).await; - let favorite_section_items = c.get_workspace_favorite(&workspace_id).await.unwrap(); - assert_eq!(favorite_section_items.views.len(), 1); - assert_eq!( - favorite_section_items.views[0].view.view_id, - new_favorite_id - ); - let trash_section_items = c.get_workspace_trash(&workspace_id).await.unwrap(); - assert_eq!(trash_section_items.views.len(), 1); - assert_eq!( - trash_section_items.views[0].view.view_id, - to_be_deleted_favorite_id - ); - let recent_section_items = c.get_workspace_recent(&workspace_id).await.unwrap(); - assert_eq!(recent_section_items.views.len(), 1); - assert_eq!(recent_section_items.views[0].view.view_id, recent_id); -} From 0818cf7565451d196bafefc29860f1d78092c3c9 Mon Sep 17 00:00:00 2001 From: Khor Shu Heng <32997938+khorshuheng@users.noreply.github.com> Date: Tue, 19 Nov 2024 12:54:10 +0800 Subject: [PATCH 14/20] feat: api for create space (#1006) --- libs/client-api/src/http_view.rs | 19 +++- libs/shared-entity/src/dto/workspace_dto.rs | 20 +++++ src/api/workspace.rs | 27 +++++- src/biz/collab/folder_view.rs | 11 ++- src/biz/workspace/page_view.rs | 96 ++++++++++++++++++++- tests/workspace/page_view.rs | 74 +++++++++++++++- 6 files changed, 240 insertions(+), 7 deletions(-) diff --git a/libs/client-api/src/http_view.rs b/libs/client-api/src/http_view.rs index 14101253d..a5b55cd6f 100644 --- a/libs/client-api/src/http_view.rs +++ b/libs/client-api/src/http_view.rs @@ -1,4 +1,6 @@ -use client_api_entity::workspace_dto::{CreatePageParams, Page, PageCollab, UpdatePageParams}; +use client_api_entity::workspace_dto::{ + CreatePageParams, CreateSpaceParams, Page, PageCollab, Space, UpdatePageParams, +}; use reqwest::Method; use serde_json::json; use shared_entity::response::{AppResponse, AppResponseError}; @@ -112,4 +114,19 @@ impl Client { .await? .into_data() } + + pub async fn create_space( + &self, + workspace_id: Uuid, + params: &CreateSpaceParams, + ) -> Result { + let url = format!("{}/api/workspace/{}/space", self.base_url, workspace_id,); + let resp = self + .http_client_with_auth(Method::POST, &url) + .await? + .json(params) + .send() + .await?; + AppResponse::::from_response(resp).await?.into_data() + } } diff --git a/libs/shared-entity/src/dto/workspace_dto.rs b/libs/shared-entity/src/dto/workspace_dto.rs index 67d46ffed..63eae5141 100644 --- a/libs/shared-entity/src/dto/workspace_dto.rs +++ b/libs/shared-entity/src/dto/workspace_dto.rs @@ -123,11 +123,24 @@ pub struct CollabResponse { pub object_id: String, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Space { + pub view_id: String, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Page { pub view_id: String, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CreateSpaceParams { + pub space_permission: SpacePermission, + pub name: String, + pub space_icon: String, + pub space_icon_color: String, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CreatePageParams { pub parent_view_id: String, @@ -259,6 +272,13 @@ impl Default for ViewLayout { } } +#[derive(Eq, PartialEq, Debug, Hash, Clone, Serialize_repr, Deserialize_repr)] +#[repr(u8)] +pub enum SpacePermission { + PublicToAll = 0, + Private = 1, +} + #[derive(Default, Debug, Deserialize, Serialize)] pub struct QueryWorkspaceParam { pub include_member_count: Option, diff --git a/src/api/workspace.rs b/src/api/workspace.rs index 850603458..97d12bcb3 100644 --- a/src/api/workspace.rs +++ b/src/api/workspace.rs @@ -49,8 +49,8 @@ use crate::biz::workspace::ops::{ get_reactions_on_published_view, remove_comment_on_published_view, remove_reaction_on_comment, }; use crate::biz::workspace::page_view::{ - create_page, get_page_view_collab, move_page_to_trash, restore_all_pages_from_trash, - restore_page_from_trash, update_page, update_page_collab_data, + create_page, create_space, get_page_view_collab, move_page_to_trash, + restore_all_pages_from_trash, restore_page_from_trash, update_page, update_page_collab_data, }; use crate::biz::workspace::publish::get_workspace_default_publish_view_info_meta; use crate::domain::compression::{ @@ -127,6 +127,7 @@ pub fn workspace_scope() -> Scope { web::resource("/v1/{workspace_id}/collab/{object_id}/web-update") .route(web::post().to(post_web_update_handler)), ) + .service(web::resource("/{workspace_id}/space").route(web::post().to(post_space_handler))) .service( web::resource("/{workspace_id}/page-view").route(web::post().to(post_page_view_handler)), ) @@ -901,6 +902,28 @@ async fn post_web_update_handler( Ok(Json(AppResponse::Ok())) } +async fn post_space_handler( + user_uuid: UserUuid, + path: web::Path, + payload: Json, + state: Data, +) -> Result>> { + let uid = state.user_cache.get_user_uid(&user_uuid).await?; + let workspace_uuid = path.into_inner(); + let space = create_space( + &state.pg_pool, + &state.collab_access_control_storage, + uid, + workspace_uuid, + &payload.space_permission, + &payload.name, + &payload.space_icon, + &payload.space_icon_color, + ) + .await?; + Ok(Json(AppResponse::Ok().with_data(space))) +} + async fn post_page_view_handler( user_uuid: UserUuid, path: web::Path, diff --git a/src/biz/collab/folder_view.rs b/src/biz/collab/folder_view.rs index 783b22a7e..46b77b8b2 100644 --- a/src/biz/collab/folder_view.rs +++ b/src/biz/collab/folder_view.rs @@ -2,7 +2,9 @@ use std::collections::HashSet; use app_error::AppError; use chrono::DateTime; -use collab_folder::{Folder, SectionItem, ViewLayout as CollabFolderViewLayout}; +use collab_folder::{ + hierarchy_builder::SpacePermission, Folder, SectionItem, ViewLayout as CollabFolderViewLayout, +}; use shared_entity::dto::workspace_dto::{ self, FavoriteFolderView, FolderView, FolderViewMinimal, RecentFolderView, TrashFolderView, ViewLayout, @@ -301,3 +303,10 @@ pub fn to_folder_view_layout(layout: workspace_dto::ViewLayout) -> collab_folder ViewLayout::Chat => collab_folder::ViewLayout::Chat, } } + +pub fn to_space_permission(space_permission: &workspace_dto::SpacePermission) -> SpacePermission { + match space_permission { + workspace_dto::SpacePermission::PublicToAll => SpacePermission::PublicToAll, + workspace_dto::SpacePermission::Private => SpacePermission::Private, + } +} diff --git a/src/biz/workspace/page_view.rs b/src/biz/workspace/page_view.rs index 5c837174c..ef61a2024 100644 --- a/src/biz/workspace/page_view.rs +++ b/src/biz/workspace/page_view.rs @@ -16,8 +16,9 @@ use database::user::select_web_user_from_uid; use database_entity::dto::{CollabParams, QueryCollab, QueryCollabParams, QueryCollabResult}; use itertools::Itertools; use rayon::iter::{IntoParallelIterator, ParallelIterator}; +use serde_json::json; use shared_entity::dto::workspace_dto::{ - FolderView, Page, PageCollab, PageCollabData, ViewIcon, ViewLayout, + FolderView, Page, PageCollab, PageCollabData, Space, SpacePermission, ViewIcon, ViewLayout, }; use sqlx::{PgPool, Transaction}; use std::collections::{HashMap, HashSet}; @@ -29,6 +30,7 @@ use yrs::Update; use crate::api::metrics::AppFlowyWebMetrics; use crate::biz::collab::folder_view::{ parse_extra_field_as_json, to_dto_view_icon, to_dto_view_layout, to_folder_view_icon, + to_space_permission, }; use crate::biz::collab::{ folder_view::view_is_space, @@ -42,6 +44,56 @@ struct FolderUpdate { pub encoded_updates: Vec, } +#[allow(clippy::too_many_arguments)] +pub async fn create_space( + pg_pool: &PgPool, + collab_storage: &CollabAccessControlStorage, + uid: i64, + workspace_id: Uuid, + space_permission: &SpacePermission, + name: &str, + space_icon: &str, + space_color: &str, +) -> Result { + let default_document_collab_params = prepare_default_document_collab_param()?; + let view_id = default_document_collab_params.object_id.clone(); + let collab_origin = GetCollabOrigin::User { uid }; + let mut folder = + get_latest_collab_folder(collab_storage, collab_origin, &workspace_id.to_string()).await?; + let folder_update = add_new_space_to_folder( + uid, + &workspace_id.to_string(), + &view_id, + &mut folder, + space_permission, + name, + space_icon, + space_color, + ) + .await?; + let mut transaction = pg_pool.begin().await?; + let action = format!("Create new space: {}", view_id); + collab_storage + .insert_new_collab_with_transaction( + &workspace_id.to_string(), + &uid, + default_document_collab_params, + &mut transaction, + &action, + ) + .await?; + insert_and_broadcast_workspace_folder_update( + uid, + workspace_id, + folder_update, + collab_storage, + &mut transaction, + ) + .await?; + transaction.commit().await?; + Ok(Space { view_id }) +} + pub async fn create_page( pg_pool: &PgPool, collab_storage: &CollabAccessControlStorage, @@ -84,6 +136,47 @@ fn prepare_default_document_collab_param() -> Result { }) } +#[allow(clippy::too_many_arguments)] +async fn add_new_space_to_folder( + uid: i64, + workspace_id: &str, + view_id: &str, + folder: &mut Folder, + space_permission: &SpacePermission, + name: &str, + space_icon: &str, + space_color: &str, +) -> Result { + let encoded_update = { + let view = NestedChildViewBuilder::new(uid, workspace_id.to_string()) + .with_view_id(view_id) + .with_name(name) + .with_extra(|builder| { + let mut extra = builder + .is_space(true, to_space_permission(space_permission)) + .build(); + extra["space_icon_color"] = json!(space_color); + extra["space_icon"] = json!(space_icon); + extra + }) + .build() + .view; + let mut txn = folder.collab.transact_mut(); + folder.body.views.insert(&mut txn, view, None); + if *space_permission == SpacePermission::Private { + folder + .body + .views + .update_view(&mut txn, view_id, |update| update.set_private(true).done()); + } + txn.encode_update_v1() + }; + Ok(FolderUpdate { + updated_encoded_collab: folder_to_encoded_collab(folder)?, + encoded_updates: encoded_update, + }) +} + async fn add_new_view_to_folder( uid: i64, parent_view_id: &str, @@ -99,6 +192,7 @@ async fn add_new_view_to_folder( .view; let mut txn = folder.collab.transact_mut(); folder.body.views.insert(&mut txn, view, None); + txn.encode_update_v1() }; Ok(FolderUpdate { diff --git a/tests/workspace/page_view.rs b/tests/workspace/page_view.rs index 461c83d61..88734fed7 100644 --- a/tests/workspace/page_view.rs +++ b/tests/workspace/page_view.rs @@ -7,9 +7,10 @@ use client_api_test::{ use collab::{core::origin::CollabClient, preclude::Collab}; use collab_entity::CollabType; use collab_folder::{CollabOrigin, Folder}; -use serde_json::json; +use serde_json::{json, Value}; use shared_entity::dto::workspace_dto::{ - CreatePageParams, IconType, UpdatePageParams, ViewIcon, ViewLayout, + CreatePageParams, CreateSpaceParams, IconType, SpacePermission, UpdatePageParams, ViewIcon, + ViewLayout, }; use tokio::time::sleep; use uuid::Uuid; @@ -283,3 +284,72 @@ async fn update_page() { Some(json!({"is_pinned": true}).to_string()) ); } + +#[tokio::test] +async fn create_space() { + let registered_user = generate_unique_registered_user().await; + let mut app_client = TestClient::user_with_new_device(registered_user.clone()).await; + let web_client = TestClient::user_with_new_device(registered_user.clone()).await; + let workspace_id = app_client.workspace_id().await; + app_client.open_workspace_collab(&workspace_id).await; + app_client + .wait_object_sync_complete(&workspace_id) + .await + .unwrap(); + let workspace_uuid = Uuid::parse_str(&workspace_id).unwrap(); + let public_space = web_client + .api_client + .create_space( + workspace_uuid, + &CreateSpaceParams { + space_permission: SpacePermission::PublicToAll, + name: "Public Space".to_string(), + space_icon: "space_icon_1".to_string(), + space_icon_color: "0xFFA34AFD".to_string(), + }, + ) + .await + .unwrap(); + web_client + .api_client + .create_space( + workspace_uuid, + &CreateSpaceParams { + space_permission: SpacePermission::Private, + name: "Private Space".to_string(), + space_icon: "space_icon_2".to_string(), + space_icon_color: "0xFFA34AFD".to_string(), + }, + ) + .await + .unwrap(); + let folder = get_latest_folder(&app_client, &workspace_id).await; + let view = folder.get_view(&public_space.view_id).unwrap(); + let space_info: Value = serde_json::from_str(view.extra.as_ref().unwrap()).unwrap(); + assert!(space_info["is_space"].as_bool().unwrap()); + assert_eq!( + space_info["space_permission"].as_u64().unwrap() as u8, + SpacePermission::PublicToAll as u8 + ); + assert_eq!(space_info["space_icon"].as_str().unwrap(), "space_icon_1"); + assert_eq!( + space_info["space_icon_color"].as_str().unwrap(), + "0xFFA34AFD" + ); + let folder_view = web_client + .api_client + .get_workspace_folder(&workspace_id, Some(2), Some(workspace_id.to_string())) + .await + .unwrap(); + folder_view + .children + .iter() + .find(|v| v.name == "Public Space") + .unwrap(); + let private_space = folder_view + .children + .iter() + .find(|v| v.name == "Private Space") + .unwrap(); + assert!(private_space.is_private); +} From 00a6189cf332bc60e788ba9e21dba5a68c3086c0 Mon Sep 17 00:00:00 2001 From: khorshuheng Date: Tue, 19 Nov 2024 12:58:07 +0800 Subject: [PATCH 15/20] fix: stop loading collab policies to improve access control evaluation --- libs/access-control/src/casbin/adapter.rs | 30 ----------------------- 1 file changed, 30 deletions(-) diff --git a/libs/access-control/src/casbin/adapter.rs b/libs/access-control/src/casbin/adapter.rs index 629c80802..72026639f 100644 --- a/libs/access-control/src/casbin/adapter.rs +++ b/libs/access-control/src/casbin/adapter.rs @@ -7,8 +7,6 @@ use casbin::Filter; use casbin::Model; use casbin::Result; -use database::collab::select_collab_member_access_level; -use database::pg_row::AFCollabMemberAccessLevelRow; use database::pg_row::AFWorkspaceMemberPermRow; use database::workspace::select_workspace_member_perm_stream; @@ -35,28 +33,6 @@ impl PgAdapter { } } -async fn load_collab_policies( - mut stream: BoxStream<'_, sqlx::Result>, -) -> Result>> { - let mut policies: Vec> = Vec::new(); - - while let Some(Ok(member_access_lv)) = stream.next().await { - let uid = member_access_lv.uid; - let object_type = ObjectType::Collab(&member_access_lv.oid); - for act in member_access_lv.access_level.policy_acts() { - let policy = [ - uid.to_string(), - object_type.policy_object(), - act.to_string(), - ] - .to_vec(); - policies.push(policy); - } - } - - Ok(policies) -} - /// Loads workspace policies from a given stream of workspace member permissions. /// /// This function iterates over the stream of member permissions, constructing and accumulating @@ -128,12 +104,6 @@ impl Adapter for PgAdapter { // Policy definition `p` of type `p`. See `model.conf` model.add_policies("p", "p", workspace_policies); - let collab_member_access_lv_stream = select_collab_member_access_level(&self.pg_pool); - let collab_policies = load_collab_policies(collab_member_access_lv_stream).await?; - - // Policy definition `p` of type `p`. See `model.conf` - model.add_policies("p", "p", collab_policies); - self .access_control_metrics .record_load_all_policies_in_ms(start.elapsed().as_millis() as u64); From 40835f00d137bce3c38b8b462e44ea448b56e962 Mon Sep 17 00:00:00 2001 From: khorshuheng Date: Tue, 19 Nov 2024 21:11:18 +0800 Subject: [PATCH 16/20] fix: revert changes to casbin matcher --- libs/access-control/src/casbin/access.rs | 2 +- libs/access-control/src/casbin/enforcer.rs | 42 ---------------------- 2 files changed, 1 insertion(+), 43 deletions(-) diff --git a/libs/access-control/src/casbin/access.rs b/libs/access-control/src/casbin/access.rs index 4fc5feb92..6b42ecdac 100644 --- a/libs/access-control/src/casbin/access.rs +++ b/libs/access-control/src/casbin/access.rs @@ -159,7 +159,7 @@ g = _, _ # grouping rule e = some(where (p.eft == allow)) [matchers] -m = g(r.sub, p.sub) && p.obj == r.obj && (g(p.act, r.act) || cmpRoleOrLevel(r.act, p.act)) +m = r.sub == p.sub && p.obj == r.obj && (g(p.act, r.act) || cmpRoleOrLevel(r.act, p.act)) "###; pub async fn casbin_model() -> Result { diff --git a/libs/access-control/src/casbin/enforcer.rs b/libs/access-control/src/casbin/enforcer.rs index f16c6b6b8..ef7b1c9e4 100644 --- a/libs/access-control/src/casbin/enforcer.rs +++ b/libs/access-control/src/casbin/enforcer.rs @@ -223,48 +223,6 @@ mod tests { AFEnforcer::new(enforcer).await.unwrap() } - #[tokio::test] - async fn collab_group_test() { - let enforcer = test_enforcer().await; - - let uid = 1; - let group_id = "collab_owner_group:w1"; - let workspace_id = "w1"; - let object_1 = "o1"; - - // allow workspace member to access collab - enforcer - .update_policy( - SubjectType::Group(group_id.to_string()), - ObjectType::Collab(object_1), - ActionVariant::FromAccessLevel(&AFAccessLevel::FullAccess), - ) - .await - .unwrap(); - - // include user in the collab owner group - enforcer - .add_grouping_policy( - &SubjectType::User(uid), - &SubjectType::Group(group_id.to_string()), - ) - .await - .unwrap(); - - // when the user is the owner of the collab, then the user should have access to the collab - for action in [Action::Write, Action::Read] { - let result = enforcer - .enforce_policy( - workspace_id, - &uid, - ObjectType::Collab(object_1), - ActionVariant::FromAction(&action), - ) - .await; - assert!(result.is_ok()); - } - } - #[tokio::test] async fn workspace_group_policy_test() { let enforcer = test_enforcer().await; From d2a82db3004ca8c5bdf2bed0b706c2dbd21e64fe Mon Sep 17 00:00:00 2001 From: nathan Date: Wed, 20 Nov 2024 10:53:20 +0800 Subject: [PATCH 17/20] chore: remove err log --- src/api/workspace.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/api/workspace.rs b/src/api/workspace.rs index 97d12bcb3..e8356e55a 100644 --- a/src/api/workspace.rs +++ b/src/api/workspace.rs @@ -607,7 +607,7 @@ async fn update_workspace_member_handler( Ok(AppResponse::Ok().into()) } -#[instrument(skip(state, payload), err)] +#[instrument(skip(state, payload))] async fn create_collab_handler( user_uuid: UserUuid, payload: Bytes, From e6dbc95641a3c24d90a30df52fcff82fd522d228 Mon Sep 17 00:00:00 2001 From: Khor Shu Heng <32997938+khorshuheng@users.noreply.github.com> Date: Wed, 20 Nov 2024 11:32:29 +0800 Subject: [PATCH 18/20] feat: api to update space (#1009) --- libs/client-api/src/http_view.rs | 21 +++++- libs/shared-entity/src/dto/workspace_dto.rs | 8 +++ src/api/workspace.rs | 27 ++++++++ src/biz/workspace/page_view.rs | 77 ++++++++++++++++++++- tests/workspace/page_view.rs | 29 +++++++- 5 files changed, 156 insertions(+), 6 deletions(-) diff --git a/libs/client-api/src/http_view.rs b/libs/client-api/src/http_view.rs index a5b55cd6f..61177a8dc 100644 --- a/libs/client-api/src/http_view.rs +++ b/libs/client-api/src/http_view.rs @@ -1,5 +1,5 @@ use client_api_entity::workspace_dto::{ - CreatePageParams, CreateSpaceParams, Page, PageCollab, Space, UpdatePageParams, + CreatePageParams, CreateSpaceParams, Page, PageCollab, Space, UpdatePageParams, UpdateSpaceParams, }; use reqwest::Method; use serde_json::json; @@ -129,4 +129,23 @@ impl Client { .await?; AppResponse::::from_response(resp).await?.into_data() } + + pub async fn update_space( + &self, + workspace_id: Uuid, + view_id: &str, + params: &UpdateSpaceParams, + ) -> Result<(), AppResponseError> { + let url = format!( + "{}/api/workspace/{}/space/{}", + self.base_url, workspace_id, view_id + ); + let resp = self + .http_client_with_auth(Method::PATCH, &url) + .await? + .json(params) + .send() + .await?; + AppResponse::<()>::from_response(resp).await?.into_error() + } } diff --git a/libs/shared-entity/src/dto/workspace_dto.rs b/libs/shared-entity/src/dto/workspace_dto.rs index 63eae5141..5671a11e7 100644 --- a/libs/shared-entity/src/dto/workspace_dto.rs +++ b/libs/shared-entity/src/dto/workspace_dto.rs @@ -141,6 +141,14 @@ pub struct CreateSpaceParams { pub space_icon_color: String, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UpdateSpaceParams { + pub space_permission: SpacePermission, + pub name: String, + pub space_icon: String, + pub space_icon_color: String, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CreatePageParams { pub parent_view_id: String, diff --git a/src/api/workspace.rs b/src/api/workspace.rs index e8356e55a..195cdd1a7 100644 --- a/src/api/workspace.rs +++ b/src/api/workspace.rs @@ -51,6 +51,7 @@ use crate::biz::workspace::ops::{ use crate::biz::workspace::page_view::{ create_page, create_space, get_page_view_collab, move_page_to_trash, restore_all_pages_from_trash, restore_page_from_trash, update_page, update_page_collab_data, + update_space, }; use crate::biz::workspace::publish::get_workspace_default_publish_view_info_meta; use crate::domain::compression::{ @@ -128,6 +129,9 @@ pub fn workspace_scope() -> Scope { .route(web::post().to(post_web_update_handler)), ) .service(web::resource("/{workspace_id}/space").route(web::post().to(post_space_handler))) + .service( + web::resource("/{workspace_id}/space/{view_id}").route(web::patch().to(update_space_handler)), + ) .service( web::resource("/{workspace_id}/page-view").route(web::post().to(post_page_view_handler)), ) @@ -924,6 +928,29 @@ async fn post_space_handler( Ok(Json(AppResponse::Ok().with_data(space))) } +async fn update_space_handler( + user_uuid: UserUuid, + path: web::Path<(Uuid, String)>, + payload: Json, + state: Data, +) -> Result>> { + let uid = state.user_cache.get_user_uid(&user_uuid).await?; + let (workspace_uuid, view_id) = path.into_inner(); + update_space( + &state.pg_pool, + &state.collab_access_control_storage, + uid, + workspace_uuid, + &view_id, + &payload.space_permission, + &payload.name, + &payload.space_icon, + &payload.space_icon_color, + ) + .await?; + Ok(Json(AppResponse::Ok())) +} + async fn post_page_view_handler( user_uuid: UserUuid, path: web::Path, diff --git a/src/biz/workspace/page_view.rs b/src/biz/workspace/page_view.rs index ef61a2024..c0c1a86f5 100644 --- a/src/biz/workspace/page_view.rs +++ b/src/biz/workspace/page_view.rs @@ -9,7 +9,7 @@ use collab_document::document::Document; use collab_document::document_data::default_document_data; use collab_entity::{CollabType, EncodedCollab}; use collab_folder::hierarchy_builder::NestedChildViewBuilder; -use collab_folder::{CollabOrigin, Folder}; +use collab_folder::{timestamp, CollabOrigin, Folder}; use database::collab::{select_workspace_database_oid, CollabStorage, GetCollabOrigin}; use database::publish::select_published_view_ids_for_workspace; use database::user::select_web_user_from_uid; @@ -44,6 +44,43 @@ struct FolderUpdate { pub encoded_updates: Vec, } +#[allow(clippy::too_many_arguments)] +pub async fn update_space( + pg_pool: &PgPool, + collab_storage: &CollabAccessControlStorage, + uid: i64, + workspace_id: Uuid, + view_id: &str, + space_permission: &SpacePermission, + name: &str, + space_icon: &str, + space_icon_color: &str, +) -> Result<(), AppError> { + let collab_origin = GetCollabOrigin::User { uid }; + let mut folder = + get_latest_collab_folder(collab_storage, collab_origin, &workspace_id.to_string()).await?; + let folder_update = update_space_properties( + view_id, + &mut folder, + space_permission, + name, + space_icon, + space_icon_color, + ) + .await?; + let mut transaction = pg_pool.begin().await?; + insert_and_broadcast_workspace_folder_update( + uid, + workspace_id, + folder_update, + collab_storage, + &mut transaction, + ) + .await?; + transaction.commit().await?; + Ok(()) +} + #[allow(clippy::too_many_arguments)] pub async fn create_space( pg_pool: &PgPool, @@ -145,7 +182,7 @@ async fn add_new_space_to_folder( space_permission: &SpacePermission, name: &str, space_icon: &str, - space_color: &str, + space_icon_color: &str, ) -> Result { let encoded_update = { let view = NestedChildViewBuilder::new(uid, workspace_id.to_string()) @@ -155,7 +192,7 @@ async fn add_new_space_to_folder( let mut extra = builder .is_space(true, to_space_permission(space_permission)) .build(); - extra["space_icon_color"] = json!(space_color); + extra["space_icon_color"] = json!(space_icon_color); extra["space_icon"] = json!(space_icon); extra }) @@ -177,6 +214,40 @@ async fn add_new_space_to_folder( }) } +async fn update_space_properties( + view_id: &str, + folder: &mut Folder, + space_permission: &SpacePermission, + name: &str, + space_icon: &str, + space_icon_color: &str, +) -> Result { + let encoded_update = { + let mut txn = folder.collab.transact_mut(); + folder.body.views.update_view(&mut txn, view_id, |update| { + let extra = json!({ + "is_space": true, + "space_permission": to_space_permission(space_permission) as u8, + "space_created_at": timestamp(), + "space_icon": space_icon, + "space_icon_color": space_icon_color, + }) + .to_string(); + let is_private = *space_permission == SpacePermission::Private; + update + .set_name(name) + .set_extra(&extra) + .set_private(is_private) + .done() + }); + txn.encode_update_v1() + }; + Ok(FolderUpdate { + updated_encoded_collab: folder_to_encoded_collab(folder)?, + encoded_updates: encoded_update, + }) +} + async fn add_new_view_to_folder( uid: i64, parent_view_id: &str, diff --git a/tests/workspace/page_view.rs b/tests/workspace/page_view.rs index 88734fed7..07193d628 100644 --- a/tests/workspace/page_view.rs +++ b/tests/workspace/page_view.rs @@ -9,8 +9,8 @@ use collab_entity::CollabType; use collab_folder::{CollabOrigin, Folder}; use serde_json::{json, Value}; use shared_entity::dto::workspace_dto::{ - CreatePageParams, CreateSpaceParams, IconType, SpacePermission, UpdatePageParams, ViewIcon, - ViewLayout, + CreatePageParams, CreateSpaceParams, IconType, SpacePermission, UpdatePageParams, + UpdateSpaceParams, ViewIcon, ViewLayout, }; use tokio::time::sleep; use uuid::Uuid; @@ -352,4 +352,29 @@ async fn create_space() { .find(|v| v.name == "Private Space") .unwrap(); assert!(private_space.is_private); + + web_client + .api_client + .update_space( + workspace_uuid, + &private_space.view_id, + &UpdateSpaceParams { + space_permission: SpacePermission::PublicToAll, + name: "Renamed Space".to_string(), + space_icon: "space_icon_3".to_string(), + space_icon_color: "#000000".to_string(), + }, + ) + .await + .unwrap(); + let folder = get_latest_folder(&app_client, &workspace_id).await; + let view = folder.get_view(&private_space.view_id).unwrap(); + let space_info: Value = serde_json::from_str(view.extra.as_ref().unwrap()).unwrap(); + assert!(space_info["is_space"].as_bool().unwrap()); + assert_eq!( + space_info["space_permission"].as_u64().unwrap() as u8, + SpacePermission::PublicToAll as u8 + ); + assert_eq!(space_info["space_icon"].as_str().unwrap(), "space_icon_3"); + assert_eq!(space_info["space_icon_color"].as_str().unwrap(), "#000000"); } From afeaeb77969b519bf606e121345b63b450ba8440 Mon Sep 17 00:00:00 2001 From: Khor Shu Heng <32997938+khorshuheng@users.noreply.github.com> Date: Wed, 20 Nov 2024 12:29:16 +0800 Subject: [PATCH 19/20] chore: simplify collab level access control (#1008) --- libs/access-control/src/casbin/collab.rs | 85 +++++++++++++++--------- libs/client-api-test/src/test_client.rs | 47 +------------ tests/collab/awareness_test.rs | 10 +-- tests/collab/missing_update_test.rs | 12 ++-- tests/collab/multi_devices_edit.rs | 15 ++--- tests/collab/permission_test.rs | 70 ++++++------------- tests/collab/single_device_edit.rs | 12 ++-- 7 files changed, 91 insertions(+), 160 deletions(-) diff --git a/libs/access-control/src/casbin/collab.rs b/libs/access-control/src/casbin/collab.rs index f64289a46..a50ffd811 100644 --- a/libs/access-control/src/casbin/collab.rs +++ b/libs/access-control/src/casbin/collab.rs @@ -6,7 +6,7 @@ use tracing::instrument; use crate::{ act::{Action, ActionVariant}, collab::{CollabAccessControl, RealtimeAccessControl}, - entity::{ObjectType, SubjectType}, + entity::ObjectType, }; use super::access::AccessControl; @@ -28,16 +28,25 @@ impl CollabAccessControl for CollabAccessControlImpl { &self, workspace_id: &str, uid: &i64, - oid: &str, + _oid: &str, action: Action, ) -> Result<(), AppError> { + // TODO: allow non workspace member to read a collab. + + // Anyone who can write to a workspace, can also delete a collab. + let workspace_action = match action { + Action::Read => Action::Read, + Action::Write => Action::Write, + Action::Delete => Action::Write, + }; + self .access_control .enforce( workspace_id, uid, - ObjectType::Collab(oid), - ActionVariant::FromAction(&action), + ObjectType::Workspace(workspace_id), + ActionVariant::FromAction(&workspace_action), ) .await } @@ -46,16 +55,26 @@ impl CollabAccessControl for CollabAccessControlImpl { &self, workspace_id: &str, uid: &i64, - oid: &str, + _oid: &str, access_level: AFAccessLevel, ) -> Result<(), AppError> { + // TODO: allow non workspace member to read a collab. + + // Anyone who can write to a workspace, also have full access to a collab. + let workspace_action = match access_level { + AFAccessLevel::ReadOnly => Action::Read, + AFAccessLevel::ReadAndComment => Action::Read, + AFAccessLevel::ReadAndWrite => Action::Write, + AFAccessLevel::FullAccess => Action::Write, + }; + self .access_control .enforce( workspace_id, uid, - ObjectType::Collab(oid), - ActionVariant::FromAccessLevel(&access_level), + ObjectType::Workspace(workspace_id), + ActionVariant::FromAction(&workspace_action), ) .await } @@ -63,28 +82,17 @@ impl CollabAccessControl for CollabAccessControlImpl { #[instrument(level = "info", skip_all)] async fn update_access_level_policy( &self, - uid: &i64, - oid: &str, - level: AFAccessLevel, + _uid: &i64, + _oid: &str, + _level: AFAccessLevel, ) -> Result<(), AppError> { - self - .access_control - .update_policy( - SubjectType::User(*uid), - ObjectType::Collab(oid), - ActionVariant::FromAccessLevel(&level), - ) - .await?; - + // TODO: allow non workspace member to read a collab. Ok(()) } #[instrument(level = "info", skip_all)] - async fn remove_access_level(&self, uid: &i64, oid: &str) -> Result<(), AppError> { - self - .access_control - .remove_policy(&SubjectType::User(*uid), &ObjectType::Collab(oid)) - .await?; + async fn remove_access_level(&self, _uid: &i64, _oid: &str) -> Result<(), AppError> { + // TODO: allow non workspace member to read a collab. Ok(()) } } @@ -103,20 +111,35 @@ impl RealtimeCollabAccessControlImpl { &self, workspace_id: &str, uid: &i64, - oid: &str, + _oid: &str, required_action: Action, ) -> Result { - self + // TODO: allow non workspace member to read a collab. + + // Anyone who can write to a workspace, can also delete a collab. + let workspace_action = match required_action { + Action::Read => Action::Read, + Action::Write => Action::Write, + Action::Delete => Action::Write, + }; + + let enforcement_result = self .access_control .enforce( workspace_id, uid, - ObjectType::Collab(oid), - ActionVariant::FromAction(&required_action), + ObjectType::Workspace(workspace_id), + ActionVariant::FromAction(&workspace_action), ) - .await?; - - Ok(true) + .await; + match enforcement_result { + Ok(_) => Ok(true), + Err(AppError::NotEnoughPermissions { + user: _user, + workspace_id: _workspace_id, + }) => Ok(false), + Err(e) => Err(e), + } } } diff --git a/libs/client-api-test/src/test_client.rs b/libs/client-api-test/src/test_client.rs index 1970d2a65..dcc0b7377 100644 --- a/libs/client-api-test/src/test_client.rs +++ b/libs/client-api-test/src/test_client.rs @@ -37,10 +37,9 @@ use client_api::entity::{ }; use client_api::ws::{WSClient, WSClientConfig}; use database_entity::dto::{ - AFAccessLevel, AFRole, AFSnapshotMeta, AFSnapshotMetas, AFUserProfile, AFUserWorkspaceInfo, - AFWorkspace, AFWorkspaceInvitationStatus, AFWorkspaceMember, BatchQueryCollabResult, - CollabParams, CreateCollabParams, InsertCollabMemberParams, QueryCollab, QueryCollabParams, - QuerySnapshotParams, SnapshotData, UpdateCollabMemberParams, + AFRole, AFSnapshotMeta, AFSnapshotMetas, AFUserProfile, AFUserWorkspaceInfo, AFWorkspace, + AFWorkspaceInvitationStatus, AFWorkspaceMember, BatchQueryCollabResult, CollabParams, + CreateCollabParams, QueryCollab, QueryCollabParams, QuerySnapshotParams, SnapshotData, }; use shared_entity::dto::workspace_dto::{ BlobMetadata, CollabResponse, PublishedDuplicate, WorkspaceMemberChangeset, @@ -441,46 +440,6 @@ impl TestClient { self.api_client.get_workspace_member(params).await } - pub async fn add_collab_member( - &self, - workspace_id: &str, - object_id: &str, - other_client: &TestClient, - access_level: AFAccessLevel, - ) { - let uid = other_client.uid().await; - self - .api_client - .add_collab_member(InsertCollabMemberParams { - uid, - workspace_id: workspace_id.to_string(), - object_id: object_id.to_string(), - access_level, - }) - .await - .unwrap(); - } - - pub async fn update_collab_member_access_level( - &self, - workspace_id: &str, - object_id: &str, - other_client: &TestClient, - access_level: AFAccessLevel, - ) { - let uid = other_client.uid().await; - self - .api_client - .update_collab_member(UpdateCollabMemberParams { - uid, - workspace_id: workspace_id.to_string(), - object_id: object_id.to_string(), - access_level, - }) - .await - .unwrap(); - } - pub async fn wait_object_sync_complete(&self, object_id: &str) -> Result<(), Error> { self .wait_object_sync_complete_with_secs(object_id, 60) diff --git a/tests/collab/awareness_test.rs b/tests/collab/awareness_test.rs index 169f1cdd4..d6f45f318 100644 --- a/tests/collab/awareness_test.rs +++ b/tests/collab/awareness_test.rs @@ -4,7 +4,7 @@ use collab_entity::CollabType; use tokio::time::sleep; use client_api_test::TestClient; -use database_entity::dto::{AFAccessLevel, AFRole}; +use database_entity::dto::AFRole; #[tokio::test] async fn viewing_document_editing_users_test() { @@ -27,14 +27,6 @@ async fn viewing_document_editing_users_test() { assert_eq!(clients.len(), 1); assert_eq!(clients[0], owner_uid); - owner - .add_collab_member( - &workspace_id, - &object_id, - &guest, - AFAccessLevel::ReadAndWrite, - ) - .await; guest .open_collab(&workspace_id, &object_id, collab_type) .await; diff --git a/tests/collab/missing_update_test.rs b/tests/collab/missing_update_test.rs index 545c8ec95..166b24808 100644 --- a/tests/collab/missing_update_test.rs +++ b/tests/collab/missing_update_test.rs @@ -1,11 +1,11 @@ use std::time::Duration; +use client_api::entity::AFRole; use collab_entity::CollabType; use serde_json::{json, Value}; use tokio::time::sleep; use client_api_test::{assert_client_collab_include_value, TestClient}; -use database_entity::dto::AFAccessLevel; #[tokio::test] async fn client_apply_update_find_missing_update_test() { @@ -56,13 +56,9 @@ async fn make_clients() -> (TestClient, TestClient, String, Value) { .create_and_edit_collab(&workspace_id, collab_type.clone()) .await; client_1 - .add_collab_member( - &workspace_id, - &object_id, - &client_2, - AFAccessLevel::ReadAndWrite, - ) - .await; + .invite_and_accepted_workspace_member(&workspace_id, &client_2, AFRole::Member) + .await + .unwrap(); // after client 2 finish init sync and then disable receive message client_2 diff --git a/tests/collab/multi_devices_edit.rs b/tests/collab/multi_devices_edit.rs index 92aee778a..907356ab1 100644 --- a/tests/collab/multi_devices_edit.rs +++ b/tests/collab/multi_devices_edit.rs @@ -1,5 +1,6 @@ use std::time::Duration; +use client_api::entity::AFRole; use collab_entity::CollabType; use serde_json::json; use sqlx::types::uuid; @@ -7,7 +8,7 @@ use tokio::time::sleep; use tracing::trace; use client_api_test::*; -use database_entity::dto::{AFAccessLevel, QueryCollabParams}; +use database_entity::dto::QueryCollabParams; #[tokio::test] async fn sync_collab_content_after_reconnect_test() { @@ -200,15 +201,11 @@ async fn edit_document_with_both_clients_offline_then_online_sync_test() { .create_and_edit_collab(&workspace_id, collab_type.clone()) .await; - // add client 2 as a member of the collab + // add client 2 as a member of the workspace client_1 - .add_collab_member( - &workspace_id, - &object_id, - &client_2, - AFAccessLevel::ReadAndWrite, - ) - .await; + .invite_and_accepted_workspace_member(&workspace_id, &client_2, AFRole::Member) + .await + .unwrap(); client_1.disconnect().await; client_2 diff --git a/tests/collab/permission_test.rs b/tests/collab/permission_test.rs index 5e0cea744..d2c9b3822 100644 --- a/tests/collab/permission_test.rs +++ b/tests/collab/permission_test.rs @@ -12,7 +12,7 @@ use client_api_test::{ assert_client_collab_include_value, assert_client_collab_within_secs, assert_server_collab, TestClient, }; -use database_entity::dto::{AFAccessLevel, AFRole}; +use database_entity::dto::AFRole; use crate::collab::util::generate_random_string; @@ -164,13 +164,9 @@ async fn edit_collab_with_readonly_permission_test() { // Add client 2 as the member of the collab then the client 2 will receive the update. client_1 - .add_collab_member( - &workspace_id, - &object_id, - &client_2, - AFAccessLevel::ReadOnly, - ) - .await; + .invite_and_accepted_workspace_member(&workspace_id, &client_2, AFRole::Guest) + .await + .unwrap(); client_2 .open_collab(&workspace_id, &object_id, collab_type.clone()) @@ -214,13 +210,9 @@ async fn edit_collab_with_read_and_write_permission_test() { // Add client 2 as the member of the collab then the client 2 will receive the update. client_1 - .add_collab_member( - &workspace_id, - &object_id, - &client_2, - AFAccessLevel::ReadAndWrite, - ) - .await; + .invite_and_accepted_workspace_member(&workspace_id, &client_2, AFRole::Member) + .await + .unwrap(); client_2 .open_collab(&workspace_id, &object_id, collab_type.clone()) @@ -265,13 +257,9 @@ async fn edit_collab_with_full_access_permission_test() { // Add client 2 as the member of the collab then the client 2 will receive the update. client_1 - .add_collab_member( - &workspace_id, - &object_id, - &client_2, - AFAccessLevel::FullAccess, - ) - .await; + .invite_and_accepted_workspace_member(&workspace_id, &client_2, AFRole::Member) + .await + .unwrap(); client_2 .open_collab(&workspace_id, &object_id, collab_type.clone()) @@ -314,13 +302,9 @@ async fn edit_collab_with_full_access_then_readonly_permission() { // Add client 2 as the member of the collab then the client 2 will receive the update. client_1 - .add_collab_member( - &workspace_id, - &object_id, - &client_2, - AFAccessLevel::FullAccess, - ) - .await; + .invite_and_accepted_workspace_member(&workspace_id, &client_2, AFRole::Member) + .await + .unwrap(); // client 2 edit the collab and then the server will broadcast the update { @@ -340,13 +324,9 @@ async fn edit_collab_with_full_access_then_readonly_permission() { // updates generated by client 2 { client_1 - .update_collab_member_access_level( - &workspace_id, - &object_id, - &client_2, - AFAccessLevel::ReadOnly, - ) - .await; + .try_update_workspace_member(&workspace_id, &client_2, AFRole::Guest) + .await + .unwrap(); client_2 .insert_into(&object_id, "subtitle", "Writing Rust, fun") .await; @@ -404,14 +384,6 @@ async fn multiple_user_with_read_and_write_permission_edit_same_collab_test() { .invite_and_accepted_workspace_member(&workspace_id, &new_member, AFRole::Member) .await .unwrap(); - owner - .add_collab_member( - &workspace_id, - &object_id, - &new_member, - AFAccessLevel::ReadAndWrite, - ) - .await; new_member .open_collab(&workspace_id, &object_id, collab_type.clone()) @@ -490,13 +462,9 @@ async fn multiple_user_with_read_only_permission_edit_same_collab_test() { // sleep 2 secs to make sure it do not trigger register user too fast in gotrue sleep(Duration::from_secs(i % 2)).await; owner - .add_collab_member( - &workspace_id, - &object_id, - &new_user, - AFAccessLevel::ReadOnly, - ) - .await; + .invite_and_accepted_workspace_member(&workspace_id, &new_user, AFRole::Guest) + .await + .unwrap(); new_user .open_collab(&workspace_id, &object_id, collab_type.clone()) diff --git a/tests/collab/single_device_edit.rs b/tests/collab/single_device_edit.rs index 08a21c4fd..bb5a85adb 100644 --- a/tests/collab/single_device_edit.rs +++ b/tests/collab/single_device_edit.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use std::time::Duration; use assert_json_diff::assert_json_eq; +use client_api::entity::AFRole; use collab::core::origin::CollabOrigin; use collab_entity::CollabType; use serde_json::json; @@ -11,7 +12,6 @@ use uuid::Uuid; use client_api_test::*; use collab_rt_entity::{CollabMessage, RealtimeMessage, UpdateSync, MAXIMUM_REALTIME_MESSAGE_SIZE}; -use database_entity::dto::AFAccessLevel; use crate::collab::util::{ generate_random_bytes, generate_random_string, make_big_collab_doc_state, @@ -325,13 +325,9 @@ async fn two_direction_peer_sync_test() { // Before the client_2 want to edit the collab object, it needs to become a member of the collab // Otherwise, the server will reject the edit request client_1 - .add_collab_member( - &workspace_id, - &object_id, - &client_2, - AFAccessLevel::FullAccess, - ) - .await; + .invite_and_accepted_workspace_member(&workspace_id, &client_2, AFRole::Member) + .await + .unwrap(); client_2 .open_collab(&workspace_id, &object_id, collab_type.clone()) From 1e18180e9d03cb0fa8babd3940092acd5754cc20 Mon Sep 17 00:00:00 2001 From: "Nathan.fooo" <86001920+appflowy@users.noreply.github.com> Date: Wed, 20 Nov 2024 14:07:36 +0800 Subject: [PATCH 20/20] chore: set max import zip file size (#1011) * chore: set max import zip file size * chore: fix test --- services/appflowy-worker/src/application.rs | 7 ++++ services/appflowy-worker/src/error.rs | 30 +++++++++++++++++ .../src/import_worker/worker.rs | 33 +++++++++++++++++-- services/appflowy-worker/tests/import_test.rs | 2 ++ tests/search/document_search.rs | 7 +++- 5 files changed, 76 insertions(+), 3 deletions(-) diff --git a/services/appflowy-worker/src/application.rs b/services/appflowy-worker/src/application.rs index 55cc2106c..db82ca9c9 100644 --- a/services/appflowy-worker/src/application.rs +++ b/services/appflowy-worker/src/application.rs @@ -107,6 +107,12 @@ pub async fn create_app(listener: TcpListener, config: Config) -> Result<(), Err .parse::() .unwrap_or(10); + // Maximum file size for import + let maximum_import_file_size = + get_env_var("APPFLOWY_WORKER_MAX_IMPORT_FILE_SIZE", "1_000_000_000") + .parse::() + .unwrap_or(1_000_000_000); + let import_worker_fut = local_set.run_until(run_import_worker( state.pg_pool.clone(), state.redis_client.clone(), @@ -115,6 +121,7 @@ pub async fn create_app(listener: TcpListener, config: Config) -> Result<(), Err Arc::new(email_notifier), "import_task_stream", tick_interval, + maximum_import_file_size, )); let app = Router::new() diff --git a/services/appflowy-worker/src/error.rs b/services/appflowy-worker/src/error.rs index 77985ae65..781e52eb6 100644 --- a/services/appflowy-worker/src/error.rs +++ b/services/appflowy-worker/src/error.rs @@ -40,6 +40,15 @@ pub enum ImportError { #[error("Upload file expired")] UploadFileExpire, + #[error("Please upgrade to the latest version of the app")] + UpgradeToLatestVersion(String), + + #[error("Upload file too large")] + UploadFileTooLarge { + file_size_in_mb: f64, + max_size_in_mb: f64, + }, + #[error(transparent)] Internal(#[from] anyhow::Error), } @@ -184,6 +193,27 @@ impl ImportError { format!("Task ID: {} - Upload file expired", task_id), ) } + ImportError::UpgradeToLatestVersion(s) => { + ( + format!( + "Task ID: {} - {}, please upgrade to the latest version of the app to import this file", + task_id, + s, + + ), + format!("Task ID: {} - Upgrade to latest version", task_id), + ) + } + ImportError::UploadFileTooLarge{ file_size_in_mb, max_size_in_mb}=> { + ( + format!( + "Task ID: {} - The file size is too large. The maximum file size allowed is {} MB. Please upload a smaller file.", + task_id, + max_size_in_mb, + ), + format!("Task ID: {} - Upload file too large: {} MB", task_id, file_size_in_mb), + ) + } } } } diff --git a/services/appflowy-worker/src/import_worker/worker.rs b/services/appflowy-worker/src/import_worker/worker.rs index 6be6767d1..03df19e1e 100644 --- a/services/appflowy-worker/src/import_worker/worker.rs +++ b/services/appflowy-worker/src/import_worker/worker.rs @@ -78,6 +78,7 @@ pub async fn run_import_worker( notifier: Arc, stream_name: &str, tick_interval_secs: u64, + max_import_file_size: u64, ) -> Result<(), ImportError> { info!("Starting importer worker"); if let Err(err) = ensure_consumer_group(stream_name, GROUP_NAME, &mut redis_client).await { @@ -95,6 +96,7 @@ pub async fn run_import_worker( CONSUMER_NAME, notifier.clone(), &metrics, + max_import_file_size, ) .await; @@ -109,6 +111,7 @@ pub async fn run_import_worker( notifier.clone(), tick_interval_secs, &metrics, + max_import_file_size, ) .await?; @@ -126,6 +129,7 @@ async fn process_un_acked_tasks( consumer_name: &str, notifier: Arc, metrics: &Option>, + maximum_import_file_size: u64, ) { // when server restarts, we need to check if there are any unacknowledged tasks match get_un_ack_tasks(stream_name, group_name, consumer_name, redis_client).await { @@ -139,6 +143,7 @@ async fn process_un_acked_tasks( pg_pool: pg_pool.clone(), notifier: notifier.clone(), metrics: metrics.clone(), + maximum_import_file_size, }; // Ignore the error here since the consume task will handle the error let _ = consume_task( @@ -167,6 +172,7 @@ async fn process_upcoming_tasks( notifier: Arc, interval_secs: u64, metrics: &Option>, + maximum_import_file_size: u64, ) -> Result<(), ImportError> { let options = StreamReadOptions::default() .group(group_name, consumer_name) @@ -215,6 +221,7 @@ async fn process_upcoming_tasks( pg_pool: pg_pool.clone(), notifier: notifier.clone(), metrics: metrics.clone(), + maximum_import_file_size, }; let handle = spawn_local(async move { @@ -254,6 +261,7 @@ struct TaskContext { pg_pool: PgPool, notifier: Arc, metrics: Option>, + maximum_import_file_size: u64, } #[allow(clippy::too_many_arguments)] @@ -270,6 +278,26 @@ async fn consume_task( return process_and_ack_task(context, import_task, stream_name, group_name, &entry_id).await; } + match task.file_size { + None => { + return Err(ImportError::UpgradeToLatestVersion(format!( + "Missing file_size for task: {}", + task.task_id + ))) + }, + Some(file_size) => { + if file_size > context.maximum_import_file_size as i64 { + let file_size_in_mb = file_size as f64 / 1_048_576.0; + let max_size_in_mb = context.maximum_import_file_size as f64 / 1_048_576.0; + + return Err(ImportError::UploadFileTooLarge { + file_size_in_mb, + max_size_in_mb, + }); + } + }, + } + // Check if the task is expired if let Err(err) = is_task_expired(task.created_at.unwrap(), task.last_process_at) { if let Ok(import_record) = select_import_task(&context.pg_pool, &task.task_id).await { @@ -1395,10 +1423,11 @@ pub struct NotionImportTask { impl Display for NotionImportTask { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let file_size_mb = self.file_size.map(|size| size as f64 / 1_048_576.0); write!( f, - "NotionImportTask {{ task_id: {}, workspace_id: {}, file_size:{:?}, workspace_name: {}, user_name: {}, user_email: {} }}", - self.task_id, self.workspace_id, self.file_size, self.workspace_name, self.user_name, self.user_email + "NotionImportTask {{ task_id: {}, workspace_id: {}, file_size:{:?}MB, workspace_name: {}, user_name: {}, user_email: {} }}", + self.task_id, self.workspace_id, file_size_mb, self.workspace_name, self.user_name, self.user_email ) } } diff --git a/services/appflowy-worker/tests/import_test.rs b/services/appflowy-worker/tests/import_test.rs index ce40a9274..eabbad9c4 100644 --- a/services/appflowy-worker/tests/import_test.rs +++ b/services/appflowy-worker/tests/import_test.rs @@ -136,6 +136,7 @@ fn run_importer_worker( tick_interval_secs: u64, ) -> std::thread::JoinHandle<()> { setup_log(); + let max_import_file_size = 1_000_000_000; std::thread::spawn(move || { let runtime = Builder::new_current_thread().enable_all().build().unwrap(); @@ -148,6 +149,7 @@ fn run_importer_worker( notifier, &stream_name, tick_interval_secs, + max_import_file_size, )); runtime.block_on(import_worker_fut).unwrap(); }) diff --git a/tests/search/document_search.rs b/tests/search/document_search.rs index e696c4369..249cea130 100644 --- a/tests/search/document_search.rs +++ b/tests/search/document_search.rs @@ -131,7 +131,12 @@ The Five Dysfunctions of a Team by Patrick Lencioni The Five Dysfunctions of a T .unwrap() .score; - assert!(score > 0.9, "score: {}, input:{}", score, answer); + assert!( + score > 0.8, + "expected: 0.8, but got score: {}, input:{}", + score, + answer + ); } }