Skip to content

Commit

Permalink
Remove dynamic dispatch in app
Browse files Browse the repository at this point in the history
  • Loading branch information
tyleragreen committed Jul 19, 2024
1 parent 148168c commit e9de957
Show file tree
Hide file tree
Showing 13 changed files with 252 additions and 189 deletions.
6 changes: 4 additions & 2 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,7 @@ jobs:
run: cargo build --verbose
- name: Test
run: cargo test --verbose
- name: Test example-app with third-party libs
run: cd app && cargo test --features use_dependencies --verbose
- name: Test async mode
run: cargo test --features async_mode --verbose
- name: Test with third-party libs
run: cargo test --features use_dependencies --verbose
1 change: 1 addition & 0 deletions app/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,4 @@ regex = "1.9.3" # for prost_build

[features]
use_dependencies = [ "mime", "mockito", "prost-build" ]
async_mode = []
161 changes: 110 additions & 51 deletions app/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,43 @@ use axum::{
routing::{get, post},
Json, Router,
};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::{
collections::HashMap,
sync::{Arc, RwLock},
};

use crate::middleware::log_request;
use crate::models::{CreateFeed, Feed, Status};
use crate::scheduler_interface::ToScheduler;
use crate::{
middleware::log_request,
models::{CreateFeed, Feed, Status},
scheduler_interface::ToScheduler,
};

