diff --git a/Cargo.lock b/Cargo.lock index 60891fc17..8c58098d1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -806,6 +806,7 @@ dependencies = [ "aws-config", "aws-sdk-s3", "axum 0.7.5", + "base64 0.22.1", "bytes", "collab", "collab-database", @@ -818,6 +819,7 @@ dependencies = [ "futures", "infra", "mailer", + "md5", "mime_guess", "redis 0.25.4", "secrecy", diff --git a/services/appflowy-worker/Cargo.toml b/services/appflowy-worker/Cargo.toml index eabf7b565..80a076f26 100644 --- a/services/appflowy-worker/Cargo.toml +++ b/services/appflowy-worker/Cargo.toml @@ -46,5 +46,7 @@ mime_guess = "2.0" bytes.workspace = true uuid.workspace = true mailer.workspace = true +md5.workspace = true +base64.workspace = true diff --git a/services/appflowy-worker/src/import_worker/worker.rs b/services/appflowy-worker/src/import_worker/worker.rs index 77c9003d8..c67f92e11 100644 --- a/services/appflowy-worker/src/import_worker/worker.rs +++ b/services/appflowy-worker/src/import_worker/worker.rs @@ -397,7 +397,14 @@ async fn download_and_unzip_file( .map_err(|err| ImportError::Internal(err.into()))?; let buffer_size = buffer_size_from_content_length(content_length); - let zip_reader = get_zip_reader(storage_dir, stream, buffer_size, streaming).await?; + let zip_reader = get_zip_reader( + storage_dir, + stream, + buffer_size, + streaming, + &import_task.md5_base64, + ) + .await?; let unique_file_name = Uuid::new_v4().to_string(); let output_file_path = storage_dir.join(unique_file_name); fs::create_dir_all(&output_file_path) @@ -431,6 +438,7 @@ async fn get_zip_reader( stream: Box, buffer_size: usize, streaming: bool, + file_md5_base64: &Option, ) -> Result { let zip_reader = if streaming { // Occasionally, we encounter the error 'unable to locate the end of central directory record' @@ -444,7 +452,7 @@ async fn get_zip_reader( file: None, } } else { - let file = download_file(storage_dir, stream).await?; + let file = download_file(storage_dir, stream, file_md5_base64).await?; let handle = fs::File::open(&file) .await .map_err(|err| ImportError::Internal(err.into()))?; @@ -996,6 +1004,8 @@ pub struct NotionImportTask { pub workspace_name: String, pub s3_key: String, pub host: String, + #[serde(default)] + pub md5_base64: Option, } impl Display for NotionImportTask { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { diff --git a/services/appflowy-worker/src/s3_client.rs b/services/appflowy-worker/src/s3_client.rs index 3ec8ab742..0cd3d7f39 100644 --- a/services/appflowy-worker/src/s3_client.rs +++ b/services/appflowy-worker/src/s3_client.rs @@ -6,6 +6,8 @@ use anyhow::Result; use aws_sdk_s3::operation::get_object::GetObjectError; use aws_sdk_s3::primitives::ByteStream; use axum::async_trait; +use base64::engine::general_purpose::STANDARD; +use base64::Engine; use futures::AsyncReadExt; use std::ops::Deref; use std::path::{Path, PathBuf}; @@ -170,16 +172,19 @@ impl Drop for AutoRemoveDownloadedFile { pub async fn download_file( storage_dir: &Path, stream: Box, + expected_md5_base64: &Option, ) -> Result { let zip_file_path = storage_dir.join(format!("{}.zip", Uuid::new_v4())); - write_stream_to_file(&zip_file_path, stream).await?; + write_stream_to_file(&zip_file_path, expected_md5_base64, stream).await?; Ok(AutoRemoveDownloadedFile(zip_file_path)) } pub async fn write_stream_to_file( file_path: &PathBuf, + expected_md5_base64: &Option, mut stream: Box, ) -> Result<(), anyhow::Error> { + let mut context = md5::Context::new(); let mut file = File::create(file_path).await?; let mut buffer = vec![0u8; 1_048_576]; loop { @@ -187,9 +192,22 @@ pub async fn write_stream_to_file( if bytes_read == 0 { break; } + context.consume(&buffer[..bytes_read]); file.write_all(&buffer[..bytes_read]).await?; } - file.flush().await?; + let digest = context.compute(); + let md5_base64 = STANDARD.encode(digest.as_ref()); + if let Some(expected_md5) = expected_md5_base64 { + if md5_base64 != *expected_md5 { + error!( + "[Import]: MD5 mismatch, expected: {}, current: {}", + expected_md5, md5_base64 + ); + return Err(anyhow!("MD5 mismatch")); + } + } + + file.flush().await?; Ok(()) } diff --git a/src/api/data_import.rs b/src/api/data_import.rs index f05fd35d7..8485617cf 100644 --- a/src/api/data_import.rs +++ b/src/api/data_import.rs @@ -14,6 +14,7 @@ use base64::Engine; use database::user::select_name_and_email_from_uuid; use database::workspace::select_import_task; use futures_util::StreamExt; +use serde_json::json; use shared_entity::dto::import_dto::{ImportTaskDetail, ImportTaskStatus, UserImportTask}; use shared_entity::response::{AppResponse, JsonAppResponse}; use std::env::temp_dir; @@ -76,7 +77,7 @@ async fn import_data_handler( .and_then(|s| s.parse::().ok()) .unwrap_or(0); - let md5 = req + let md5_base64 = req .headers() .get("X-Content-MD5") .and_then(|h| h.to_str().ok()) @@ -88,19 +89,19 @@ async fn import_data_handler( trace!( "[Import] content length: {}, content md5: {}", content_length, - md5 + md5_base64 ); - if file.md5_base64 != md5 { + if file.md5_base64 != md5_base64 { trace!( "Import file fail. The Content-MD5:{} doesn't match file md5:{}", - md5, + md5_base64, file.md5_base64 ); return Err( AppError::InvalidRequest(format!( "Content-MD5:{} doesn't match file md5:{}", - md5, file.md5_base64 + md5_base64, file.md5_base64 )) .into(), ); @@ -145,14 +146,29 @@ async fn import_data_handler( .put_blob_as_content_type(&workspace_id, stream, "application/zip") .await?; + // This task will be deserialized into ImportTask + let task_id = Uuid::new_v4(); + let task = json!({ + "notion": { + "uid": uid, + "user_name": user_name, + "user_email": user_email, + "task_id": task_id.to_string(), + "workspace_id": workspace_id, + "s3_key": workspace_id, + "host": host, + "workspace_name": &file.name, + "md5_base64": md5_base64, + } + }); + create_upload_task( uid, - &user_name, - &user_email, + task_id, + task, + &host, &workspace_id, - &file.name, file.size, - &host, &state.redis_connection_manager, &state.pg_pool, ) diff --git a/src/biz/workspace/ops.rs b/src/biz/workspace/ops.rs index 782b51fb1..42a5acaa7 100644 --- a/src/biz/workspace/ops.rs +++ b/src/biz/workspace/ops.rs @@ -678,17 +678,14 @@ async fn check_if_user_is_allowed_to_delete_comment( #[allow(clippy::too_many_arguments)] pub async fn create_upload_task( uid: i64, - user_name: &str, - user_email: &str, + task_id: Uuid, + task: serde_json::Value, + host: &str, workspace_id: &str, - workspace_name: &str, file_size: usize, - host: &str, redis_client: &RedisConnectionManager, pg_pool: &PgPool, ) -> Result<(), AppError> { - let task_id = Uuid::new_v4(); - // Insert the task into the database insert_import_task( task_id, @@ -700,19 +697,6 @@ pub async fn create_upload_task( ) .await?; - // This task will be deserialized into ImportTask - let task = json!({ - "notion": { - "uid": uid, - "user_name": user_name, - "user_email": user_email, - "task_id": task_id, - "workspace_id": workspace_id, - "s3_key": workspace_id, - "host": host, - "workspace_name": workspace_name, - } - }); let _: () = redis_client .clone() .xadd("import_task_stream", "*", &[("task", task.to_string())])