Skip to content

Commit

Permalink
Add more try_* fns to Request (#929)
Browse files Browse the repository at this point in the history
  • Loading branch information
chrislearn authored Sep 25, 2024
1 parent e6f864f commit 81d3e86
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 45 deletions.
4 changes: 4 additions & 0 deletions crates/core/src/http/errors/parse_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ pub enum ParseError {
#[error("The request's body is empty.")]
EmptyBody,

/// The Hyper request's body is empty.
#[error("Data is not exist.")]
NotExist,

/// Parse error when parse from str.
#[error("Parse error when parse from str.")]
ParseFromStr,
Expand Down
2 changes: 1 addition & 1 deletion crates/core/src/http/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ cfg_feature! {
#![feature = "cookie"]
pub use cookie;
}
pub use errors::{ParseError, StatusError};
pub use errors::{ParseError, ParseResult, StatusError, StatusResult};
pub use headers;
pub use http::method::Method;
pub use http::{header, method, uri, HeaderMap, HeaderName, HeaderValue, StatusCode};
Expand Down
156 changes: 112 additions & 44 deletions crates/core/src/http/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use crate::extract::{Extractible, Metadata};
use crate::fuse::TransProto;
use crate::http::body::ReqBody;
use crate::http::form::{FilePart, FormData};
use crate::http::{Mime, ParseError, Response, Version};
use crate::http::{Mime, ParseError, ParseResult, Response, Version};
use crate::routing::PathParams;
use crate::serde::{
from_request, from_str_map, from_str_multi_map, from_str_multi_val, from_str_val,
Expand Down Expand Up @@ -416,9 +416,20 @@ impl Request {
&mut self.headers
}

/// Get header with supplied name and try to parse to a 'T', returns None if failed or not found.
/// Get header with supplied name and try to parse to a 'T'.
///
/// Returns `None` if failed or not found.
#[inline]
pub fn header<'de, T>(&'de self, key: impl AsHeaderName) -> Option<T>
where
T: Deserialize<'de>,
{
self.try_header(key).ok()
}

/// Try to get header with supplied name and try to parse to a 'T'.
#[inline]
pub fn try_header<'de, T>(&'de self, key: impl AsHeaderName) -> ParseResult<T>
where
T: Deserialize<'de>,
{
Expand All @@ -428,7 +439,7 @@ impl Request {
.iter()
.filter_map(|v| v.to_str().ok())
.collect::<Vec<_>>();
from_str_multi_val(values).ok()
from_str_multi_val(values).map_err(Into::into)
}

/// Modify a header for this request.
Expand Down Expand Up @@ -641,7 +652,19 @@ impl Request {
where
T: Deserialize<'de>,
{
self.params.get(key).and_then(|v| from_str_val(v).ok())
self.try_param(key).ok()
}

/// Try to get param value from params.
#[inline]
pub fn try_param<'de, T>(&'de self, key: &str) -> ParseResult<T>
where
T: Deserialize<'de>,
{
self.params
.get(key)
.ok_or(ParseError::NotExist)
.and_then(|v| from_str_val(v).map_err(Into::into))
}

/// Get queries reference.
Expand All @@ -663,95 +686,146 @@ impl Request {
/// Get query value from queries.
#[inline]
pub fn query<'de, T>(&'de self, key: &str) -> Option<T>
where
T: Deserialize<'de>,
{
self.try_query(key).ok()
}

/// Try to get query value from queries.
#[inline]
pub fn try_query<'de, T>(&'de self, key: &str) -> ParseResult<T>
where
T: Deserialize<'de>,
{
self.queries()
.get_vec(key)
.and_then(|vs| from_str_multi_val(vs).ok())
.ok_or(ParseError::NotExist)
.and_then(|vs| from_str_multi_val(vs).map_err(Into::into))
}

/// Get field data from form.
#[inline]
pub async fn form<'de, T>(&'de mut self, key: &str) -> Option<T>
where
T: Deserialize<'de>,
{
self.try_form(key).await.ok()
}

/// Try to get field data from form.
#[inline]
pub async fn try_form<'de, T>(&'de mut self, key: &str) -> ParseResult<T>
where
T: Deserialize<'de>,
{
self.form_data()
.await
.ok()
.and_then(|ps| ps.fields.get_vec(key))
.and_then(|vs| from_str_multi_val(vs).ok())
.and_then(|ps| ps.fields.get_vec(key).ok_or(ParseError::NotExist))
.and_then(|vs| from_str_multi_val(vs).map_err(Into::into))
}

