diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 8328f51..696ca02 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -22,5 +22,7 @@ jobs: run: cargo build --verbose - name: Test run: cargo test --verbose - - name: Test example-app with third-party libs - run: cd app && cargo test --features use_dependencies --verbose + - name: Test async mode + run: cargo test --features async_mode --verbose + - name: Test with third-party libs + run: cargo test --features use_dependencies --verbose diff --git a/app/Cargo.toml b/app/Cargo.toml index e9d3f21..9325d32 100644 --- a/app/Cargo.toml +++ b/app/Cargo.toml @@ -42,3 +42,4 @@ regex = "1.9.3" # for prost_build [features] use_dependencies = [ "mime", "mockito", "prost-build" ] +async_mode = [] diff --git a/app/src/api.rs b/app/src/api.rs index 8ee818e..b7ee209 100644 --- a/app/src/api.rs +++ b/app/src/api.rs @@ -6,21 +6,43 @@ use axum::{ routing::{get, post}, Json, Router, }; -use std::collections::HashMap; -use std::sync::{Arc, RwLock}; +use std::{ + collections::HashMap, + sync::{Arc, RwLock}, +}; -use crate::middleware::log_request; -use crate::models::{CreateFeed, Feed, Status}; -use crate::scheduler_interface::ToScheduler; +use crate::{ + middleware::log_request, + models::{CreateFeed, Feed, Status}, + scheduler_interface::ToScheduler, +}; -#[derive(Clone)] -struct AppState { +struct AppState +where + T: ToScheduler + Send + Sync + 'static, +{ next_feed_id: Arc>, db: Arc>>, - scheduler_interface: Arc, + scheduler_interface: Arc, } -pub fn app(scheduler_interface: Arc) -> Router { +impl Clone for AppState +where + T: ToScheduler + Send + Sync + 'static, +{ + fn clone(&self) -> Self { + Self { + next_feed_id: self.next_feed_id.clone(), + db: self.db.clone(), + scheduler_interface: self.scheduler_interface.clone(), + } + } +} + +pub fn app(scheduler_interface: Arc) -> Router +where + T: ToScheduler + Send + Sync + 'static, +{ let state = AppState { next_feed_id: Arc::new(RwLock::new(1)), db: Arc::new(RwLock::new(HashMap::new())), @@ -42,15 +64,18 @@ async fn status_handler() -> impl IntoResponse { Json(Status::new("OK")) } -async fn post_handler( - state: State, +async fn post_handler( + state: State>, Json(CreateFeed { name, url, frequency, headers, }): Json, -) -> Result { +) -> Result +where + T: ToScheduler + Send + Sync + 'static, +{ let id = *(state .next_feed_id .read() @@ -72,15 +97,22 @@ async fn post_handler( .write() .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? .insert(id, feed.clone()); - state.scheduler_interface.create(feed.clone()); + + state + .scheduler_interface + .create(feed.clone()) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; Ok((StatusCode::CREATED, Json(feed))) } -async fn get_handler( +async fn get_handler( Path(id): Path, - state: State, -) -> Result { + state: State>, +) -> Result +where + T: ToScheduler + Send + Sync + 'static, +{ let feed = state .db .read() @@ -92,16 +124,19 @@ async fn get_handler( Ok(Json(feed)) } -async fn put_handler( +async fn put_handler( Path(id): Path, - state: State, + state: State>, Json(CreateFeed { name, url, frequency, headers, }): Json, -) -> impl IntoResponse { +) -> impl IntoResponse +where + T: ToScheduler + Send + Sync + 'static, +{ if state .db .read() @@ -125,15 +160,21 @@ async fn put_handler( .write() .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? .insert(id, feed.clone()); - state.scheduler_interface.update(feed.clone()); + state + .scheduler_interface + .update(feed.clone()) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; Ok(Json(feed)) } -async fn delete_handler( +async fn delete_handler( Path(id): Path, - state: State, -) -> Result { + state: State>, +) -> Result +where + T: ToScheduler + Send + Sync + 'static, +{ let feed = state .db .read() @@ -147,12 +188,18 @@ async fn delete_handler( .write() .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? .remove(&feed.id); - state.scheduler_interface.delete(feed); + state + .scheduler_interface + .delete(feed) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; Ok(StatusCode::NO_CONTENT) } -async fn list_handler(state: State) -> Result { +async fn list_handler(state: State>) -> Result +where + T: ToScheduler + Send + Sync + 'static, +{ let feeds: Vec = state .db .read() @@ -165,28 +212,40 @@ async fn list_handler(state: State) -> Result { tasks: Arc>>, } - impl TaskSender for MockSender { + impl Clone for MockSender { + fn clone(&self) -> Self { + Self { + tasks: self.tasks.clone(), + } + } + } + + impl TaskSend for MockSender + where + T: Task, + { fn send(&self, task: T) -> Result<(), SendError> { self.tasks.lock().unwrap().push(task); Ok(()) @@ -213,7 +272,7 @@ mod api_tests { #[tokio::test] async fn status() { - let sender = Arc::new(Mutex::new(MockSender::new())); + let sender = MockSender::new(); let interface = Arc::new(SchedulerInterface::new(sender.clone())); let response = app(interface) .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) @@ -221,7 +280,7 @@ mod api_tests { .unwrap(); assert_eq!(response.status(), StatusCode::OK); - assert_eq!(sender.lock().unwrap().count(), 0); + assert_eq!(sender.count(), 0); let body = axum::body::to_bytes(response.into_body(), usize::MAX) .await @@ -231,8 +290,8 @@ mod api_tests { #[tokio::test] async fn invalid() { - let sender = Arc::new(Mutex::new(MockSender::new())); - let interface = Arc::new(SchedulerInterface::new(sender)); + let sender = MockSender::new(); + let interface = Arc::new(SchedulerInterface::new(sender.clone())); let response = app(interface) .oneshot( Request::builder() @@ -250,8 +309,8 @@ mod api_tests { .unwrap(); assert_eq!(&body[..], b"Invalid URL: Cannot parse `\"abc\"` to a `u64`"); - let sender = Arc::new(Mutex::new(MockSender::new())); - let interface = Arc::new(SchedulerInterface::new(sender)); + let sender = MockSender::new(); + let interface = Arc::new(SchedulerInterface::new(sender.clone())); let response = app(interface) .oneshot( Request::builder() @@ -279,8 +338,8 @@ mod api_tests { frequency: 10, headers, }; - let sender = Arc::new(Mutex::new(MockSender::new())); - let interface = Arc::new(SchedulerInterface::new(sender)); + let sender = MockSender::new(); + let interface = Arc::new(SchedulerInterface::new(sender.clone())); let response = app(interface) .oneshot( Request::builder() @@ -295,8 +354,8 @@ mod api_tests { assert_eq!(response.status(), StatusCode::NOT_FOUND); - let sender = Arc::new(Mutex::new(MockSender::new())); - let interface = Arc::new(SchedulerInterface::new(sender)); + let sender = MockSender::new(); + let interface = Arc::new(SchedulerInterface::new(sender.clone())); let response = app(interface) .oneshot( Request::builder() @@ -321,8 +380,8 @@ mod api_tests { frequency: 10, headers, }; - let sender = Arc::new(Mutex::new(MockSender::new())); - assert_eq!(sender.lock().unwrap().count(), 0); + let sender = MockSender::new(); + assert_eq!(sender.count(), 0); let interface = Arc::new(SchedulerInterface::new(sender.clone())); let response = app(interface) .oneshot( @@ -337,7 +396,7 @@ mod api_tests { .unwrap(); assert_eq!(response.status(), StatusCode::CREATED); - assert_eq!(sender.lock().unwrap().count(), 1); + assert_eq!(sender.count(), 1); let body = axum::body::to_bytes(response.into_body(), usize::MAX) .await @@ -366,7 +425,7 @@ mod api_tests { headers: HashMap::new(), }; - let sender = Arc::new(Mutex::new(MockSender::new())); + let sender = MockSender::new(); let interface = Arc::new(SchedulerInterface::new(sender)); let address = SocketAddr::from(([0, 0, 0, 0], 3000)); tokio::spawn(async move { diff --git a/app/src/fetcher.rs b/app/src/fetcher.rs index 9f99307..09a3f45 100644 --- a/app/src/fetcher.rs +++ b/app/src/fetcher.rs @@ -1,10 +1,10 @@ -use crate::fetcher::transit::FeedMessage; -use prost::bytes::Bytes; -use prost::Message; +use prost::{bytes::Bytes, Message}; use reqwest::Client; use tokio::time::{Duration, Interval}; use ureq; +use crate::fetcher::transit::FeedMessage; + mod transit { include!(concat!(env!("OUT_DIR"), "/transit_realtime.rs")); } diff --git a/app/src/main.rs b/app/src/main.rs index 3a984a8..4e227e9 100644 --- a/app/src/main.rs +++ b/app/src/main.rs @@ -1,18 +1,17 @@ use std::net::SocketAddr; -use tokio::net::TcpListener; -use tokio::runtime::Builder; +use tokio::{net::TcpListener, runtime::Builder}; +use tracing::info; -use app::api; -use app::scheduler_interface::{build, Mode}; +use app::{api, scheduler_interface::build}; fn main() { // Initialize tracing subscriber for logging tracing_subscriber::fmt::init(); let address = SocketAddr::from(([0, 0, 0, 0], 3000)); - println!("Starting server on {}.", address); + info!("Starting server on {}.", address); - let interface = build(Mode::Async); + let interface = build(); // We use a runtime::Builder to specify the number of threads and // their name. diff --git a/app/src/scheduler_interface.rs b/app/src/scheduler_interface.rs index c280b6c..fd5eb3f 100644 --- a/app/src/scheduler_interface.rs +++ b/app/src/scheduler_interface.rs @@ -1,116 +1,123 @@ -use std::sync::mpsc::{self, SendError, Sender}; -use std::sync::{Arc, Mutex}; -use std::time::Duration; -use tulsa::{AsyncTask, Scheduler, SyncTask}; - -use crate::fetcher::{fetch_sync, recurring_fetch}; -use crate::models::Feed; - -pub enum Mode { - Sync, - Async, -} +use std::{ + marker::PhantomData, + sync::{ + mpsc::{self, SendError, Sender}, + Arc, + }, + time::Duration, +}; +use tulsa::{AsyncTask, Scheduler, SyncTask, Task}; + +use crate::{ + fetcher::{fetch_sync, recurring_fetch}, + models::Feed, +}; + +/// Used to indicate an action by a `ToScheduler` was unsuccessful. +pub struct AppSendError; + +pub fn build() -> Arc { + #[cfg(feature = "async_mode")] + { + let (sender, receiver) = mpsc::channel(); + Scheduler::::new(receiver).run(); + Arc::new(SchedulerInterface::new(sender)) + } -pub fn build(mode: Mode) -> Arc { - match mode { - Mode::Async => { - let (sender, receiver) = mpsc::channel(); - Scheduler::::new(receiver).run(); - Arc::new(SchedulerInterface::new(Arc::new(Mutex::new(sender)))) - } - Mode::Sync => { - let (sender, receiver) = mpsc::channel(); - Scheduler::::new(receiver).run(); - Arc::new(SchedulerInterface::new(Arc::new(Mutex::new(sender)))) - } + #[cfg(not(feature = "async_mode"))] + { + let (sender, receiver) = mpsc::channel(); + Scheduler::::new(receiver).run(); + Arc::new(SchedulerInterface::new(sender)) } } -pub trait TaskSender { +/// An interface to send a `Task`. This allows clients to mock a `Sender` for unit tests. +pub trait TaskSend +where + T: Task, +{ fn send(&self, task: T) -> Result<(), SendError>; } +impl TaskSend for Sender +where + T: Task, +{ + fn send(&self, task: T) -> Result<(), SendError> { + self.send(task) + } +} + /// The [Feed] will be sent to another thread, so we require ownership. pub trait ToScheduler { - fn create(&self, feed: Feed); - fn update(&self, feed: Feed); - fn delete(&self, feed: Feed); + fn create(&self, feed: Feed) -> Result<(), AppSendError>; + fn update(&self, feed: Feed) -> Result<(), AppSendError>; + fn delete(&self, feed: Feed) -> Result<(), AppSendError>; } -pub struct SchedulerInterface { - sender: Arc + Send + 'static>>, +pub struct SchedulerInterface +where + R: TaskSend + Send + 'static, + T: Task, +{ + sender: R, + _marker: PhantomData, } -impl TaskSender for Sender { - fn send(&self, task: T) -> Result<(), SendError> { - self.send(task) - } -} - -impl SchedulerInterface { - pub fn new(sender: Arc + Send + 'static>>) -> Self { - Self { sender } +impl SchedulerInterface +where + R: TaskSend + Send + 'static, + T: Task, +{ + pub fn new(sender: R) -> Self { + Self { + sender, + _marker: PhantomData, + } } } -impl ToScheduler for SchedulerInterface { - fn create(&self, feed: Feed) { +impl ToScheduler for SchedulerInterface +where + R: TaskSend + Send + 'static, +{ + fn create(&self, feed: Feed) -> Result<(), AppSendError> { let action = SyncTask::new(feed.id, Duration::from_secs(feed.frequency), move || { fetch_sync(&feed); }); - let result = self.sender.lock().unwrap().send(action); - - if let Err(e) = result { - println!("{}", e); - } + self.sender.send(action).map_err(|_| AppSendError) } - fn update(&self, feed: Feed) { + fn update(&self, feed: Feed) -> Result<(), AppSendError> { let action = SyncTask::update(feed.id, Duration::from_secs(feed.frequency), move || { fetch_sync(&feed); }); - let result = self.sender.lock().unwrap().send(action); - - if let Err(e) = result { - println!("{}", e); - } + self.sender.send(action).map_err(|_| AppSendError) } - fn delete(&self, feed: Feed) { + fn delete(&self, feed: Feed) -> Result<(), AppSendError> { let action = SyncTask::stop(feed.id); - let result = self.sender.lock().unwrap().send(action); - - if let Err(e) = result { - println!("{}", e); - } + self.sender.send(action).map_err(|_| AppSendError) } } -impl ToScheduler for SchedulerInterface { - fn create(&self, feed: Feed) { +impl ToScheduler for SchedulerInterface +where + R: TaskSend + Send + 'static, +{ + fn create(&self, feed: Feed) -> Result<(), AppSendError> { let action = AsyncTask::new(feed.id, recurring_fetch(feed)); - let result = self.sender.lock().unwrap().send(action); - - if let Err(e) = result { - println!("{}", e); - } + self.sender.send(action).map_err(|_| AppSendError) } - fn update(&self, feed: Feed) { + fn update(&self, feed: Feed) -> Result<(), AppSendError> { let action = AsyncTask::update(feed.id, recurring_fetch(feed)); - let result = self.sender.lock().unwrap().send(action); - - if let Err(e) = result { - println!("{}", e); - } + self.sender.send(action).map_err(|_| AppSendError) } - fn delete(&self, feed: Feed) { + fn delete(&self, feed: Feed) -> Result<(), AppSendError> { let action = AsyncTask::stop(feed.id); - let result = self.sender.lock().unwrap().send(action); - - if let Err(e) = result { - println!("{}", e); - } + self.sender.send(action).map_err(|_| AppSendError) } } diff --git a/app/tests/integration_test.rs b/app/tests/integration_test.rs index 836c950..5ad58e2 100644 --- a/app/tests/integration_test.rs +++ b/app/tests/integration_test.rs @@ -2,17 +2,14 @@ mod tests { use reqwest::blocking::Client; use serde_json::json; - use std::net::SocketAddr; - use std::thread; - use std::time::Duration; - use tokio::net::TcpListener; - use tokio::runtime::Builder; + use std::{net::SocketAddr, thread, time::Duration}; + use tokio::{net::TcpListener, runtime::Builder}; - use app::api; - use app::scheduler_interface::{build, Mode}; + use app::{api, scheduler_interface::build}; - fn run(mode: Mode) { - let interface = build(mode); + #[test] + fn test_run() { + let interface = build(); thread::spawn(move || { let runtime = Builder::new_multi_thread().enable_io().build().unwrap(); @@ -49,14 +46,4 @@ mod tests { } } } - - #[test] - fn async_run() { - run(Mode::Async); - } - - #[test] - fn sync_run() { - run(Mode::Sync); - } } diff --git a/tulsa/src/async_scheduler.rs b/tulsa/src/async_scheduler.rs index 09ad5f4..5e488e5 100644 --- a/tulsa/src/async_scheduler.rs +++ b/tulsa/src/async_scheduler.rs @@ -1,8 +1,8 @@ -use std::collections::HashMap; -use std::sync::mpsc::Receiver; -use std::sync::{Arc, Mutex}; -use tokio::runtime::Builder as TokioBuilder; -use tokio::task::JoinHandle as TaskJoinHandle; +use std::{ + collections::HashMap, + sync::{mpsc::Receiver, Arc, Mutex}, +}; +use tokio::{runtime::Builder as TokioBuilder, task::JoinHandle as TaskJoinHandle}; use crate::model::{AsyncTask, Operation}; diff --git a/tulsa/src/lib.rs b/tulsa/src/lib.rs index 99e29d5..fdc9a41 100644 --- a/tulsa/src/lib.rs +++ b/tulsa/src/lib.rs @@ -3,5 +3,5 @@ mod model; mod scheduler; mod thread_scheduler; -pub use model::{AsyncTask, SyncTask}; +pub use model::{AsyncTask, SyncTask, Task}; pub use scheduler::Scheduler; diff --git a/tulsa/src/model.rs b/tulsa/src/model.rs index c20f54b..1179be2 100644 --- a/tulsa/src/model.rs +++ b/tulsa/src/model.rs @@ -1,8 +1,5 @@ -use std::future::Future; -use std::pin::Pin; -use std::time::Duration; +use std::{future::Future, pin::Pin, time::Duration}; -#[derive(Clone)] pub enum Operation { Create, Update, @@ -11,14 +8,14 @@ pub enum Operation { pub struct AsyncTask { pub id: usize, - pub func: Pin + Send>>, + pub func: Pin + Send + Sync>>, pub op: Operation, } impl AsyncTask { pub fn new(id: usize, func: F) -> Self where - F: Future + Send + 'static, + F: Future + Send + Sync + 'static, { Self { id, @@ -29,7 +26,7 @@ impl AsyncTask { pub fn update(id: usize, func: F) -> Self where - F: Future + Send + 'static, + F: Future + Send + Sync + 'static, { Self { id, @@ -88,3 +85,9 @@ impl SyncTask { } } } + +/// An empty trait which allows for trait bounds to only allow `AsyncTask` or `SyncTask`. +pub trait Task {} + +impl Task for AsyncTask {} +impl Task for SyncTask {} diff --git a/tulsa/src/scheduler.rs b/tulsa/src/scheduler.rs index 5b6cd38..baed9e6 100644 --- a/tulsa/src/scheduler.rs +++ b/tulsa/src/scheduler.rs @@ -1,10 +1,13 @@ -use std::sync::mpsc::Receiver; -use std::sync::{Arc, Mutex}; -use std::thread::Builder as ThreadBuilder; +use std::{ + sync::{mpsc::Receiver, Arc, Mutex}, + thread::Builder as ThreadBuilder, +}; -use crate::async_scheduler::AsyncScheduler; -use crate::model::{AsyncTask, SyncTask}; -use crate::thread_scheduler::ThreadScheduler; +use crate::{ + async_scheduler::AsyncScheduler, + model::{AsyncTask, SyncTask}, + thread_scheduler::ThreadScheduler, +}; pub struct Scheduler { receiver: Arc>>, diff --git a/tulsa/src/thread_scheduler.rs b/tulsa/src/thread_scheduler.rs index b117dfc..69be59c 100644 --- a/tulsa/src/thread_scheduler.rs +++ b/tulsa/src/thread_scheduler.rs @@ -1,8 +1,9 @@ -use std::pin::Pin; -use std::sync::mpsc::Receiver; -use std::sync::{Arc, Mutex}; -use std::thread::{sleep, Builder as ThreadBuilder, JoinHandle as ThreadJoinHandle}; -use std::time::Duration; +use std::{ + pin::Pin, + sync::{mpsc::Receiver, Arc, Mutex}, + thread::{sleep, Builder as ThreadBuilder, JoinHandle as ThreadJoinHandle}, + time::Duration, +}; use crate::model::{Operation, SyncTask}; diff --git a/tulsa/tests/scheduler_test.rs b/tulsa/tests/scheduler_test.rs index ee60c37..2b00c62 100644 --- a/tulsa/tests/scheduler_test.rs +++ b/tulsa/tests/scheduler_test.rs @@ -1,12 +1,13 @@ #[cfg(test)] mod tests { - use std::fs::File; - use std::fs::OpenOptions; - use std::io::prelude::*; - use std::process::Command; - use std::sync::{mpsc, Mutex}; - use std::thread; - use std::time::Duration; + use std::{ + fs::{File, OpenOptions}, + io::prelude::*, + process::Command, + sync::{mpsc, Mutex}, + thread, + time::Duration, + }; use tulsa::{AsyncTask, Scheduler, SyncTask};