From 1daf4b7ec0c7601330f698f37b6dd5e2543249f2 Mon Sep 17 00:00:00 2001 From: Tyler Green Date: Thu, 27 Jun 2024 19:22:18 -0400 Subject: [PATCH] Small api cleanup --- example-app/src/api.rs | 92 ++++++++++++++++----------- example-app/src/deps/mockito/mock.rs | 4 +- example-app/tests/integration_test.rs | 2 +- 3 files changed, 59 insertions(+), 39 deletions(-) diff --git a/example-app/src/api.rs b/example-app/src/api.rs index 99e0c1c..a9295cc 100644 --- a/example-app/src/api.rs +++ b/example-app/src/api.rs @@ -14,14 +14,14 @@ use crate::scheduler_interface::ToScheduler; #[derive(Clone)] struct AppState { - feed_id: Arc>, + next_feed_id: Arc>, db: Arc>>, scheduler_interface: Arc, } pub fn app(scheduler_interface: Arc) -> Router { let state = AppState { - feed_id: Arc::new(RwLock::new(1)), + next_feed_id: Arc::new(RwLock::new(1)), db: Arc::new(RwLock::new(HashMap::new())), scheduler_interface, }; @@ -40,10 +40,16 @@ struct Status { status: String, } +impl Status { + fn new(status: &str) -> Self { + Self { + status: status.to_string(), + } + } +} + async fn status_handler() -> impl IntoResponse { - Json(Status { - status: "OK".to_string(), - }) + Json(Status::new("OK")) } #[derive(Clone, Deserialize, Serialize)] @@ -56,18 +62,23 @@ struct CreateFeed { async fn post_handler( state: State, - Json(payload): Json, + Json(CreateFeed { + name, + url, + frequency, + headers, + }): Json, ) -> impl IntoResponse { - let id = *(state.feed_id.read().unwrap()); + let id = *(state.next_feed_id.read().unwrap()); let feed = Feed { id, - name: payload.name, - url: payload.url, - frequency: payload.frequency, - headers: payload.headers, + name, + url, + frequency, + headers, }; - *(state.feed_id.write().unwrap()) += 1; + *(state.next_feed_id.write().unwrap()) += 1; state.db.write().unwrap().insert(id, feed.clone()); state.scheduler_interface.create(feed.clone()); @@ -75,10 +86,7 @@ async fn post_handler( } async fn get_handler(path: Path, state: State) -> impl IntoResponse { - let id: usize = match path.parse() { - Ok(i) => i, - Err(_) => return Err(StatusCode::BAD_REQUEST), - }; + let id = path.parse::().map_err(|_| StatusCode::BAD_REQUEST)?; match state.db.read().unwrap().get(&id).cloned() { Some(feed) => Ok(Json(feed)), @@ -89,22 +97,24 @@ async fn get_handler(path: Path, state: State) -> impl IntoRes async fn put_handler( path: Path, state: State, - Json(payload): Json, + Json(CreateFeed { + name, + url, + frequency, + headers, + }): Json, ) -> impl IntoResponse { - let id: usize = match path.parse() { - Ok(i) => i, - Err(_) => return Err(StatusCode::BAD_REQUEST), - }; + let id = path.parse::().map_err(|_| StatusCode::BAD_REQUEST)?; if state.db.read().unwrap().get(&id).is_none() { return Err(StatusCode::NOT_FOUND); } let feed = Feed { id, - name: payload.name, - url: payload.url, - frequency: payload.frequency, - headers: payload.headers, + name, + url, + frequency, + headers, }; state.db.write().unwrap().insert(id, feed.clone()); @@ -114,10 +124,7 @@ async fn put_handler( } async fn delete_handler(path: Path, state: State) -> impl IntoResponse { - let id: usize = match path.parse() { - Ok(i) => i, - Err(_) => return Err(StatusCode::BAD_REQUEST), - }; + let id = path.parse::().map_err(|_| StatusCode::BAD_REQUEST)?; let feed = match state.db.read().unwrap().get(&id).cloned() { Some(f) => f, None => return Err(StatusCode::NOT_FOUND), @@ -195,7 +202,9 @@ mod api_tests { assert_eq!(response.status(), StatusCode::OK); assert_eq!(sender.lock().unwrap().count(), 0); - let body = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap(); + let body = axum::body::to_bytes(response.into_body(), usize::MAX) + .await + .unwrap(); assert_eq!(&body[..], b"{\"status\":\"OK\"}"); } @@ -215,7 +224,9 @@ mod api_tests { assert_eq!(response.status(), StatusCode::BAD_REQUEST); - let body = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap(); + let body = axum::body::to_bytes(response.into_body(), usize::MAX) + .await + .unwrap(); assert_eq!(body.len(), 0); let sender = Arc::new(Mutex::new(MockSender::new())); @@ -232,7 +243,9 @@ mod api_tests { assert_eq!(response.status(), StatusCode::BAD_REQUEST); - let body = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap(); + let body = axum::body::to_bytes(response.into_body(), usize::MAX) + .await + .unwrap(); assert_eq!(body.len(), 0); } @@ -305,7 +318,9 @@ mod api_tests { assert_eq!(response.status(), StatusCode::CREATED); assert_eq!(sender.lock().unwrap().count(), 1); - let body = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap(); + let body = axum::body::to_bytes(response.into_body(), usize::MAX) + .await + .unwrap(); let f: Feed = serde_json::from_slice(&body).unwrap(); assert_eq!(f.id, 1); @@ -355,7 +370,10 @@ mod api_tests { let response = client .post(format!("http://localhost:3000/feed")) - .header(http::header::CONTENT_TYPE.as_str(), mime::APPLICATION_JSON.as_ref()) + .header( + http::header::CONTENT_TYPE.as_str(), + mime::APPLICATION_JSON.as_ref(), + ) .json(&serde_json::json!(input)) .send() .await @@ -377,7 +395,10 @@ mod api_tests { let response = client .put(format!("http://localhost:3000/feed/1")) - .header(http::header::CONTENT_TYPE.as_str(), mime::APPLICATION_JSON.as_ref()) + .header( + http::header::CONTENT_TYPE.as_str(), + mime::APPLICATION_JSON.as_ref(), + ) .json(&serde_json::json!(input_new)) .send() .await @@ -401,7 +422,6 @@ mod api_tests { assert_eq!(f.url, "http"); assert_eq!(f.frequency, 20); - let response = client .delete(format!("http://localhost:3000/feed/1")) .send() diff --git a/example-app/src/deps/mockito/mock.rs b/example-app/src/deps/mockito/mock.rs index b7ed6fc..3bd4196 100644 --- a/example-app/src/deps/mockito/mock.rs +++ b/example-app/src/deps/mockito/mock.rs @@ -1,6 +1,6 @@ -use hyper::StatusCode; -use hyper::Request; use hyper::body::Incoming; +use hyper::Request; +use hyper::StatusCode; use rand; use std::sync::{Arc, RwLock}; diff --git a/example-app/tests/integration_test.rs b/example-app/tests/integration_test.rs index 48f796e..52195e3 100644 --- a/example-app/tests/integration_test.rs +++ b/example-app/tests/integration_test.rs @@ -2,10 +2,10 @@ mod tests { use reqwest::blocking::Client; use serde_json::json; - use tokio::net::TcpListener; use std::net::SocketAddr; use std::thread; use std::time::Duration; + use tokio::net::TcpListener; use tokio::runtime::Builder; use gtfs_realtime_rust::api;