/// Get field data from form, if key is not found in form data, then get from query.
#[inline]
pub async fn form_or_query<'de, T>(&'de mut self, key: &str) -> Option<T>
where
T: Deserialize<'de>,
{
self.try_form_or_query(key).await.ok()
}

/// Try to get field data from form, if key is not found in form data, then get from query.
#[inline]
pub async fn try_form_or_query<'de, T>(&'de mut self, key: &str) -> ParseResult<T>
where
T: Deserialize<'de>,
{
if let Ok(form_data) = self.form_data().await {
if form_data.fields.contains_key(key) {
return self.form(key).await;
return self.try_form(key).await;
}
}
self.query(key)
self.try_query(key)
}

/// Get value from query, if key is not found in queries, then get from form.
#[inline]
pub async fn query_or_form<'de, T>(&'de mut self, key: &str) -> Option<T>
where
T: Deserialize<'de>,
{
self.try_query_or_form(key).await.ok()
}

/// Try to get value from query, if key is not found in queries, then get from form.
#[inline]
pub async fn try_query_or_form<'de, T>(&'de mut self, key: &str) -> ParseResult<T>
where
T: Deserialize<'de>,
{
if self.queries().contains_key(key) {
self.query(key)
self.try_query(key)
} else {
self.form(key).await
self.try_form(key).await
}
}

