diff --git a/README.md b/README.md index 1a23b5a..16e1854 100644 --- a/README.md +++ b/README.md @@ -46,8 +46,7 @@ JSON Body的格式如下: "delay": 0, # 延迟执行的秒数,0为即时触发 "name": "命令名", # GET|POST的时候是请求地址,EXEC为命令名,MAIL_TO为邮件名 "params": "参数", # GET时为QueryString,POST时是JSON字符串,EXEC时为空格隔开的参数,MAIL_TO的时候为专门定义的JSON字符串 - "cc": 3, # 支持并行执行的数量,0为不限制 - "cf": "test" # 并行标识,相同标识的处于相同的并行队列 + "cc": "1 3", # 空格隔开的指定线程编号,这里指定两个线程,表示最多可以并行执行2个任务,留空表示不限制 } MAIL_TO 的params的JSON结构 diff --git a/src/main.rs b/src/main.rs index 26a5636..e053f2b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -24,10 +24,9 @@ use dirs; #[derive(Clone)] pub struct AppState { - /// 配置 pub config_path: String, pub redis_client: redis::Client, - pub queue: Arc>>, + pub queue: QueueGroup, pub config: AppConfig } @@ -38,8 +37,32 @@ struct WebTask { delay: i32, name: String, params: String, - cf: String, - cc: i32 + cc: String +} + +impl WebTask { + fn gen_task(&mut self, size: i32) -> Task { + let now: DateTime = Utc::now(); + let now_ts = now.timestamp(); + let task_key = format!("task-{}", self.id.clone()); + let cc = if self.cc.is_empty() { vec![] } else { + self.cc.split_whitespace().collect::>().iter() + .map(|s| if let Ok(n) = s.parse::() {n} else {0} ) + .filter(|s| *s < size).collect::>() + }; + Task{ + id:task_key.clone(), + tk_tp:self.method.clone(), + delay: self.delay as i64, + name: self.name.clone(), + params: self.params.clone(), + post_time: now_ts, + exec_time: 0, + retry: 0, + cc, + error: "".to_string() + } + } } struct LastModifyHolder { @@ -48,6 +71,90 @@ struct LastModifyHolder { avaliables: Vec } +#[derive(Clone)] +struct QueueGroup { + size: i32, + queues: Vec>>> +} + +impl QueueGroup { + fn init_by_number(size: i32) -> QueueGroup { + let mut ins = QueueGroup { size, queues: vec![]}; + for _i in 0..size { + let q = Arc::new(Mutex::new(VecDeque::new())); + ins.queues.push(q); + } + ins + } + fn push_to(&mut self, chn: i32, msg: &String) { + let q: &Arc>> = self.queues.get(chn as usize).unwrap(); + let mut queue = q.lock().unwrap(); + queue.push_back(msg.to_string()); + } + fn shorted_chn(&mut self) -> i32 { + let rg = 0..self.size; + self.shorted_chn_in(&rg.collect()) + } + fn get_queue_len(&mut self, chn: i32) -> usize { + if let Some(q) = self.queues.get(chn as usize) { + println!("获取到{}的信箱", chn); + if let Ok(queue) = q.lock() { + return queue.len(); + } + } + 0usize + } + fn shorted_chn_in(&mut self, selected: &Vec) -> i32 { + let mut counters: Vec<(i32, usize)> = vec![]; + for chn in selected { + let queue_len = self.get_queue_len(*chn); + counters.push((*chn, queue_len)); + } + counters.sort_by(|f, s| f.1.cmp(&s.1)); + counters[0].0 + } + + fn wait_for(&mut self, chn: i32) -> String { + loop { + let task_key = { + if let Some(q) = self.queues.get(chn as usize){ + if let Ok(mut queue) = q.lock(){ + if let Some(tk) = queue.pop_front(){ + tk + }else{ + "".to_string() + } + }else{ + "".to_string() + } + }else{ + "".to_string() + } + }; + if task_key.is_empty() { + thread::sleep(std::time::Duration::from_secs(1)); + continue; + } + return task_key; + } + } + fn dispatch(&mut self, msg: &String) { + let chn = self.shorted_chn(); + self.push_to(chn, msg); + } + fn dispatch_in(&mut self, msg: &String, channels: &Vec) { + let chn = self.shorted_chn_in(channels); + self.push_to(chn, msg); + } + fn dispatch_task(&mut self, task: &Task) { + if task.cc.is_empty() { + self.dispatch(&task.id); + }else{ + self.dispatch_in(&task.id, &task.cc); + } + } +} + const TASK_WRONG: &'static str = "task||wrong"; const TASK_WORKING: &'static str = "task||working"; @@ -143,17 +250,16 @@ fn config_from_matches(matches: &ArgMatches) -> AppConfig { AppConfig{ smtp_name: smtp_name.to_string(), smtp_pwd:smtp_pwd.to_string(), - retry_interval:retry_interval, - max_retry:max_retry, + retry_interval, + max_retry, smtp_server: smtp_server.to_string(), - smtp_port: smtp_port, - starttls: starttls + smtp_port, + starttls } } #[tokio::main] async fn main() { - // 读取配置 let app = App::new("Poulpe Task Management"); let matches = matches_config(app); @@ -162,47 +268,33 @@ async fn main() { let redis = matches.value_of("redis").unwrap(); let cron_path = matches.value_of("cron").unwrap(); let dead_base = matches.get_one::("dead").map(|s| s.to_string()).unwrap(); - // 将拥有自己数据的 String 转换为 Arc 类型,以便在其他线程中共享所有权 let dead_base_arc = Arc::new(dead_base); - // 在其他线程中使用 Arc 类型的字符串 let thread_arc_dead_base = Arc::clone(&dead_base_arc); let appconfig = config_from_matches(&matches); let workers = matches.value_of("workers").unwrap().parse().unwrap(); - let queue: Arc>> = Arc::new(Mutex::new(VecDeque::new())); + let mut queue_group = QueueGroup::init_by_number(workers); + let client = redis::Client::open(redis).unwrap(); for thread_id in 0..workers { - let queue_ref = queue.clone(); + let mut group = queue_group.clone(); let mut redis_connection = client.get_connection().unwrap(); let appconfig = appconfig.clone(); thread::spawn(move || { println!("worker:{} started", thread_id); loop { - let task_id = { - let mut queue = queue_ref.lock().unwrap(); - if queue.is_empty() { - "".to_string() - }else{ - queue.pop_front().unwrap() - } - }; - if task_id == "" { - thread::sleep(std::time::Duration::from_secs(1)); - continue; - } + let task_id = group.wait_for(thread_id); println!("worker:{} 捕获任务", task_id.clone()); if let Ok(task_str) = redis_connection.get::(task_id.clone()) { println!("线程[{}]获取到任务:{} 的数据:{}", thread_id, task_id.clone(), task_str); if let Ok(mut task) = serde_json::from_str::(task_str.as_str()){ - // 执行任务 let if_err = match task.execute(&appconfig) { Ok(())=>{"".to_string()}, Err(err_str)=>{err_str} }; if !if_err.is_empty() { - // 错误处理逻辑 task.error = if_err.to_string(); println!("执行错误:{}", task.error); if let Ok(save_payload) = serde_json::to_string(&task) { @@ -212,15 +304,7 @@ async fn main() { }else{ redis_connection.del::<&str, ()>(task.id.as_str()).expect("redis del error"); } - //从正在执行队列中去掉 redis_connection.srem::<&str, String, ()>(TASK_WORKING, task.id.clone()).expect("redis error"); - if !task.cf.is_empty() { - let cc_flag = format!("cc:{}", task.cf); - if let Ok(cc) = redis_connection.get::(cc_flag.clone()) { - let new_cc = if let Ok(int_cc) = cc.parse::() { if int_cc>0 { int_cc - 1}else { 0 } } else { 0 }; - redis_connection.set::(cc_flag.clone(), format!("{}", new_cc)).expect("设置并行标识失败"); - } - } } }else{ println!("任务:{} 不存在", task_id.clone()); @@ -230,7 +314,7 @@ async fn main() { } let mut conn_err = client.get_async_connection().await.unwrap(); - let queue_err = queue.clone(); + let mut queue_err = queue_group.clone(); tokio::spawn(async move { println!("错误补发线程启动"); @@ -246,7 +330,6 @@ async fn main() { conn_err.srem::<&str, String, ()>(TASK_WRONG, tk.id.clone()).await.expect("从出错队列删除错误"); if tk.retry > appconfig.max_retry - 1 { conn_err.del::(tk.id.clone()).await.expect("删除错误"); - // 超过重试总次数,从错误队列删除,写入死信箱 if let Ok(save_str) = serde_json::to_string_pretty(&tk) { write_dead_pool(thread_arc_dead_base.to_string(), &tk.id, &save_str).await; } @@ -255,8 +338,7 @@ async fn main() { if let Ok(save_str) = serde_json::to_string(&tk){ conn_err.set::(tk.id.clone(), save_str).await.expect("回写重试次数出错"); conn_err.sadd::(TASK_WORKING.to_string(), tk.id.clone()).await.expect("set list error"); - let mut queue4 = queue_err.lock().unwrap(); - queue4.push_back(tk.id.clone()); + queue_err.dispatch_task(&tk); } } } @@ -269,7 +351,7 @@ async fn main() { let str_cron_path = get_abs_path(cron_path.to_string()); let mut conn = client.get_async_connection().await.unwrap(); - let queue2 = queue.clone(); + let mut cron_queue = queue_group.clone(); tokio::spawn(async move { let mut holder = LastModifyHolder{last_modify:"".to_string(), tasks: vec![], avaliables: vec![]}; let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(1)); @@ -312,24 +394,35 @@ async fn main() { } for task_key in will_exec{ conn.sadd::(TASK_WORKING.to_string(), task_key.clone()).await.expect("set list error"); - let mut queue3 = queue2.lock().unwrap(); - queue3.push_back(task_key.clone()); + cron_queue.dispatch(&task_key); } } }); + if let Ok(mut main_conn) = client.get_async_connection().await { + if let Ok(working_ids) = main_conn.smembers::<&str, Vec>(TASK_WORKING).await { + for tk_id in working_ids { + if let Ok(tk_str) = main_conn.get::(tk_id.clone()).await { + if let Ok(tk) = serde_json::from_str::(tk_str.as_str()) { + queue_group.dispatch_task(&tk); + } + }else{ + main_conn.srem::<&str, String, ()>(TASK_WORKING, tk_id.clone()).await.expect("删除不存在的key"); + } + } + } + } + let app_state = AppState{ config_path: cron_path.to_string(), redis_client: client.clone(), - queue: queue.clone(), + queue: queue_group.clone(), config: appconfig.clone() }; - // 创建HTTP路由 let app = Router::new() .route("/task_in_queue", post(handler)) .with_state(app_state); - // 启动HTTP服务器 println!("server will start at 0.0.0.0:{}", port); let serv = axum::Server::bind(& SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), int_port)) .serve(app.into_make_service()) @@ -344,50 +437,26 @@ async fn main() { } } -pub async fn handler(State(state): State, +pub async fn handler(State(mut state): State, Json(payload): Json ) -> impl IntoResponse { let now: DateTime = Utc::now(); let now_ts = now.timestamp(); + println!("In Queue in {}", now_ts); let mut conn = state.redis_client.get_async_connection().await.unwrap(); - if let Ok(web_task) = serde_json::from_value::(payload.clone()) { + if let Ok(mut web_task) = serde_json::from_value::(payload.clone()) { if state.config.smtp_name.is_empty() && web_task.method.to_lowercase() == "mail_to" { return Json(serde_json::json!({"result":"Fail", "reason": "未配置SMTP服务器"})); } - // concurrency 控制 只对即时任务有效 - if web_task.cc > 0 && web_task.delay == 0 { - let cc_flag = format!("cc:{}", web_task.cf); - let cc_now = if let Ok(v) = conn.get::(cc_flag.clone()).await { v.parse().unwrap() } else { 0 }; - println!("当前流控并发:{}", cc_now); - if cc_now >= web_task.cc { - return Json(serde_json::json!({"result":"Over", "reason": "超标限流"})); - }else{ - println!("流控+1"); - conn.set::(cc_flag.clone(), format!("{}", cc_now + 1)).await.expect("增量错误"); - } - } - - let task_key = format!("task-{}", web_task.id.clone()); - let task = Task{ - id:task_key.clone(), - tk_tp:web_task.method.clone(), - delay: web_task.delay as i64, - name: web_task.name.clone(), - params: web_task.params.clone(), - post_time: now_ts, - exec_time: 0, - retry: 0, - cf: web_task.cf, - error: "".to_string() - }; + let task = web_task.gen_task(state.queue.size); let redis_payload = serde_json::to_string(&task).unwrap(); - conn.set::(task_key.clone(), redis_payload).await.expect("set error"); + conn.set::(task.id.clone(), redis_payload).await.expect("set error"); if web_task.delay == 0 { - conn.sadd::(TASK_WORKING.to_string(), task_key.clone()).await.expect("set list error"); - let mut queue = state.queue.lock().unwrap(); - queue.push_back(task_key); + conn.sadd::(TASK_WORKING.to_string(), task.id.clone()).await.expect("set list error"); + println!("开始分配Worker线程"); + state.queue.dispatch_task(&task); } else { - let delay_key = format!("{} {}", task_key, (now_ts + web_task.delay as i64)); + let delay_key = format!("{} {}", task.id, now_ts + web_task.delay as i64); conn.sadd::(TASK_DELAY.to_string(), delay_key.clone()).await.expect("set list error"); } Json(serde_json::json!({"result":"OK", "reason": ""})) @@ -429,11 +498,11 @@ async fn crontab_file_changed(file_path: PathBuf, last_dt: &str) -> String { fn get_abs_path(address: String) -> String{ let absolute_path = if Path::new(&address).is_relative() { if address.starts_with("~"){ - if let Some(mut path) = dirs::home_dir() { - path.push( &address[2..].to_string()); - return path.to_str().unwrap().to_string(); - }else{ - return address; + return if let Some(mut path) = dirs::home_dir() { + path.push(&address[2..].to_string()); + path.to_str().unwrap().to_string() + } else { + address } }else{ let current_dir = env::current_dir().expect("Failed to get current directory"); @@ -460,7 +529,6 @@ fn get_tasks_avaliable(tasks: &Vec, avaliable_tasks: &mut Vec) { let next_ts = next.timestamp() - 1; let now_ts = now.timestamp(); if next_ts == now_ts { - // execute command let command_id = format!("ST:{}-{}", sequence, now_ts); let command_seqs = command.split_whitespace().collect::>(); let name = command_seqs[0]; @@ -474,13 +542,10 @@ fn get_tasks_avaliable(tasks: &Vec, avaliable_tasks: &mut Vec) { post_time:now_ts, exec_time:0, retry:0, - cf: "".to_string(), + cc: vec![], error: "".to_string() }; avaliable_tasks.insert(0, task); - //println!("cmd:{} methid:{} with:{}", name, method, params); - }else{ - //println!("cron:{} next:{}, current:{}", time_config, next_ts, now_ts); } } } diff --git a/src/runner.rs b/src/runner.rs index 697e03a..b17b278 100644 --- a/src/runner.rs +++ b/src/runner.rs @@ -34,7 +34,7 @@ pub struct Task { pub post_time: i64, pub exec_time: i64, pub retry: i32, - pub cf: String, + pub cc: Vec, pub error: String }