Skip to content

Commit

Permalink
feat: add main executable for POC
Browse files Browse the repository at this point in the history
  • Loading branch information
Zvicii committed Dec 27, 2023
1 parent f7e0ffc commit dcd1337
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 69 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@ Cargo.lock
# MSVC Windows builds of rustc generate these, which store debugging information
*.pdb
.vscode
exports
.DS_Store
11 changes: 10 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,19 @@ serde = { version = "1.0.193", features = ["derive"] }
serde_json = "1.0.108"
tokio = { version = "1.35.0", features = ["full"] }
rustls-pemfile = "1.0.3"
log = "0.4.20"
flexi_logger = "0.27.3"
sysinfo = "0.30.1"
clap = { version = "4.4.11", features = ["derive"] }

[dependencies.hyper-rustls]
version = "0.24.2"
features = ["http2"]

[lib]
crate-type = ["cdylib"]
path = "src/lib.rs"
crate-type = ["cdylib", "rlib"]

[[bin]]
name = "ne-s3"
path = "src/main.rs"
15 changes: 14 additions & 1 deletion src/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use aws_sdk_s3::{config::Credentials, Client};
use aws_smithy_runtime::client::http::hyper_014::HyperClientBuilder;
use bytes::Bytes;
use http_body::{Body, SizeHint};
use log::info;
use rustls::{Certificate, RootCertStore};
use rustls_pemfile::certs;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -61,13 +62,21 @@ struct ProgressTracker {
bytes_written: Arc<Mutex<u64>>,
content_length: u64,
progress_callback: ProgressCallback,
last_callback_time: std::time::Instant,
}
impl ProgressTracker {
fn track(&mut self, len: u64) {
let mut bytes_written = self.bytes_written.lock().unwrap();
*bytes_written += len;
let progress = *bytes_written as f64 / self.content_length as f64 * 100.0;
let progress_callback = self.progress_callback.lock().unwrap();
if std::time::Instant::now() - self.last_callback_time
< std::time::Duration::from_millis(500)
&& progress < 100.0
{
return;
}
self.last_callback_time = std::time::Instant::now();
progress_callback(progress);
}
}
Expand Down Expand Up @@ -98,6 +107,7 @@ where
bytes_written,
content_length,
progress_callback,
last_callback_time: std::time::Instant::now(),
},
}
}
Expand Down Expand Up @@ -184,7 +194,10 @@ pub fn create_s3_client(params: &S3Params) -> Result<Client, Box<dyn std::error:
// If tries is 1, there are no retries.
.retry_config(RetryConfig::standard().with_max_attempts(params.tries.unwrap_or(1)));
if params.ca_certs_path.is_some() {
println!("use custom ca certs, path: {}", params.ca_certs_path.as_ref().unwrap());
info!(
"use custom ca certs, path: {}",
params.ca_certs_path.as_ref().unwrap()
);
let root_store = load_ca_cert(params.ca_certs_path.as_ref().unwrap())?;
let config = rustls::ClientConfig::builder()
.with_safe_defaults()
Expand Down
141 changes: 82 additions & 59 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,74 @@
//! simple s3 client with C interfaces
use std::sync::Mutex;
use flexi_logger::{with_thread, Age, Cleanup, Criterion, FileSpec, Logger, Naming, WriteMode};
use log::{error, info};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::{sync::Mutex, path::Path};
use sysinfo::System;
pub use basic::S3Params;
mod basic;
mod upload;
mod download;
mod upload;

#[derive(Debug, Serialize, Deserialize)]
struct InitParams {
log_path: Option<String>,
}

static mut RUNTIME: Mutex<Option<tokio::runtime::Runtime>> = Mutex::new(None);
/// init tokio runtime
/// run this function before any other functions
pub fn init() {
/// # Arguments
/// - `params` - The params of init, json format
/// - `log_path` - The log path, use stdout if not set
pub fn init(params_str: String) {
let mut runtime = unsafe { RUNTIME.lock().unwrap() };
if !runtime.is_none() {
return;
}
let params = match serde_json::from_str::<InitParams>(&params_str) {
Ok(params) => params,
Err(err) => {
panic!("parse init params failed: {}", err);
}
};
let log_path = params.log_path.as_ref();
if log_path.is_some_and(|path| Path::new(path).exists()) {
let log_path = log_path.unwrap();
let _logger = Logger::try_with_str("info")
.unwrap()
.log_to_file(FileSpec::default().directory(log_path))
.write_mode(WriteMode::Direct)
.rotate(
Criterion::Age(Age::Day),
Naming::Timestamps,
Cleanup::KeepLogFiles(7),
)
.format(with_thread)
.start()
.unwrap();
} else {
let _logger = Logger::try_with_str("info")
.unwrap()
.format(with_thread)
.start()
.unwrap();
}
info!("init params: {}", params_str);
let mut system_info = System::new_all();
system_info.refresh_all();
let system_info_json = json!({
"system name": System::name(),
"system kernel version": System::kernel_version(),
"system os version": System::os_version(),
"system host name": System::host_name(),
"total memory": system_info.total_memory(),
"used memory": system_info.used_memory(),
});
info!("system info: {}", system_info_json);
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.worker_threads(8)
.build()
.unwrap();
*runtime = Some(rt);
Expand Down Expand Up @@ -53,27 +108,34 @@ pub fn upload(
result_callback: basic::ResultCallback,
progress_callback: basic::ProgressCallback,
) {
info!("upload params: {}", params);
let runtime = unsafe { RUNTIME.lock().unwrap() };
let runtime = match &*runtime {
Some(runtime) => runtime,
None => {
result_callback(false, "runtime not initialized".to_string());
error!("runtime not initialized");
return;
}
};
let params = match serde_json::from_str::<basic::S3Params>(&params) {
Ok(params) => params,
Err(err) => {
result_callback(false, format!("parse params failed: {}", err));
error!("parse params failed: {}", err);
return;
}
};
runtime.spawn(async move {
let result = upload::put_object(&params, progress_callback).await;
result_callback(
result.is_ok(),
result.err().map(|err| err.to_string()).unwrap_or_default(),
);
if result.is_ok() {
info!("upload finished");
result_callback(true, "".to_string());
} else {
let error_descrption = result.err().map(|err| err.to_string()).unwrap_or_default();
error!("upload failed: {}", error_descrption);
result_callback(false, error_descrption);
}
});
}

Expand All @@ -92,73 +154,34 @@ pub fn upload(
/// - `result_callback` - The callback function when download finished
/// - `success` - The download succeeded or not
/// - `message` - The error message if download failed
pub fn download(
params: String,
result_callback: basic::ResultCallback,
) {
pub fn download(params: String, result_callback: basic::ResultCallback) {
info!("download params: {}", params);
let runtime = unsafe { RUNTIME.lock().unwrap() };
let runtime = match &*runtime {
Some(runtime) => runtime,
None => {
error!("runtime not initialized");
result_callback(false, "runtime not initialized".to_string());
return;
}
};
let params = match serde_json::from_str::<basic::S3Params>(&params) {
Ok(params) => params,
Err(err) => {
error!("parse params failed: {}", err);
result_callback(false, format!("parse params failed: {}", err));
return;
}
};
runtime.spawn(async move {
let result = download::get_object(&params).await;
result_callback(
result.is_ok(),
result.err().map(|err| err.to_string()).unwrap_or_default(),
);
});
}

#[cfg(test)]
mod tests {
use std::{env, sync::Arc};
use super::*;

#[test]
fn test() {
init();
{
let rt = unsafe { RUNTIME.lock().unwrap() };
if let Some(rt) = &*rt {
rt.block_on(async {
let mut params = basic::S3Params {
bucket: env::var("AWS_BUCKET").unwrap(),
object: env::var("AWS_OBJECT_KEY").unwrap(),
access_key_id: env::var("AWS_ACCESS_KEY_ID").unwrap(),
secret_access_key: env::var("AWS_SECRET_ACCESS_KEY").unwrap(),
session_token: env::var("AWS_SESSION_TOKEN").unwrap(),
file_path: env::var("AWS_UPLOAD_FILE_PATH").unwrap(),
security_token: env::var("AWS_SECURITY_TOKEN").unwrap(),
region: Some(env::var("AWS_REGION").unwrap_or("ap-southeast-1".to_string())),
tries: Some(3),
endpoint: None,
ca_certs_path: env::var("AWS_CA_CERTS_PATH").ok(),
};
println!("uploading begin");
let progress_callback = |progress: f64| {
println!("put object progress: {:.2}%", progress);
};
let upload_size = upload::put_object(&params, Arc::new(Mutex::new(progress_callback))).await.unwrap();
println!("uploading finished");
params.file_path = env::var("AWS_DOWNLOAD_FILE_PATH").unwrap();
println!("downloading begin");
let download_size = download::get_object(&params).await.unwrap();
assert_eq!(download_size, upload_size);
println!("downloading finished");
});
}
if result.is_ok() {
info!("download finished");
result_callback(true, "".to_string());
} else {
let error_descrption = result.err().map(|err| err.to_string()).unwrap_or_default();
error!("download failed: {}", error_descrption);
result_callback(false, error_descrption);
}
uninit();
}
}
});
}
79 changes: 79 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
use clap::Parser;
use log::info;
use ne_s3::{download, init, uninit, upload};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::sync::{Arc, Mutex};

#[derive(Parser, Debug, Serialize, Deserialize)]
#[command(author, version, about, long_about = None)]
struct Args {
command: String,
#[arg(long)]
bucket: String,
#[arg(long)]
object: String,
#[arg(long)]
access_key_id: String,
#[arg(long)]
secret_access_key: String,
#[arg(long)]
session_token: String,
#[arg(long)]
file_path: String,
#[arg(long)]
security_token: String,
#[arg(long, default_value_t = String::new())]
ca_certs_path: String,
#[arg(long, default_value_t = String::new())]
region: String,
#[arg(long, default_value_t = 3)]
tries: u32,
#[arg(long, default_value_t = String::new())]
endpoint: String,
#[arg(long, default_value_t = String::new())]
log_path: String,
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let args = Args::parse();
init(
json!(
{
"log_path": args.log_path,
}
)
.to_string(),
);
match args.command.as_str() {
"upload" => {
let result_callback = |success: bool, message: String| {
info!("upload finished: {}", success);
info!("upload message: {}", message);
};
let progress_callback = |progress: f64| {
info!("put object progress: {:.2}%", progress);
};
upload(
serde_json::to_string(&args).unwrap(),
Box::new(result_callback),
Arc::new(Mutex::new(progress_callback)),
);
}
"download" => {
download(
serde_json::to_string(&args).unwrap(),
Box::new(|success: bool, message: String| {
info!("download finished: {}", success);
info!("download message: {}", message);
}),
);
}
_ => {
println!("unknown command: {}", args.command);
}
}
uninit();
Ok(())
}
Loading

0 comments on commit dcd1337

Please sign in to comment.