/// Get [`FilePart`] reference from request.
#[inline]
pub async fn file<'a>(&'a mut self, key: &'a str) -> Option<&'a FilePart> {
self.form_data().await.ok().and_then(|ps| ps.files.get(key))
pub async fn file(&mut self, key: &str) -> Option<&FilePart> {
self.try_file(key).await.ok().flatten()
}
/// Try to get [`FilePart`] reference from request.
#[inline]
pub async fn try_file(&mut self, key: &str) -> ParseResult<Option<&FilePart>> {
self.form_data().await.map(|ps| ps.files.get(key))
}

/// Get [`FilePart`] reference from request.
#[inline]
pub async fn first_file(&mut self) -> Option<&FilePart> {
self.try_first_file().await.ok().flatten()
}

/// Try to get [`FilePart`] reference from request.
#[inline]
pub async fn try_first_file(&mut self) -> ParseResult<Option<&FilePart>> {
self.form_data()
.await
.ok()
.and_then(|ps| ps.files.iter().next())
.map(|(_, f)| f)
.map(|ps| ps.files.iter().next().map(|(_, f)| f))
}

/// Get [`FilePart`] list reference from request.
#[inline]
pub async fn files<'a>(&'a mut self, key: &'a str) -> Option<&'a Vec<FilePart>> {
self.form_data()
.await
.ok()
.and_then(|ps| ps.files.get_vec(key))
pub async fn files(&mut self, key: &str) -> Option<&Vec<FilePart>> {
self.try_files(key).await.ok().flatten()
}
/// Try to get [`FilePart`] list reference from request.
#[inline]
pub async fn try_files(&mut self, key: &str) -> ParseResult<Option<&Vec<FilePart>>> {
self.form_data().await.map(|ps| ps.files.get_vec(key))
}

/// Get [`FilePart`] list reference from request.
#[inline]
pub async fn all_files(&mut self) -> Vec<&FilePart> {
self.try_all_files().await.unwrap_or_default()
}

/// Try to get [`FilePart`] list reference from request.
#[inline]
pub async fn try_all_files(&mut self) -> ParseResult<Vec<&FilePart>> {
self.form_data()
.await
.ok()
.map(|ps| ps.files.iter().map(|(_, f)| f).collect())
.unwrap_or_default()
}

/// Get request payload with default max size limit(64KB).
///
/// <https://github.com/hyperium/hyper/issues/3111>
/// *Notice: This method takes body.
#[inline]
pub async fn payload(&mut self) -> Result<&Bytes, ParseError> {
pub async fn payload(&mut self) -> ParseResult<&Bytes> {
self.payload_with_max_size(self.secure_max_size()).await
}

Expand All @@ -760,7 +834,7 @@ impl Request {
/// <https://github.com/hyperium/hyper/issues/3111>
/// *Notice: This method takes body.
#[inline]
pub async fn payload_with_max_size(&mut self, max_size: usize) -> Result<&Bytes, ParseError> {
pub async fn payload_with_max_size(&mut self, max_size: usize) -> ParseResult<&Bytes> {
let body = self.take_body();
self.payload
.get_or_try_init(|| async {
Expand All @@ -777,7 +851,7 @@ impl Request {
///
/// *Notice: This method takes body and body's size is not limited.
#[inline]
pub async fn form_data(&mut self) -> Result<&FormData, ParseError> {
pub async fn form_data(&mut self) -> ParseResult<&FormData> {
if let Some(ctype) = self.content_type() {
if ctype.subtype() == mime::WWW_FORM_URLENCODED || ctype.type_() == mime::MULTIPART {
let body = self.take_body();
Expand All @@ -795,7 +869,7 @@ impl Request {

/// Extract request as type `T` from request's different parts.
#[inline]
pub async fn extract<'de, T>(&'de mut self) -> Result<T, ParseError>
pub async fn extract<'de, T>(&'de mut self) -> ParseResult<T>
where
T: Extractible<'de> + Deserialize<'de> + Send,
{
Expand All @@ -807,7 +881,7 @@ impl Request {
pub async fn extract_with_metadata<'de, T>(
&'de mut self,
metadata: &'de Metadata,
) -> Result<T, ParseError>
) -> ParseResult<T>
where
T: Deserialize<'de> + Send,
{
Expand All @@ -816,7 +890,7 @@ impl Request {

/// Parse url params as type `T` from request.
#[inline]
pub fn parse_params<'de, T>(&'de mut self) -> Result<T, ParseError>
pub fn parse_params<'de, T>(&'de mut self) -> ParseResult<T>
where
T: Deserialize<'de>,
{
Expand All @@ -826,7 +900,7 @@ impl Request {

/// Parse queries as type `T` from request.
#[inline]
pub fn parse_queries<'de, T>(&'de mut self) -> Result<T, ParseError>
pub fn parse_queries<'de, T>(&'de mut self) -> ParseResult<T>
where
T: Deserialize<'de>,
{
Expand All @@ -836,7 +910,7 @@ impl Request {

/// Parse headers as type `T` from request.
#[inline]
pub fn parse_headers<'de, T>(&'de mut self) -> Result<T, ParseError>
pub fn parse_headers<'de, T>(&'de mut self) -> ParseResult<T>
where
T: Deserialize<'de>,
{
Expand All @@ -851,7 +925,7 @@ impl Request {
#![feature = "cookie"]
/// Parse cookies as type `T` from request.
#[inline]
pub fn parse_cookies<'de, T>(&'de mut self) -> Result<T, ParseError>
pub fn parse_cookies<'de, T>(&'de mut self) -> ParseResult<T>
where
T: Deserialize<'de>,
{
Expand All @@ -865,18 +939,15 @@ impl Request {

/// Parse json body as type `T` from request with default max size limit.
#[inline]
pub async fn parse_json<'de, T>(&'de mut self) -> Result<T, ParseError>
pub async fn parse_json<'de, T>(&'de mut self) -> ParseResult<T>
where
T: Deserialize<'de>,
{
self.parse_json_with_max_size(self.secure_max_size()).await
}
/// Parse json body as type `T` from request with max size limit.
#[inline]
pub async fn parse_json_with_max_size<'de, T>(
&'de mut self,
max_size: usize,
) -> Result<T, ParseError>
pub async fn parse_json_with_max_size<'de, T>(&'de mut self, max_size: usize) -> ParseResult<T>
where
T: Deserialize<'de>,
{
Expand All @@ -902,7 +973,7 @@ impl Request {

/// Parse form body as type `T` from request.
#[inline]
pub async fn parse_form<'de, T>(&'de mut self) -> Result<T, ParseError>
pub async fn parse_form<'de, T>(&'de mut self) -> ParseResult<T>
where
T: Deserialize<'de>,
{
Expand All @@ -917,18 +988,15 @@ impl Request {

/// Parse json body or form body as type `T` from request with default max size.
#[inline]
pub async fn parse_body<'de, T>(&'de mut self) -> Result<T, ParseError>
pub async fn parse_body<'de, T>(&'de mut self) -> ParseResult<T>
where
T: Deserialize<'de>,
{
self.parse_body_with_max_size(self.secure_max_size()).await
}

/// Parse json body or form body as type `T` from request with max size.
pub async fn parse_body_with_max_size<'de, T>(
&'de mut self,
max_size: usize,
) -> Result<T, ParseError>
pub async fn parse_body_with_max_size<'de, T>(&'de mut self, max_size: usize) -> ParseResult<T>
where
T: Deserialize<'de>,
{
Expand Down

0 comments on commit 81d3e86

Please sign in to comment.