Skip to content

Commit

Permalink
chore: check file md5 before import (#895)
Browse files Browse the repository at this point in the history
  • Loading branch information
appflowy authored Oct 17, 2024
1 parent 7d6d1fd commit 3623d9f
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 32 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions services/appflowy-worker/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,7 @@ mime_guess = "2.0"
bytes.workspace = true
uuid.workspace = true
mailer.workspace = true
md5.workspace = true
base64.workspace = true


14 changes: 12 additions & 2 deletions services/appflowy-worker/src/import_worker/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,14 @@ async fn download_and_unzip_file(
.map_err(|err| ImportError::Internal(err.into()))?;
let buffer_size = buffer_size_from_content_length(content_length);

let zip_reader = get_zip_reader(storage_dir, stream, buffer_size, streaming).await?;
let zip_reader = get_zip_reader(
storage_dir,
stream,
buffer_size,
streaming,
&import_task.md5_base64,
)
.await?;
let unique_file_name = Uuid::new_v4().to_string();
let output_file_path = storage_dir.join(unique_file_name);
fs::create_dir_all(&output_file_path)
Expand Down Expand Up @@ -431,6 +438,7 @@ async fn get_zip_reader(
stream: Box<dyn AsyncBufRead + Unpin + Send>,
buffer_size: usize,
streaming: bool,
file_md5_base64: &Option<String>,
) -> Result<ZipReader, ImportError> {
let zip_reader = if streaming {
// Occasionally, we encounter the error 'unable to locate the end of central directory record'
Expand All @@ -444,7 +452,7 @@ async fn get_zip_reader(
file: None,
}
} else {
let file = download_file(storage_dir, stream).await?;
let file = download_file(storage_dir, stream, file_md5_base64).await?;
let handle = fs::File::open(&file)
.await
.map_err(|err| ImportError::Internal(err.into()))?;
Expand Down Expand Up @@ -996,6 +1004,8 @@ pub struct NotionImportTask {
pub workspace_name: String,
pub s3_key: String,
pub host: String,
#[serde(default)]
pub md5_base64: Option<String>,
}
impl Display for NotionImportTask {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Expand Down
22 changes: 20 additions & 2 deletions services/appflowy-worker/src/s3_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use anyhow::Result;
use aws_sdk_s3::operation::get_object::GetObjectError;
use aws_sdk_s3::primitives::ByteStream;
use axum::async_trait;
use base64::engine::general_purpose::STANDARD;
use base64::Engine;
use futures::AsyncReadExt;
use std::ops::Deref;
use std::path::{Path, PathBuf};
Expand Down Expand Up @@ -170,26 +172,42 @@ impl Drop for AutoRemoveDownloadedFile {
pub async fn download_file(
storage_dir: &Path,
stream: Box<dyn futures::AsyncBufRead + Unpin + Send>,
expected_md5_base64: &Option<String>,
) -> Result<AutoRemoveDownloadedFile, anyhow::Error> {
let zip_file_path = storage_dir.join(format!("{}.zip", Uuid::new_v4()));
write_stream_to_file(&zip_file_path, stream).await?;
write_stream_to_file(&zip_file_path, expected_md5_base64, stream).await?;
Ok(AutoRemoveDownloadedFile(zip_file_path))
}

pub async fn write_stream_to_file(
file_path: &PathBuf,
expected_md5_base64: &Option<String>,
mut stream: Box<dyn futures::AsyncBufRead + Unpin + Send>,
) -> Result<(), anyhow::Error> {
let mut context = md5::Context::new();
let mut file = File::create(file_path).await?;
let mut buffer = vec![0u8; 1_048_576];
loop {
let bytes_read = stream.read(&mut buffer).await?;
if bytes_read == 0 {
break;
}
context.consume(&buffer[..bytes_read]);
file.write_all(&buffer[..bytes_read]).await?;
}
file.flush().await?;

let digest = context.compute();
let md5_base64 = STANDARD.encode(digest.as_ref());
if let Some(expected_md5) = expected_md5_base64 {
if md5_base64 != *expected_md5 {
error!(
"[Import]: MD5 mismatch, expected: {}, current: {}",
expected_md5, md5_base64
);
return Err(anyhow!("MD5 mismatch"));
}
}

file.flush().await?;
Ok(())
}
34 changes: 25 additions & 9 deletions src/api/data_import.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use base64::Engine;
use database::user::select_name_and_email_from_uuid;
use database::workspace::select_import_task;
use futures_util::StreamExt;
use serde_json::json;
use shared_entity::dto::import_dto::{ImportTaskDetail, ImportTaskStatus, UserImportTask};
use shared_entity::response::{AppResponse, JsonAppResponse};
use std::env::temp_dir;
Expand Down Expand Up @@ -76,7 +77,7 @@ async fn import_data_handler(
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(0);

let md5 = req
let md5_base64 = req
.headers()
.get("X-Content-MD5")
.and_then(|h| h.to_str().ok())
Expand All @@ -88,19 +89,19 @@ async fn import_data_handler(
trace!(
"[Import] content length: {}, content md5: {}",
content_length,
md5
md5_base64
);
if file.md5_base64 != md5 {
if file.md5_base64 != md5_base64 {
trace!(
"Import file fail. The Content-MD5:{} doesn't match file md5:{}",
md5,
md5_base64,
file.md5_base64
);

return Err(
AppError::InvalidRequest(format!(
"Content-MD5:{} doesn't match file md5:{}",
md5, file.md5_base64
md5_base64, file.md5_base64
))
.into(),
);
Expand Down Expand Up @@ -145,14 +146,29 @@ async fn import_data_handler(
.put_blob_as_content_type(&workspace_id, stream, "application/zip")
.await?;

// This task will be deserialized into ImportTask
let task_id = Uuid::new_v4();
let task = json!({
"notion": {
"uid": uid,
"user_name": user_name,
"user_email": user_email,
"task_id": task_id.to_string(),
"workspace_id": workspace_id,
"s3_key": workspace_id,
"host": host,
"workspace_name": &file.name,
"md5_base64": md5_base64,
}
});

create_upload_task(
uid,
&user_name,
&user_email,
task_id,
task,
&host,
&workspace_id,
&file.name,
file.size,
&host,
&state.redis_connection_manager,
&state.pg_pool,
)
Expand Down
22 changes: 3 additions & 19 deletions src/biz/workspace/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -678,17 +678,14 @@ async fn check_if_user_is_allowed_to_delete_comment(
#[allow(clippy::too_many_arguments)]
pub async fn create_upload_task(
uid: i64,
user_name: &str,
user_email: &str,
task_id: Uuid,
task: serde_json::Value,
host: &str,
workspace_id: &str,
workspace_name: &str,
file_size: usize,
host: &str,
redis_client: &RedisConnectionManager,
pg_pool: &PgPool,
) -> Result<(), AppError> {
let task_id = Uuid::new_v4();

// Insert the task into the database
insert_import_task(
task_id,
Expand All @@ -700,19 +697,6 @@ pub async fn create_upload_task(
)
.await?;

// This task will be deserialized into ImportTask
let task = json!({
"notion": {
"uid": uid,
"user_name": user_name,
"user_email": user_email,
"task_id": task_id,
"workspace_id": workspace_id,
"s3_key": workspace_id,
"host": host,
"workspace_name": workspace_name,
}
});
let _: () = redis_client
.clone()
.xadd("import_task_stream", "*", &[("task", task.to_string())])
Expand Down

0 comments on commit 3623d9f

Please sign in to comment.