From 14479a88c2a9ae61a9c9a18e6b9a2e41ed23012e Mon Sep 17 00:00:00 2001 From: nwrenger Date: Wed, 31 Jul 2024 16:33:05 +0200 Subject: [PATCH] :sparkles: Added support for generics && updated docs --- Cargo.toml | 2 +- README.md | 34 +++++++-- macros/Cargo.toml | 2 +- macros/src/lib.rs | 175 +++++++++++++++++++++++++++++++++++++--------- src/lib.rs | 98 +++++++++++++++++--------- tests/api.ts | 9 ++- tests/main.rs | 20 +++--- 7 files changed, 254 insertions(+), 86 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c4c32e0..5b8cf6f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "gluer" -version = "0.4.0" +version = "0.4.1" edition = "2021" authors = ["Nils Wrenger "] description = "A wrapper for Rust frameworks that eliminates redundant type and function definitions between the frontend and backend" diff --git a/README.md b/README.md index 818b910..ea256c2 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,16 @@ use gluer::metadata; #[metadata] #[derive(Default, serde::Serialize)] struct Book { - // imagine some fields here + name: String, + // Sometimes you don't have access to certain data types, so you can override them using `#[into(Type)]` + #[into(String)] + user: User, +} + +#[derive(Default, serde::Serialize)] +struct User { + name: String, + password: String, } // Define the functions with the metadata macro @@ -135,6 +144,7 @@ use axum::{ Json, }; use gluer::{extract, metadata, Api}; +use serde::{Deserialize, Serialize}; use std::collections::HashMap; #[metadata] @@ -142,15 +152,29 @@ async fn fetch_root(Query(test): Query>, Path(p): Path { name: String, + vec: Vec, +} + +#[metadata] +#[derive(Serialize, Deserialize, Default)] +struct Age { + #[into(String)] + age: AgeInner, +} + +#[derive(Serialize, Deserialize, Default)] +struct AgeInner { + age: u8, } #[metadata] -async fn add_root(Path(_): Path, Json(hello): Json) -> Json> { - vec![hello].into() +async fn add_root(Path(_): Path, Json(hello): Json>) -> Json { + Json(hello.name.to_string()) } #[tokio::main] diff --git a/macros/Cargo.toml b/macros/Cargo.toml index 09047d8..8acacf3 100644 --- a/macros/Cargo.toml +++ b/macros/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "gluer-macros" -version = "0.4.0" +version = "0.4.1" edition = "2021" authors = ["Nils Wrenger "] description = "Procedural macros for the gluer framework" diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 9c111d6..32e80c6 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -1,8 +1,8 @@ -use proc_macro as pc; +use proc_macro::{self as pc}; use proc_macro2::TokenStream; use quote::{quote, ToTokens}; -use std::{collections::HashMap, fmt}; -use syn::{parenthesized, parse::Parse, spanned::Spanned}; +use std::{collections::HashMap, fmt, vec}; +use syn::{parenthesized, parse::Parse, spanned::Spanned, TypeParam}; fn s_err(span: proc_macro2::Span, msg: impl fmt::Display) -> syn::Error { syn::Error::new(span, msg) @@ -111,6 +111,18 @@ fn generate_struct_const( ) -> syn::Result<(TokenStream, TokenStream)> { let struct_name = item_struct.ident.to_string(); let vis = &item_struct.vis; + let generics: Vec = item_struct + .generics + .params + .iter() + .filter_map(|generic| { + if let syn::GenericParam::Type(TypeParam { ident, .. }) = generic { + return Some(ident.to_string()); + } + None + }) + .collect(); + let mut dependencies = HashMap::new(); let fields = item_struct @@ -120,16 +132,16 @@ fn generate_struct_const( let ident = field .ident .clone() - .ok_or_else(|| s_err(field.span(), "Unnamed field not supported"))? + .ok_or_else(|| syn::Error::new(field.span(), "Unnamed field not supported"))? .to_string(); let conversion_fn = parse_field_attr(&field.attrs)?; - // clean off all "info" attributes + // clean off all "into" attributes field.attrs = field .attrs .iter() - .filter(|attr| attr.path().is_ident("info")) + .filter(|attr| !attr.path().is_ident("into")) .cloned() .collect(); @@ -140,7 +152,10 @@ fn generate_struct_const( }; if conversion_fn.is_none() { - if let Some((is_basic, _, inner_ty)) = basic_rust_type(&field.ty)? { + if let Some(RustType { + is_basic, inner_ty, .. + }) = basic_rust_type(&field.ty, &generics)? + { if !is_basic { dependencies.insert( inner_ty.clone(), @@ -170,11 +185,20 @@ fn generate_struct_const( quote! { #struct_const } }); + let generics_quote = generics.iter().map(|generic| { + quote! { #generic } + }); + let item_struct = quote! { #item_struct }; Ok(( quote! { - #vis const #const_ident: gluer::StructInfo = gluer::StructInfo { name: #struct_name, fields: &[#(#const_value),*], dependencies: &[#(#dependencies_quote),*] }; + #vis const #const_ident: gluer::StructInfo = gluer::StructInfo { + name: #struct_name, + generics: &[#(#generics_quote),*], + fields: &[#(#const_value),*], + dependencies: &[#(#dependencies_quote),*] + }; }, item_struct, )) @@ -200,7 +224,7 @@ fn parse_field_attr(attrs: &[syn::Attribute]) -> syn::Result> } } - let message = "expected #[into = \"...\"]"; + let message = "expected #[into(...)]"; return Err(syn::Error::new_spanned(attr, message)); } Ok(None) @@ -212,6 +236,7 @@ fn generate_fn_const( let fn_name = item_fn.sig.ident.to_string(); let vis = &item_fn.vis; let mut structs = HashMap::new(); + let generics: Vec = vec![]; let params = item_fn .sig @@ -220,11 +245,24 @@ fn generate_fn_const( .filter_map(|param| match param { syn::FnArg::Typed(syn::PatType { pat, ty, .. }) => { let pat = pat.to_token_stream().to_string(); - if let Some((is_basic, outer_ty, inner_ty)) = basic_rust_type(ty).ok()? { + if let Some(RustType { + is_basic, + outer_ty, + inner_ty, + is_generic: (is_generic, is_basic_generic), + }) = basic_rust_type(ty, &generics).ok()? + { if !is_basic { let struct_const = format!("STRUCT_{}", inner_ty.to_uppercase()); structs.insert(inner_ty.clone(), struct_const); } + if is_generic && !is_basic_generic { + let ty = outer_ty.clone(); + let ty = ty.split("<").last().unwrap(); + let ty = ty.replace('>', "").replace(" ", ""); + let struct_const = format!("STRUCT_{}", ty.to_uppercase()); + structs.insert(ty.clone(), struct_const); + } Some(Ok((pat, outer_ty))) } else { None @@ -238,11 +276,24 @@ fn generate_fn_const( let response = match &item_fn.sig.output { syn::ReturnType::Type(_, ty) => { - if let Some((is_basic, outer_ty, inner_ty)) = basic_rust_type(ty)? { + if let Some(RustType { + is_basic, + outer_ty, + inner_ty, + is_generic: (is_generic, is_basic_generic), + }) = basic_rust_type(ty, &generics)? + { if !is_basic { let struct_const = format!("STRUCT_{}", inner_ty.to_uppercase()); structs.insert(inner_ty.clone(), struct_const); } + if is_generic && !is_basic_generic { + let ty = outer_ty.clone(); + let ty = ty.split("<").last().unwrap(); + let ty = ty.replace('>', "").replace(" ", ""); + let struct_const = format!("STRUCT_{}", ty.to_uppercase()); + structs.insert(ty.clone(), struct_const); + } outer_ty } else { return Err(s_err(ty.span(), "Unsupported return type")); @@ -290,14 +341,41 @@ impl syn::parse::Parse for NoArgs { } } -/// Returns a tuple (bool, outermost_type, innermost_type) -fn basic_rust_type(ty: &syn::Type) -> syn::Result> { +struct RustType { + is_basic: bool, + outer_ty: String, + inner_ty: String, + is_generic: (bool, bool), +} + +impl RustType { + fn new(is_basic: bool, outer_ty: String, inner_ty: String, is_generic: (bool, bool)) -> Self { + RustType { + is_basic, + outer_ty, + inner_ty, + is_generic, + } + } +} + +/// Returns a tuple (bool, outermost_type, innermost_type, (is_generic, is_basic_generic)) +fn basic_rust_type(ty: &syn::Type, generics: &Vec) -> syn::Result> { let ty_str = ty.to_token_stream().to_string(); match ty { syn::Type::Path(syn::TypePath { path, .. }) => { if let Some(segment) = path.segments.last() { let ty_name = segment.ident.to_string(); + if generics.contains(&ty_name) { + return Ok(Some(RustType::new( + true, + ty_name.clone(), + ty_name, + (true, true), + ))); + } + // Skip types like State<...> and more, see the `extract` section in axum's docs if matches!( ty_name.as_ref(), @@ -329,28 +407,44 @@ fn basic_rust_type(ty: &syn::Type) -> syn::Result | "f64" | "String" ); - return Ok(Some((is_basic, ty_name.clone(), ty_name))); + return Ok(Some(RustType::new( + is_basic, + ty_name.clone(), + ty_name, + (false, false), + ))); } syn::PathArguments::AngleBracketed(ref args) => { - let mut outer_type = ty_name.clone(); - let mut innermost_type = String::new(); - let mut is_basic = false; - - for arg in &args.args { - if let syn::GenericArgument::Type(ref inner_ty) = arg { - if let Ok(Some((inner_is_basic, outer_most, inner_innermost))) = - basic_rust_type(inner_ty) - { - outer_type = format!("{}<{}>", ty_name, outer_most); - innermost_type = inner_innermost; - is_basic = inner_is_basic; - } else { - return Ok(None); + if matches!( + ty_name.as_str(), + "Query" | "HashMap" | "Path" | "Vec" | "Json" | "Option" | "Result" + ) { + for arg in &args.args { + if let syn::GenericArgument::Type(ref inner_ty) = arg { + if let Ok(Some(RustType { + is_basic, + outer_ty, + inner_ty, + is_generic, + })) = basic_rust_type(inner_ty, generics) + { + return Ok(Some(RustType::new( + is_basic, + format!("{}<{}>", ty_name, outer_ty), + inner_ty, + is_generic, + ))); + } } } } - return Ok(Some((is_basic, outer_type, innermost_type))); + let mut outer_ty = ty_name.clone(); + if let Some(generic_type) = args.args.get(0) { + outer_ty = format!("{}<{}>", outer_ty, generic_type.to_token_stream()); + } + + return Ok(Some(RustType::new(false, outer_ty, ty_name, (true, false)))); } _ => {} } @@ -358,19 +452,32 @@ fn basic_rust_type(ty: &syn::Type) -> syn::Result } syn::Type::Reference(syn::TypeReference { elem, .. }) | syn::Type::Paren(syn::TypeParen { elem, .. }) - | syn::Type::Group(syn::TypeGroup { elem, .. }) => return basic_rust_type(elem), + | syn::Type::Group(syn::TypeGroup { elem, .. }) => return basic_rust_type(elem, generics), syn::Type::Tuple(elems) => { if elems.elems.len() == 1 { - return basic_rust_type(&elems.elems[0]); + return basic_rust_type(&elems.elems[0], generics); } else if elems.elems.is_empty() { - return Ok(Some((true, "()".to_string(), "()".to_string()))); + return Ok(Some(RustType::new( + true, + "()".to_string(), + "()".to_string(), + (false, false), + ))); } } syn::Type::Array(syn::TypeArray { elem, .. }) | syn::Type::Slice(syn::TypeSlice { elem, .. }) => { - if let Some((is_basic, outer_ty, inner_ty)) = basic_rust_type(elem)? { + if let Some(RustType { + is_basic, + outer_ty, + inner_ty, + is_generic, + }) = basic_rust_type(elem, generics)? + { let vec_type = format!("Vec<{}>", outer_ty); - return Ok(Some((is_basic, vec_type, inner_ty))); + return Ok(Some(RustType::new( + is_basic, vec_type, inner_ty, is_generic, + ))); } else { return Ok(None); } diff --git a/src/lib.rs b/src/lib.rs index 4f66ef8..0891624 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -72,26 +72,17 @@ where if !route.fn_info.structs.is_empty() { for struct_info in route.fn_info.structs { Self::deps(struct_info.dependencies, &mut ts_interfaces); - let ts_interfaces_clone = ts_interfaces.clone(); - ts_interfaces - .entry(struct_info.name.to_string()) - .or_insert_with(|| { - generate_ts_interface( - struct_info.name, - struct_info.fields, - &ts_interfaces_clone, - ) - }); } + Self::deps(route.fn_info.structs, &mut ts_interfaces); } let params_type = route .fn_info .params .iter() - .map(|Field { name: _, ty }| ty_to_ts(ty, &ts_interfaces).unwrap()) + .map(|Field { name: _, ty }| ty_to_ts(ty, &[], &ts_interfaces).unwrap()) .collect::>(); - let response_type = ty_to_ts(route.fn_info.response, &ts_interfaces).unwrap(); + let response_type = ty_to_ts(route.fn_info.response, &[], &ts_interfaces).unwrap(); if ts_functions.contains_key(route.fn_name) { return Err(format!( @@ -120,17 +111,19 @@ where fn deps(dependencies: &[StructInfo], ts_interfaces: &mut BTreeMap) { for StructInfo { name, + generics, fields, dependencies, } in dependencies { - if !dependencies.is_empty() { - Self::deps(dependencies, ts_interfaces); + if !ts_interfaces.contains_key(&name.to_string()) { + let ts_interfaces_clone = ts_interfaces.clone(); + ts_interfaces.insert( + name.to_string(), + generate_ts_interface(name, generics, fields, &ts_interfaces_clone), + ); } - let ts_interfaces_clone = ts_interfaces.clone(); - ts_interfaces - .entry(name.to_string()) - .or_insert_with(move || generate_ts_interface(name, fields, &ts_interfaces_clone)); + Self::deps(dependencies, ts_interfaces); } } @@ -170,6 +163,7 @@ pub struct FnInfo<'a> { #[derive(Clone, Debug)] pub struct StructInfo<'a> { pub name: &'a str, + pub generics: &'a [&'a str], pub fields: &'a [Field<'a>], pub dependencies: &'a [StructInfo<'a>], } @@ -183,12 +177,18 @@ pub struct Field<'a> { fn generate_ts_interface( struct_name: &str, + generics: &[&str], fields: &[Field], ts_interfaces: &BTreeMap, ) -> String { - let mut interface = format!("export interface {} {{\n", struct_name); - for Field { name, ty } in dbg!(fields) { - let ty = ty_to_ts(ty, ts_interfaces).unwrap(); + let generics_str = if generics.is_empty() { + "".to_string() + } else { + format!("<{}>", generics.join(", ")) + }; + let mut interface = format!("export interface {}{} {{\n", struct_name, generics_str); + for Field { name, ty } in fields { + let ty = ty_to_ts(ty, generics, ts_interfaces).unwrap(); interface.push_str(&format!(" {}: {};\n", name, ty.unwrap())); } interface.push_str("}\n\n"); @@ -257,7 +257,7 @@ fn generate_ts_function( ) } -#[derive(PartialEq, Eq, PartialOrd, Ord, Clone)] +#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Debug)] enum Type { Unknown(T), Json(T), @@ -280,12 +280,14 @@ impl Type { fn ty_to_ts<'a>( ty: &'a str, + generics: &[&str], ts_interfaces: &'a BTreeMap, ) -> Result, String> { let ty = ty.trim().replace(" ", ""); if ts_interfaces.contains_key(ty.as_str()) { return Ok(Unknown(ty.to_string())); } + Ok(match ty.as_str() { "str" | "String" => Unknown(String::from("string")), "usize" | "isize" | "u8" | "u16" | "u32" | "u64" | "i8" | "i16" | "i32" | "i64" | "f32" @@ -293,40 +295,70 @@ fn ty_to_ts<'a>( "bool" => Unknown(String::from("boolean")), "()" => Unknown(String::from("void")), t if t.starts_with("Vec<") => { - let ty = ty_to_ts(&t[4..t.len() - 1], ts_interfaces)?.unwrap(); + let inner_ty = &t[4..t.len() - 1]; + let ty = ty_to_ts(inner_ty, generics, ts_interfaces)?.unwrap(); Unknown(format!("{}[]", ty)) } t if t.starts_with("Html<") => { - let ty = ty_to_ts(&t[5..t.len() - 1], ts_interfaces)?.unwrap(); + let inner_ty = &t[5..t.len() - 1]; + let ty = ty_to_ts(inner_ty, generics, ts_interfaces)?.unwrap(); Unknown(ty) } t if t.starts_with("Json<") => { - let ty = ty_to_ts(&t[5..t.len() - 1], ts_interfaces)?.unwrap(); + let inner_ty = &t[5..t.len() - 1]; + let ty = ty_to_ts(inner_ty, generics, ts_interfaces)?.unwrap(); Json(ty) } t if t.starts_with("Path<") => { - let ty = ty_to_ts(&t[5..t.len() - 1], ts_interfaces)?.unwrap(); + let inner_ty = &t[5..t.len() - 1]; + let ty = ty_to_ts(inner_ty, generics, ts_interfaces)?.unwrap(); Path(ty) } t if t.starts_with("Query { - let ty = ty_to_ts(&t[14..t.len() - 2], ts_interfaces)?.unwrap(); + let inner_ty = &t[14..t.len() - 2]; + let ty = ty_to_ts(inner_ty, generics, ts_interfaces)?.unwrap(); QueryMap(ty) } t if t.starts_with("Query<") => { - let ty = ty_to_ts(&t[6..t.len() - 1], ts_interfaces)?.unwrap(); + let inner_ty = &t[6..t.len() - 1]; + let ty = ty_to_ts(inner_ty, generics, ts_interfaces)?.unwrap(); Query(ty) } t if t.starts_with("Result<") => { - let ty: String = ty_to_ts(&t[7..t.len() - 1], ts_interfaces)?.unwrap(); + let inner_ty = &t[7..t.len() - 1]; + let ty = ty_to_ts(inner_ty, generics, ts_interfaces)?.unwrap(); Unknown(format!("{} | any", ty)) } t if t.starts_with("Option<") => { - let ty: String = ty_to_ts(&t[7..t.len() - 1], ts_interfaces)?.unwrap(); + let inner_ty = &t[7..t.len() - 1]; + let ty = ty_to_ts(inner_ty, generics, ts_interfaces)?.unwrap(); Unknown(format!("{} | null", ty)) } - t if t.starts_with("&") => ty_to_ts(&t[1..t.len()], ts_interfaces)?, - t if t.starts_with("'static") => ty_to_ts(&t[7..t.len()], ts_interfaces)?, - _ => return Err(format!("Type '{}' couldn't be converted to TypeScript", ty)), + t if t.starts_with("&") => ty_to_ts(&t[1..t.len()], generics, ts_interfaces)?, + t if t.starts_with("'static") => ty_to_ts(&t[7..t.len()], generics, ts_interfaces)?, + t if t.contains('<') && t.contains('>') => { + let split: Vec<&str> = t.split('<').collect(); + let base_ty = split[0]; + let generic_params = &split[1][..split[1].len() - 1]; + + let generic_ts = generic_params + .split(',') + .map(|param| ty_to_ts(param, generics, ts_interfaces)) + .collect::, _>>()? + .into_iter() + .map(|t| t.unwrap()) + .collect::>() + .join(", "); + + Unknown(format!("{}<{}>", base_ty, generic_ts)) + } + t => { + if let Some(t) = generics.iter().find(|p| **p == t) { + Unknown(t.to_string()) + } else { + return Err(format!("Type '{}' couldn't be converted to TypeScript", ty)); + } + } }) } diff --git a/tests/api.ts b/tests/api.ts index 12bd158..da4b214 100644 --- a/tests/api.ts +++ b/tests/api.ts @@ -1,8 +1,13 @@ -export interface Hello { +export interface Age { + age: string; +} + +export interface Hello { name: string; + vec: T[]; } -export async function add_root(path: number, data: Hello): Promise { +export async function add_root(path: number, data: Hello): Promise { const response = await fetch(`/${encodeURIComponent(path)}`, { method: "POST", headers: { diff --git a/tests/main.rs b/tests/main.rs index 4ff1eda..560933b 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -4,6 +4,7 @@ use axum::{ Json, }; use gluer::{extract, metadata, Api}; +use serde::{Deserialize, Serialize}; use std::collections::HashMap; #[metadata] @@ -11,30 +12,29 @@ async fn fetch_root(Query(test): Query>, Path(p): Path { name: String, - #[into(String)] - age: Age, - age2: Age, + vec: Vec, } #[metadata] -#[derive(serde::Serialize, serde::Deserialize, Default)] +#[derive(Serialize, Deserialize, Default)] struct Age { + #[into(String)] age: AgeInner, } -#[metadata] -#[derive(serde::Serialize, serde::Deserialize, Default)] +#[derive(Serialize, Deserialize, Default)] struct AgeInner { age: u8, } #[metadata] -async fn add_root(Path(_): Path, Json(hello): Json) -> Json> { - vec![hello].into() +async fn add_root(Path(_): Path, Json(hello): Json>) -> Json { + Json(hello.name.to_string()) } #[tokio::test]