#[derive(Clone)]
struct AppState {
struct AppState<T>
where
T: ToScheduler + Send + Sync + 'static,
{
next_feed_id: Arc<RwLock<usize>>,
db: Arc<RwLock<HashMap<usize, Feed>>>,
scheduler_interface: Arc<dyn ToScheduler + Send + Sync>,
scheduler_interface: Arc<T>,
}

pub fn app(scheduler_interface: Arc<dyn ToScheduler + Send + Sync>) -> Router {
impl<T> Clone for AppState<T>
where
T: ToScheduler + Send + Sync + 'static,
{
fn clone(&self) -> Self {
Self {
next_feed_id: self.next_feed_id.clone(),
db: self.db.clone(),
scheduler_interface: self.scheduler_interface.clone(),
}
}
}

pub fn app<T>(scheduler_interface: Arc<T>) -> Router
where
T: ToScheduler + Send + Sync + 'static,
{
let state = AppState {
next_feed_id: Arc::new(RwLock::new(1)),
db: Arc::new(RwLock::new(HashMap::new())),
Expand All @@ -42,15 +64,18 @@ async fn status_handler() -> impl IntoResponse {
Json(Status::new("OK"))
}

async fn post_handler(
state: State<AppState>,
async fn post_handler<T>(
state: State<AppState<T>>,
Json(CreateFeed {
name,
url,
frequency,
headers,
}): Json<CreateFeed>,
) -> Result<impl IntoResponse, StatusCode> {
) -> Result<impl IntoResponse, StatusCode>
where
T: ToScheduler + Send + Sync + 'static,
{
let id = *(state
.next_feed_id
.read()
Expand All @@ -72,15 +97,22 @@ async fn post_handler(
.write()
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.insert(id, feed.clone());
state.scheduler_interface.create(feed.clone());

state
.scheduler_interface
.create(feed.clone())
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

Ok((StatusCode::CREATED, Json(feed)))
}

async fn get_handler(
async fn get_handler<T>(
Path(id): Path<usize>,
state: State<AppState>,
) -> Result<impl IntoResponse, StatusCode> {
state: State<AppState<T>>,
) -> Result<impl IntoResponse, StatusCode>
where
T: ToScheduler + Send + Sync + 'static,
{
let feed = state
.db
.read()
Expand All @@ -92,16 +124,19 @@ async fn get_handler(
Ok(Json(feed))
}

async fn put_handler(
async fn put_handler<T>(
Path(id): Path<usize>,
state: State<AppState>,
state: State<AppState<T>>,
Json(CreateFeed {
name,
url,
frequency,
headers,
}): Json<CreateFeed>,
) -> impl IntoResponse {
) -> impl IntoResponse
where
T: ToScheduler + Send + Sync + 'static,
{
if state
.db
.read()
Expand All @@ -125,15 +160,21 @@ async fn put_handler(
.write()
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.insert(id, feed.clone());
state.scheduler_interface.update(feed.clone());
state
.scheduler_interface
.update(feed.clone())
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

Ok(Json(feed))
}

async fn delete_handler(
async fn delete_handler<T>(
Path(id): Path<usize>,
state: State<AppState>,
) -> Result<impl IntoResponse, StatusCode> {
state: State<AppState<T>>,
) -> Result<impl IntoResponse, StatusCode>
where
T: ToScheduler + Send + Sync + 'static,
{
let feed = state
.db
.read()
Expand All @@ -147,12 +188,18 @@ async fn delete_handler(
.write()
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.remove(&feed.id);
state.scheduler_interface.delete(feed);
state
.scheduler_interface
.delete(feed)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

Ok(StatusCode::NO_CONTENT)
}

async fn list_handler(state: State<AppState>) -> Result<impl IntoResponse, StatusCode> {
async fn list_handler<T>(state: State<AppState<T>>) -> Result<impl IntoResponse, StatusCode>
where
T: ToScheduler + Send + Sync + 'static,
{
let feeds: Vec<Feed> = state
.db
.read()
Expand All @@ -165,28 +212,40 @@ async fn list_handler(state: State<AppState>) -> Result<impl IntoResponse, Statu

#[cfg(test)]
mod api_tests {
#[cfg(not(feature = "use_dependencies"))]
use crate::deps::mime;

use crate::scheduler_interface::{SchedulerInterface, TaskSender};
use tokio::net::TcpListener;
use tulsa::AsyncTask;

use super::*;
use axum::{
body::Body,
http::{self, Request, StatusCode},
};
use std::net::SocketAddr;
use std::sync::mpsc::SendError;
use std::sync::Mutex;
use std::{
net::SocketAddr,
sync::{mpsc::SendError, Mutex},
};
use tokio::net::TcpListener;
use tower::ServiceExt; // for `oneshot`

#[cfg(not(feature = "use_dependencies"))]
use crate::deps::mime;
use crate::scheduler_interface::{SchedulerInterface, TaskSend};
use tulsa::{AsyncTask, Task};

use super::*;

struct MockSender<T> {
tasks: Arc<Mutex<Vec<T>>>,
}

impl<T> TaskSender<T> for MockSender<T> {
impl<T> Clone for MockSender<T> {
fn clone(&self) -> Self {
Self {
tasks: self.tasks.clone(),
}
}
}

impl<T> TaskSend<T> for MockSender<T>
where
T: Task,
{
fn send(&self, task: T) -> Result<(), SendError<T>> {
self.tasks.lock().unwrap().push(task);
Ok(())
Expand All @@ -213,15 +272,15 @@ mod api_tests {

#[tokio::test]
async fn status() {
let sender = Arc::new(Mutex::new(MockSender::new()));
let sender = MockSender::new();
let interface = Arc::new(SchedulerInterface::new(sender.clone()));
let response = app(interface)
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();

assert_eq!(response.status(), StatusCode::OK);
assert_eq!(sender.lock().unwrap().count(), 0);
assert_eq!(sender.count(), 0);

let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
Expand All @@ -231,8 +290,8 @@ mod api_tests {

#[tokio::test]
async fn invalid() {
let sender = Arc::new(Mutex::new(MockSender::new()));
let interface = Arc::new(SchedulerInterface::new(sender));
let sender = MockSender::new();
let interface = Arc::new(SchedulerInterface::new(sender.clone()));
let response = app(interface)
.oneshot(
Request::builder()
Expand All @@ -250,8 +309,8 @@ mod api_tests {
.unwrap();
assert_eq!(&body[..], b"Invalid URL: Cannot parse `\"abc\"` to a `u64`");

let sender = Arc::new(Mutex::new(MockSender::new()));
let interface = Arc::new(SchedulerInterface::new(sender));
let sender = MockSender::new();
let interface = Arc::new(SchedulerInterface::new(sender.clone()));
let response = app(interface)
.oneshot(
Request::builder()
Expand Down Expand Up @@ -279,8 +338,8 @@ mod api_tests {
frequency: 10,
headers,
};
let sender = Arc::new(Mutex::new(MockSender::new()));
let interface = Arc::new(SchedulerInterface::new(sender));
let sender = MockSender::new();
let interface = Arc::new(SchedulerInterface::new(sender.clone()));
let response = app(interface)
.oneshot(
Request::builder()
Expand All @@ -295,8 +354,8 @@ mod api_tests {

assert_eq!(response.status(), StatusCode::NOT_FOUND);

let sender = Arc::new(Mutex::new(MockSender::new()));
let interface = Arc::new(SchedulerInterface::new(sender));
let sender = MockSender::new();
let interface = Arc::new(SchedulerInterface::new(sender.clone()));
let response = app(interface)
.oneshot(
Request::builder()
Expand All @@ -321,8 +380,8 @@ mod api_tests {
frequency: 10,
headers,
};
let sender = Arc::new(Mutex::new(MockSender::new()));
assert_eq!(sender.lock().unwrap().count(), 0);
let sender = MockSender::new();
assert_eq!(sender.count(), 0);
let interface = Arc::new(SchedulerInterface::new(sender.clone()));
let response = app(interface)
.oneshot(
Expand All @@ -337,7 +396,7 @@ mod api_tests {
.unwrap();

assert_eq!(response.status(), StatusCode::CREATED);
assert_eq!(sender.lock().unwrap().count(), 1);
assert_eq!(sender.count(), 1);

let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
Expand Down Expand Up @@ -366,7 +425,7 @@ mod api_tests {
headers: HashMap::new(),
};

let sender = Arc::new(Mutex::new(MockSender::new()));
let sender = MockSender::new();
let interface = Arc::new(SchedulerInterface::new(sender));
let address = SocketAddr::from(([0, 0, 0, 0], 3000));
tokio::spawn(async move {
Expand Down
6 changes: 3 additions & 3 deletions app/src/fetcher.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use crate::fetcher::transit::FeedMessage;
use prost::bytes::Bytes;
use prost::Message;
use prost::{bytes::Bytes, Message};
use reqwest::Client;
use tokio::time::{Duration, Interval};
use ureq;

use crate::fetcher::transit::FeedMessage;

mod transit {
include!(concat!(env!("OUT_DIR"), "/transit_realtime.rs"));
}
Expand Down
11 changes: 5 additions & 6 deletions app/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
use std::net::SocketAddr;
use tokio::net::TcpListener;
use tokio::runtime::Builder;
use tokio::{net::TcpListener, runtime::Builder};
use tracing::info;

use app::api;
use app::scheduler_interface::{build, Mode};
use app::{api, scheduler_interface::build};

fn main() {
// Initialize tracing subscriber for logging
tracing_subscriber::fmt::init();

let address = SocketAddr::from(([0, 0, 0, 0], 3000));
println!("Starting server on {}.", address);
info!("Starting server on {}.", address);

let interface = build(Mode::Async);
let interface = build();

// We use a runtime::Builder to specify the number of threads and
// their name.
Expand Down
Loading

0 comments on commit e9de957

Please sign in to comment.