diff --git a/crates/core/src/http/errors/parse_error.rs b/crates/core/src/http/errors/parse_error.rs index 1dc3820c6..85f22ce82 100644 --- a/crates/core/src/http/errors/parse_error.rs +++ b/crates/core/src/http/errors/parse_error.rs @@ -31,7 +31,7 @@ pub enum ParseError { #[error("Parse error when parse from str.")] ParseFromStr, - //// A possible error value when converting a `StatusCode` from a `u16` or `&str` + /// A possible error value when converting a `StatusCode` from a `u16` or `&str` /// This error indicates that the supplied input was not a valid number, was less /// than 100, or was greater than 999. #[error("invalid StatusCode: {0}")] diff --git a/crates/core/src/routing/mod.rs b/crates/core/src/routing/mod.rs index 8dca5fd30..b2384a641 100644 --- a/crates/core/src/routing/mod.rs +++ b/crates/core/src/routing/mod.rs @@ -455,7 +455,7 @@ pub struct PathState { pub(crate) cursor: (usize, usize), pub(crate) params: PathParams, pub(crate) end_slash: bool, // For rest match, we want include the last slash. - pub(crate) has_any_goal: bool, + pub(crate) once_ended: bool, // Once it has ended, used to determine whether the error code returned is 404 or 405. } impl PathState { /// Create new `PathState`. @@ -479,7 +479,7 @@ impl PathState { cursor: (0, 0), params: PathParams::new(), end_slash, - has_any_goal: false, + once_ended: false, } } diff --git a/crates/core/src/routing/router.rs b/crates/core/src/routing/router.rs index 0a61dc92f..e56966617 100644 --- a/crates/core/src/routing/router.rs +++ b/crates/core/src/routing/router.rs @@ -104,9 +104,9 @@ impl Router { } } } - if let Some(goal) = &self.goal.clone() { - if path_state.is_ended() { - path_state.has_any_goal = true; + if path_state.is_ended() { + path_state.once_ended = true; + if let Some(goal) = &self.goal { return Some(DetectMatched { hoops: self.hoops.clone(), goal: goal.clone(), diff --git a/crates/core/src/service.rs b/crates/core/src/service.rs index e0d503d22..cfebc1f85 100644 --- a/crates/core/src/service.rs +++ b/crates/core/src/service.rs @@ -246,7 +246,7 @@ impl HyperHandler { req.params = path_state.params; // Set default status code before service hoops executed. // We hope all hoops in service can get the correct status code. - if path_state.has_any_goal { + if path_state.once_ended { res.status_code = Some(StatusCode::METHOD_NOT_ALLOWED); } else { res.status_code = Some(StatusCode::NOT_FOUND); @@ -254,10 +254,10 @@ impl HyperHandler { let mut ctrl = FlowCtrl::new(hoops); ctrl.call_next(&mut req, &mut depot, &mut res).await; // Set it to default status code again if any hoop set status code to None. - if res.status_code.is_none() && path_state.has_any_goal { + if res.status_code.is_none() && path_state.once_ended { res.status_code = Some(StatusCode::METHOD_NOT_ALLOWED); } - } else if path_state.has_any_goal { + } else if path_state.once_ended { res.status_code = Some(StatusCode::METHOD_NOT_ALLOWED); } @@ -468,4 +468,63 @@ mod tests { let content = access(&service, "3").await; assert_eq!(content, "before1before2before3"); } + + #[tokio::test] + async fn test_service_405_or_404_error() { + #[handler] + async fn login() -> &'static str { + "login" + } + #[handler] + async fn hello() -> &'static str { + "hello" + } + let router = Router::new() + .push(Router::with_path("hello").goal(hello)) + .push( + Router::with_path("login") + .post(login) + .push(Router::with_path("user").get(login)), + ); + let service = Service::new(router); + + let res = TestClient::get("http://127.0.0.1:5801/hello") + .send(&service) + .await; + assert_eq!(res.status_code.unwrap(), StatusCode::OK); + let res = TestClient::put("http://127.0.0.1:5801/hello") + .send(&service) + .await; + assert_eq!(res.status_code.unwrap(), StatusCode::OK); + + let res = TestClient::post("http://127.0.0.1:5801/login") + .send(&service) + .await; + assert_eq!(res.status_code.unwrap(), StatusCode::OK); + + let res = TestClient::get("http://127.0.0.1:5801/login") + .send(&service) + .await; + assert_eq!(res.status_code.unwrap(), StatusCode::METHOD_NOT_ALLOWED); + + let res = TestClient::get("http://127.0.0.1:5801/login2") + .send(&service) + .await; + assert_eq!(res.status_code.unwrap(), StatusCode::NOT_FOUND); + + let res = TestClient::get("http://127.0.0.1:5801/login/user") + .send(&service) + .await; + assert_eq!(res.status_code.unwrap(), StatusCode::OK); + + let res = TestClient::post("http://127.0.0.1:5801/login/user") + .send(&service) + .await; + assert_eq!(res.status_code.unwrap(), StatusCode::METHOD_NOT_ALLOWED); + + let res = TestClient::post("http://127.0.0.1:5801/login/user1") + .send(&service) + .await; + assert_eq!(res.status_code.unwrap(), StatusCode::NOT_FOUND); + } }