diff --git a/Cargo.toml b/Cargo.toml index c6fe397..8e81da0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ members = [ "flareon-orm", # Examples "examples/hello-world", + "examples/todo-list", ] resolver = "2" @@ -15,6 +16,7 @@ edition = "2021" license = "MIT OR Apache-2.0" [workspace.dependencies] +askama = "0.12.1" async-trait = "0.1.80" axum = "0.7.5" bytes = "1.6.1" @@ -22,9 +24,13 @@ chrono = { version = "0.4.38", features = ["serde"] } clap = { version = "4.5.8", features = ["derive", "env"] } derive_builder = "0.20.0" env_logger = "0.11.3" +flareon = { path = "flareon" } +flareon_macros = { path = "flareon-macros" } +form_urlencoded = "1.2.1" indexmap = "2.2.6" itertools = "0.13.0" log = "0.4.22" +num-traits = "0.2.19" regex = "1.10.5" serde = "1.0.203" slug = "0.1.5" diff --git a/examples/hello-world/src/main.rs b/examples/hello-world/src/main.rs index c0d5667..73ea398 100644 --- a/examples/hello-world/src/main.rs +++ b/examples/hello-world/src/main.rs @@ -1,8 +1,8 @@ use std::sync::Arc; -use flareon::prelude::{ - Body, Error, FlareonApp, FlareonProject, Request, Response, Route, StatusCode, -}; +use flareon::prelude::{Body, Error, FlareonApp, FlareonProject, Response, StatusCode}; +use flareon::request::Request; +use flareon::router::Route; fn return_hello(_request: Request) -> Result { Ok(Response::new_html( diff --git a/examples/todo-list/Cargo.toml b/examples/todo-list/Cargo.toml new file mode 100644 index 0000000..8db67d6 --- /dev/null +++ b/examples/todo-list/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "example-todo-list" +version = "0.1.0" +publish = false +description = "TODO List - Flareon example." +edition = "2021" + +[dependencies] +askama = "0.12.1" +flareon = { path = "../../flareon" } +tokio = { version = "1.38.0", features = ["macros", "rt-multi-thread"] } +env_logger = "0.11.5" diff --git a/examples/todo-list/src/main.rs b/examples/todo-list/src/main.rs new file mode 100644 index 0000000..b3f6808 --- /dev/null +++ b/examples/todo-list/src/main.rs @@ -0,0 +1,100 @@ +use std::cell::OnceCell; +use std::sync::Arc; + +use askama::Template; +use flareon::forms::{AsFormField, CharField, Form, FormFieldBase}; +use flareon::prelude::{Body, Error, FlareonApp, FlareonProject, Response, Route, StatusCode}; +use flareon::request::Request; +use flareon::reverse; +use tokio::sync::{Mutex, RwLock}; + +#[derive(Debug, Clone)] +struct TodoItem { + id: u32, + title: String, +} + +#[derive(Debug, Template)] +#[template(path = "index.html")] +struct IndexTemplate<'a> { + request: &'a Request, + todo_items: Vec, +} + +static TODOS: RwLock> = RwLock::const_new(Vec::new()); + +async fn index(request: Request) -> Result { + let todo_items = (*TODOS.read().await).clone(); + let index_template = IndexTemplate { + request: &request, + todo_items, + }; + let rendered = index_template.render().unwrap(); + + Ok(Response::new_html( + StatusCode::OK, + Body::fixed(rendered.as_bytes().to_vec()), + )) +} + +#[derive(Debug, Form)] +struct TodoForm { + #[form(opt(max_length = 100))] + title: String, +} + +#[derive(Debug, Form)] +struct RemoveTodoForm { + id: u32, +} + +async fn add_todo(mut request: Request) -> Result { + let todo_form = TodoForm::from_request(&mut request).await.unwrap(); + + { + let mut todos = TODOS.write().await; + let index = todos.len() as u32; + todos.push(TodoItem { + id: index + 1, + title: todo_form.title, + }); + } + + Ok(reverse!(request, "index")) +} + +async fn remove_todo(mut request: Request) -> Result { + let remove_todo_form = RemoveTodoForm::from_request(&mut request).await.unwrap(); + + // { + // let mut todos = TODOS.write().await; + // let index = todos.len() as u32; + // todos.push(TodoItem { + // id: index + 1, + // title: todo_form.title, + // }); + // } + + Ok(reverse!(request, "index")) +} + +#[tokio::main] +async fn main() { + env_logger::init(); + + let todo_app = FlareonApp::builder() + .urls([ + Route::with_handler_and_name("/", Arc::new(Box::new(index)), "index"), + Route::with_handler_and_name("/add", Arc::new(Box::new(add_todo)), "add-todo"), + Route::with_handler_and_name("/remove", Arc::new(Box::new(remove_todo)), "remove-todo"), + ]) + .build() + .unwrap(); + + let todo_project = FlareonProject::builder() + .register_app_with_views(todo_app, "") + .build() + .unwrap(); + + flareon::run(todo_project, "127.0.0.1:8000").await.unwrap(); +} diff --git a/examples/todo-list/templates/index.html b/examples/todo-list/templates/index.html new file mode 100644 index 0000000..3cf11cb --- /dev/null +++ b/examples/todo-list/templates/index.html @@ -0,0 +1,27 @@ +{% let request = request %} + + + + + + + TODO List + + +

TODO List

+
+ + +
+
    + {% for todo in todo_items %} +
  • + {{ todo.title }} +
    + +
    +
  • + {% endfor %} +
+ + diff --git a/flareon-macros/Cargo.toml b/flareon-macros/Cargo.toml index e33eb64..5dbb46e 100644 --- a/flareon-macros/Cargo.toml +++ b/flareon-macros/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "flareon-macros" +name = "flareon_macros" version = "0.1.0" edition.workspace = true license.workspace = true @@ -8,4 +8,19 @@ description = "Modern web framework focused on speed and ease of use - macros." [lib] proc-macro = true +[[test]] +name = "tests" +path = "tests/compile_tests.rs" + [dependencies] +darling = "0.20.10" +proc-macro-crate = "3.1.0" +proc-macro2 = "1.0.86" +proc-macro2-diagnostics = "0.10.1" +quote = "1.0.36" +syn = { version = "2.0.72", features = ["full"] } + +[dev-dependencies] +flareon.workspace = true +serde.workspace = true +trybuild = { version = "1.0.99", features = ["diff"] } diff --git a/flareon-macros/src/form.rs b/flareon-macros/src/form.rs new file mode 100644 index 0000000..7dcda16 --- /dev/null +++ b/flareon-macros/src/form.rs @@ -0,0 +1,302 @@ +use std::collections::HashMap; + +use darling::{FromDeriveInput, FromField, FromMeta}; +use proc_macro2::{Ident, TokenStream}; +use quote::{format_ident, quote, ToTokens, TokenStreamExt}; + +pub fn form_for_struct(ast: syn::DeriveInput) -> TokenStream { + let opts = match FormOpts::from_derive_input(&ast) { + Ok(val) => val, + Err(err) => { + return err.write_errors(); + } + }; + + let mut builder = opts.as_form_derive_builder(); + for field in opts.fields() { + builder.push_field(field); + } + + quote!(#builder) +} + +#[derive(Debug, FromDeriveInput)] +#[darling(forward_attrs(allow, doc, cfg), supports(struct_named))] +pub struct FormOpts { + ident: syn::Ident, + data: darling::ast::Data, +} + +impl FormOpts { + fn fields(&self) -> Vec<&Field> { + self.data + .as_ref() + .take_struct() + .expect("Only structs are supported") + .fields + } + + fn field_count(&self) -> usize { + self.fields().len() + } + + fn as_form_derive_builder(&self) -> FormDeriveBuilder { + FormDeriveBuilder { + name: self.ident.clone(), + context_struct_name: format_ident!("{}Context", self.ident), + context_struct_errors_name: format_ident!("{}ContextErrors", self.ident), + context_struct_field_iterator_name: format_ident!("{}ContextFieldIterator", self.ident), + fields_as_struct_fields: Vec::with_capacity(self.field_count()), + fields_as_struct_fields_new: Vec::with_capacity(self.field_count()), + fields_as_context_from_request: Vec::with_capacity(self.field_count()), + fields_as_from_context: Vec::with_capacity(self.field_count()), + fields_as_errors: Vec::with_capacity(self.field_count()), + fields_as_get_errors: Vec::with_capacity(self.field_count()), + fields_as_get_errors_mut: Vec::with_capacity(self.field_count()), + fields_as_iterator_next: Vec::with_capacity(self.field_count()), + } + } +} + +#[derive(Debug, Clone, FromField)] +#[darling(attributes(form))] +pub struct Field { + ident: Option, + ty: syn::Type, + opt: Option>, +} + +#[derive(Debug)] +struct FormDeriveBuilder { + name: Ident, + context_struct_name: Ident, + context_struct_errors_name: Ident, + context_struct_field_iterator_name: Ident, + fields_as_struct_fields: Vec, + fields_as_struct_fields_new: Vec, + fields_as_context_from_request: Vec, + fields_as_from_context: Vec, + fields_as_errors: Vec, + fields_as_get_errors: Vec, + fields_as_get_errors_mut: Vec, + fields_as_iterator_next: Vec, +} + +impl ToTokens for FormDeriveBuilder { + fn to_tokens(&self, tokens: &mut TokenStream) { + tokens.append_all(self.build_form_impl()); + tokens.append_all(self.build_form_context_impl()); + tokens.append_all(self.build_errors_struct()); + tokens.append_all(self.build_context_field_iterator_impl()); + } +} + +impl FormDeriveBuilder { + fn push_field(&mut self, field: &Field) { + let name = field.ident.as_ref().unwrap(); + let ty = &field.ty; + let index = self.fields_as_struct_fields.len(); + let opt = &field.opt; + + self.fields_as_struct_fields + .push(quote!(#name: <#ty as ::flareon::forms::AsFormField>::Type)); + + self.fields_as_struct_fields_new.push({ + let custom_options_setters: Vec<_> = if let Some(opt) = opt { + opt.iter() + .map(|(key, value)| quote!(custom_options.#key = Some(#value))) + .collect() + } else { + Vec::new() + }; + quote!(#name: { + let options = ::flareon::forms::FormFieldOptions { + id: stringify!(#name).to_owned(), + }; + type Field = <#ty as ::flareon::forms::AsFormField>::Type; + type CustomOptions = ::CustomOptions; + let mut custom_options: CustomOptions = ::core::default::Default::default(); + #( #custom_options_setters; )* + ::with_options(options, custom_options) + }) + }); + + self.fields_as_context_from_request + .push(quote!(stringify!(#name) => { + ::flareon::forms::FormFieldBase::set_value(&mut self.#name, value) + })); + + self.fields_as_from_context.push(quote!(#name: <#ty as ::flareon::forms::AsFormField>::clean_value(&context.#name).unwrap())); + + self.fields_as_errors + .push(quote!(#name: Vec<::flareon::forms::FormFieldValidationError>)); + + self.fields_as_get_errors + .push(quote!(stringify!(#name) => self.__errors.#name.as_slice())); + + self.fields_as_get_errors_mut + .push(quote!(stringify!(#name) => self.__errors.#name.as_mut())); + + self.fields_as_iterator_next.push( + quote!(#index => Some(&self.context.#name as &'a dyn ::flareon::forms::FormFieldBase)), + ); + } + + fn build_form_impl(&self) -> TokenStream { + let name = &self.name; + let context_struct_name = &self.context_struct_name; + let fields_as_from_context = &self.fields_as_from_context; + + quote! { + #[::flareon::private::async_trait] + #[automatically_derived] + impl ::flareon::forms::Form for #name { + type Context = #context_struct_name; + + async fn from_request( + request: &mut ::flareon::request::Request + ) -> Result> { + let mut context = ::build_context(request).await?; + + Ok(Self { + #( #fields_as_from_context, )* + }) + } + } + } + } + + fn build_form_context_impl(&self) -> TokenStream { + let context_struct_name = &self.context_struct_name; + let context_struct_errors_name = &self.context_struct_errors_name; + let context_struct_field_iterator_name = &self.context_struct_field_iterator_name; + + let fields_as_struct_fields = &self.fields_as_struct_fields; + let fields_as_struct_fields_new = &self.fields_as_struct_fields_new; + let fields_as_context_from_request = &self.fields_as_context_from_request; + let fields_as_get_errors = &self.fields_as_get_errors; + let fields_as_get_errors_mut = &self.fields_as_get_errors_mut; + + quote! { + #[derive(::core::fmt::Debug)] + struct #context_struct_name { + __errors: #context_struct_errors_name, + #( #fields_as_struct_fields, )* + } + + #[automatically_derived] + impl ::flareon::forms::FormContext for #context_struct_name { + fn new() -> Self { + Self { + __errors: ::core::default::Default::default(), + #( #fields_as_struct_fields_new, )* + } + } + + fn fields(&self) -> impl Iterator + '_ { + #context_struct_field_iterator_name { + context: self, + index: 0, + } + } + + fn set_value( + &mut self, + field_id: &str, + value: ::std::borrow::Cow, + ) -> Result<(), ::flareon::forms::FormFieldValidationError> { + match field_id { + #( #fields_as_context_from_request, )* + _ => {} + } + Ok(()) + } + + fn get_errors( + &self, + target: ::flareon::forms::FormErrorTarget + ) -> &[::flareon::forms::FormFieldValidationError] { + match target { + ::flareon::forms::FormErrorTarget::Field(field_id) => { + match field_id { + #( #fields_as_get_errors, )* + _ => { + panic!("Unknown field name passed to get_errors: `{}`", field_id); + } + } + } + ::flareon::forms::FormErrorTarget::Form => { + self.__errors.__form.as_slice() + } + } + } + + fn get_errors_mut( + &mut self, + target: ::flareon::forms::FormErrorTarget + ) -> &mut Vec<::flareon::forms::FormFieldValidationError> { + match target { + ::flareon::forms::FormErrorTarget::Field(field_id) => { + match field_id { + #( #fields_as_get_errors_mut, )* + _ => { + panic!("Unknown field name passed to get_errors_mut: `{}`", field_id); + } + } + } + ::flareon::forms::FormErrorTarget::Form => { + self.__errors.__form.as_mut() + } + } + } + } + } + } + + fn build_errors_struct(&self) -> TokenStream { + let context_struct_errors_name = &self.context_struct_errors_name; + let fields_as_errors = &self.fields_as_errors; + + quote! { + #[derive(::core::fmt::Debug, ::core::default::Default)] + struct #context_struct_errors_name { + __form: Vec<::flareon::forms::FormFieldValidationError>, + #( #fields_as_errors, )* + } + } + } + + fn build_context_field_iterator_impl(&self) -> TokenStream { + let context_struct_name = &self.context_struct_name; + let context_struct_field_iterator_name = &self.context_struct_field_iterator_name; + let fields_as_iterator_next = &self.fields_as_iterator_next; + + quote! { + #[derive(::core::fmt::Debug)] + struct #context_struct_field_iterator_name<'a> { + context: &'a #context_struct_name, + index: usize, + } + + #[automatically_derived] + impl<'a> Iterator for #context_struct_field_iterator_name<'a> { + type Item = &'a dyn ::flareon::forms::FormFieldBase; + + fn next(&mut self) -> Option { + let result = match self.index { + #( #fields_as_iterator_next, )* + _ => None, + }; + + if result.is_some() { + self.index += 1; + } else { + self.index = 0; + } + + result + } + } + } + } +} diff --git a/flareon-macros/src/lib.rs b/flareon-macros/src/lib.rs index 6136834..4248757 100644 --- a/flareon-macros/src/lib.rs +++ b/flareon-macros/src/lib.rs @@ -1,6 +1,14 @@ +mod form; + +use darling::FromDeriveInput; use proc_macro::TokenStream; +use syn::parse_macro_input; + +use crate::form::form_for_struct; -#[proc_macro] -pub fn flareon(_input: TokenStream) -> TokenStream { - unimplemented!() +#[proc_macro_derive(Form, attributes(form))] +pub fn derive_form(input: TokenStream) -> TokenStream { + let ast = parse_macro_input!(input as syn::DeriveInput); + let token_stream = form_for_struct(ast); + token_stream.into() } diff --git a/flareon-macros/tests/compile_tests.rs b/flareon-macros/tests/compile_tests.rs new file mode 100644 index 0000000..d067819 --- /dev/null +++ b/flareon-macros/tests/compile_tests.rs @@ -0,0 +1,5 @@ +#[test] +fn test_derive_form() { + let t = trybuild::TestCases::new(); + t.pass("tests/ui/derive_form.rs"); +} diff --git a/flareon-macros/tests/ui/derive_form.rs b/flareon-macros/tests/ui/derive_form.rs new file mode 100644 index 0000000..a33c27a --- /dev/null +++ b/flareon-macros/tests/ui/derive_form.rs @@ -0,0 +1,10 @@ +use flareon_macros::Form; + +#[derive(Debug, Form)] +struct MyForm { + name: String, + name2: std::string::String, + age: u32, +} + +fn main() {} diff --git a/flareon/Cargo.toml b/flareon/Cargo.toml index 25d8064..afbf362 100644 --- a/flareon/Cargo.toml +++ b/flareon/Cargo.toml @@ -6,11 +6,16 @@ license.workspace = true description = "Modern web framework focused on speed and ease of use." [dependencies] +askama.workspace = true async-trait.workspace = true axum.workspace = true bytes.workspace = true derive_builder.workspace = true +flareon_macros.workspace = true +form_urlencoded.workspace = true indexmap.workspace = true log.workspace = true +num-traits.workspace = true +regex.workspace = true thiserror.workspace = true tokio.workspace = true diff --git a/flareon/src/error.rs b/flareon/src/error.rs new file mode 100644 index 0000000..7f22524 --- /dev/null +++ b/flareon/src/error.rs @@ -0,0 +1,28 @@ +use thiserror::Error; + +#[derive(Debug, Error)] +#[non_exhaustive] +pub enum Error { + #[error("Could not retrieve request body: {source}")] + ReadRequestBody { + #[from] + source: axum::Error, + }, + #[error("Invalid content type; expected {expected}, found {actual}")] + InvalidContentType { + expected: &'static str, + actual: String, + }, + #[error("Could not create a response object: {0}")] + ResponseBuilder(#[from] axum::http::Error), + #[error("Failed to reverse route `{view_name}`")] + ReverseFailed { view_name: String }, + #[error("Failed to render template: {0}")] + TemplateRender(#[from] askama::Error), +} + +impl From for askama::Error { + fn from(value: Error) -> Self { + askama::Error::Custom(Box::new(value)) + } +} diff --git a/flareon/src/forms.rs b/flareon/src/forms.rs new file mode 100644 index 0000000..9922b67 --- /dev/null +++ b/flareon/src/forms.rs @@ -0,0 +1,295 @@ +use std::borrow::Cow; + +use async_trait::async_trait; +pub use flareon_macros::Form; +use num_traits::Num; +use thiserror::Error; + +use crate::request::Request; + +#[derive(Debug, Error)] +pub enum FormError { + #[error("Request error: {error}")] + RequestError { + #[from] + error: crate::Error, + }, + // TODO better error message + #[error("...")] + ValidationError { context: T::Context }, +} + +const FORM_FIELD_REQUIRED: &str = "This field is required."; + +#[derive(Debug, Error)] +#[error("{message}")] +pub struct FormFieldValidationError { + message: Cow<'static, str>, +} + +#[derive(Debug)] +pub enum FormErrorTarget<'a> { + Field(&'a str), + Form, +} + +impl FormFieldValidationError { + #[must_use] + pub const fn from_string(message: String) -> Self { + Self { + message: Cow::Owned(message), + } + } + + #[must_use] + pub const fn from_static(message: &'static str) -> Self { + Self { + message: Cow::Borrowed(message), + } + } +} + +#[async_trait] +pub trait Form: Sized { + type Context: FormContext; + + async fn from_request(request: &mut Request) -> Result>; + + async fn build_context(request: &mut Request) -> Result> { + let form_data = request + .form_data() + .await + .map_err(|error| FormError::RequestError { error })?; + + let mut context = Self::Context::new(); + let mut has_errors = false; + + for (field_id, value) in Request::query_pairs(&form_data) { + let field_id = field_id.as_ref(); + + if let Err(err) = context.set_value(field_id, value) { + context.add_error(FormErrorTarget::Field(field_id), err); + has_errors = true; + } + } + + if has_errors { + Err(FormError::ValidationError { context }) + } else { + Ok(context) + } + } +} + +pub trait FormContext: Sized { + fn new() -> Self; + + fn fields(&self) -> impl Iterator + '_; + + fn set_value( + &mut self, + field_id: &str, + value: Cow, + ) -> Result<(), FormFieldValidationError>; + + fn add_error(&mut self, target: FormErrorTarget, error: FormFieldValidationError) { + self.get_errors_mut(target).push(error); + } + + fn get_errors(&self, target: FormErrorTarget) -> &[FormFieldValidationError]; + + fn get_errors_mut(&mut self, target: FormErrorTarget) -> &mut Vec; +} + +#[derive(Debug)] +pub struct FormFieldOptions { + pub id: String, +} + +pub trait FormFieldBase { + fn options(&self) -> &FormFieldOptions; + + fn id(&self) -> &str { + &self.options().id + } + + fn set_value(&mut self, value: Cow); + + fn render(&self) -> String; +} + +pub trait FormField: FormFieldBase { + type CustomOptions: Default; + + fn with_options(options: FormFieldOptions, custom_options: Self::CustomOptions) -> Self; +} + +pub trait AsFormField { + type Type: FormField; + + fn clean_value(field: &Self::Type) -> Result + where + Self: Sized; +} + +#[derive(Debug)] +pub struct CharField { + options: FormFieldOptions, + custom_options: CharFieldOptions, + value: Option, +} + +#[derive(Debug, Default)] +pub struct CharFieldOptions { + pub max_length: Option, +} + +impl CharFieldOptions { + pub fn set_max_length(&mut self, max_length: u32) { + self.max_length = Some(max_length); + } +} + +impl FormFieldBase for CharField { + fn options(&self) -> &FormFieldOptions { + &self.options + } + + fn set_value(&mut self, value: Cow) { + self.value = Some(value.into_owned()); + } + + fn render(&self) -> String { + let mut tag = HtmlTag::input("text"); + tag.attr("name", self.id()); + self.custom_options.max_length.map(|max_length| { + tag.attr("maxlength", &max_length.to_string()); + }); + tag.render() + } + + // TODO validate +} + +impl FormField for CharField { + type CustomOptions = CharFieldOptions; + + fn with_options(options: FormFieldOptions, custom_options: Self::CustomOptions) -> Self { + Self { + options, + custom_options, + value: None, + } + } +} + +impl AsFormField for String { + type Type = CharField; + + fn clean_value(field: &Self::Type) -> Result { + if let Some(value) = &field.value { + Ok(value.clone()) + } else { + Err(FormFieldValidationError::from_static(FORM_FIELD_REQUIRED)) + } + } +} + +#[derive(Debug)] +pub struct IntegerField { + options: FormFieldOptions, + custom_options: (), + value: Option, +} + +impl FormFieldBase for IntegerField { + fn options(&self) -> &FormFieldOptions { + &self.options + } + + fn set_value(&mut self, value: Cow) { + if let Ok(value) = T::from_str_radix(&value, 10) { + self.value = Some(value); + } else { + todo!("throw error"); + } + } + + fn render(&self) -> String { + todo!() + } +} + +impl FormField for IntegerField { + type CustomOptions = (); + + fn with_options(options: FormFieldOptions, custom_options: Self::CustomOptions) -> Self { + Self { + options, + custom_options, + value: None, + } + } +} + +impl AsFormField for u32 { + type Type = IntegerField; + + fn clean_value(field: &Self::Type) -> Result + where + Self: Sized, + { + if let Some(value) = &field.value { + Ok(*value) + } else { + Err(FormFieldValidationError::from_static(FORM_FIELD_REQUIRED)) + } + } +} + +#[derive(Debug)] +struct HtmlTag { + tag: String, + attributes: Vec<(String, String)>, +} + +impl HtmlTag { + #[must_use] + fn new(tag: &str) -> Self { + Self { + tag: tag.to_string(), + attributes: Vec::new(), + } + } + + #[must_use] + fn input(input_type: &str) -> Self { + let mut input = Self::new("input"); + input.attr("type", input_type); + input + } + + fn id(&mut self, id: &str) -> &mut Self { + self.attr("id", id) + } + + fn attr(&mut self, key: &str, value: &str) -> &mut Self { + if self.attributes.iter().any(|(k, _)| k == key) { + panic!("Attribute already exists: {}", key); + } + self.attributes.push((key.to_string(), value.to_string())); + self + } + + #[must_use] + fn render(&self) -> String { + let mut result = format!("<{} ", self.tag); + + for (key, value) in &self.attributes { + result.push_str(&format!("{}=\"{}\" ", key, value)); + } + + result.push_str(" />"); + result + } +} diff --git a/flareon/src/lib.rs b/flareon/src/lib.rs index 1b41d28..3792d2f 100644 --- a/flareon/src/lib.rs +++ b/flareon/src/lib.rs @@ -1,67 +1,62 @@ +mod error; +pub mod forms; pub mod prelude; +#[doc(hidden)] +pub mod private; +pub mod request; +pub mod router; +pub mod templates; +use std::borrow::Cow; use std::fmt::{Debug, Formatter}; +use std::future::Future; use std::io::Read; +use std::pin::Pin; use std::sync::Arc; use async_trait::async_trait; +use axum::extract::RawForm; use axum::handler::HandlerWithoutStateExt; +use axum::RequestExt; use bytes::Bytes; use derive_builder::Builder; +pub use error::Error; use indexmap::IndexMap; use log::info; -use thiserror::Error; +use request::Request; +use router::{Route, Router}; -pub type StatusCode = axum::http::StatusCode; +pub type Result = std::result::Result; -#[async_trait] -pub trait RequestHandler { - async fn handle(&self, request: Request) -> Result; -} +pub type StatusCode = axum::http::StatusCode; -#[derive(Clone, Debug)] -pub struct Router { - urls: Vec, +#[macro_export] +macro_rules! reverse { + ( $request:expr, $view_name:literal ) => { + ::flareon::Response::new_redirect(::flareon::reverse_str!($request, $view_name)) + }; } -impl Router { - #[must_use] - pub fn with_urls>>(urls: T) -> Self { - Self { urls: urls.into() } - } - - async fn route(&self, request: Request, request_path: &str) -> Result { - for route in &self.urls { - if request_path.starts_with(&route.url) { - let request_path = &request_path[route.url.len()..]; - match &route.view { - RouteInner::Handler(handler) => return handler.handle(request).await, - RouteInner::Router(router) => { - return Box::pin(router.route(request, request_path)).await - } - } - } - } - - unimplemented!("404 handler is not implemented yet") - } +#[macro_export] +macro_rules! reverse_str { + ( $request:expr, $view_name:literal ) => { + $request.project().router().reverse($view_name)? + }; } #[async_trait] -impl RequestHandler for Router { - async fn handle(&self, request: Request) -> Result { - let path = request.uri().path().to_owned(); - self.route(request, &path).await - } +pub trait RequestHandler { + async fn handle(&self, request: Request) -> Result; } #[async_trait] -impl RequestHandler for T +impl RequestHandler for T where - T: Fn(Request) -> Result + Send + Sync, + T: Fn(Request) -> R + Clone + Send + Sync + 'static, + R: for<'a> Future> + Send, { - async fn handle(&self, request: Request) -> Result { - self(request) + async fn handle(&self, request: Request) -> Result { + self(request).await } } @@ -100,50 +95,6 @@ impl FlareonAppBuilder { } } -#[derive(Clone)] -pub struct Route { - url: String, - view: RouteInner, -} - -impl Route { - #[must_use] - pub fn with_handler>( - url: T, - view: Arc>, - ) -> Self { - Self { - url: url.into(), - view: RouteInner::Handler(view), - } - } - - #[must_use] - pub fn with_router>(url: T, router: Router) -> Self { - Self { - url: url.into(), - view: RouteInner::Router(router), - } - } -} - -#[derive(Clone)] -enum RouteInner { - Handler(Arc>), - Router(Router), -} - -impl Debug for Route { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match &self.view { - RouteInner::Handler(_) => f.debug_tuple("Handler").field(&"handler(...)").finish(), - RouteInner::Router(router) => f.debug_tuple("Router").field(router).finish(), - } - } -} - -pub type Request = axum::extract::Request; - type HeadersMap = IndexMap; #[derive(Debug)] @@ -155,6 +106,8 @@ pub struct Response { const CONTENT_TYPE_HEADER: &str = "Content-Type"; const HTML_CONTENT_TYPE: &str = "text/html"; +const FORM_CONTENT_TYPE: &str = "application/x-www-form-urlencoded"; +const LOCATION_HEADER: &str = "Location"; impl Response { #[must_use] @@ -166,6 +119,17 @@ impl Response { } } + #[must_use] + pub fn new_redirect>(location: T) -> Self { + let mut headers = HeadersMap::new(); + headers.insert(LOCATION_HEADER.to_owned(), location.into()); + Self { + status: StatusCode::SEE_OTHER, + headers, + body: Body::empty(), + } + } + #[must_use] fn html_headers() -> HeadersMap { let mut headers = HeadersMap::new(); @@ -200,12 +164,6 @@ impl Body { } } -#[derive(Debug, thiserror::Error)] -pub enum Error { - #[error("Could not create a response object: {0}")] - ResponseBuilder(#[from] axum::http::Error), -} - #[derive(Clone, Debug)] pub struct FlareonProject { apps: Vec, @@ -230,15 +188,13 @@ impl FlareonProjectBuilder { #[must_use] pub fn register_app_with_views(&mut self, app: FlareonApp, url_prefix: &str) -> &mut Self { let new = self; - new.urls.push(Route::with_handler( - url_prefix, - Arc::new(Box::new(app.router.clone())), - )); + new.urls + .push(Route::with_router(url_prefix, app.router.clone())); new.apps.push(app); new } - pub fn build(&self) -> Result { + pub fn build(&self) -> Result { Ok(FlareonProject { apps: self.apps.clone(), router: Router::with_urls(self.urls.clone()), @@ -257,17 +213,23 @@ impl FlareonProject { pub fn builder() -> FlareonProjectBuilder { FlareonProjectBuilder::default() } + + #[must_use] + pub fn router(&self) -> &Router { + &self.router + } } -pub async fn run(mut project: FlareonProject, address_str: &str) -> Result<(), Error> { +pub async fn run(mut project: FlareonProject, address_str: &str) -> Result<()> { for app in &mut project.apps { info!("Initializing app: {:?}", app); } + let project = Arc::new(project); let listener = tokio::net::TcpListener::bind(address_str).await.unwrap(); let handler = |request: axum::extract::Request| async move { - pass_to_axum(&project, request) + pass_to_axum(&project, Request::new(request, project.clone())) .await .unwrap_or_else(handle_response_error) }; @@ -279,9 +241,9 @@ pub async fn run(mut project: FlareonProject, address_str: &str) -> Result<(), E } async fn pass_to_axum( - project: &FlareonProject, - request: axum::extract::Request, -) -> Result { + project: &Arc, + request: Request, +) -> Result { let response = project.router.handle(request).await?; let mut builder = axum::http::Response::builder().status(response.status); diff --git a/flareon/src/prelude.rs b/flareon/src/prelude.rs index 6c2ec5f..8b8afc4 100644 --- a/flareon/src/prelude.rs +++ b/flareon/src/prelude.rs @@ -1,3 +1,3 @@ -pub use crate::{ - Body, Error, FlareonApp, FlareonProject, Request, RequestHandler, Response, Route, StatusCode, -}; +pub use crate::request::Request; +pub use crate::router::Route; +pub use crate::{Body, Error, FlareonApp, FlareonProject, RequestHandler, Response, StatusCode}; diff --git a/flareon/src/private.rs b/flareon/src/private.rs new file mode 100644 index 0000000..623d168 --- /dev/null +++ b/flareon/src/private.rs @@ -0,0 +1,5 @@ +/// Re-exports of some of the Flareon dependencies that are used in the macros. +/// +/// This is to avoid the need to add them as dependencies to the crate that uses +/// the macros. +pub use async_trait::async_trait; diff --git a/flareon/src/request.rs b/flareon/src/request.rs new file mode 100644 index 0000000..39afecb --- /dev/null +++ b/flareon/src/request.rs @@ -0,0 +1,90 @@ +use std::borrow::Cow; +use std::sync::Arc; + +use bytes::Bytes; + +use crate::{Error, FlareonProject, FORM_CONTENT_TYPE}; + +#[derive(Debug)] +pub struct Request { + inner: axum::extract::Request, + project: Arc, +} + +impl Request { + #[must_use] + pub fn new(inner: axum::extract::Request, project: Arc) -> Self { + Self { inner, project } + } + + #[must_use] + pub fn inner(&self) -> &axum::extract::Request { + &self.inner + } + + #[must_use] + pub fn project(&self) -> &FlareonProject { + &self.project + } + + #[must_use] + pub fn uri(&self) -> &axum::http::Uri { + self.inner.uri() + } + + #[must_use] + pub fn method(&self) -> &axum::http::Method { + self.inner.method() + } + + #[must_use] + pub fn headers(&self) -> &axum::http::HeaderMap { + self.inner.headers() + } + + #[must_use] + pub fn content_type(&self) -> Option<&axum::http::HeaderValue> { + self.inner.headers().get(axum::http::header::CONTENT_TYPE) + } + + pub async fn form_data(&mut self) -> Result { + if self.method() == axum::http::Method::GET { + if let Some(query) = self.inner.uri().query() { + return Ok(Bytes::copy_from_slice(query.as_bytes())); + } + + Ok(Bytes::new()) + } else { + self.expect_content_type(FORM_CONTENT_TYPE)?; + + let body = std::mem::take(self.inner.body_mut()); + let bytes = axum::body::to_bytes(body, usize::MAX) + .await + .map_err(|err| Error::ReadRequestBody { source: err })?; + + Ok(bytes) + } + } + + fn expect_content_type(&mut self, expected: &'static str) -> Result<(), Error> { + let content_type = self + .content_type() + .map(|value| String::from_utf8_lossy(value.as_bytes())) + .unwrap_or("".into()); + if self.content_type() == Some(&axum::http::HeaderValue::from_static(expected)) { + Ok(()) + } else { + Err(Error::InvalidContentType { + expected, + actual: content_type.into_owned(), + }) + } + } + + #[must_use] + pub fn query_pairs<'data>( + bytes: &'data Bytes, + ) -> impl Iterator, Cow)> + 'data { + form_urlencoded::parse(bytes.as_ref()) + } +} diff --git a/flareon/src/router.rs b/flareon/src/router.rs new file mode 100644 index 0000000..930c3ce --- /dev/null +++ b/flareon/src/router.rs @@ -0,0 +1,149 @@ +use std::collections::HashMap; +use std::fmt::{Debug, Formatter}; +use std::sync::Arc; + +use async_trait::async_trait; +use axum::http::StatusCode; +use bytes::Bytes; +use log::debug; + +use crate::request::Request; +use crate::router::path::PathMatcher; +use crate::{Body, Error, RequestHandler, Response, Result}; + +mod path; + +#[derive(Clone, Debug)] +pub struct Router { + urls: Vec, + names: HashMap>, +} + +impl Router { + #[must_use] + pub fn with_urls>>(urls: T) -> Self { + let urls = urls.into(); + let mut names = HashMap::new(); + + for url in &urls { + if let Some(name) = &url.name { + names.insert(name.clone(), url.url.clone()); + } + } + + Self { urls, names } + } + + async fn route(&self, request: Request, request_path: &str) -> Result { + debug!("Routing request to {}", request_path); + + for route in &self.urls { + if let Some(matches) = route.url.capture(request_path) { + match &route.view { + RouteInner::Handler(handler) => { + if matches.matches_fully() { + return handler.handle(request).await; + } + } + RouteInner::Router(router) => { + return Box::pin(router.route(request, matches.remaining_path())).await + } + } + } + } + + debug!("Not found: {}", request_path); + Ok(handle_not_found()) + } + + pub async fn handle(&self, request: Request) -> Result { + let path = request.uri().path().to_owned(); + self.route(request, &path).await + } + + pub fn reverse(&self, name: &str) -> Result { + self.reverse_option(name) + .ok_or_else(|| Error::ReverseFailed { + view_name: name.to_owned(), + }) + } + + #[must_use] + pub fn reverse_option(&self, name: &str) -> Option { + let url = self.names.get(name).map(|matcher| matcher.reverse()); + if let Some(url) = url { + return Some(url); + } + + for route in &self.urls { + if let RouteInner::Router(router) = &route.view { + if let Some(url) = router.reverse_option(name) { + return Some(route.url.reverse() + &url); + } + } + } + None + } +} + +#[derive(Debug, Clone)] +pub struct Route { + url: Arc, + view: RouteInner, + name: Option, +} + +impl Route { + #[must_use] + pub fn with_handler(url: &str, view: Arc>) -> Self { + Self { + url: Arc::new(PathMatcher::new(url)), + view: RouteInner::Handler(view), + name: None, + } + } + + #[must_use] + pub fn with_handler_and_name>( + url: &str, + view: Arc>, + name: T, + ) -> Self { + Self { + url: Arc::new(PathMatcher::new(url)), + view: RouteInner::Handler(view), + name: Some(name.into()), + } + } + + #[must_use] + pub fn with_router(url: &str, router: Router) -> Self { + Self { + url: Arc::new(PathMatcher::new(url)), + view: RouteInner::Router(router), + name: None, + } + } +} + +fn handle_not_found() -> Response { + Response::new_html( + StatusCode::NOT_FOUND, + Body::Fixed(Bytes::from("404 Not Found")), + ) +} + +#[derive(Clone)] +enum RouteInner { + Handler(Arc>), + Router(Router), +} + +impl Debug for RouteInner { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match &self { + RouteInner::Handler(_) => f.debug_tuple("Handler").field(&"handler(...)").finish(), + RouteInner::Router(router) => f.debug_tuple("Router").field(router).finish(), + } + } +} diff --git a/flareon/src/router/path.rs b/flareon/src/router/path.rs new file mode 100644 index 0000000..055e2d3 --- /dev/null +++ b/flareon/src/router/path.rs @@ -0,0 +1,221 @@ +use std::fmt::Display; +use std::path::Path; +use std::thread::current; + +use log::debug; +use regex::Regex; + +#[derive(Debug, Clone)] +pub(super) struct PathMatcher { + parts: Vec, +} + +impl PathMatcher { + #[must_use] + pub fn new>(path_pattern: T) -> Self { + let path_pattern = path_pattern.into(); + + let mut last_end = 0; + let mut parts = Vec::new(); + let param_regex = Regex::new(":([^/]+)").expect("Invalid regex"); + for capture in param_regex.captures_iter(&path_pattern) { + let full_match = capture.get(0).expect("Could not get regex match"); + let start = full_match.start(); + if start > last_end { + parts.push(PathPart::Literal(path_pattern[last_end..start].to_string())); + } + + let name = capture.get(1).expect("Could not get regex capture"); + // TODO check if name is a valid identifier + parts.push(PathPart::Param { + name: name.as_str().to_owned(), + }); + last_end = start + full_match.len(); + } + if last_end < path_pattern.len() { + parts.push(PathPart::Literal(path_pattern[last_end..].to_string())); + } + + Self { parts } + } + + #[must_use] + pub fn capture<'matcher, 'path>( + &'matcher self, + path: &'path str, + ) -> Option> { + debug!("Matching path `{}` against pattern `{}`", path, self); + + let mut current_path = path; + let mut params = Vec::with_capacity(self.param_len()); + for part in &self.parts { + match part { + PathPart::Literal(s) => { + if !current_path.starts_with(s) { + return None; + } + current_path = ¤t_path[s.len()..]; + } + PathPart::Param { name } => { + let next_slash = current_path.find('/'); + let value = if let Some(next_slash) = next_slash { + ¤t_path[..next_slash] + } else { + ¤t_path + }; + params.push(PathParam::new(name, value)); + current_path = ¤t_path[value.len()..]; + } + } + } + + Some(CaptureResult::new(params, current_path)) + } + + #[must_use] + pub fn reverse(&self) -> String { + todo!(); + // if !self.param_names.is_empty() { + // unimplemented!("Reverse routing with parameters is not yet + // supported"); } + // self.path_pattern.clone() + } + + #[must_use] + fn param_len(&self) -> usize { + self.parts + .iter() + .map(|part| match part { + PathPart::Literal(s) => 0, + PathPart::Param { name } => 1, + }) + .sum() + } +} + +impl Display for PathMatcher { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for part in &self.parts { + write!(f, "{}", part)?; + } + Ok(()) + } +} + +#[derive(Debug, PartialEq, Eq)] +pub(super) struct CaptureResult<'matcher, 'path> { + params: Vec>, + remaining_path: &'path str, +} + +impl<'matcher, 'path> CaptureResult<'matcher, 'path> { + #[must_use] + fn new(params: Vec>, remaining_path: &'path str) -> Self { + Self { + params, + remaining_path, + } + } + + #[must_use] + pub fn matches_fully(&self) -> bool { + self.remaining_path.is_empty() + } + + #[must_use] + pub fn remaining_path(&self) -> &'path str { + self.remaining_path + } +} + +#[derive(Debug, Clone)] +enum PathPart { + Literal(String), + Param { name: String }, +} + +impl Display for PathPart { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + PathPart::Literal(s) => write!(f, "{}", s), + PathPart::Param { name } => write!(f, ":{}", name), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct PathParam<'a> { + name: &'a str, + value: String, +} + +impl<'a> PathParam<'a> { + #[must_use] + pub fn new(name: &'a str, value: &str) -> Self { + Self { + name, + value: value.to_string(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_path_parser_no_params() { + let path_parser = PathMatcher::new("/users"); + assert_eq!( + path_parser.capture("/users"), + Some(CaptureResult::new(vec![], "")) + ); + assert_eq!(path_parser.capture("/test"), None); + } + + #[test] + fn test_path_parser_single_param() { + let path_parser = PathMatcher::new("/users/:id"); + assert_eq!( + path_parser.capture("/users/123"), + Some(CaptureResult::new(vec![PathParam::new("id", "123")], "")) + ); + assert_eq!( + path_parser.capture("/users/123/"), + Some(CaptureResult::new(vec![PathParam::new("id", "123")], "/")) + ); + assert_eq!( + path_parser.capture("/users/123/abc"), + Some(CaptureResult::new( + vec![PathParam::new("id", "123")], + "/abc" + )) + ); + assert_eq!(path_parser.capture("/users/"), None); + } + + #[test] + fn test_path_parser_multiple_params() { + let path_parser = PathMatcher::new("/users/:id/posts/:post_id"); + assert_eq!( + path_parser.capture("/users/123/posts/456"), + Some(CaptureResult::new( + vec![ + PathParam::new("id", "123"), + PathParam::new("post_id", "456"), + ], + "" + )) + ); + assert_eq!( + path_parser.capture("/users/123/posts/456/abc"), + Some(CaptureResult::new( + vec![ + PathParam::new("id", "123"), + PathParam::new("post_id", "456"), + ], + "/abc" + )) + ); + } +} diff --git a/flareon/src/templates.rs b/flareon/src/templates.rs new file mode 100644 index 0000000..4cc6697 --- /dev/null +++ b/flareon/src/templates.rs @@ -0,0 +1 @@ +pub use askama::Template;